From b9abb8cabb395d5f86d52ca6e691234d0263c140 Mon Sep 17 00:00:00 2001 From: Yuri Tseretyan Date: Fri, 22 Mar 2024 15:37:10 -0400 Subject: [PATCH] Alerting: Update provisioning API to support regular permissions (#77007) * allow users with regular actions access provisioning API paths * update methods that read rules skip new authorization logic if user CanReadAllRules to avoid performance impact on file-provisioning update all methods to accept identity.Requester that contains all permissions and is required by access control. * create deltas for single rul e * update modify methods skip new authorization logic if user CanWriteAllRules to avoid performance impact on file-provisioning update all methods to accept identity.Requester that contains all permissions and is required by access control. * implement RuleAccessControlService in provisioning * update file provisioning user to have all permissions to bypass authz * update provisioning API to return errutil errors correctly --------- Co-authored-by: Alexander Weaver --- .../ngalert/accesscontrol/fakes/rules.go | 4 +- pkg/services/ngalert/api/api_provisioning.go | 20 +- .../ngalert/api/api_provisioning_test.go | 7 +- pkg/services/ngalert/api/authorization.go | 94 +- pkg/services/ngalert/api/tooling/api.json | 2 +- .../definitions/provisioning_alert_rules.go | 2 +- pkg/services/ngalert/api/tooling/post.json | 2 +- pkg/services/ngalert/api/tooling/spec.json | 2 +- pkg/services/ngalert/ngalert.go | 4 +- .../ngalert/provisioning/accesscontrol.go | 76 ++ .../provisioning/accesscontrol_test.go | 231 ++++ .../ngalert/provisioning/alert_rules.go | 326 ++++-- .../ngalert/provisioning/alert_rules_test.go | 988 +++++++++++++++++- pkg/services/ngalert/provisioning/testing.go | 61 ++ pkg/services/ngalert/store/deltas.go | 119 ++- pkg/services/ngalert/store/deltas_test.go | 168 +++ pkg/services/ngalert/tests/fakes/rules.go | 13 +- .../alerting/rules_provisioner.go | 16 +- pkg/services/provisioning/provisioning.go | 6 +- public/api-merged.json | 2 +- public/openapi3.json | 2 +- 21 files changed, 2038 insertions(+), 107 deletions(-) create mode 100644 pkg/services/ngalert/provisioning/accesscontrol.go create mode 100644 pkg/services/ngalert/provisioning/accesscontrol_test.go diff --git a/pkg/services/ngalert/accesscontrol/fakes/rules.go b/pkg/services/ngalert/accesscontrol/fakes/rules.go index 059f08e9157..0514a69cb9a 100644 --- a/pkg/services/ngalert/accesscontrol/fakes/rules.go +++ b/pkg/services/ngalert/accesscontrol/fakes/rules.go @@ -58,7 +58,7 @@ func (s *FakeRuleService) HasAccessToRuleGroup(ctx context.Context, user identit } func (s *FakeRuleService) AuthorizeAccessToRuleGroup(ctx context.Context, user identity.Requester, rules models.RulesGroup) error { - s.Calls = append(s.Calls, Call{"AuthorizeAccessToRuleGroup", []interface{}{ctx, user, rules}}) + s.Calls = append(s.Calls, Call{"AuthorizeRuleGroupRead", []interface{}{ctx, user, rules}}) if s.AuthorizeAccessToRuleGroupFunc != nil { return s.AuthorizeAccessToRuleGroupFunc(ctx, user, rules) } @@ -66,7 +66,7 @@ func (s *FakeRuleService) AuthorizeAccessToRuleGroup(ctx context.Context, user i } func (s *FakeRuleService) AuthorizeRuleChanges(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { - s.Calls = append(s.Calls, Call{"AuthorizeRuleChanges", []interface{}{ctx, user, change}}) + s.Calls = append(s.Calls, Call{"AuthorizeRuleGroupWrite", []interface{}{ctx, user, change}}) if s.AuthorizeRuleChangesFunc != nil { return s.AuthorizeRuleChangesFunc(ctx, user, change) } diff --git a/pkg/services/ngalert/api/api_provisioning.go b/pkg/services/ngalert/api/api_provisioning.go index d4425cf678a..7ec9b7b3da9 100644 --- a/pkg/services/ngalert/api/api_provisioning.go +++ b/pkg/services/ngalert/api/api_provisioning.go @@ -311,7 +311,7 @@ func (srv *ProvisioningSrv) RouteDeleteMuteTiming(c *contextmodel.ReqContext, na func (srv *ProvisioningSrv) RouteGetAlertRules(c *contextmodel.ReqContext) response.Response { rules, provenances, err := srv.alertRules.GetAlertRules(c.Req.Context(), c.SignedInUser) if err != nil { - return ErrResp(http.StatusInternalServerError, err, "") + return response.ErrOrFallback(http.StatusInternalServerError, "", err) } return response.JSON(http.StatusOK, ProvisionedAlertRuleFromAlertRules(rules, provenances)) } @@ -322,7 +322,7 @@ func (srv *ProvisioningSrv) RouteRouteGetAlertRule(c *contextmodel.ReqContext, U if errors.Is(err, alerting_models.ErrAlertRuleNotFound) { return response.Empty(http.StatusNotFound) } - return ErrResp(http.StatusInternalServerError, err, "") + return response.ErrOrFallback(http.StatusInternalServerError, "failed to get rule by UID", err) } return response.JSON(http.StatusOK, ProvisionedAlertRuleFromAlertRule(rule, provenace)) } @@ -348,7 +348,7 @@ func (srv *ProvisioningSrv) RoutePostAlertRule(c *contextmodel.ReqContext, ar de if errors.Is(err, alerting_models.ErrQuotaReached) { return ErrResp(http.StatusForbidden, err, "") } - return ErrResp(http.StatusInternalServerError, err, "") + return response.ErrOrFallback(http.StatusInternalServerError, "", err) } resp := ProvisionedAlertRuleFromAlertRule(createdAlertRule, alerting_models.Provenance(provenance)) @@ -377,7 +377,7 @@ func (srv *ProvisioningSrv) RoutePutAlertRule(c *contextmodel.ReqContext, ar def if errors.Is(err, store.ErrOptimisticLock) { return ErrResp(http.StatusConflict, err, "") } - return ErrResp(http.StatusInternalServerError, err, "") + return response.ErrOrFallback(http.StatusInternalServerError, "", err) } resp := ProvisionedAlertRuleFromAlertRule(updatedAlertRule, alerting_models.Provenance(provenance)) @@ -388,7 +388,7 @@ func (srv *ProvisioningSrv) RouteDeleteAlertRule(c *contextmodel.ReqContext, UID provenance := determineProvenance(c) err := srv.alertRules.DeleteAlertRule(c.Req.Context(), c.SignedInUser, UID, alerting_models.Provenance(provenance)) if err != nil { - return ErrResp(http.StatusInternalServerError, err, "") + return response.ErrOrFallback(http.StatusInternalServerError, "", err) } return response.JSON(http.StatusNoContent, "") } @@ -424,7 +424,7 @@ func (srv *ProvisioningSrv) RouteGetAlertRulesExport(c *contextmodel.ReqContext) groupsWithTitle, err := srv.alertRules.GetAlertGroupsWithFolderTitle(c.Req.Context(), c.SignedInUser, folderUIDs) if err != nil { - return ErrResp(http.StatusInternalServerError, err, "failed to get alert rules") + return response.ErrOrFallback(http.StatusInternalServerError, "failed to get alert rules", err) } if len(groupsWithTitle) == 0 { return response.Empty(http.StatusNotFound) @@ -432,7 +432,7 @@ func (srv *ProvisioningSrv) RouteGetAlertRulesExport(c *contextmodel.ReqContext) e, err := AlertingFileExportFromAlertRuleGroupWithFolderTitle(groupsWithTitle) if err != nil { - return ErrResp(http.StatusInternalServerError, err, "failed to create alerting file export") + return response.ErrOrFallback(http.StatusInternalServerError, "failed to create alerting file export", err) } return exportResponse(c, e) @@ -447,7 +447,7 @@ func (srv *ProvisioningSrv) RouteGetAlertRuleGroupExport(c *contextmodel.ReqCont e, err := AlertingFileExportFromAlertRuleGroupWithFolderTitle([]alerting_models.AlertRuleGroupWithFolderTitle{g}) if err != nil { - return ErrResp(http.StatusInternalServerError, err, "failed to create alerting file export") + return response.ErrOrFallback(http.StatusInternalServerError, "failed to create alerting file export", err) } return exportResponse(c, e) @@ -460,7 +460,7 @@ func (srv *ProvisioningSrv) RouteGetAlertRuleExport(c *contextmodel.ReqContext, if errors.Is(err, alerting_models.ErrAlertRuleNotFound) { return ErrResp(http.StatusNotFound, err, "") } - return ErrResp(http.StatusInternalServerError, err, "") + return response.ErrOrFallback(http.StatusInternalServerError, "failed to get alert rules", err) } e, err := AlertingFileExportFromAlertRuleGroupWithFolderTitle([]alerting_models.AlertRuleGroupWithFolderTitle{ @@ -492,7 +492,7 @@ func (srv *ProvisioningSrv) RoutePutAlertRuleGroup(c *contextmodel.ReqContext, a if errors.Is(err, store.ErrOptimisticLock) { return ErrResp(http.StatusConflict, err, "") } - return ErrResp(http.StatusInternalServerError, err, "") + return response.ErrOrFallback(http.StatusInternalServerError, "", err) } return response.JSON(http.StatusOK, ag) } diff --git a/pkg/services/ngalert/api/api_provisioning_test.go b/pkg/services/ngalert/api/api_provisioning_test.go index 361d3a651a9..ad304c47ee5 100644 --- a/pkg/services/ngalert/api/api_provisioning_test.go +++ b/pkg/services/ngalert/api/api_provisioning_test.go @@ -28,6 +28,7 @@ import ( "github.com/grafana/grafana/pkg/services/dashboards" "github.com/grafana/grafana/pkg/services/folder" "github.com/grafana/grafana/pkg/services/folder/foldertest" + "github.com/grafana/grafana/pkg/services/ngalert/accesscontrol/fakes" "github.com/grafana/grafana/pkg/services/ngalert/api/tooling/definitions" "github.com/grafana/grafana/pkg/services/ngalert/models" "github.com/grafana/grafana/pkg/services/ngalert/notifier" @@ -1613,6 +1614,7 @@ type testEnvironment struct { quotas provisioning.QuotaChecker prov provisioning.ProvisioningStore ac *recordingAccessControlFake + rulesAuthz *fakes.FakeRuleService } func createTestEnv(t *testing.T, testConfig string) testEnvironment { @@ -1674,6 +1676,8 @@ func createTestEnv(t *testing.T, testConfig string) testEnvironment { ac := &recordingAccessControlFake{} + ruleAuthz := &fakes.FakeRuleService{} + return testEnvironment{ secrets: secretsService, log: log, @@ -1685,6 +1689,7 @@ func createTestEnv(t *testing.T, testConfig string) testEnvironment { prov: prov, quotas: quotas, ac: ac, + rulesAuthz: ruleAuthz, } } @@ -1705,7 +1710,7 @@ func createProvisioningSrvSutFromEnv(t *testing.T, env *testEnvironment) Provisi contactPointService: provisioning.NewContactPointService(env.configs, env.secrets, env.prov, env.xact, receiverSvc, env.log, env.store), templates: provisioning.NewTemplateService(env.configs, env.prov, env.xact, env.log), muteTimings: provisioning.NewMuteTimingService(env.configs, env.prov, env.xact, env.log), - alertRules: provisioning.NewAlertRuleService(env.store, env.prov, env.folderService, env.dashboardService, env.quotas, env.xact, 60, 10, 100, env.log, &provisioning.NotificationSettingsValidatorProviderFake{}), + alertRules: provisioning.NewAlertRuleService(env.store, env.prov, env.folderService, env.dashboardService, env.quotas, env.xact, 60, 10, 100, env.log, &provisioning.NotificationSettingsValidatorProviderFake{}, env.rulesAuthz), } } diff --git a/pkg/services/ngalert/api/authorization.go b/pkg/services/ngalert/api/authorization.go index f8676d6a798..58440bb9495 100644 --- a/pkg/services/ngalert/api/authorization.go +++ b/pkg/services/ngalert/api/authorization.go @@ -223,19 +223,90 @@ func (api *API) authorize(method, path string) web.Handler { ac.EvalPermission(ac.ActionAlertingProvisioningReadSecrets), // organization scope ) + case http.MethodGet + "/api/v1/provisioning/alert-rules", + http.MethodGet + "/api/v1/provisioning/alert-rules/export": + eval = ac.EvalAny( + ac.EvalPermission(ac.ActionAlertingProvisioningRead), + ac.EvalPermission(ac.ActionAlertingProvisioningReadSecrets), + ac.EvalAll( // scopes are enforced in the handler + ac.EvalPermission(ac.ActionAlertingRuleRead), + ac.EvalPermission(dashboards.ActionFoldersRead), + ), + ) + case http.MethodGet + "/api/v1/provisioning/alert-rules/{UID}", + http.MethodGet + "/api/v1/provisioning/alert-rules/{UID}/export": + eval = ac.EvalAny( + ac.EvalPermission(ac.ActionAlertingProvisioningRead), + ac.EvalPermission(ac.ActionAlertingProvisioningReadSecrets), + ac.EvalAll( + ac.EvalPermission(ac.ActionAlertingRuleRead), + ac.EvalPermission(dashboards.ActionFoldersRead), + ), + ) + + case http.MethodGet + "/api/v1/provisioning/folder/{FolderUID}/rule-groups/{Group}", + http.MethodGet + "/api/v1/provisioning/folder/{FolderUID}/rule-groups/{Group}/export": + scope := dashboards.ScopeFoldersProvider.GetResourceScopeUID(ac.Parameter(":FolderUID")) + eval = ac.EvalAny( + ac.EvalPermission(ac.ActionAlertingProvisioningRead), + ac.EvalPermission(ac.ActionAlertingProvisioningReadSecrets), + ac.EvalAll( + ac.EvalPermission(ac.ActionAlertingRuleRead, scope), + ac.EvalPermission(dashboards.ActionFoldersRead, scope), + ), + ) + case http.MethodGet + "/api/v1/provisioning/policies", http.MethodGet + "/api/v1/provisioning/contact-points", http.MethodGet + "/api/v1/provisioning/templates", http.MethodGet + "/api/v1/provisioning/templates/{name}", http.MethodGet + "/api/v1/provisioning/mute-timings", - http.MethodGet + "/api/v1/provisioning/mute-timings/{name}", - http.MethodGet + "/api/v1/provisioning/alert-rules", - http.MethodGet + "/api/v1/provisioning/alert-rules/{UID}", - http.MethodGet + "/api/v1/provisioning/alert-rules/export", - http.MethodGet + "/api/v1/provisioning/alert-rules/{UID}/export", - http.MethodGet + "/api/v1/provisioning/folder/{FolderUID}/rule-groups/{Group}", - http.MethodGet + "/api/v1/provisioning/folder/{FolderUID}/rule-groups/{Group}/export": - eval = ac.EvalAny(ac.EvalPermission(ac.ActionAlertingProvisioningRead), ac.EvalPermission(ac.ActionAlertingProvisioningReadSecrets)) // organization scope + http.MethodGet + "/api/v1/provisioning/mute-timings/{name}": + eval = ac.EvalAny( + ac.EvalPermission(ac.ActionAlertingProvisioningRead), + ac.EvalPermission(ac.ActionAlertingProvisioningReadSecrets), + ) + + // Grafana-only Provisioning Write Paths + case http.MethodPost + "/api/v1/provisioning/alert-rules": + eval = ac.EvalAny( + ac.EvalPermission(ac.ActionAlertingProvisioningWrite), + ac.EvalPermission(ac.ActionAlertingRuleCreate), // more granular permissions are enforced by the handler via "authorizeRuleChanges" + ) + case http.MethodPut + "/api/v1/provisioning/alert-rules/{UID}": + eval = ac.EvalAny( + ac.EvalPermission(ac.ActionAlertingProvisioningWrite), + ac.EvalPermission(ac.ActionAlertingRuleUpdate), // more granular permissions are enforced by the handler via "authorizeRuleChanges" + ) + case http.MethodDelete + "/api/v1/provisioning/alert-rules/{UID}": + eval = ac.EvalAny( + ac.EvalPermission(ac.ActionAlertingProvisioningWrite), + ac.EvalPermission(ac.ActionAlertingRuleDelete), // more granular permissions are enforced by the handler via "authorizeRuleChanges" + ) + case http.MethodDelete + "/api/v1/provisioning/folder/{FolderUID}/rule-groups/{Group}": + scope := dashboards.ScopeFoldersProvider.GetResourceScopeUID(ac.Parameter(":FolderUID")) + eval = ac.EvalAny( + ac.EvalPermission(ac.ActionAlertingProvisioningWrite), + ac.EvalAll( + ac.EvalPermission(ac.ActionAlertingRuleDelete, scope), + ac.EvalPermission(ac.ActionAlertingRuleRead, scope), + ac.EvalPermission(dashboards.ActionFoldersRead, scope), + ), + ) + case http.MethodPut + "/api/v1/provisioning/folder/{FolderUID}/rule-groups/{Group}": + scope := dashboards.ScopeFoldersProvider.GetResourceScopeUID(ac.Parameter(":FolderUID")) + eval = ac.EvalAny( + ac.EvalPermission(ac.ActionAlertingProvisioningWrite), + ac.EvalAll( + ac.EvalPermission(ac.ActionAlertingRuleRead, scope), + ac.EvalPermission(dashboards.ActionFoldersRead, scope), + ac.EvalAny( // the exact permissions will be checked after the operations are determined + ac.EvalPermission(ac.ActionAlertingRuleUpdate, scope), + ac.EvalPermission(ac.ActionAlertingRuleCreate, scope), + ac.EvalPermission(ac.ActionAlertingRuleDelete, scope), + ), + ), + ) case http.MethodPut + "/api/v1/provisioning/policies", http.MethodDelete + "/api/v1/provisioning/policies", @@ -246,12 +317,7 @@ func (api *API) authorize(method, path string) web.Handler { http.MethodDelete + "/api/v1/provisioning/templates/{name}", http.MethodPost + "/api/v1/provisioning/mute-timings", http.MethodPut + "/api/v1/provisioning/mute-timings/{name}", - http.MethodDelete + "/api/v1/provisioning/mute-timings/{name}", - http.MethodPost + "/api/v1/provisioning/alert-rules", - http.MethodPut + "/api/v1/provisioning/alert-rules/{UID}", - http.MethodDelete + "/api/v1/provisioning/alert-rules/{UID}", - http.MethodPut + "/api/v1/provisioning/folder/{FolderUID}/rule-groups/{Group}", - http.MethodDelete + "/api/v1/provisioning/folder/{FolderUID}/rule-groups/{Group}": + http.MethodDelete + "/api/v1/provisioning/mute-timings/{name}": eval = ac.EvalPermission(ac.ActionAlertingProvisioningWrite) // organization scope case http.MethodGet + "/api/v1/notifications/time-intervals/{name}", http.MethodGet + "/api/v1/notifications/time-intervals": diff --git a/pkg/services/ngalert/api/tooling/api.json b/pkg/services/ngalert/api/tooling/api.json index a83c274673c..5e25f7b6702 100644 --- a/pkg/services/ngalert/api/tooling/api.json +++ b/pkg/services/ngalert/api/tooling/api.json @@ -5557,7 +5557,7 @@ } } }, - "summary": "Update the interval of a rule group.", + "summary": "Create or update alert rule group.", "tags": [ "provisioning" ] diff --git a/pkg/services/ngalert/api/tooling/definitions/provisioning_alert_rules.go b/pkg/services/ngalert/api/tooling/definitions/provisioning_alert_rules.go index ccf5c11da04..bd8e1548861 100644 --- a/pkg/services/ngalert/api/tooling/definitions/provisioning_alert_rules.go +++ b/pkg/services/ngalert/api/tooling/definitions/provisioning_alert_rules.go @@ -203,7 +203,7 @@ type ProvisionedAlertRule struct { // swagger:route PUT /v1/provisioning/folder/{FolderUID}/rule-groups/{Group} provisioning stable RoutePutAlertRuleGroup // -// Update the interval of a rule group. +// Create or update alert rule group. // // Consumes: // - application/json diff --git a/pkg/services/ngalert/api/tooling/post.json b/pkg/services/ngalert/api/tooling/post.json index 5616075cc2e..687d223c558 100644 --- a/pkg/services/ngalert/api/tooling/post.json +++ b/pkg/services/ngalert/api/tooling/post.json @@ -7706,7 +7706,7 @@ } } }, - "summary": "Update the interval of a rule group.", + "summary": "Create or update alert rule group.", "tags": [ "provisioning" ] diff --git a/pkg/services/ngalert/api/tooling/spec.json b/pkg/services/ngalert/api/tooling/spec.json index a268d42e5ea..c58984d9c4a 100644 --- a/pkg/services/ngalert/api/tooling/spec.json +++ b/pkg/services/ngalert/api/tooling/spec.json @@ -2698,7 +2698,7 @@ "provisioning", "stable" ], - "summary": "Update the interval of a rule group.", + "summary": "Create or update alert rule group.", "operationId": "RoutePutAlertRuleGroup", "parameters": [ { diff --git a/pkg/services/ngalert/ngalert.go b/pkg/services/ngalert/ngalert.go index 0547084a667..065d4ddfd47 100644 --- a/pkg/services/ngalert/ngalert.go +++ b/pkg/services/ngalert/ngalert.go @@ -26,6 +26,7 @@ import ( "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/folder" + ac "github.com/grafana/grafana/pkg/services/ngalert/accesscontrol" "github.com/grafana/grafana/pkg/services/ngalert/api" "github.com/grafana/grafana/pkg/services/ngalert/eval" "github.com/grafana/grafana/pkg/services/ngalert/image" @@ -316,7 +317,8 @@ func (ng *AlertNG) init() error { alertRuleService := provisioning.NewAlertRuleService(ng.store, ng.store, ng.folderService, ng.dashboardService, ng.QuotaService, ng.store, int64(ng.Cfg.UnifiedAlerting.DefaultRuleEvaluationInterval.Seconds()), int64(ng.Cfg.UnifiedAlerting.BaseInterval.Seconds()), - ng.Cfg.UnifiedAlerting.RulesPerRuleGroupLimit, ng.Log, notifier.NewNotificationSettingsValidationService(ng.store)) + ng.Cfg.UnifiedAlerting.RulesPerRuleGroupLimit, ng.Log, notifier.NewNotificationSettingsValidationService(ng.store), + ac.NewRuleService(ng.accesscontrol)) ng.api = &api.API{ Cfg: ng.Cfg, diff --git a/pkg/services/ngalert/provisioning/accesscontrol.go b/pkg/services/ngalert/provisioning/accesscontrol.go new file mode 100644 index 00000000000..91567e6b1f3 --- /dev/null +++ b/pkg/services/ngalert/provisioning/accesscontrol.go @@ -0,0 +1,76 @@ +package provisioning + +import ( + "context" + + ac "github.com/grafana/grafana/pkg/services/accesscontrol" + "github.com/grafana/grafana/pkg/services/auth/identity" + "github.com/grafana/grafana/pkg/services/ngalert/models" + "github.com/grafana/grafana/pkg/services/ngalert/store" +) + +type RuleAccessControlService interface { + HasAccess(ctx context.Context, user identity.Requester, evaluator ac.Evaluator) (bool, error) + AuthorizeAccessToRuleGroup(ctx context.Context, user identity.Requester, rules models.RulesGroup) error + AuthorizeRuleChanges(ctx context.Context, user identity.Requester, change *store.GroupDelta) error +} + +func newRuleAccessControlService(ac RuleAccessControlService) *provisioningRuleAccessControl { + return &provisioningRuleAccessControl{ + RuleAccessControlService: ac, + } +} + +type provisioningRuleAccessControl struct { + RuleAccessControlService +} + +var _ ruleAccessControlService = &provisioningRuleAccessControl{} + +// AuthorizeRuleGroupRead authorizes the read access to a group of rules for a user. +// It first checks if the user has permission to read all rules. If yes, it bypasses the authorization. +// If not, it calls the RuleAccessControlService to authorize access to the rule group. +// It returns an error if the authorization fails or if there is an error during permission check. +func (p *provisioningRuleAccessControl) AuthorizeRuleGroupRead(ctx context.Context, user identity.Requester, rules models.RulesGroup) error { + can, err := p.CanReadAllRules(ctx, user) + if err != nil { + return err + } + if !can { + return p.RuleAccessControlService.AuthorizeAccessToRuleGroup(ctx, user, rules) + } + return nil +} + +// AuthorizeRuleGroupWrite authorizes the write access to a group of rules for a user. +// It first checks if the user has permission to write all rules. If yes, it bypasses the authorization. +// If not, it calls the RuleAccessControlService to authorize the rule changes. +// It returns an error if the authorization fails or if there is an error during permission check. +func (p *provisioningRuleAccessControl) AuthorizeRuleGroupWrite(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { + can, err := p.CanWriteAllRules(ctx, user) + if err != nil { + return err + } + if !can { + return p.RuleAccessControlService.AuthorizeRuleChanges(ctx, user, change) + } + return nil +} + +// CanReadAllRules checks if the user has permission to read all rules. +// It evaluates if the user has either "alert.provisioning:read" or "alert.provisioning.secrets:read" permissions. +// It returns true if the user has the required permissions, otherwise it returns false. +func (p *provisioningRuleAccessControl) CanReadAllRules(ctx context.Context, user identity.Requester) (bool, error) { + return p.HasAccess(ctx, user, ac.EvalAny( + ac.EvalPermission(ac.ActionAlertingProvisioningRead), + ac.EvalPermission(ac.ActionAlertingProvisioningReadSecrets), + )) +} + +// CanWriteAllRules is a method that checks if a user has permission to write all rules. +// It calls the HasAccess method with the provided action "alert.provisioning:write". +// It returns true if the user has permission, false otherwise. +// It returns an error if there is a problem checking the permission. +func (p *provisioningRuleAccessControl) CanWriteAllRules(ctx context.Context, user identity.Requester) (bool, error) { + return p.HasAccess(ctx, user, ac.EvalPermission(ac.ActionAlertingProvisioningWrite)) +} diff --git a/pkg/services/ngalert/provisioning/accesscontrol_test.go b/pkg/services/ngalert/provisioning/accesscontrol_test.go new file mode 100644 index 00000000000..04d9000c0ec --- /dev/null +++ b/pkg/services/ngalert/provisioning/accesscontrol_test.go @@ -0,0 +1,231 @@ +package provisioning + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" + + "github.com/grafana/grafana/pkg/services/accesscontrol" + "github.com/grafana/grafana/pkg/services/auth/identity" + "github.com/grafana/grafana/pkg/services/ngalert/accesscontrol/fakes" + "github.com/grafana/grafana/pkg/services/ngalert/models" + "github.com/grafana/grafana/pkg/services/ngalert/store" + "github.com/grafana/grafana/pkg/services/user" +) + +func TestCanReadAllRules(t *testing.T) { + testUser := &user.SignedInUser{} + + t.Run("should check for provisioning permissions", func(t *testing.T) { + rs := &fakes.FakeRuleService{} + expected := rand.Int()%2 == 1 + rs.HasAccessFunc = func(ctx context.Context, requester identity.Requester, evaluator accesscontrol.Evaluator) (bool, error) { + return expected, nil + } + p := &provisioningRuleAccessControl{rs} + res, err := p.CanReadAllRules(context.Background(), testUser) + require.NoError(t, err) + require.Equal(t, expected, res) + + require.Len(t, rs.Calls, 1) + require.Equal(t, "HasAccess", rs.Calls[0].MethodName) + require.Equal(t, accesscontrol.EvalAny( + accesscontrol.EvalPermission(accesscontrol.ActionAlertingProvisioningRead), + accesscontrol.EvalPermission(accesscontrol.ActionAlertingProvisioningReadSecrets), + ).GoString(), rs.Calls[0].Arguments[2].(accesscontrol.Evaluator).GoString()) + }) + + t.Run("should return error", func(t *testing.T) { + rs := &fakes.FakeRuleService{} + expected := errors.New("test") + rs.HasAccessFunc = func(ctx context.Context, requester identity.Requester, evaluator accesscontrol.Evaluator) (bool, error) { + return false, expected + } + p := &provisioningRuleAccessControl{rs} + _, err := p.CanReadAllRules(context.Background(), testUser) + require.ErrorIs(t, err, expected) + }) +} + +func TestCanWriteAllRules(t *testing.T) { + testUser := &user.SignedInUser{} + + t.Run("should check for provisioning permissions", func(t *testing.T) { + rs := &fakes.FakeRuleService{} + expected := rand.Int()%2 == 1 + rs.HasAccessFunc = func(ctx context.Context, requester identity.Requester, evaluator accesscontrol.Evaluator) (bool, error) { + return expected, nil + } + p := &provisioningRuleAccessControl{rs} + res, err := p.CanWriteAllRules(context.Background(), testUser) + require.NoError(t, err) + require.Equal(t, expected, res) + + require.Len(t, rs.Calls, 1) + require.Equal(t, "HasAccess", rs.Calls[0].MethodName) + require.Equal(t, accesscontrol.EvalPermission(accesscontrol.ActionAlertingProvisioningWrite).GoString(), rs.Calls[0].Arguments[2].(accesscontrol.Evaluator).GoString()) + }) + + t.Run("should return error", func(t *testing.T) { + rs := &fakes.FakeRuleService{} + expected := errors.New("test") + rs.HasAccessFunc = func(ctx context.Context, requester identity.Requester, evaluator accesscontrol.Evaluator) (bool, error) { + return false, expected + } + p := &provisioningRuleAccessControl{rs} + _, err := p.CanWriteAllRules(context.Background(), testUser) + require.ErrorIs(t, err, expected) + }) +} + +func TestAuthorizeAccessToRuleGroup(t *testing.T) { + testUser := &user.SignedInUser{} + rules := models.GenerateAlertRules(1, models.AlertRuleGen()) + + t.Run("should return nil when user has provisioning permissions", func(t *testing.T) { + rs := &fakes.FakeRuleService{} + provisioner := provisioningRuleAccessControl{ + RuleAccessControlService: rs, + } + + rs.HasAccessFunc = func(ctx context.Context, user identity.Requester, evaluator accesscontrol.Evaluator) (bool, error) { + return true, nil + } + + err := provisioner.AuthorizeRuleGroupRead(context.Background(), testUser, rules) + require.NoError(t, err) + + require.Len(t, rs.Calls, 1) + require.Equal(t, "HasAccess", rs.Calls[0].MethodName) + assert.Equal(t, accesscontrol.EvalAny( + accesscontrol.EvalPermission(accesscontrol.ActionAlertingProvisioningRead), + accesscontrol.EvalPermission(accesscontrol.ActionAlertingProvisioningReadSecrets), + ).GoString(), rs.Calls[0].Arguments[2].(accesscontrol.Evaluator).GoString()) + assert.Equal(t, testUser, rs.Calls[0].Arguments[1]) + }) + + t.Run("should call upstream method if no provisioning permissions", func(t *testing.T) { + rs := &fakes.FakeRuleService{} + provisioner := provisioningRuleAccessControl{ + RuleAccessControlService: rs, + } + + rs.HasAccessFunc = func(ctx context.Context, user identity.Requester, evaluator accesscontrol.Evaluator) (bool, error) { + return false, nil + } + rs.AuthorizeAccessToRuleGroupFunc = func(ctx context.Context, requester identity.Requester, group models.RulesGroup) error { + return nil + } + + err := provisioner.AuthorizeRuleGroupRead(context.Background(), testUser, rules) + require.NoError(t, err) + + require.Len(t, rs.Calls, 2) + require.Equal(t, "HasAccess", rs.Calls[0].MethodName) + require.Equal(t, "AuthorizeRuleGroupRead", rs.Calls[1].MethodName) + require.Equal(t, models.RulesGroup(rules), rs.Calls[1].Arguments[2]) + }) + + t.Run("should propagate error", func(t *testing.T) { + rs := &fakes.FakeRuleService{} + provisioner := provisioningRuleAccessControl{ + RuleAccessControlService: rs, + } + + expected := errors.New("test1") + rs.HasAccessFunc = func(ctx context.Context, user identity.Requester, evaluator accesscontrol.Evaluator) (bool, error) { + return false, expected + } + + err := provisioner.AuthorizeRuleGroupRead(context.Background(), testUser, rules) + require.ErrorIs(t, err, expected) + + rs.HasAccessFunc = func(ctx context.Context, user identity.Requester, evaluator accesscontrol.Evaluator) (bool, error) { + return false, nil + } + expected = errors.New("test2") + rs.AuthorizeAccessToRuleGroupFunc = func(ctx context.Context, requester identity.Requester, group models.RulesGroup) error { + return expected + } + + err = provisioner.AuthorizeRuleGroupRead(context.Background(), testUser, rules) + require.ErrorIs(t, err, expected) + }) +} + +func TestAuthorizeRuleChanges(t *testing.T) { + testUser := &user.SignedInUser{} + change := &store.GroupDelta{} + + t.Run("should return nil when user has provisioning permissions", func(t *testing.T) { + rs := &fakes.FakeRuleService{} + provisioner := provisioningRuleAccessControl{ + RuleAccessControlService: rs, + } + + rs.HasAccessFunc = func(ctx context.Context, user identity.Requester, evaluator accesscontrol.Evaluator) (bool, error) { + return true, nil + } + + err := provisioner.AuthorizeRuleGroupWrite(context.Background(), testUser, change) + require.NoError(t, err) + + require.Len(t, rs.Calls, 1) + require.Equal(t, "HasAccess", rs.Calls[0].MethodName) + assert.Equal(t, accesscontrol.EvalPermission(accesscontrol.ActionAlertingProvisioningWrite).GoString(), rs.Calls[0].Arguments[2].(accesscontrol.Evaluator).GoString()) + assert.Equal(t, testUser, rs.Calls[0].Arguments[1]) + }) + + t.Run("should call upstream method if no provisioning permissions", func(t *testing.T) { + rs := &fakes.FakeRuleService{} + provisioner := provisioningRuleAccessControl{ + RuleAccessControlService: rs, + } + + rs.HasAccessFunc = func(ctx context.Context, user identity.Requester, evaluator accesscontrol.Evaluator) (bool, error) { + return false, nil + } + rs.AuthorizeRuleChangesFunc = func(ctx context.Context, user identity.Requester, delta *store.GroupDelta) error { + return nil + } + + err := provisioner.AuthorizeRuleGroupWrite(context.Background(), testUser, change) + require.NoError(t, err) + + require.Len(t, rs.Calls, 2) + require.Equal(t, "HasAccess", rs.Calls[0].MethodName) + require.Equal(t, "AuthorizeRuleGroupWrite", rs.Calls[1].MethodName) + require.Equal(t, testUser, rs.Calls[1].Arguments[1]) + require.Equal(t, change, rs.Calls[1].Arguments[2]) + }) + + t.Run("should propagate error", func(t *testing.T) { + rs := &fakes.FakeRuleService{} + provisioner := provisioningRuleAccessControl{ + RuleAccessControlService: rs, + } + + expected := errors.New("test1") + rs.HasAccessFunc = func(ctx context.Context, user identity.Requester, evaluator accesscontrol.Evaluator) (bool, error) { + return false, expected + } + + err := provisioner.AuthorizeRuleGroupWrite(context.Background(), testUser, change) + require.ErrorIs(t, err, expected) + + rs.HasAccessFunc = func(ctx context.Context, user identity.Requester, evaluator accesscontrol.Evaluator) (bool, error) { + return false, nil + } + expected = errors.New("test2") + rs.AuthorizeRuleChangesFunc = func(ctx context.Context, requester identity.Requester, delta *store.GroupDelta) error { + return expected + } + + err = provisioner.AuthorizeRuleGroupWrite(context.Background(), testUser, change) + require.ErrorIs(t, err, expected) + }) +} diff --git a/pkg/services/ngalert/provisioning/alert_rules.go b/pkg/services/ngalert/provisioning/alert_rules.go index 4aab1d60f41..df227f061eb 100644 --- a/pkg/services/ngalert/provisioning/alert_rules.go +++ b/pkg/services/ngalert/provisioning/alert_rules.go @@ -10,6 +10,7 @@ import ( "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/dashboards" "github.com/grafana/grafana/pkg/services/folder" + "github.com/grafana/grafana/pkg/services/ngalert/accesscontrol" "github.com/grafana/grafana/pkg/services/ngalert/models" "github.com/grafana/grafana/pkg/services/ngalert/notifier" "github.com/grafana/grafana/pkg/services/ngalert/store" @@ -17,6 +18,15 @@ import ( "github.com/grafana/grafana/pkg/util" ) +type ruleAccessControlService interface { + AuthorizeRuleGroupRead(ctx context.Context, user identity.Requester, rules models.RulesGroup) error + AuthorizeRuleGroupWrite(ctx context.Context, user identity.Requester, change *store.GroupDelta) error + // CanReadAllRules returns true if the user has full access to read rules via provisioning API and bypass regular checks + CanReadAllRules(ctx context.Context, user identity.Requester) (bool, error) + // CanWriteAllRules returns true if the user has full access to write rules via provisioning API and bypass regular checks + CanWriteAllRules(ctx context.Context, user identity.Requester) (bool, error) +} + type NotificationSettingsValidatorProvider interface { Validator(ctx context.Context, orgID int64) (notifier.NotificationSettingsValidator, error) } @@ -33,6 +43,7 @@ type AlertRuleService struct { xact TransactionManager log log.Logger nsValidatorProvider NotificationSettingsValidatorProvider + authz ruleAccessControlService } func NewAlertRuleService(ruleStore RuleStore, @@ -46,6 +57,7 @@ func NewAlertRuleService(ruleStore RuleStore, rulesPerRuleGroupLimit int64, log log.Logger, ns NotificationSettingsValidatorProvider, + authz RuleAccessControlService, ) *AlertRuleService { return &AlertRuleService{ defaultIntervalSeconds: defaultIntervalSeconds, @@ -59,6 +71,7 @@ func NewAlertRuleService(ruleStore RuleStore, xact: xact, log: log, nsValidatorProvider: ns, + authz: newRuleAccessControlService(authz), } } @@ -78,23 +91,89 @@ func (service *AlertRuleService) GetAlertRules(ctx context.Context, user identit return nil, nil, err } } - return rules, provenances, nil + + can, err := service.authz.CanReadAllRules(ctx, user) + if err != nil { + return nil, nil, err + } + if can { + return rules, provenances, nil + } + // If user does not have blanket privilege to read rules, remove all rules that are not allowed to the user. + groups := models.GroupByAlertRuleGroupKey(rules) + result := make([]*models.AlertRule, 0, len(rules)) + for _, group := range groups { + if err := service.authz.AuthorizeRuleGroupRead(ctx, user, group); err != nil { + if errors.Is(err, accesscontrol.ErrAuthorizationBase) { + // remove provenances for rules that will not be added to the output + for _, rule := range group { + delete(provenances, rule.ResourceID()) + } + continue + } + return nil, nil, err + } + result = append(result, group...) + } + return result, provenances, nil } -func (service *AlertRuleService) GetAlertRule(ctx context.Context, user identity.Requester, ruleUID string) (models.AlertRule, models.Provenance, error) { - query := &models.GetAlertRuleByUIDQuery{ - OrgID: user.GetOrgID(), +func (service *AlertRuleService) getAlertRuleAuthorized(ctx context.Context, user identity.Requester, ruleUID string) (models.AlertRule, error) { + // check if the user can read all rules. If it cannot, pull the entire group and verify access to the entire group. + can, err := service.authz.CanReadAllRules(ctx, user) + if err != nil { + return models.AlertRule{}, err + } + // if user has blanket access to all rules, just read a single rule from database + if can { + query := &models.GetAlertRuleByUIDQuery{ + OrgID: user.GetOrgID(), + UID: ruleUID, + } + rule, err := service.ruleStore.GetAlertRuleByUID(ctx, query) + if err != nil { + return models.AlertRule{}, err + } + if rule == nil { + return models.AlertRule{}, models.ErrAlertRuleNotFound + } + return *rule, nil + } + + // if user does not have privilege to access all rules, check that the user can read this rule by fetching entire group and + // checking that user has access to it. + q := &models.GetAlertRulesGroupByRuleUIDQuery{ UID: ruleUID, + OrgID: user.GetOrgID(), } - rule, err := service.ruleStore.GetAlertRuleByUID(ctx, query) + group, err := service.ruleStore.GetAlertRulesGroupByRuleUID(ctx, q) + if err != nil { + return models.AlertRule{}, err + } + if len(group) == 0 { + return models.AlertRule{}, models.ErrAlertRuleNotFound + } + if err := service.authz.AuthorizeRuleGroupRead(ctx, user, group); err != nil { + return models.AlertRule{}, err + } + for _, rule := range group { + if rule.UID == ruleUID { + return *rule, nil + } + } + return models.AlertRule{}, models.ErrAlertRuleNotFound +} + +func (service *AlertRuleService) GetAlertRule(ctx context.Context, user identity.Requester, ruleUID string) (models.AlertRule, models.Provenance, error) { + rule, err := service.getAlertRuleAuthorized(ctx, user, ruleUID) if err != nil { return models.AlertRule{}, models.ProvenanceNone, err } - provenance, err := service.provenanceStore.GetProvenance(ctx, rule, user.GetOrgID()) + provenance, err := service.provenanceStore.GetProvenance(ctx, &rule, user.GetOrgID()) if err != nil { return models.AlertRule{}, models.ProvenanceNone, err } - return *rule, provenance, nil + return rule, provenance, nil } type AlertRuleWithFolderTitle struct { @@ -104,11 +183,7 @@ type AlertRuleWithFolderTitle struct { // GetAlertRuleWithFolderTitle returns a single alert rule with its folder title. func (service *AlertRuleService) GetAlertRuleWithFolderTitle(ctx context.Context, user identity.Requester, ruleUID string) (AlertRuleWithFolderTitle, error) { - query := &models.GetAlertRuleByUIDQuery{ - OrgID: user.GetOrgID(), - UID: ruleUID, - } - rule, err := service.ruleStore.GetAlertRuleByUID(ctx, query) + rule, err := service.getAlertRuleAuthorized(ctx, user, ruleUID) if err != nil { return AlertRuleWithFolderTitle{}, err } @@ -124,7 +199,7 @@ func (service *AlertRuleService) GetAlertRuleWithFolderTitle(ctx context.Context } return AlertRuleWithFolderTitle{ - AlertRule: *rule, + AlertRule: rule, FolderTitle: dash.Title, }, nil } @@ -138,13 +213,33 @@ func (service *AlertRuleService) CreateAlertRule(ctx context.Context, user ident } else if err := util.ValidateUID(rule.UID); err != nil { return models.AlertRule{}, errors.Join(models.ErrAlertRuleFailedValidation, fmt.Errorf("cannot create rule with UID '%s': %w", rule.UID, err)) } - interval, err := service.ruleStore.GetRuleGroupInterval(ctx, rule.OrgID, rule.NamespaceUID, rule.RuleGroup) - // if the alert group does not exist we just use the default interval - if err != nil && errors.Is(err, models.ErrAlertRuleGroupNotFound) { - interval = service.defaultIntervalSeconds - } else if err != nil { + var interval = service.defaultIntervalSeconds + // check if user can bypass fine-grained rule authorization checks. If it cannot, verfiy that the user can add rules to the group + canWriteAllRules, err := service.authz.CanWriteAllRules(ctx, user) + if err != nil { return models.AlertRule{}, err } + if canWriteAllRules { + groupInterval, err := service.ruleStore.GetRuleGroupInterval(ctx, rule.OrgID, rule.NamespaceUID, rule.RuleGroup) + // if the alert group does not exist we just use the default interval + if err == nil { + interval = groupInterval + } else if !errors.Is(err, models.ErrAlertRuleGroupNotFound) { + return models.AlertRule{}, err + } + } else { + delta, err := store.CalculateRuleCreate(ctx, service.ruleStore, &rule) + if err != nil { + return models.AlertRule{}, fmt.Errorf("failed to calculate delta: %w", err) + } + if err := service.authz.AuthorizeRuleGroupWrite(ctx, user, delta); err != nil { + return models.AlertRule{}, err + } + existingGroup := delta.AffectedGroups[rule.GetGroupKey()] + if len(existingGroup) > 0 { + interval = existingGroup[0].IntervalSeconds + } + } rule.IntervalSeconds = interval err = rule.SetDashboardAndPanelFromAnnotations() if err != nil { @@ -209,11 +304,21 @@ func (service *AlertRuleService) GetRuleGroup(ctx context.Context, user identity if len(ruleList) == 0 { return models.AlertRuleGroup{}, models.ErrAlertRuleGroupNotFound.Errorf("") } + + can, err := service.authz.CanReadAllRules(ctx, user) + if err != nil { + return models.AlertRuleGroup{}, err + } + if !can { + if err := service.authz.AuthorizeRuleGroupRead(ctx, user, ruleList); err != nil { + return models.AlertRuleGroup{}, err + } + } res := models.AlertRuleGroup{ Title: ruleList[0].RuleGroup, FolderUID: ruleList[0].NamespaceUID, Interval: ruleList[0].IntervalSeconds, - Rules: []models.AlertRule{}, + Rules: make([]models.AlertRule, 0, len(ruleList)), } for _, r := range ruleList { if r != nil { @@ -250,6 +355,39 @@ func (service *AlertRuleService) UpdateRuleGroup(ctx context.Context, user ident New: newRule, }) } + + // check if user has write access to all rules and can bypass the regular checks. + can, err := service.authz.CanWriteAllRules(ctx, user) + if err != nil { + return err + } + // If it cannot, check that the user is authorized to perform all the changes caused by this request + if !can { + groupKey := models.AlertRuleGroupKey{ + OrgID: user.GetOrgID(), + NamespaceUID: namespaceUID, + RuleGroup: ruleGroup, + } + ruleDeltas := make([]store.RuleDelta, 0, len(ruleList)) + for _, upd := range updateRules { + updNew := upd.New + ruleDeltas = append(ruleDeltas, store.RuleDelta{ + Existing: upd.Existing, + New: &updNew, + }) + } + delta := &store.GroupDelta{ + GroupKey: groupKey, + AffectedGroups: map[models.AlertRuleGroupKey]models.RulesGroup{ + groupKey: ruleList, + }, + Update: ruleDeltas, + } + if err := service.authz.AuthorizeRuleGroupWrite(ctx, user, delta); err != nil { + return err + } + } + return service.ruleStore.UpdateAlertRules(ctx, updateRules) }) } @@ -264,10 +402,22 @@ func (service *AlertRuleService) ReplaceRuleGroup(ctx context.Context, user iden return err } - if len(delta.New) == 0 && len(delta.Update) == 0 && len(delta.Delete) == 0 { + if delta.IsEmpty() { return nil } + // check if the current user has permissions to all rules and can bypass the regular authorization validation. + can, err := service.authz.CanWriteAllRules(ctx, user) + if err != nil { + return err + } + + if !can { + if err := service.authz.AuthorizeRuleGroupWrite(ctx, user, delta); err != nil { + return err + } + } + newOrUpdatedNotificationSettings := delta.NewOrUpdatedNotificationSettings() if len(newOrUpdatedNotificationSettings) > 0 { validator, err := service.nsValidatorProvider.Validator(ctx, delta.GroupKey.OrgID) @@ -285,35 +435,27 @@ func (service *AlertRuleService) ReplaceRuleGroup(ctx context.Context, user iden } func (service *AlertRuleService) DeleteRuleGroup(ctx context.Context, user identity.Requester, namespaceUID, group string, provenance models.Provenance) error { - // List all rules in the group. - q := models.ListAlertRulesQuery{ - OrgID: user.GetOrgID(), - NamespaceUIDs: []string{namespaceUID}, - RuleGroup: group, - } - ruleList, err := service.ruleStore.ListAlertRules(ctx, &q) + delta, err := store.CalculateRuleGroupDelete(ctx, service.ruleStore, models.AlertRuleGroupKey{ + OrgID: user.GetOrgID(), + NamespaceUID: namespaceUID, + RuleGroup: group, + }) if err != nil { return err } - if len(ruleList) == 0 { - return models.ErrAlertRuleGroupNotFound.Errorf("") - } - // Check provenance for all rules in the group. Fail to delete if any deletions aren't allowed. - for _, rule := range ruleList { - storedProvenance, err := service.provenanceStore.GetProvenance(ctx, rule, rule.OrgID) - if err != nil { + // check if the current user has permissions to all rules and can bypass the regular authorization validation. + can, err := service.authz.CanWriteAllRules(ctx, user) + if err != nil { + return err + } + if !can { + if err := service.authz.AuthorizeRuleGroupWrite(ctx, user, delta); err != nil { return err } - if storedProvenance != provenance && storedProvenance != models.ProvenanceNone { - return fmt.Errorf("cannot delete with provided provenance '%s', needs '%s'", provenance, storedProvenance) - } } - // Delete all rules. - return service.xact.InTransaction(ctx, func(ctx context.Context) error { - return service.deleteRules(ctx, user.GetOrgID(), ruleList...) - }) + return service.persistDelta(ctx, user, delta, provenance) } func (service *AlertRuleService) calcDelta(ctx context.Context, user identity.Requester, group models.AlertRuleGroup) (*store.GroupDelta, error) { @@ -346,7 +488,7 @@ func (service *AlertRuleService) calcDelta(ctx context.Context, user identity.Re NamespaceUID: group.FolderUID, RuleGroup: group.Title, } - rules := make([]*models.AlertRuleWithOptionals, len(group.Rules)) + rules := make([]*models.AlertRuleWithOptionals, 0, len(group.Rules)) group = *syncGroupRuleFields(&group, user.GetOrgID()) for i := range group.Rules { if err := group.Rules[i].SetDashboardAndPanelFromAnnotations(); err != nil { @@ -374,7 +516,7 @@ func (service *AlertRuleService) persistDelta(ctx context.Context, user identity return err } if canUpdate := canUpdateProvenanceInRuleGroup(storedProvenance, provenance); !canUpdate { - return fmt.Errorf("cannot update with provided provenance '%s', needs '%s'", provenance, storedProvenance) + return fmt.Errorf("cannot delete with provided provenance '%s', needs '%s'", provenance, storedProvenance) } } if err := service.deleteRules(ctx, user.GetOrgID(), delta.Delete...); err != nil { @@ -430,7 +572,41 @@ func (service *AlertRuleService) persistDelta(ctx context.Context, user identity // UpdateAlertRule updates an alert rule. func (service *AlertRuleService) UpdateAlertRule(ctx context.Context, user identity.Requester, rule models.AlertRule, provenance models.Provenance) (models.AlertRule, error) { - storedRule, storedProvenance, err := service.GetAlertRule(ctx, user, rule.UID) + var storedRule *models.AlertRule + // check if the user has full access to all rules and can bypass the regular authorization validations. + // If it cannot, calculate the changes to the group caused by this update and authorize them. + canWriteAllRules, err := service.authz.CanWriteAllRules(ctx, user) + if err != nil { + return models.AlertRule{}, err + } + if canWriteAllRules { + query := &models.GetAlertRuleByUIDQuery{ + OrgID: rule.OrgID, + UID: rule.UID, + } + existing, err := service.ruleStore.GetAlertRuleByUID(ctx, query) + if err != nil { + return models.AlertRule{}, err + } + storedRule = existing + } else { + delta, err := store.CalculateRuleUpdate(ctx, service.ruleStore, &models.AlertRuleWithOptionals{AlertRule: rule}) + if err != nil { + return models.AlertRule{}, err + } + if err = service.authz.AuthorizeRuleGroupWrite(ctx, user, delta); err != nil { + return models.AlertRule{}, err + } + for _, d := range delta.Update { + if d.Existing.GetKey() == rule.GetKey() { + storedRule = d.Existing + } + } + if storedRule == nil { // this should not happen but we better catch it to avoid panic + return models.AlertRule{}, fmt.Errorf("cannot find rule in the delta") + } + } + storedProvenance, err := service.provenanceStore.GetProvenance(ctx, storedRule, storedRule.OrgID) if err != nil { return models.AlertRule{}, err } @@ -458,7 +634,7 @@ func (service *AlertRuleService) UpdateAlertRule(ctx context.Context, user ident err = service.xact.InTransaction(ctx, func(ctx context.Context) error { err := service.ruleStore.UpdateAlertRules(ctx, []models.UpdateRule{ { - Existing: &storedRule, + Existing: storedRule, New: rule, }, }) @@ -486,6 +662,24 @@ func (service *AlertRuleService) DeleteAlertRule(ctx context.Context, user ident if storedProvenance != provenance && storedProvenance != models.ProvenanceNone { return fmt.Errorf("cannot delete with provided provenance '%s', needs '%s'", provenance, storedProvenance) } + + can, err := service.authz.CanWriteAllRules(ctx, user) + if err != nil { + return err + } + if !can { + delta, err := store.CalculateRuleDelete(ctx, service.ruleStore, rule.GetKey()) + if err != nil { + return err + } + if err = service.authz.AuthorizeRuleGroupWrite(ctx, user, delta); err != nil { + return err + } + } + + // The single delete is idempotent, and doesn't error when deleting a group that already doesn't exist. + // This is different from deleting groups. We delete the rules directly rather than persisting a delta here to keep the semantics the same. + // TODO: Either persist a delta here as a breaking change, or deprecate this endpoint in favor of the group endpoint. return service.xact.InTransaction(ctx, func(ctx context.Context) error { return service.deleteRules(ctx, user.GetOrgID(), rule) }) @@ -536,18 +730,10 @@ func (service *AlertRuleService) deleteRules(ctx context.Context, orgID int64, t // GetAlertRuleGroupWithFolderTitle returns the alert rule group with folder title. func (service *AlertRuleService) GetAlertRuleGroupWithFolderTitle(ctx context.Context, user identity.Requester, namespaceUID, group string) (models.AlertRuleGroupWithFolderTitle, error) { - q := models.ListAlertRulesQuery{ - OrgID: user.GetOrgID(), - NamespaceUIDs: []string{namespaceUID}, - RuleGroup: group, - } - ruleList, err := service.ruleStore.ListAlertRules(ctx, &q) + ruleList, err := service.GetRuleGroup(ctx, user, namespaceUID, group) if err != nil { return models.AlertRuleGroupWithFolderTitle{}, err } - if len(ruleList) == 0 { - return models.AlertRuleGroupWithFolderTitle{}, models.ErrAlertRuleGroupNotFound.Errorf("") - } dq := dashboards.GetDashboardQuery{ OrgID: user.GetOrgID(), @@ -558,7 +744,7 @@ func (service *AlertRuleService) GetAlertRuleGroupWithFolderTitle(ctx context.Co return models.AlertRuleGroupWithFolderTitle{}, err } - res := models.NewAlertRuleGroupWithFolderTitleFromRulesGroup(ruleList[0].GetGroupKey(), ruleList, dash.Title) + res := models.NewAlertRuleGroupWithFolderTitle(ruleList.Rules[0].GetGroupKey(), ruleList.Rules, dash.Title) return res, nil } @@ -576,16 +762,28 @@ func (service *AlertRuleService) GetAlertGroupsWithFolderTitle(ctx context.Conte if err != nil { return nil, err } + groups := models.GroupByAlertRuleGroupKey(ruleList) - groups := make(map[models.AlertRuleGroupKey][]models.AlertRule) - namespaces := make(map[string][]*models.AlertRuleGroupKey) - for _, r := range ruleList { - groupKey := r.GetGroupKey() - group := groups[groupKey] - group = append(group, *r) - groups[groupKey] = group + can, err := service.authz.CanReadAllRules(ctx, user) + if err != nil { + return nil, err + } + if !can { + // if user cannot read all rules, check read access to each group and remove groups that the user does not have access to + for key, group := range groups { + if err := service.authz.AuthorizeRuleGroupRead(ctx, user, group); err != nil { + if errors.Is(err, accesscontrol.ErrAuthorizationBase) { + delete(groups, key) + continue + } + return nil, err + } + } + } - namespaces[r.NamespaceUID] = append(namespaces[r.NamespaceUID], &groupKey) + namespaces := make(map[string][]*models.AlertRuleGroupKey) + for groupKey := range groups { + namespaces[groupKey.NamespaceUID] = append(namespaces[groupKey.NamespaceUID], util.Pointer(groupKey)) } if len(namespaces) == 0 { @@ -615,7 +813,7 @@ func (service *AlertRuleService) GetAlertGroupsWithFolderTitle(ctx context.Conte if !ok { return nil, fmt.Errorf("cannot find title for folder with uid '%s'", groupKey.NamespaceUID) } - result = append(result, models.NewAlertRuleGroupWithFolderTitle(groupKey, rules, title)) + result = append(result, models.NewAlertRuleGroupWithFolderTitleFromRulesGroup(groupKey, rules, title)) } // Return results in a stable manner. diff --git a/pkg/services/ngalert/provisioning/alert_rules_test.go b/pkg/services/ngalert/provisioning/alert_rules_test.go index f44a94f86cc..fea81b19053 100644 --- a/pkg/services/ngalert/provisioning/alert_rules_test.go +++ b/pkg/services/ngalert/provisioning/alert_rules_test.go @@ -3,14 +3,20 @@ package provisioning import ( "context" "encoding/json" + "errors" + "math/rand" "strconv" "strings" "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/expr" + "github.com/grafana/grafana/pkg/services/auth/identity" + "github.com/grafana/grafana/pkg/services/ngalert/accesscontrol" + "github.com/grafana/grafana/pkg/services/ngalert/tests/fakes" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/util" @@ -541,10 +547,216 @@ func TestAlertRuleService(t *testing.T) { } func TestCreateAlertRule(t *testing.T) { - ruleService := createAlertRuleService(t) - var orgID int64 = 1 + orgID := rand.Int63() u := &user.SignedInUser{OrgID: orgID} + groupKey := models.GenerateGroupKey(orgID) + groupIntervalSeconds := int64(30) + rules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithGroupKey(groupKey), models.WithInterval(time.Duration(groupIntervalSeconds)*time.Second))) + groupProvenance := models.ProvenanceAPI + + initServiceWithData := func(t *testing.T) (*AlertRuleService, *fakes.RuleStore, *fakes.FakeProvisioningStore, *fakeRuleAccessControlService) { + service, ruleStore, provenanceStore, ac := initService(t) + ruleStore.Rules = map[int64][]*models.AlertRule{ + orgID: rules, + } + for _, rule := range rules { + require.NoError(t, provenanceStore.SetProvenance(context.Background(), rule, orgID, groupProvenance)) + } + return service, ruleStore, provenanceStore, ac + } + + t.Run("when user can write all rules", func(t *testing.T) { + t.Run("and a new rule creates a new group", func(t *testing.T) { + rule := models.AlertRuleGen(models.WithOrgID(orgID))() + service, ruleStore, provenanceStore, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return true, nil + } + + actualRule, err := service.CreateAlertRule(context.Background(), u, *rule, models.ProvenanceFile) + require.NoError(t, err) + + require.Len(t, ac.Calls, 1) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + + t.Run("it should assign default interval", func(t *testing.T) { + require.Equal(t, service.defaultIntervalSeconds, actualRule.IntervalSeconds) + }) + + t.Run("inserts to database", func(t *testing.T) { + inserts := ruleStore.GetRecordedCommands(func(cmd any) (any, bool) { + a, ok := cmd.([]models.AlertRule) + return a, ok + }) + require.Len(t, inserts, 1) + cmd := inserts[0].([]models.AlertRule) + require.Len(t, cmd, 1) + }) + + t.Run("set correct provenance", func(t *testing.T) { + p, err := provenanceStore.GetProvenance(context.Background(), &actualRule, orgID) + require.NoError(t, err) + require.Equal(t, models.ProvenanceFile, p) + }) + }) + t.Run("and it adds a rule to a group", func(t *testing.T) { + rule := models.AlertRuleGen(models.WithGroupKey(groupKey))() + service, ruleStore, provenanceStore, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return true, nil + } + + actualRule, err := service.CreateAlertRule(context.Background(), u, *rule, models.ProvenanceNone) + require.NoError(t, err) + + require.Len(t, ac.Calls, 1) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + + t.Run("it should assign group interval", func(t *testing.T) { + require.Equal(t, groupIntervalSeconds, actualRule.IntervalSeconds) + }) + + t.Run("inserts to database", func(t *testing.T) { + inserts := ruleStore.GetRecordedCommands(func(cmd any) (any, bool) { + a, ok := cmd.([]models.AlertRule) + return a, ok + }) + require.Len(t, inserts, 1) + cmd := inserts[0].([]models.AlertRule) + require.Len(t, cmd, 1) + }) + + t.Run("set correct provenance", func(t *testing.T) { + p, err := provenanceStore.GetProvenance(context.Background(), &actualRule, orgID) + require.NoError(t, err) + require.Equal(t, models.ProvenanceNone, p) + }) + }) + }) + t.Run("when user cannot write all rules", func(t *testing.T) { + t.Run("and it creates a new group", func(t *testing.T) { + rule := models.AlertRuleGen(models.WithOrgID(orgID))() + t.Run("it should authorize the change", func(t *testing.T) { + service, ruleStore, provenanceStore, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + ac.AuthorizeRuleChangesFunc = func(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { + assert.Equal(t, u, user) + assert.Equal(t, rule.GetGroupKey(), change.GroupKey) + assert.Len(t, change.New, 1) + assert.Empty(t, change.Update) + assert.Empty(t, change.Delete) + assert.Empty(t, change.AffectedGroups) + return nil + } + + actualRule, err := service.CreateAlertRule(context.Background(), u, *rule, models.ProvenanceFile) + require.NoError(t, err) + + require.Len(t, ac.Calls, 2) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupWrite", ac.Calls[1].Method) + + t.Run("it should assign default interval", func(t *testing.T) { + require.Equal(t, service.defaultIntervalSeconds, actualRule.IntervalSeconds) + }) + + t.Run("inserts to database", func(t *testing.T) { + inserts := ruleStore.GetRecordedCommands(func(cmd any) (any, bool) { + a, ok := cmd.([]models.AlertRule) + return a, ok + }) + require.Len(t, inserts, 1) + cmd := inserts[0].([]models.AlertRule) + require.Len(t, cmd, 1) + }) + + t.Run("set correct provenance", func(t *testing.T) { + p, err := provenanceStore.GetProvenance(context.Background(), &actualRule, orgID) + require.NoError(t, err) + require.Equal(t, models.ProvenanceFile, p) + }) + }) + }) + t.Run("and it adds a rule to a group", func(t *testing.T) { + rule := models.AlertRuleGen(models.WithGroupKey(groupKey))() + t.Run("it should authorize the change to whole group", func(t *testing.T) { + service, ruleStore, provenanceStore, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + ac.AuthorizeRuleChangesFunc = func(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { + assert.Equal(t, u, user) + assert.Equal(t, rule.GetGroupKey(), change.GroupKey) + assert.Contains(t, change.AffectedGroups, change.GroupKey) + assert.EqualValues(t, change.AffectedGroups[change.GroupKey], rules) + assert.Len(t, change.New, 1) + assert.Empty(t, change.Update) + assert.Empty(t, change.Delete) + return nil + } + + actualRule, err := service.CreateAlertRule(context.Background(), u, *rule, models.ProvenanceNone) + require.NoError(t, err) + + require.Len(t, ac.Calls, 2) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupWrite", ac.Calls[1].Method) + + t.Run("it should assign group interval", func(t *testing.T) { + require.Equal(t, groupIntervalSeconds, actualRule.IntervalSeconds) + }) + + t.Run("inserts to database", func(t *testing.T) { + inserts := ruleStore.GetRecordedCommands(func(cmd any) (any, bool) { + a, ok := cmd.([]models.AlertRule) + return a, ok + }) + require.Len(t, inserts, 1) + cmd := inserts[0].([]models.AlertRule) + require.Len(t, cmd, 1) + }) + + t.Run("set correct provenance", func(t *testing.T) { + p, err := provenanceStore.GetProvenance(context.Background(), &actualRule, orgID) + require.NoError(t, err) + require.Equal(t, models.ProvenanceNone, p) + }) + }) + }) + t.Run("it should not insert if not authorized", func(t *testing.T) { + rule := models.AlertRuleGen(models.WithGroupKey(groupKey))() + service, ruleStore, _, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + expectedErr := errors.New("test error") + ac.AuthorizeRuleChangesFunc = func(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { + return expectedErr + } + + _, err := service.CreateAlertRule(context.Background(), u, *rule, models.ProvenanceFile) + require.ErrorIs(t, expectedErr, err) + + require.Len(t, ac.Calls, 2) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupWrite", ac.Calls[1].Method) + + inserts := ruleStore.GetRecordedCommands(func(cmd any) (any, bool) { + a, ok := cmd.([]models.AlertRule) + return a, ok + }) + require.Empty(t, inserts) + }) + }) + ruleService := createAlertRuleService(t) t.Run("should return the created id", func(t *testing.T) { rule, err := ruleService.CreateAlertRule(context.Background(), u, dummyRule("test#1", orgID), models.ProvenanceNone) require.NoError(t, err) @@ -580,6 +792,749 @@ func TestCreateAlertRule(t *testing.T) { }) } +func TestUpdateAlertRule(t *testing.T) { + orgID := rand.Int63() + u := &user.SignedInUser{OrgID: orgID} + groupKey := models.GenerateGroupKey(orgID) + groupIntervalSeconds := int64(30) + rules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithGroupKey(groupKey), models.WithInterval(time.Duration(groupIntervalSeconds)*time.Second))) + groupProvenance := models.ProvenanceAPI + + initServiceWithData := func(t *testing.T) (*AlertRuleService, *fakes.RuleStore, *fakes.FakeProvisioningStore, *fakeRuleAccessControlService) { + service, ruleStore, provenanceStore, ac := initService(t) + ruleStore.Rules = map[int64][]*models.AlertRule{ + orgID: rules, + } + for _, rule := range rules { + require.NoError(t, provenanceStore.SetProvenance(context.Background(), rule, orgID, groupProvenance)) + } + return service, ruleStore, provenanceStore, ac + } + + t.Run("when user can write all rules", func(t *testing.T) { + rule := models.CopyRule(rules[0]) + rule.RuleGroup = rule.RuleGroup + "_new" + rule.Title = rule.Title + "_new" + service, ruleStore, _, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return true, nil + } + + _, err := service.UpdateAlertRule(context.Background(), u, *rule, models.ProvenanceAPI) + require.NoError(t, err) + + require.Len(t, ac.Calls, 1) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + + updates := ruleStore.GetRecordedCommands(func(cmd any) (any, bool) { + a, ok := cmd.([]models.UpdateRule) + return a, ok + }) + require.Len(t, updates, 1) + }) + t.Run("when user cannot write all rules", func(t *testing.T) { + rule := models.CopyRule(rules[0]) + rule.Title = rule.Title + "_new" + + t.Run("it should authorize the change to whole group", func(t *testing.T) { + service, ruleStore, _, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + ac.AuthorizeRuleChangesFunc = func(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { + assert.Equal(t, u, user) + assert.Equal(t, groupKey, change.GroupKey) + assert.Contains(t, change.AffectedGroups, groupKey) + assert.EqualValues(t, rules, change.AffectedGroups[groupKey]) + assert.Len(t, change.Update, 1) + assert.Empty(t, change.New) + assert.Empty(t, change.Delete) + return nil + } + + _, err := service.UpdateAlertRule(context.Background(), u, *rule, groupProvenance) + require.NoError(t, err) + + require.Len(t, ac.Calls, 2) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupWrite", ac.Calls[1].Method) + + updates := ruleStore.GetRecordedCommands(func(cmd any) (any, bool) { + a, ok := cmd.([]models.UpdateRule) + return a, ok + }) + require.Len(t, updates, 1) + }) + t.Run("it should not update if not authorized", func(t *testing.T) { + service, ruleStore, _, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + expectedErr := errors.New("test error") + ac.AuthorizeRuleChangesFunc = func(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { + return expectedErr + } + + _, err := service.UpdateAlertRule(context.Background(), u, *rule, groupProvenance) + require.ErrorIs(t, expectedErr, err) + + require.Len(t, ac.Calls, 2) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupWrite", ac.Calls[1].Method) + + updates := ruleStore.GetRecordedCommands(func(cmd any) (any, bool) { + a, ok := cmd.([]models.UpdateRule) + return a, ok + }) + require.Empty(t, updates) + }) + }) +} + +func TestDeleteAlertRule(t *testing.T) { + orgID := rand.Int63() + u := &user.SignedInUser{OrgID: orgID} + groupKey := models.GenerateGroupKey(orgID) + groupIntervalSeconds := int64(30) + rules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithGroupKey(groupKey), models.WithInterval(time.Duration(groupIntervalSeconds)*time.Second))) + groupProvenance := models.ProvenanceAPI + + initServiceWithData := func(t *testing.T) (*AlertRuleService, *fakes.RuleStore, *fakes.FakeProvisioningStore, *fakeRuleAccessControlService) { + service, ruleStore, provenanceStore, ac := initService(t) + ruleStore.Rules = map[int64][]*models.AlertRule{ + orgID: rules, + } + for _, rule := range rules { + require.NoError(t, provenanceStore.SetProvenance(context.Background(), rule, orgID, groupProvenance)) + } + return service, ruleStore, provenanceStore, ac + } + + t.Run("when user can write all rules", func(t *testing.T) { + rule := rules[0] + service, ruleStore, _, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return true, nil + } + + err := service.DeleteAlertRule(context.Background(), u, rule.UID, groupProvenance) + require.NoError(t, err) + + require.Len(t, ac.Calls, 1) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + + deletes := getDeleteQueries(ruleStore) + require.Len(t, deletes, 1) + }) + t.Run("when user cannot write all rules", func(t *testing.T) { + rule := models.CopyRule(rules[0]) + rule.Title = rule.Title + "_new" + + t.Run("it should authorize the change to whole group", func(t *testing.T) { + rule := rules[0] + service, ruleStore, _, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + ac.AuthorizeRuleChangesFunc = func(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { + assert.Equal(t, u, user) + assert.Equal(t, groupKey, change.GroupKey) + assert.Contains(t, change.AffectedGroups, groupKey) + assert.EqualValues(t, rules, change.AffectedGroups[groupKey]) + assert.Empty(t, change.Update) + assert.Empty(t, change.New) + assert.Len(t, change.Delete, 1) + return nil + } + + err := service.DeleteAlertRule(context.Background(), u, rule.UID, groupProvenance) + require.NoError(t, err) + + require.Len(t, ac.Calls, 2) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupWrite", ac.Calls[1].Method) + + deletes := getDeleteQueries(ruleStore) + require.Len(t, deletes, 1) + }) + t.Run("it should not delete if not authorized", func(t *testing.T) { + service, ruleStore, _, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + expectedErr := errors.New("test error") + ac.AuthorizeRuleChangesFunc = func(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { + return expectedErr + } + + _, err := service.UpdateAlertRule(context.Background(), u, *rule, groupProvenance) + require.ErrorIs(t, expectedErr, err) + + require.Len(t, ac.Calls, 2) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupWrite", ac.Calls[1].Method) + + deletes := getDeleteQueries(ruleStore) + require.Empty(t, deletes) + }) + }) +} + +func TestGetAlertRule(t *testing.T) { + orgID := rand.Int63() + u := &user.SignedInUser{OrgID: orgID} + groupKey := models.GenerateGroupKey(orgID) + rules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithGroupKey(groupKey))) + rule := rules[0] + expectedProvenance := models.ProvenanceAPI + + initServiceWithData := func(t *testing.T) (*AlertRuleService, *fakes.RuleStore, *fakes.FakeProvisioningStore, *fakeRuleAccessControlService) { + service, ruleStore, provenanceStore, ac := initService(t) + ruleStore.Rules = map[int64][]*models.AlertRule{ + orgID: rules, + } + require.NoError(t, provenanceStore.SetProvenance(context.Background(), rule, orgID, expectedProvenance)) + + return service, ruleStore, provenanceStore, ac + } + + t.Run("when user cannot read all rules", func(t *testing.T) { + t.Run("should authorize access to entire group", func(t *testing.T) { + service, _, _, ac := initServiceWithData(t) + + ac.CanReadAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + + expected := errors.New("test") + ac.AuthorizeAccessToRuleGroupFunc = func(ctx context.Context, user identity.Requester, r models.RulesGroup) error { + assert.Equal(t, u, user) + assert.EqualValues(t, rules, r) + return expected + } + + _, _, err := service.GetAlertRule(context.Background(), u, rule.UID) + require.Error(t, err) + require.Equal(t, expected, err) + + assert.Len(t, ac.Calls, 2) + assert.Equal(t, "CanReadAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupRead", ac.Calls[1].Method) + + ac.Calls = nil + ac.AuthorizeAccessToRuleGroupFunc = func(ctx context.Context, user identity.Requester, rules models.RulesGroup) error { + return nil + } + + actual, provenance, err := service.GetAlertRule(context.Background(), u, rule.UID) + require.NoError(t, err) + assert.Equal(t, *rule, actual) + assert.Equal(t, expectedProvenance, provenance) + }) + + t.Run("should return ErrAlertRuleNotFound if rule does not exist", func(t *testing.T) { + service, ruleStore, _, ac := initServiceWithData(t) + ac.CanReadAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + + _, _, err := service.GetAlertRule(context.Background(), u, "no-rule-uid") + require.ErrorIs(t, err, models.ErrAlertRuleNotFound) + + assert.Len(t, ac.Calls, 1) + assert.Equal(t, "CanReadAllRules", ac.Calls[0].Method) + require.IsType(t, ruleStore.RecordedOps[0], models.GetAlertRulesGroupByRuleUIDQuery{}) + query := ruleStore.RecordedOps[0].(models.GetAlertRulesGroupByRuleUIDQuery) + assert.Equal(t, models.GetAlertRulesGroupByRuleUIDQuery{ + OrgID: orgID, + UID: "no-rule-uid", + }, query) + }) + }) + + t.Run("when user can read all rules", func(t *testing.T) { + t.Run("should query rule by UID and do not check any permissions", func(t *testing.T) { + service, ruleStore, _, ac := initServiceWithData(t) + ac.CanReadAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + assert.Equal(t, u, user) + return true, nil + } + + actual, provenance, err := service.GetAlertRule(context.Background(), u, rule.UID) + require.NoError(t, err) + assert.Equal(t, *rule, actual) + assert.Equal(t, expectedProvenance, provenance) + + assert.Len(t, ac.Calls, 1) + assert.Equal(t, "CanReadAllRules", ac.Calls[0].Method) + + require.Len(t, ruleStore.RecordedOps, 1) + require.IsType(t, ruleStore.RecordedOps[0], models.GetAlertRuleByUIDQuery{}) + query := ruleStore.RecordedOps[0].(models.GetAlertRuleByUIDQuery) + assert.Equal(t, models.GetAlertRuleByUIDQuery{ + OrgID: rule.OrgID, + UID: rule.UID, + }, query) + }) + + t.Run("should return ErrAlertRuleNotFound if rule does not exist", func(t *testing.T) { + service, _, _, ac := initServiceWithData(t) + ac.CanReadAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return true, nil + } + + _, _, err := service.GetAlertRule(context.Background(), u, "no-rule-uid") + require.ErrorIs(t, err, models.ErrAlertRuleNotFound) + }) + }) + + t.Run("return error immediately when CanReadAllRules returns error", func(t *testing.T) { + service, ruleStore, _, ac := initServiceWithData(t) + + expectedErr := errors.New("test") + ac.CanReadAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, expectedErr + } + + _, _, err := service.GetAlertRule(context.Background(), u, rule.UID) + require.Error(t, err) + require.Equal(t, expectedErr, err) + + assert.Len(t, ac.Calls, 1) + assert.Equal(t, "CanReadAllRules", ac.Calls[0].Method) + + assert.Empty(t, ruleStore.RecordedOps) + }) +} + +func TestGetRuleGroup(t *testing.T) { + orgID := rand.Int63() + u := &user.SignedInUser{OrgID: orgID} + groupKey := models.GenerateGroupKey(orgID) + intervalSeconds := int64(30) + rules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithGroupKey(groupKey), models.WithInterval(time.Duration(intervalSeconds)*time.Second))) + derefRules := make([]models.AlertRule, 0, len(rules)) + for _, rule := range rules { + derefRules = append(derefRules, *rule) + } + expectedProvenance := models.ProvenanceAPI + + initServiceWithData := func(t *testing.T) (*AlertRuleService, *fakes.RuleStore, *fakes.FakeProvisioningStore, *fakeRuleAccessControlService) { + service, ruleStore, provenanceStore, ac := initService(t) + ruleStore.Rules = map[int64][]*models.AlertRule{ + orgID: rules, + } + for _, rule := range rules { + require.NoError(t, provenanceStore.SetProvenance(context.Background(), rule, orgID, expectedProvenance)) + } + + return service, ruleStore, provenanceStore, ac + } + + t.Run("return ErrAlertRuleGroupNotFound when rule group does not exist", func(t *testing.T) { + service, _, _, ac := initServiceWithData(t) + + _, err := service.GetRuleGroup(context.Background(), u, groupKey.NamespaceUID, "no-rule-group") + require.ErrorIs(t, err, models.ErrAlertRuleGroupNotFound) + require.Empty(t, ac.Calls) + }) + + t.Run("when user cannot read all rules", func(t *testing.T) { + t.Run("it should authorize access to entire group", func(t *testing.T) { + service, _, _, ac := initServiceWithData(t) + + ac.CanReadAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + assert.Equal(t, u, user) + return false, nil + } + expectedErr := errors.New("error") + ac.AuthorizeAccessToRuleGroupFunc = func(ctx context.Context, user identity.Requester, r models.RulesGroup) error { + assert.Equal(t, u, user) + assert.EqualValues(t, rules, r) + return expectedErr + } + + _, err := service.GetRuleGroup(context.Background(), u, groupKey.NamespaceUID, groupKey.RuleGroup) + require.Error(t, err) + require.Equal(t, expectedErr, err) + + assert.Len(t, ac.Calls, 2) + assert.Equal(t, "CanReadAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupRead", ac.Calls[1].Method) + + ac.AuthorizeAccessToRuleGroupFunc = func(ctx context.Context, user identity.Requester, rules models.RulesGroup) error { + return nil + } + + group, err := service.GetRuleGroup(context.Background(), u, groupKey.NamespaceUID, groupKey.RuleGroup) + require.NoError(t, err) + + assert.Equal(t, groupKey.RuleGroup, group.Title) + assert.Equal(t, groupKey.NamespaceUID, group.FolderUID) + assert.Equal(t, intervalSeconds, group.Interval) + assert.Equal(t, derefRules, group.Rules) + }) + }) + + t.Run("when user can read all rules", func(t *testing.T) { + t.Run("it should skip AuthorizeRuleGroupRead", func(t *testing.T) { + service, _, _, ac := initServiceWithData(t) + + ac.CanReadAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + assert.Equal(t, u, user) + return true, nil + } + + group, err := service.GetRuleGroup(context.Background(), u, groupKey.NamespaceUID, groupKey.RuleGroup) + require.NoError(t, err) + + assert.Len(t, ac.Calls, 1) + assert.Equal(t, "CanReadAllRules", ac.Calls[0].Method) + + assert.Equal(t, groupKey.RuleGroup, group.Title) + assert.Equal(t, groupKey.NamespaceUID, group.FolderUID) + assert.Equal(t, intervalSeconds, group.Interval) + assert.Equal(t, derefRules, group.Rules) + }) + }) + + t.Run("return error immediately when CanReadAllRules returns error", func(t *testing.T) { + service, _, _, ac := initServiceWithData(t) + + expectedErr := errors.New("test") + ac.CanReadAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, expectedErr + } + + _, err := service.GetRuleGroup(context.Background(), u, groupKey.NamespaceUID, groupKey.RuleGroup) + require.Error(t, err) + require.Equal(t, expectedErr, err) + + assert.Len(t, ac.Calls, 1) + assert.Equal(t, "CanReadAllRules", ac.Calls[0].Method) + }) +} + +func TestGetAlertRules(t *testing.T) { + orgID := rand.Int63() + u := &user.SignedInUser{OrgID: orgID} + groupKey1 := models.GenerateGroupKey(orgID) + groupKey2 := models.GenerateGroupKey(orgID) + rules1 := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithGroupKey(groupKey1))) + models.RulesGroup(rules1).SortByGroupIndex() + rules2 := models.GenerateAlertRules(4, models.AlertRuleGen(models.WithGroupKey(groupKey2))) + models.RulesGroup(rules2).SortByGroupIndex() + allRules := append(rules1, rules2...) + expectedProvenance := models.ProvenanceAPI + + initServiceWithData := func(t *testing.T) (*AlertRuleService, *fakes.RuleStore, *fakes.FakeProvisioningStore, *fakeRuleAccessControlService) { + service, ruleStore, provenanceStore, ac := initService(t) + ruleStore.Rules = map[int64][]*models.AlertRule{ + orgID: allRules, + } + for _, rule := range rules1 { + require.NoError(t, provenanceStore.SetProvenance(context.Background(), rule, orgID, expectedProvenance)) + } + + return service, ruleStore, provenanceStore, ac + } + + t.Run("return error when CanReadAllRules return error", func(t *testing.T) { + service, _, _, ac := initServiceWithData(t) + expectedErr := errors.New("test") + ac.CanReadAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, expectedErr + } + + _, _, err := service.GetAlertRules(context.Background(), u) + require.ErrorIs(t, err, expectedErr) + }) + + t.Run("when user can read all rules", func(t *testing.T) { + t.Run("should skip AuthorizeRuleGroupRead and return all rules", func(t *testing.T) { + service, _, _, ac := initServiceWithData(t) + ac.CanReadAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return true, nil + } + + rules, provenance, err := service.GetAlertRules(context.Background(), u) + require.NoError(t, err) + require.Equal(t, allRules, rules) + require.Len(t, provenance, len(rules1)) + + assert.Len(t, ac.Calls, 1) + assert.Equal(t, "CanReadAllRules", ac.Calls[0].Method) + }) + }) + + t.Run("when user cannot read all rules", func(t *testing.T) { + t.Run("should group rules and check AuthorizeRuleGroupRead and return only available rules", func(t *testing.T) { + t.Run("should remove group from output if AuthorizeRuleGroupRead returns authorization error", func(t *testing.T) { + service, _, _, ac := initServiceWithData(t) + ac.CanReadAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + ac.AuthorizeAccessToRuleGroupFunc = func(ctx context.Context, user identity.Requester, rules models.RulesGroup) error { + if rules[0].GetGroupKey() == groupKey1 { + return accesscontrol.NewAuthorizationErrorGeneric("test") + } + return nil + } + + rules, provenance, err := service.GetAlertRules(context.Background(), u) + require.NoError(t, err) + + assert.Equal(t, rules2, rules) + assert.Empty(t, provenance) + + assert.Len(t, ac.Calls, 3) + assert.Equal(t, "CanReadAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupRead", ac.Calls[1].Method) + assert.Equal(t, "AuthorizeRuleGroupRead", ac.Calls[2].Method) + + group1 := ac.Calls[1].Args[2].(models.RulesGroup) + group2 := ac.Calls[2].Args[2].(models.RulesGroup) + require.Len(t, append(group1, group2...), len(allRules)) + }) + + t.Run("should immediately exist if AuthorizeRuleGroupRead returns another error", func(t *testing.T) { + service, _, _, ac := initServiceWithData(t) + ac.CanReadAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + expectedErr := errors.New("test") + ac.AuthorizeAccessToRuleGroupFunc = func(ctx context.Context, user identity.Requester, rules models.RulesGroup) error { + return expectedErr + } + + _, _, err := service.GetAlertRules(context.Background(), u) + require.ErrorIs(t, err, expectedErr) + }) + }) + }) +} + +func TestReplaceGroup(t *testing.T) { + orgID := rand.Int63() + u := &user.SignedInUser{OrgID: orgID} + groupKey := models.GenerateGroupKey(orgID) + groupIntervalSeconds := int64(30) + rules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithGroupKey(groupKey), models.WithInterval(time.Duration(groupIntervalSeconds)*time.Second))) + groupProvenance := models.ProvenanceAPI + + initServiceWithData := func(t *testing.T) (*AlertRuleService, *fakes.RuleStore, *fakes.FakeProvisioningStore, *fakeRuleAccessControlService) { + service, ruleStore, provenanceStore, ac := initService(t) + ruleStore.Rules = map[int64][]*models.AlertRule{ + orgID: rules, + } + for _, rule := range rules { + require.NoError(t, provenanceStore.SetProvenance(context.Background(), rule, orgID, groupProvenance)) + } + return service, ruleStore, provenanceStore, ac + } + + t.Run("when user can write all rules", func(t *testing.T) { + group := models.AlertRuleGroup{ + Title: groupKey.RuleGroup, + FolderUID: groupKey.NamespaceUID, + Interval: groupIntervalSeconds, + Provenance: groupProvenance, + } + for _, rule := range rules { + r := models.CopyRule(rule) + r.Title = r.Title + "_new" + group.Rules = append(group.Rules, *r) + } + + service, ruleStore, _, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return true, nil + } + + err := service.ReplaceRuleGroup(context.Background(), u, group, models.ProvenanceAPI) + require.NoError(t, err) + + require.Len(t, ac.Calls, 1) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + + updates := ruleStore.GetRecordedCommands(func(cmd any) (any, bool) { + a, ok := cmd.([]models.UpdateRule) + return a, ok + }) + require.Len(t, updates, 1) + }) + t.Run("when user cannot write all rules", func(t *testing.T) { + group := models.AlertRuleGroup{ + Title: groupKey.RuleGroup, + FolderUID: groupKey.NamespaceUID, + Interval: groupIntervalSeconds, + Provenance: groupProvenance, + } + for _, rule := range rules { + r := models.CopyRule(rule) + r.Title = r.Title + "_new" + group.Rules = append(group.Rules, *r) + } + + t.Run("it should not update if not authorized", func(t *testing.T) { + service, ruleStore, _, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + expectedErr := errors.New("test error") + ac.AuthorizeRuleChangesFunc = func(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { + return expectedErr + } + + err := service.ReplaceRuleGroup(context.Background(), u, group, models.ProvenanceAPI) + require.ErrorIs(t, err, expectedErr) + + require.Len(t, ac.Calls, 2) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupWrite", ac.Calls[1].Method) + + updates := ruleStore.GetRecordedCommands(func(cmd any) (any, bool) { + a, ok := cmd.([]models.UpdateRule) + return a, ok + }) + require.Empty(t, updates) + }) + t.Run("it should update if authorized", func(t *testing.T) { + service, ruleStore, _, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + ac.AuthorizeRuleChangesFunc = func(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { + return nil + } + + err := service.ReplaceRuleGroup(context.Background(), u, group, models.ProvenanceAPI) + require.NoError(t, err) + + require.Len(t, ac.Calls, 2) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupWrite", ac.Calls[1].Method) + + updates := ruleStore.GetRecordedCommands(func(cmd any) (any, bool) { + a, ok := cmd.([]models.UpdateRule) + return a, ok + }) + require.Len(t, updates, 1) + }) + }) +} + +func TestDeleteRuleGroup(t *testing.T) { + orgID := rand.Int63() + u := &user.SignedInUser{OrgID: orgID} + groupKey := models.GenerateGroupKey(orgID) + groupIntervalSeconds := int64(30) + rules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithGroupKey(groupKey), models.WithInterval(time.Duration(groupIntervalSeconds)*time.Second))) + groupProvenance := models.ProvenanceAPI + + initServiceWithData := func(t *testing.T) (*AlertRuleService, *fakes.RuleStore, *fakes.FakeProvisioningStore, *fakeRuleAccessControlService) { + service, ruleStore, provenanceStore, ac := initService(t) + ruleStore.Rules = map[int64][]*models.AlertRule{ + orgID: rules, + } + for _, rule := range rules { + require.NoError(t, provenanceStore.SetProvenance(context.Background(), rule, orgID, groupProvenance)) + } + return service, ruleStore, provenanceStore, ac + } + + t.Run("when user can write all rules", func(t *testing.T) { + service, ruleStore, _, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return true, nil + } + + err := service.DeleteRuleGroup(context.Background(), u, groupKey.NamespaceUID, groupKey.RuleGroup, groupProvenance) + require.NoError(t, err) + + require.Len(t, ac.Calls, 1) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + + deletes := getDeleteQueries(ruleStore) + require.Len(t, deletes, 1) + }) + t.Run("when user cannot write all rules", func(t *testing.T) { + t.Run("it should not update if not authorized", func(t *testing.T) { + service, ruleStore, _, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + expectedErr := errors.New("test error") + ac.AuthorizeRuleChangesFunc = func(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { + return expectedErr + } + + err := service.DeleteRuleGroup(context.Background(), u, groupKey.NamespaceUID, groupKey.RuleGroup, groupProvenance) + require.ErrorIs(t, err, expectedErr) + + require.Len(t, ac.Calls, 2) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupWrite", ac.Calls[1].Method) + + deletes := getDeleteQueries(ruleStore) + require.Empty(t, deletes) + }) + t.Run("it should update if authorized", func(t *testing.T) { + service, ruleStore, _, ac := initServiceWithData(t) + + ac.CanWriteAllRulesFunc = func(ctx context.Context, user identity.Requester) (bool, error) { + return false, nil + } + ac.AuthorizeRuleChangesFunc = func(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { + assert.Equal(t, u, user) + assert.Equal(t, groupKey, change.GroupKey) + assert.Contains(t, change.AffectedGroups, groupKey) + assert.EqualValues(t, rules, change.AffectedGroups[groupKey]) + assert.Empty(t, change.Update) + assert.Empty(t, change.New) + assert.Len(t, change.Delete, len(rules)) + return nil + } + + err := service.DeleteRuleGroup(context.Background(), u, groupKey.NamespaceUID, groupKey.RuleGroup, groupProvenance) + require.NoError(t, err) + + require.Len(t, ac.Calls, 2) + assert.Equal(t, "CanWriteAllRules", ac.Calls[0].Method) + assert.Equal(t, "AuthorizeRuleGroupWrite", ac.Calls[1].Method) + + deletes := getDeleteQueries(ruleStore) + require.Len(t, deletes, 1) + }) + }) +} + +func getDeleteQueries(ruleStore *fakes.RuleStore) []fakes.GenericRecordedQuery { + generic := ruleStore.GetRecordedCommands(func(cmd any) (any, bool) { + a, ok := cmd.(fakes.GenericRecordedQuery) + if !ok || a.Name != "DeleteAlertRulesByUID" { + return nil, false + } + return a, ok + }) + result := make([]fakes.GenericRecordedQuery, 0, len(generic)) + for _, g := range generic { + result = append(result, g.(fakes.GenericRecordedQuery)) + } + return result +} + func createAlertRuleService(t *testing.T) AlertRuleService { t.Helper() sqlStore := db.InitTestDB(t) @@ -607,6 +1562,8 @@ func createAlertRuleService(t *testing.T) AlertRuleService { baseIntervalSeconds: 10, defaultIntervalSeconds: 60, folderService: folderService, + authz: &fakeRuleAccessControlService{}, + nsValidatorProvider: &NotificationSettingsValidatorProviderFake{}, } } @@ -650,3 +1607,30 @@ func createDummyGroup(title string, orgID int64) models.AlertRuleGroup { }, } } + +func initService(t *testing.T) (*AlertRuleService, *fakes.RuleStore, *fakes.FakeProvisioningStore, *fakeRuleAccessControlService) { + t.Helper() + + ac := &fakeRuleAccessControlService{} + ruleStore := fakes.NewRuleStore(t) + provenanceStore := fakes.NewFakeProvisioningStore() + folderService := foldertest.NewFakeService() + + quotas := MockQuotaChecker{} + quotas.EXPECT().LimitOK() + + service := &AlertRuleService{ + folderService: folderService, + ruleStore: ruleStore, + provenanceStore: provenanceStore, + quotas: "as, + xact: newNopTransactionManager(), + log: log.New("testing"), + baseIntervalSeconds: 10, + defaultIntervalSeconds: 60, + authz: ac, + nsValidatorProvider: &NotificationSettingsValidatorProviderFake{}, + } + + return service, ruleStore, provenanceStore, ac +} diff --git a/pkg/services/ngalert/provisioning/testing.go b/pkg/services/ngalert/provisioning/testing.go index 921da0498a8..62e6128319e 100644 --- a/pkg/services/ngalert/provisioning/testing.go +++ b/pkg/services/ngalert/provisioning/testing.go @@ -2,13 +2,16 @@ package provisioning import ( "context" + "sync" "testing" "github.com/stretchr/testify/assert" mock "github.com/stretchr/testify/mock" + "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/ngalert/models" "github.com/grafana/grafana/pkg/services/ngalert/notifier" + "github.com/grafana/grafana/pkg/services/ngalert/store" ) const defaultAlertmanagerConfigJSON = ` @@ -147,3 +150,61 @@ type NotificationSettingsValidatorProviderFake struct { func (n *NotificationSettingsValidatorProviderFake) Validator(ctx context.Context, orgID int64) (notifier.NotificationSettingsValidator, error) { return notifier.NoValidation{}, nil } + +type call struct { + Method string + Args []interface{} +} + +type fakeRuleAccessControlService struct { + mu sync.Mutex + Calls []call + AuthorizeAccessToRuleGroupFunc func(ctx context.Context, user identity.Requester, rules models.RulesGroup) error + AuthorizeRuleChangesFunc func(ctx context.Context, user identity.Requester, change *store.GroupDelta) error + CanReadAllRulesFunc func(ctx context.Context, user identity.Requester) (bool, error) + CanWriteAllRulesFunc func(ctx context.Context, user identity.Requester) (bool, error) +} + +func (s *fakeRuleAccessControlService) RecordCall(method string, args ...interface{}) { + s.mu.Lock() + defer s.mu.Unlock() + + call := call{ + Method: method, + Args: args, + } + + s.Calls = append(s.Calls, call) +} + +func (s *fakeRuleAccessControlService) AuthorizeRuleGroupRead(ctx context.Context, user identity.Requester, rules models.RulesGroup) error { + s.RecordCall("AuthorizeRuleGroupRead", ctx, user, rules) + if s.AuthorizeAccessToRuleGroupFunc != nil { + return s.AuthorizeAccessToRuleGroupFunc(ctx, user, rules) + } + return nil +} + +func (s *fakeRuleAccessControlService) AuthorizeRuleGroupWrite(ctx context.Context, user identity.Requester, change *store.GroupDelta) error { + s.RecordCall("AuthorizeRuleGroupWrite", ctx, user, change) + if s.AuthorizeRuleChangesFunc != nil { + return s.AuthorizeRuleChangesFunc(ctx, user, change) + } + return nil +} + +func (s *fakeRuleAccessControlService) CanReadAllRules(ctx context.Context, user identity.Requester) (bool, error) { + s.RecordCall("CanReadAllRules", ctx, user) + if s.CanReadAllRulesFunc != nil { + return s.CanReadAllRulesFunc(ctx, user) + } + return false, nil +} + +func (s *fakeRuleAccessControlService) CanWriteAllRules(ctx context.Context, user identity.Requester) (bool, error) { + s.RecordCall("CanWriteAllRules", ctx, user) + if s.CanWriteAllRulesFunc != nil { + return s.CanWriteAllRulesFunc(ctx, user) + } + return false, nil +} diff --git a/pkg/services/ngalert/store/deltas.go b/pkg/services/ngalert/store/deltas.go index 27fe5d86289..2b24bf52ea1 100644 --- a/pkg/services/ngalert/store/deltas.go +++ b/pkg/services/ngalert/store/deltas.go @@ -60,7 +60,6 @@ type RuleReader interface { // CalculateChanges calculates the difference between rules in the group in the database and the submitted rules. If a submitted rule has UID it tries to find it in the database (in other groups). // returns a list of rules that need to be added, updated and deleted. Deleted considered rules in the database that belong to the group but do not exist in the list of submitted rules. func CalculateChanges(ctx context.Context, ruleReader RuleReader, groupKey models.AlertRuleGroupKey, submittedRules []*models.AlertRuleWithOptionals) (*GroupDelta, error) { - affectedGroups := make(map[models.AlertRuleGroupKey]models.RulesGroup) q := &models.ListAlertRulesQuery{ OrgID: groupKey.OrgID, NamespaceUIDs: []string{groupKey.NamespaceUID}, @@ -70,6 +69,13 @@ func CalculateChanges(ctx context.Context, ruleReader RuleReader, groupKey model if err != nil { return nil, fmt.Errorf("failed to query database for rules in the group %s: %w", groupKey, err) } + + return calculateChanges(ctx, ruleReader, groupKey, existingGroupRules, submittedRules) +} + +func calculateChanges(ctx context.Context, ruleReader RuleReader, groupKey models.AlertRuleGroupKey, existingGroupRules []*models.AlertRule, submittedRules []*models.AlertRuleWithOptionals) (*GroupDelta, error) { + affectedGroups := make(map[models.AlertRuleGroupKey]models.RulesGroup) + if len(existingGroupRules) > 0 { affectedGroups[groupKey] = existingGroupRules } @@ -191,3 +197,114 @@ func UpdateCalculatedRuleFields(ch *GroupDelta) *GroupDelta { Delete: ch.Delete, } } + +// CalculateRuleUpdate calculates GroupDelta for rule update operation +func CalculateRuleUpdate(ctx context.Context, ruleReader RuleReader, rule *models.AlertRuleWithOptionals) (*GroupDelta, error) { + q := &models.ListAlertRulesQuery{ + OrgID: rule.OrgID, + NamespaceUIDs: []string{rule.NamespaceUID}, + RuleGroup: rule.RuleGroup, + } + existingGroupRules, err := ruleReader.ListAlertRules(ctx, q) + if err != nil { + return nil, err + } + + newGroup := make([]*models.AlertRuleWithOptionals, 0, len(existingGroupRules)+1) + added := false + for _, alertRule := range existingGroupRules { + if alertRule.GetKey() == rule.GetKey() { + newGroup = append(newGroup, rule) + added = true + } + newGroup = append(newGroup, &models.AlertRuleWithOptionals{AlertRule: *alertRule}) + } + if !added { + newGroup = append(newGroup, rule) + } + + return calculateChanges(ctx, ruleReader, rule.GetGroupKey(), existingGroupRules, newGroup) +} + +// CalculateRuleGroupDelete calculates GroupDelta that reflects an operation of removing entire group +func CalculateRuleGroupDelete(ctx context.Context, ruleReader RuleReader, groupKey models.AlertRuleGroupKey) (*GroupDelta, error) { + // List all rules in the group. + q := models.ListAlertRulesQuery{ + OrgID: groupKey.OrgID, + NamespaceUIDs: []string{groupKey.NamespaceUID}, + RuleGroup: groupKey.RuleGroup, + } + ruleList, err := ruleReader.ListAlertRules(ctx, &q) + if err != nil { + return nil, err + } + if len(ruleList) == 0 { + return nil, models.ErrAlertRuleGroupNotFound.Errorf("") + } + + delta := &GroupDelta{ + GroupKey: groupKey, + Delete: ruleList, + AffectedGroups: map[models.AlertRuleGroupKey]models.RulesGroup{ + groupKey: ruleList, + }, + } + return delta, nil +} + +// CalculateRuleDelete calculates GroupDelta that reflects an operation of removing a rule from the group. +func CalculateRuleDelete(ctx context.Context, ruleReader RuleReader, ruleKey models.AlertRuleKey) (*GroupDelta, error) { + q := &models.GetAlertRulesGroupByRuleUIDQuery{ + UID: ruleKey.UID, + OrgID: ruleKey.OrgID, + } + group, err := ruleReader.GetAlertRulesGroupByRuleUID(ctx, q) + if err != nil { + return nil, err + } + var toDelete *models.AlertRule + for _, rule := range group { + if rule.GetKey() == ruleKey { + toDelete = rule + break + } + } + if toDelete == nil { // should not happen if rule exists. + return nil, models.ErrAlertRuleNotFound + } + groupKey := group[0].GetGroupKey() + delta := &GroupDelta{ + GroupKey: groupKey, + Delete: []*models.AlertRule{toDelete}, + AffectedGroups: map[models.AlertRuleGroupKey]models.RulesGroup{ + groupKey: group, + }, + } + return delta, nil +} + +// CalculateRuleCreate calculates GroupDelta that reflects an operation of adding a new rule to the group. +func CalculateRuleCreate(ctx context.Context, ruleReader RuleReader, rule *models.AlertRule) (*GroupDelta, error) { + q := &models.ListAlertRulesQuery{ + OrgID: rule.OrgID, + NamespaceUIDs: []string{rule.NamespaceUID}, + RuleGroup: rule.RuleGroup, + } + group, err := ruleReader.ListAlertRules(ctx, q) + if err != nil { + return nil, err + } + + delta := &GroupDelta{ + GroupKey: rule.GetGroupKey(), + AffectedGroups: make(map[models.AlertRuleGroupKey]models.RulesGroup), + New: []*models.AlertRule{rule}, + Update: nil, + Delete: nil, + } + + if len(group) > 0 { + delta.AffectedGroups[rule.GetGroupKey()] = group + } + return delta, nil +} diff --git a/pkg/services/ngalert/store/deltas_test.go b/pkg/services/ngalert/store/deltas_test.go index 8d29c5db099..5a39e4cc787 100644 --- a/pkg/services/ngalert/store/deltas_test.go +++ b/pkg/services/ngalert/store/deltas_test.go @@ -416,6 +416,174 @@ func TestCalculateAutomaticChanges(t *testing.T) { }) } +func TestCalculateRuleGroupDelete(t *testing.T) { + fakeStore := fakes.NewRuleStore(t) + groupKey := models.GenerateGroupKey(1) + otherRules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithOrgID(groupKey.OrgID), models.WithNamespaceUIDNotIn(groupKey.NamespaceUID))) + fakeStore.Rules[groupKey.OrgID] = otherRules + + t.Run("NotFound when group does not exist", func(t *testing.T) { + delta, err := CalculateRuleGroupDelete(context.Background(), fakeStore, groupKey) + require.ErrorIs(t, err, models.ErrAlertRuleGroupNotFound, "expected ErrAlertRuleGroupNotFound but got %s", err) + require.Nil(t, delta) + }) + + t.Run("set AffectedGroups when a rule refers to an existing group", func(t *testing.T) { + groupRules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithGroupKey(groupKey))) + fakeStore.Rules[groupKey.OrgID] = append(fakeStore.Rules[groupKey.OrgID], groupRules...) + + delta, err := CalculateRuleGroupDelete(context.Background(), fakeStore, groupKey) + require.NoError(t, err) + + assert.Equal(t, groupKey, delta.GroupKey) + assert.EqualValues(t, groupRules, delta.Delete) + + assert.Empty(t, delta.Update) + assert.Empty(t, delta.New) + + assert.Len(t, delta.AffectedGroups, 1) + assert.Equal(t, models.RulesGroup(groupRules), delta.AffectedGroups[delta.GroupKey]) + }) +} + +func TestCalculateRuleDelete(t *testing.T) { + fakeStore := fakes.NewRuleStore(t) + rule := models.AlertRuleGen()() + otherRules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithOrgID(rule.OrgID), models.WithNamespaceUIDNotIn(rule.NamespaceUID))) + fakeStore.Rules[rule.OrgID] = otherRules + + t.Run("nil when a rule does not exist", func(t *testing.T) { + delta, err := CalculateRuleDelete(context.Background(), fakeStore, rule.GetKey()) + require.ErrorIs(t, err, models.ErrAlertRuleNotFound) + require.Nil(t, delta) + }) + + t.Run("set AffectedGroups when a rule refers to an existing group", func(t *testing.T) { + groupRules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithGroupKey(rule.GetGroupKey()))) + groupRules = append(groupRules, rule) + fakeStore.Rules[rule.OrgID] = append(fakeStore.Rules[rule.OrgID], groupRules...) + + delta, err := CalculateRuleDelete(context.Background(), fakeStore, rule.GetKey()) + require.NoError(t, err) + + assert.Equal(t, rule.GetGroupKey(), delta.GroupKey) + assert.Len(t, delta.Delete, 1) + assert.Equal(t, rule, delta.Delete[0]) + + assert.Empty(t, delta.Update) + assert.Empty(t, delta.New) + + assert.Len(t, delta.AffectedGroups, 1) + assert.Equal(t, models.RulesGroup(groupRules), delta.AffectedGroups[delta.GroupKey]) + }) +} + +func TestCalculateRuleUpdate(t *testing.T) { + fakeStore := fakes.NewRuleStore(t) + rule := models.AlertRuleGen()() + otherRules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithOrgID(rule.OrgID), models.WithNamespaceUIDNotIn(rule.NamespaceUID))) + groupRules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithGroupKey(rule.GetGroupKey()))) + groupRules = append(groupRules, rule) + fakeStore.Rules[rule.OrgID] = append(otherRules, groupRules...) + + t.Run("when a rule is not changed", func(t *testing.T) { + cp := models.CopyRule(rule) + delta, err := CalculateRuleUpdate(context.Background(), fakeStore, &models.AlertRuleWithOptionals{ + AlertRule: *cp, + HasPause: false, + }) + require.NoError(t, err) + require.True(t, delta.IsEmpty()) + }) + + t.Run("when a rule is updated", func(t *testing.T) { + cp := models.CopyRule(rule) + cp.For = cp.For + 1*time.Minute // cause any diff + + delta, err := CalculateRuleUpdate(context.Background(), fakeStore, &models.AlertRuleWithOptionals{ + AlertRule: *cp, + HasPause: false, + }) + require.NoError(t, err) + + assert.Equal(t, rule.GetGroupKey(), delta.GroupKey) + assert.Empty(t, delta.New) + assert.Empty(t, delta.Delete) + assert.Len(t, delta.Update, 1) + assert.Equal(t, cp, delta.Update[0].New) + assert.Equal(t, rule, delta.Update[0].Existing) + + require.Contains(t, delta.AffectedGroups, delta.GroupKey) + assert.Equal(t, models.RulesGroup(groupRules), delta.AffectedGroups[delta.GroupKey]) + }) + + t.Run("when a rule is moved between groups", func(t *testing.T) { + sourceGroupKey := rule.GetGroupKey() + targetGroupKey := models.GenerateGroupKey(rule.OrgID) + targetGroup := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithGroupKey(targetGroupKey))) + fakeStore.Rules[rule.OrgID] = append(fakeStore.Rules[rule.OrgID], targetGroup...) + + cp := models.CopyRule(rule) + cp.NamespaceUID = targetGroupKey.NamespaceUID + cp.RuleGroup = targetGroupKey.RuleGroup + + delta, err := CalculateRuleUpdate(context.Background(), fakeStore, &models.AlertRuleWithOptionals{ + AlertRule: *cp, + HasPause: false, + }) + require.NoError(t, err) + + assert.Equal(t, targetGroupKey, delta.GroupKey) + assert.Empty(t, delta.New) + assert.Empty(t, delta.Delete) + assert.Len(t, delta.Update, 1) + assert.Equal(t, cp, delta.Update[0].New) + assert.Equal(t, rule, delta.Update[0].Existing) + + require.Contains(t, delta.AffectedGroups, sourceGroupKey) + assert.Equal(t, models.RulesGroup(groupRules), delta.AffectedGroups[sourceGroupKey]) + require.Contains(t, delta.AffectedGroups, targetGroupKey) + assert.Equal(t, models.RulesGroup(targetGroup), delta.AffectedGroups[targetGroupKey]) + }) +} + +func TestCalculateRuleCreate(t *testing.T) { + t.Run("when a rule refers to a new group", func(t *testing.T) { + fakeStore := fakes.NewRuleStore(t) + rule := models.AlertRuleGen()() + + delta, err := CalculateRuleCreate(context.Background(), fakeStore, rule) + require.NoError(t, err) + + assert.Equal(t, rule.GetGroupKey(), delta.GroupKey) + assert.Empty(t, delta.AffectedGroups) + assert.Empty(t, delta.Delete) + assert.Empty(t, delta.Update) + assert.Len(t, delta.New, 1) + assert.Equal(t, rule, delta.New[0]) + }) + + t.Run("when a rule refers to an existing group", func(t *testing.T) { + fakeStore := fakes.NewRuleStore(t) + rule := models.AlertRuleGen()() + + groupRules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithGroupKey(rule.GetGroupKey()))) + otherRules := models.GenerateAlertRules(3, models.AlertRuleGen(models.WithOrgID(rule.OrgID), models.WithNamespaceUIDNotIn(rule.NamespaceUID))) + fakeStore.Rules[rule.OrgID] = append(groupRules, otherRules...) + + delta, err := CalculateRuleCreate(context.Background(), fakeStore, rule) + require.NoError(t, err) + + assert.Equal(t, rule.GetGroupKey(), delta.GroupKey) + assert.Len(t, delta.AffectedGroups, 1) + assert.Equal(t, models.RulesGroup(groupRules), delta.AffectedGroups[delta.GroupKey]) + assert.Empty(t, delta.Delete) + assert.Empty(t, delta.Update) + assert.Len(t, delta.New, 1) + assert.Equal(t, rule, delta.New[0]) + }) +} + // simulateSubmitted resets some fields of the structure that are not populated by API model to model conversion func simulateSubmitted(rule *models.AlertRule) { rule.ID = 0 diff --git a/pkg/services/ngalert/tests/fakes/rules.go b/pkg/services/ngalert/tests/fakes/rules.go index 9185047c989..1714011d0bf 100644 --- a/pkg/services/ngalert/tests/fakes/rules.go +++ b/pkg/services/ngalert/tests/fakes/rules.go @@ -2,7 +2,6 @@ package fakes import ( "context" - "errors" "fmt" "math/rand" "sync" @@ -283,6 +282,12 @@ func (f *RuleStore) InsertAlertRules(_ context.Context, q []models.AlertRule) ([ defer f.mtx.Unlock() f.RecordedOps = append(f.RecordedOps, q) ids := make([]models.AlertRuleKeyWithId, 0, len(q)) + for _, rule := range q { + ids = append(ids, models.AlertRuleKeyWithId{ + AlertRuleKey: rule.GetKey(), + ID: rand.Int63(), + }) + } if err := f.Hook(q); err != nil { return ids, err } @@ -296,12 +301,16 @@ func (f *RuleStore) InTransaction(ctx context.Context, fn func(c context.Context func (f *RuleStore) GetRuleGroupInterval(ctx context.Context, orgID int64, namespaceUID string, ruleGroup string) (int64, error) { f.mtx.Lock() defer f.mtx.Unlock() + f.RecordedOps = append(f.RecordedOps, GenericRecordedQuery{ + Name: "GetRuleGroupInterval", + Params: []any{orgID, namespaceUID, ruleGroup}, + }) for _, rule := range f.Rules[orgID] { if rule.RuleGroup == ruleGroup && rule.NamespaceUID == namespaceUID { return rule.IntervalSeconds, nil } } - return 0, errors.New("rule group not found") + return 0, models.ErrAlertRuleGroupNotFound.Errorf("") } func (f *RuleStore) UpdateRuleGroup(ctx context.Context, orgID int64, namespaceUID string, ruleGroup string, interval int64) error { diff --git a/pkg/services/provisioning/alerting/rules_provisioner.go b/pkg/services/provisioning/alerting/rules_provisioner.go index 205018c7ef9..f975d6f729b 100644 --- a/pkg/services/provisioning/alerting/rules_provisioner.go +++ b/pkg/services/provisioning/alerting/rules_provisioner.go @@ -7,12 +7,13 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/metrics" + "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/dashboards" "github.com/grafana/grafana/pkg/services/folder" alert_models "github.com/grafana/grafana/pkg/services/ngalert/models" "github.com/grafana/grafana/pkg/services/ngalert/provisioning" - "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/util" ) @@ -132,7 +133,16 @@ func (prov *defaultAlertRuleProvisioner) getOrCreateFolderUID( return cmdResult.UID, nil } -// UserID is 0 to use org quota var provisionerUser = func(orgID int64) identity.Requester { - return &user.SignedInUser{UserID: 0, Login: "alert_provisioner", OrgID: orgID} + // this user has 0 ID and therefore, organization wide quota will be applied + return accesscontrol.BackgroundUser( + "alert_provisioner", + orgID, + org.RoleAdmin, + []accesscontrol.Permission{ + {Action: dashboards.ActionFoldersRead, Scope: dashboards.ScopeFoldersAll}, + {Action: accesscontrol.ActionAlertingProvisioningReadSecrets, Scope: dashboards.ScopeFoldersAll}, + {Action: accesscontrol.ActionAlertingProvisioningWrite, Scope: dashboards.ScopeFoldersAll}, + }, + ) } diff --git a/pkg/services/provisioning/provisioning.go b/pkg/services/provisioning/provisioning.go index 44916f8c394..14e3649a1f0 100644 --- a/pkg/services/provisioning/provisioning.go +++ b/pkg/services/provisioning/provisioning.go @@ -15,6 +15,7 @@ import ( datasourceservice "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/encryption" "github.com/grafana/grafana/pkg/services/folder" + alertingauthz "github.com/grafana/grafana/pkg/services/ngalert/accesscontrol" "github.com/grafana/grafana/pkg/services/ngalert/notifier" "github.com/grafana/grafana/pkg/services/ngalert/provisioning" "github.com/grafana/grafana/pkg/services/ngalert/store" @@ -255,7 +256,10 @@ func (ps *ProvisioningServiceImpl) ProvisionAlerting(ctx context.Context) error int64(ps.Cfg.UnifiedAlerting.DefaultRuleEvaluationInterval.Seconds()), int64(ps.Cfg.UnifiedAlerting.BaseInterval.Seconds()), ps.Cfg.UnifiedAlerting.RulesPerRuleGroupLimit, - ps.log, notifier.NewCachedNotificationSettingsValidationService(&st)) + ps.log, + notifier.NewCachedNotificationSettingsValidationService(&st), + alertingauthz.NewRuleService(ps.ac), + ) receiverSvc := notifier.NewReceiverService(ps.ac, &st, st, ps.secretService, ps.SQLStore, ps.log) contactPointService := provisioning.NewContactPointService(&st, ps.secretService, st, ps.SQLStore, receiverSvc, ps.log, &st) diff --git a/public/api-merged.json b/public/api-merged.json index 59d67331f05..038827f26b1 100644 --- a/public/api-merged.json +++ b/public/api-merged.json @@ -10198,7 +10198,7 @@ "tags": [ "provisioning" ], - "summary": "Update the interval of a rule group.", + "summary": "Create or update alert rule group.", "operationId": "RoutePutAlertRuleGroup", "parameters": [ { diff --git a/public/openapi3.json b/public/openapi3.json index 8f7fa8c41e9..3c289c32df1 100644 --- a/public/openapi3.json +++ b/public/openapi3.json @@ -23342,7 +23342,7 @@ "description": "ValidationError" } }, - "summary": "Update the interval of a rule group.", + "summary": "Create or update alert rule group.", "tags": [ "provisioning" ]