AuthN: Support HA setups with External Service Account management (#78425)

* Lock when creating external service

* Add local lock back

* Improve function signature

* Define lockName separately to make it more explicit

* Update pkg/infra/serverlock/serverlock.go

Co-authored-by: Gabriel MABILLE <gamab@users.noreply.github.com>

* Update pkg/infra/serverlock/serverlock.go

Co-authored-by: Gabriel MABILLE <gamab@users.noreply.github.com>

---------

Co-authored-by: Gabriel MABILLE <gamab@users.noreply.github.com>
pull/78518/head
Xavi Lacasa 2 years ago committed by GitHub
parent 61553e1693
commit 72759be6ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 70
      pkg/infra/serverlock/serverlock.go
  2. 63
      pkg/infra/serverlock/serverlock_integration_test.go
  3. 37
      pkg/services/extsvcauth/registry/service.go
  4. 2
      pkg/services/extsvcauth/registry/service_test.go

@ -2,6 +2,8 @@ package serverlock
import ( import (
"context" "context"
"errors"
"math/rand"
"time" "time"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
@ -164,6 +166,74 @@ func (sl *ServerLockService) LockExecuteAndRelease(ctx context.Context, actionNa
return nil return nil
} }
// RetryOpt is a callback function called after each failed lock acquisition try.
// It gets the number of tries passed as an arg.
type RetryOpt func(int) error
type LockTimeConfig struct {
MaxInterval time.Duration // Duration after which we consider a lock to be dead and overtake it. Make sure this is big enough so that a server cannot acquire the lock while another server is processing.
MinWait time.Duration // Minimum time to wait before retrying to acquire the lock.
MaxWait time.Duration // Maximum time to wait before retrying to acquire the lock.
}
// LockExecuteAndReleaseWithRetries mimics LockExecuteAndRelease but waits for the lock to be released if it is already taken.
func (sl *ServerLockService) LockExecuteAndReleaseWithRetries(ctx context.Context, actionName string, timeConfig LockTimeConfig, fn func(ctx context.Context), retryOpts ...RetryOpt) error {
start := time.Now()
ctx, span := sl.tracer.Start(ctx, "ServerLockService.LockExecuteAndReleaseWithRetries")
span.SetAttributes(attribute.String("serverlock.actionName", actionName))
defer span.End()
ctxLogger := sl.log.FromContext(ctx)
ctxLogger.Debug("Start LockExecuteAndReleaseWithRetries", "actionName", actionName)
lockChecks := 0
for {
lockChecks++
err := sl.acquireForRelease(ctx, actionName, timeConfig.MaxInterval)
// could not get the lock
if err != nil {
var lockedErr *ServerLockExistsError
if errors.As(err, &lockedErr) {
// if the lock is already taken, wait and try again
if lockChecks == 1 { // only warn on first lock check
ctxLogger.Warn("another instance has the lock, waiting for it to be released", "actionName", actionName)
}
for _, op := range retryOpts {
if err := op(lockChecks); err != nil {
return err
}
}
time.Sleep(lockWait(timeConfig.MinWait, timeConfig.MaxWait))
continue
}
span.RecordError(err)
return err
}
// lock was acquired and released successfully
break
}
sl.executeFunc(ctx, actionName, fn)
if err := sl.releaseLock(ctx, actionName); err != nil {
span.RecordError(err)
ctxLogger.Error("Failed to release the lock", "error", err)
}
ctxLogger.Debug("LockExecuteAndReleaseWithRetries finished", "actionName", actionName, "duration", time.Since(start))
return nil
}
// generate a random duration between minWait and maxWait to ensure instances unlock gradually
func lockWait(minWait time.Duration, maxWait time.Duration) time.Duration {
return time.Duration(rand.Int63n(int64(maxWait-minWait)) + int64(minWait))
}
// acquireForRelease will check if the lock is already on the database, if it is, will check with maxInterval if it is // acquireForRelease will check if the lock is already on the database, if it is, will check with maxInterval if it is
// timeouted. Returns nil error if the lock was acquired correctly // timeouted. Returns nil error if the lock was acquired correctly
func (sl *ServerLockService) acquireForRelease(ctx context.Context, actionName string, maxInterval time.Duration) error { func (sl *ServerLockService) acquireForRelease(ctx context.Context, actionName string, maxInterval time.Duration) error {

@ -2,6 +2,7 @@ package serverlock
import ( import (
"context" "context"
"sync"
"testing" "testing"
"time" "time"
@ -62,3 +63,65 @@ func TestIntegrationServerLock_LockExecuteAndRelease(t *testing.T) {
require.Equal(t, 4, counter) require.Equal(t, 4, counter)
} }
func TestIntegrationServerLock_LockExecuteAndReleaseWithRetries(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
sl := createTestableServerLock(t)
retries := 0
expectedRetries := 10
funcRuns := 0
fn := func(context.Context) {
funcRuns++
}
lockTimeConfig := LockTimeConfig{
MaxInterval: time.Hour,
MinWait: 0 * time.Millisecond,
MaxWait: 1 * time.Millisecond,
}
ctx := context.Background()
actionName := "test-operation"
// Acquire lock so that when `LockExecuteAndReleaseWithRetries` runs, it is forced
// to retry
err := sl.acquireForRelease(ctx, actionName, lockTimeConfig.MaxInterval)
require.NoError(t, err)
wgRetries := sync.WaitGroup{}
wgRetries.Add(expectedRetries)
wgRelease := sync.WaitGroup{}
wgRelease.Add(1)
wgCompleted := sync.WaitGroup{}
wgCompleted.Add(1)
onRetryFn := func(int) error {
retries++
wgRetries.Done()
if retries == expectedRetries {
// When we reach `expectedRetries`, wait for the lock to be released
// to guarantee that next try will succeed
wgRelease.Wait()
}
return nil
}
go func() {
err := sl.LockExecuteAndReleaseWithRetries(ctx, actionName, lockTimeConfig, fn, onRetryFn)
require.NoError(t, err)
wgCompleted.Done()
}()
// Wait to release the lock until `LockExecuteAndReleaseWithRetries` has retried `expectedRetries` times.
wgRetries.Wait()
err = sl.releaseLock(ctx, actionName)
require.NoError(t, err)
wgRelease.Done()
// `LockExecuteAndReleaseWithRetries` has run completely.
// Check that it had to retry because the lock was already taken.
wgCompleted.Wait()
require.Equal(t, expectedRetries, retries)
require.Equal(t, 1, funcRuns)
}

@ -3,8 +3,10 @@ package registry
import ( import (
"context" "context"
"sync" "sync"
"time"
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/serverlock"
"github.com/grafana/grafana/pkg/infra/slugify" "github.com/grafana/grafana/pkg/infra/slugify"
"github.com/grafana/grafana/pkg/services/extsvcauth" "github.com/grafana/grafana/pkg/services/extsvcauth"
"github.com/grafana/grafana/pkg/services/extsvcauth/oauthserver/oasimpl" "github.com/grafana/grafana/pkg/services/extsvcauth/oauthserver/oasimpl"
@ -14,6 +16,12 @@ import (
var _ extsvcauth.ExternalServiceRegistry = &Registry{} var _ extsvcauth.ExternalServiceRegistry = &Registry{}
var lockTimeConfig = serverlock.LockTimeConfig{
MaxInterval: 2 * time.Minute,
MinWait: 1 * time.Second,
MaxWait: 5 * time.Second,
}
type Registry struct { type Registry struct {
features featuremgmt.FeatureToggles features featuremgmt.FeatureToggles
logger log.Logger logger log.Logger
@ -22,9 +30,10 @@ type Registry struct {
extSvcProviders map[string]extsvcauth.AuthProvider extSvcProviders map[string]extsvcauth.AuthProvider
lock sync.Mutex lock sync.Mutex
serverLock *serverlock.ServerLockService
} }
func ProvideExtSvcRegistry(oauthServer *oasimpl.OAuth2ServiceImpl, saSvc *extsvcaccounts.ExtSvcAccountsService, features featuremgmt.FeatureToggles) *Registry { func ProvideExtSvcRegistry(oauthServer *oasimpl.OAuth2ServiceImpl, saSvc *extsvcaccounts.ExtSvcAccountsService, serverLock *serverlock.ServerLockService, features featuremgmt.FeatureToggles) *Registry {
return &Registry{ return &Registry{
extSvcProviders: map[string]extsvcauth.AuthProvider{}, extSvcProviders: map[string]extsvcauth.AuthProvider{},
features: features, features: features,
@ -32,6 +41,7 @@ func ProvideExtSvcRegistry(oauthServer *oasimpl.OAuth2ServiceImpl, saSvc *extsvc
logger: log.New("extsvcauth.registry"), logger: log.New("extsvcauth.registry"),
oauthReg: oauthServer, oauthReg: oauthServer,
saReg: saSvc, saReg: saSvc,
serverLock: serverLock,
} }
} }
@ -104,7 +114,7 @@ func (r *Registry) RemoveExternalService(ctx context.Context, name string) error
r.logger.Debug("Routing External Service removal to the OAuth2Server", "service", name) r.logger.Debug("Routing External Service removal to the OAuth2Server", "service", name)
return r.oauthReg.RemoveExternalService(ctx, name) return r.oauthReg.RemoveExternalService(ctx, name)
default: default:
return extsvcauth.ErrUnknownProvider.Errorf("unknow provider '%v'", provider) return extsvcauth.ErrUnknownProvider.Errorf("unknown provider '%v'", provider)
} }
} }
@ -112,6 +122,13 @@ func (r *Registry) RemoveExternalService(ctx context.Context, name string) error
// it generates client_id, secrets and any additional provider specificities (ex: rsa keys). It also ensures that the // it generates client_id, secrets and any additional provider specificities (ex: rsa keys). It also ensures that the
// associated service account has the correct permissions. // associated service account has the correct permissions.
func (r *Registry) SaveExternalService(ctx context.Context, cmd *extsvcauth.ExternalServiceRegistration) (*extsvcauth.ExternalService, error) { func (r *Registry) SaveExternalService(ctx context.Context, cmd *extsvcauth.ExternalServiceRegistration) (*extsvcauth.ExternalService, error) {
var (
errSave error
extSvc *extsvcauth.ExternalService
lockName = "ext-svc-save-" + cmd.Name
)
err := r.serverLock.LockExecuteAndReleaseWithRetries(ctx, lockName, lockTimeConfig, func(ctx context.Context) {
// Record provider in case of removal // Record provider in case of removal
r.lock.Lock() r.lock.Lock()
r.extSvcProviders[slugify.Slugify(cmd.Name)] = cmd.AuthProvider r.extSvcProviders[slugify.Slugify(cmd.Name)] = cmd.AuthProvider
@ -121,20 +138,26 @@ func (r *Registry) SaveExternalService(ctx context.Context, cmd *extsvcauth.Exte
case extsvcauth.ServiceAccounts: case extsvcauth.ServiceAccounts:
if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAccounts) { if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAccounts) {
r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAccounts) r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAccounts)
return nil, nil return
} }
r.logger.Debug("Routing the External Service registration to the External Service Account service", "service", cmd.Name) r.logger.Debug("Routing the External Service registration to the External Service Account service", "service", cmd.Name)
return r.saReg.SaveExternalService(ctx, cmd) extSvc, errSave = r.saReg.SaveExternalService(ctx, cmd)
case extsvcauth.OAuth2Server: case extsvcauth.OAuth2Server:
if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAuth) { if !r.features.IsEnabled(ctx, featuremgmt.FlagExternalServiceAuth) {
r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAuth) r.logger.Warn("Skipping External Service authentication, flag disabled", "service", cmd.Name, "flag", featuremgmt.FlagExternalServiceAuth)
return nil, nil return
} }
r.logger.Debug("Routing the External Service registration to the OAuth2Server", "service", cmd.Name) r.logger.Debug("Routing the External Service registration to the OAuth2Server", "service", cmd.Name)
return r.oauthReg.SaveExternalService(ctx, cmd) extSvc, errSave = r.oauthReg.SaveExternalService(ctx, cmd)
default: default:
return nil, extsvcauth.ErrUnknownProvider.Errorf("unknow provider '%v'", cmd.AuthProvider) errSave = extsvcauth.ErrUnknownProvider.Errorf("unknown provider '%v'", cmd.AuthProvider)
} }
})
if err != nil {
return nil, err
}
return extSvc, errSave
} }
// retrieveExtSvcProviders fetches external services from store and map their associated provider // retrieveExtSvcProviders fetches external services from store and map their associated provider

@ -2,7 +2,6 @@ package registry
import ( import (
"context" "context"
"sync"
"testing" "testing"
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
@ -29,7 +28,6 @@ func setupTestEnv(t *testing.T) *TestEnv {
oauthReg: env.oauthReg, oauthReg: env.oauthReg,
saReg: env.saReg, saReg: env.saReg,
extSvcProviders: map[string]extsvcauth.AuthProvider{}, extSvcProviders: map[string]extsvcauth.AuthProvider{},
lock: sync.Mutex{},
} }
return &env return &env
} }

Loading…
Cancel
Save