diff --git a/pkg/infra/serverlock/serverlock.go b/pkg/infra/serverlock/serverlock.go index ab69eaaf941..9dffb1c0871 100644 --- a/pkg/infra/serverlock/serverlock.go +++ b/pkg/infra/serverlock/serverlock.go @@ -2,6 +2,8 @@ package serverlock import ( "context" + "errors" + "math/rand" "time" "go.opentelemetry.io/otel/attribute" @@ -164,6 +166,74 @@ func (sl *ServerLockService) LockExecuteAndRelease(ctx context.Context, actionNa return nil } +// RetryOpt is a callback function called after each failed lock acquisition try. +// It gets the number of tries passed as an arg. +type RetryOpt func(int) error + +type LockTimeConfig struct { + MaxInterval time.Duration // Duration after which we consider a lock to be dead and overtake it. Make sure this is big enough so that a server cannot acquire the lock while another server is processing. + MinWait time.Duration // Minimum time to wait before retrying to acquire the lock. + MaxWait time.Duration // Maximum time to wait before retrying to acquire the lock. +} + +// LockExecuteAndReleaseWithRetries mimics LockExecuteAndRelease but waits for the lock to be released if it is already taken. +func (sl *ServerLockService) LockExecuteAndReleaseWithRetries(ctx context.Context, actionName string, timeConfig LockTimeConfig, fn func(ctx context.Context), retryOpts ...RetryOpt) error { + start := time.Now() + ctx, span := sl.tracer.Start(ctx, "ServerLockService.LockExecuteAndReleaseWithRetries") + span.SetAttributes(attribute.String("serverlock.actionName", actionName)) + defer span.End() + + ctxLogger := sl.log.FromContext(ctx) + ctxLogger.Debug("Start LockExecuteAndReleaseWithRetries", "actionName", actionName) + + lockChecks := 0 + + for { + lockChecks++ + err := sl.acquireForRelease(ctx, actionName, timeConfig.MaxInterval) + // could not get the lock + if err != nil { + var lockedErr *ServerLockExistsError + if errors.As(err, &lockedErr) { + // if the lock is already taken, wait and try again + if lockChecks == 1 { // only warn on first lock check + ctxLogger.Warn("another instance has the lock, waiting for it to be released", "actionName", actionName) + } + + for _, op := range retryOpts { + if err := op(lockChecks); err != nil { + return err + } + } + + time.Sleep(lockWait(timeConfig.MinWait, timeConfig.MaxWait)) + continue + } + span.RecordError(err) + return err + } + + // lock was acquired and released successfully + break + } + + sl.executeFunc(ctx, actionName, fn) + + if err := sl.releaseLock(ctx, actionName); err != nil { + span.RecordError(err) + ctxLogger.Error("Failed to release the lock", "error", err) + } + + ctxLogger.Debug("LockExecuteAndReleaseWithRetries finished", "actionName", actionName, "duration", time.Since(start)) + + return nil +} + +// generate a random duration between minWait and maxWait to ensure instances unlock gradually +func lockWait(minWait time.Duration, maxWait time.Duration) time.Duration { + return time.Duration(rand.Int63n(int64(maxWait-minWait)) + int64(minWait)) +} + // acquireForRelease will check if the lock is already on the database, if it is, will check with maxInterval if it is // timeouted. Returns nil error if the lock was acquired correctly func (sl *ServerLockService) acquireForRelease(ctx context.Context, actionName string, maxInterval time.Duration) error { diff --git a/pkg/infra/serverlock/serverlock_integration_test.go b/pkg/infra/serverlock/serverlock_integration_test.go index 0cfbba25a03..9e1d43d70c0 100644 --- a/pkg/infra/serverlock/serverlock_integration_test.go +++ b/pkg/infra/serverlock/serverlock_integration_test.go @@ -2,6 +2,7 @@ package serverlock import ( "context" + "sync" "testing" "time" @@ -19,11 +20,11 @@ func TestIntegrationServerLock_LockAndExecute(t *testing.T) { atInterval := time.Hour ctx := context.Background() - //this time `fn` should be executed + // this time `fn` should be executed require.Nil(t, sl.LockAndExecute(ctx, "test-operation", atInterval, fn)) require.Equal(t, 1, counter) - //this should not execute `fn` + // this should not execute `fn` require.Nil(t, sl.LockAndExecute(ctx, "test-operation", atInterval, fn)) require.Nil(t, sl.LockAndExecute(ctx, "test-operation", atInterval, fn)) require.Equal(t, 1, counter) @@ -62,3 +63,65 @@ func TestIntegrationServerLock_LockExecuteAndRelease(t *testing.T) { require.Equal(t, 4, counter) } + +func TestIntegrationServerLock_LockExecuteAndReleaseWithRetries(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + sl := createTestableServerLock(t) + + retries := 0 + expectedRetries := 10 + funcRuns := 0 + fn := func(context.Context) { + funcRuns++ + } + lockTimeConfig := LockTimeConfig{ + MaxInterval: time.Hour, + MinWait: 0 * time.Millisecond, + MaxWait: 1 * time.Millisecond, + } + ctx := context.Background() + actionName := "test-operation" + + // Acquire lock so that when `LockExecuteAndReleaseWithRetries` runs, it is forced + // to retry + err := sl.acquireForRelease(ctx, actionName, lockTimeConfig.MaxInterval) + require.NoError(t, err) + + wgRetries := sync.WaitGroup{} + wgRetries.Add(expectedRetries) + wgRelease := sync.WaitGroup{} + wgRelease.Add(1) + wgCompleted := sync.WaitGroup{} + wgCompleted.Add(1) + + onRetryFn := func(int) error { + retries++ + wgRetries.Done() + if retries == expectedRetries { + // When we reach `expectedRetries`, wait for the lock to be released + // to guarantee that next try will succeed + wgRelease.Wait() + } + return nil + } + + go func() { + err := sl.LockExecuteAndReleaseWithRetries(ctx, actionName, lockTimeConfig, fn, onRetryFn) + require.NoError(t, err) + wgCompleted.Done() + }() + + // Wait to release the lock until `LockExecuteAndReleaseWithRetries` has retried `expectedRetries` times. + wgRetries.Wait() + err = sl.releaseLock(ctx, actionName) + require.NoError(t, err) + wgRelease.Done() + + // `LockExecuteAndReleaseWithRetries` has run completely. + // Check that it had to retry because the lock was already taken. + wgCompleted.Wait() + require.Equal(t, expectedRetries, retries) + require.Equal(t, 1, funcRuns) +} diff --git a/pkg/services/extsvcauth/registry/service.go b/pkg/services/extsvcauth/registry/service.go index e5370e4df3a..1c07c4011b9 100644 --- a/pkg/services/extsvcauth/registry/service.go +++ b/pkg/services/extsvcauth/registry/service.go @@ -3,8 +3,10 @@ package registry import ( "context" "sync" + "time" "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/infra/serverlock" "github.com/grafana/grafana/pkg/infra/slugify" "github.com/grafana/grafana/pkg/services/extsvcauth" "github.com/grafana/grafana/pkg/services/extsvcauth/oauthserver/oasimpl" @@ -14,6 +16,12 @@ import ( var _ extsvcauth.ExternalServiceRegistry = &Registry{} +var lockTimeConfig = serverlock.LockTimeConfig{ + MaxInterval: 2 * time.Minute, + MinWait: 1 * time.Second, + MaxWait: 5 * time.Second, +} + type Registry struct { features featuremgmt.FeatureToggles logger log.Logger @@ -22,9 +30,10 @@ type Registry struct { extSvcProviders map[string]extsvcauth.AuthProvider lock sync.Mutex + serverLock *serverlock.ServerLockService } -func ProvideExtSvcRegistry(oauthServer *oasimpl.OAuth2ServiceImpl, saSvc *extsvcaccounts.ExtSvcAccountsService, features featuremgmt.FeatureToggles) *Registry { +func ProvideExtSvcRegistry(oauthServer *oasimpl.OAuth2ServiceImpl, saSvc *extsvcaccounts.ExtSvcAccountsService, serverLock *serverlock.ServerLockService, features featuremgmt.FeatureToggles) *Registry { return &Registry{ extSvcProviders: map[string]extsvcauth.AuthProvider{}, features: features, @@ -32,6 +41,7 @@ func ProvideExtSvcRegistry(oauthServer *oasimpl.OAuth2ServiceImpl, saSvc *extsvc logger: log.New("extsvcauth.registry"), oauthReg: oauthServer, saReg: saSvc, + serverLock: serverLock, } } @@ -104,7 +114,7 @@ func (r *Registry) RemoveExternalService(ctx context.Context, name string) error r.logger.Debug("Routing External Service removal to the OAuth2Server", "service", name) return r.oauthReg.RemoveExternalService(ctx, name) default: - return extsvcauth.ErrUnknownProvider.Errorf("unknow provider '%v'", provider) + return extsvcauth.ErrUnknownProvider.Errorf("unknown provider '%v'", provider) } } @@ -112,29 +122,42 @@ func (r *Registry) RemoveExternalService(ctx context.Context, name string) error // it generates client_id, secrets and any additional provider specificities (ex: rsa keys). It also ensures that the // associated service account has the correct permissions. func (r *Registry) SaveExternalService(ctx context.Context, cmd *extsvcauth.ExternalServiceRegistration) (*extsvcauth.ExternalService, error) { - // Record provider in case of removal - r.lock.Lock() - r.extSvcProviders[slugify.Slugify(cmd.Name)] = cmd.AuthProvider - r.lock.Unlock() + var ( + errSave error + extSvc *extsvcauth.ExternalService + lockName = "ext-svc-save-" + cmd.Name + ) - switch cmd.AuthProvider { - case extsvcauth.ServiceAccounts: - if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAccounts) { - r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAccounts) - return nil, nil - } - r.logger.Debug("Routing the External Service registration to the External Service Account service", "service", cmd.Name) - return r.saReg.SaveExternalService(ctx, cmd) - case extsvcauth.OAuth2Server: - if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAuth) { - r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAuth) - return nil, nil + err := r.serverLock.LockExecuteAndReleaseWithRetries(ctx, lockName, lockTimeConfig, func(ctx context.Context) { + // Record provider in case of removal + r.lock.Lock() + r.extSvcProviders[slugify.Slugify(cmd.Name)] = cmd.AuthProvider + r.lock.Unlock() + + switch cmd.AuthProvider { + case extsvcauth.ServiceAccounts: + if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAccounts) { + r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAccounts) + return + } + r.logger.Debug("Routing the External Service registration to the External Service Account service", "service", cmd.Name) + extSvc, errSave = r.saReg.SaveExternalService(ctx, cmd) + case extsvcauth.OAuth2Server: + if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAuth) { + r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAuth) + return + } + r.logger.Debug("Routing the External Service registration to the OAuth2Server", "service", cmd.Name) + extSvc, errSave = r.oauthReg.SaveExternalService(ctx, cmd) + default: + errSave = extsvcauth.ErrUnknownProvider.Errorf("unknown provider '%v'", cmd.AuthProvider) } - r.logger.Debug("Routing the External Service registration to the OAuth2Server", "service", cmd.Name) - return r.oauthReg.SaveExternalService(ctx, cmd) - default: - return nil, extsvcauth.ErrUnknownProvider.Errorf("unknow provider '%v'", cmd.AuthProvider) + }) + if err != nil { + return nil, err } + + return extSvc, errSave } // retrieveExtSvcProviders fetches external services from store and map their associated provider diff --git a/pkg/services/extsvcauth/registry/service_test.go b/pkg/services/extsvcauth/registry/service_test.go index a1bf617e967..d09e0c4477a 100644 --- a/pkg/services/extsvcauth/registry/service_test.go +++ b/pkg/services/extsvcauth/registry/service_test.go @@ -2,7 +2,6 @@ package registry import ( "context" - "sync" "testing" "github.com/grafana/grafana/pkg/infra/log" @@ -29,7 +28,6 @@ func setupTestEnv(t *testing.T) *TestEnv { oauthReg: env.oauthReg, saReg: env.saReg, extSvcProviders: map[string]extsvcauth.AuthProvider{}, - lock: sync.Mutex{}, } return &env }