Auth: Update oauthtoken service to use remote cache and server lock (#90572)

* update oauthtoken service to use remote cache and server lock

* remove token cache

* retry is lock is held by an in-flight refresh

* refactor token renewal to avoid race condition

* re-add refresh token expiry cache, but in SyncOauthTokenHook

* Add delta to the cache ttl

* Fix merge

* Change lockTimeConfig

* Always set the token from within the server lock

* Improvements

* early return when user is not authed by OAuth or refresh is disabled

* Allow more time for token refresh, tracing

* Retry on Mysql Deadlock error 1213

* Update pkg/services/authn/authnimpl/sync/oauth_token_sync.go

Co-authored-by: Dan Cech <dcech@grafana.com>

* Update pkg/services/authn/authnimpl/sync/oauth_token_sync.go

Co-authored-by: Dan Cech <dcech@grafana.com>

* Add settings for configuring min wait time between retries

* Add docs for the new setting

* Clean up

* Update docs/sources/setup-grafana/configure-grafana/_index.md

Co-authored-by: Christopher Moyer <35463610+chri2547@users.noreply.github.com>

---------

Co-authored-by: Mihaly Gyongyosi <mgyongyosi@users.noreply.github.com>
Co-authored-by: Christopher Moyer <35463610+chri2547@users.noreply.github.com>
pull/92098/head
Dan Cech 9 months ago committed by GitHub
parent 5ce9324801
commit 9020eb4b17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 7
      conf/defaults.ini
  2. 7
      conf/sample.ini
  3. 8
      docs/sources/setup-grafana/configure-grafana/_index.md
  4. 4
      pkg/infra/serverlock/serverlock.go
  5. 73
      pkg/services/authn/authnimpl/sync/oauth_token_sync.go
  6. 68
      pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go
  7. 250
      pkg/services/oauthtoken/oauth_token.go
  8. 162
      pkg/services/oauthtoken/oauth_token_test.go
  9. 6
      pkg/services/oauthtoken/oauthtokentest/mock.go
  10. 4
      pkg/services/oauthtoken/oauthtokentest/oauthtokentest.go
  11. 8
      pkg/setting/setting.go

@ -579,6 +579,13 @@ oauth_auto_login = false
# OAuth state max age cookie duration in seconds. Defaults to 600 seconds.
oauth_state_cookie_max_age = 600
# Minimum wait time in milliseconds for the server lock retry mechanism.
# The server lock retry mechanism is used to prevent multiple Grafana instances from
# simultaneously refreshing OAuth tokens. This mechanism waits at least this amount
# of time before retrying to acquire the server lock. There are 5 retries in total.
# The wait time between retries is calculated as random(n, n + 500)
oauth_refresh_token_server_lock_min_wait_ms = 1000
# limit of api_key seconds to live before expiration
api_key_max_seconds_to_live = -1

@ -583,6 +583,13 @@
# OAuth state max age cookie duration in seconds. Defaults to 600 seconds.
;oauth_state_cookie_max_age = 600
# Minimum wait time in milliseconds for the server lock retry mechanism.
# The server lock retry mechanism is used to prevent multiple Grafana instances from
# simultaneously refreshing OAuth tokens. This mechanism waits at least this amount
# of time before retrying to acquire the server lock. There are 5 retries in total.
# The wait time between retries is calculated as random(n, n + 500)
; oauth_refresh_token_server_lock_min_wait_ms = 1000
# limit of api_key seconds to live before expiration
;api_key_max_seconds_to_live = -1

@ -952,6 +952,14 @@ This setting is ignored if multiple OAuth providers are configured. Default is `
How many seconds the OAuth state cookie lives before being deleted. Default is `600` (seconds)
Administrators can increase this if they experience OAuth login state mismatch errors.
### oauth_refresh_token_server_lock_min_wait_ms
Minimum wait time in milliseconds for the server lock retry mechanism. Default is `1000` (milliseconds). The server lock retry mechanism is used to prevent multiple Grafana instances from simultaneously refreshing OAuth tokens. This mechanism waits at least this amount of time before retrying to acquire the server lock.
There are five retries in total, so with the default value, the total wait time (for acquiring the lock) is at least 5 seconds (the wait time between retries is calculated as random(n, n + 500)), which means that the maximum token refresh duration must be less than 5-6 seconds.
If you experience issues with the OAuth token refresh mechanism, you can increase this value to allow more time for the token refresh to complete.
### oauth_skip_org_role_update_sync
{{% admonition type="note" %}}

@ -7,6 +7,7 @@ import (
"math/rand"
"time"
"github.com/go-sql-driver/mysql"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
@ -197,7 +198,8 @@ func (sl *ServerLockService) LockExecuteAndReleaseWithRetries(ctx context.Contex
// could not get the lock
if err != nil {
var lockedErr *ServerLockExistsError
if errors.As(err, &lockedErr) {
var deadlockErr *mysql.MySQLError
if errors.As(err, &lockedErr) || (errors.As(err, &deadlockErr) && deadlockErr.Number == 1213) {
// if the lock is already taken, wait and try again
if lockChecks == 1 { // only warn on first lock check
ctxLogger.Warn("another instance has the lock, waiting for it to be released", "actionName", actionName)

@ -3,12 +3,15 @@ package sync
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/grafana/authlib/claims"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/login/social"
@ -17,6 +20,8 @@ import (
"github.com/grafana/grafana/pkg/services/oauthtoken"
)
const maxOAuthTokenCacheTTL = 5 * time.Minute
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service, tracer tracing.Tracer) *OAuthTokenSync {
return &OAuthTokenSync{
log.New("oauth_token.sync"),
@ -25,6 +30,7 @@ func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService
socialService,
new(singleflight.Group),
tracer,
localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
}
}
@ -35,6 +41,7 @@ type OAuthTokenSync struct {
socialService social.Service
singleflightGroup *singleflight.Group
tracer tracing.Tracer
cache *localcache.CacheService
}
func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, id *authn.Identity, _ *authn.Request) error {
@ -56,44 +63,44 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, id *authn.Ident
return nil
}
ctxLogger := s.log.FromContext(ctx).New("userID", id.GetID())
userID, err := id.GetInternalID()
if err != nil {
s.log.FromContext(ctx).Error("Failed to refresh token. Invalid ID for identity", "type", id.GetIdentityType(), "err", err)
return nil
}
_, err, _ := s.singleflightGroup.Do(id.GetID(), func() (interface{}, error) {
ctxLogger := s.log.FromContext(ctx).New("userID", userID)
cacheKey := fmt.Sprintf("token-check-%s", id.GetID())
if _, ok := s.cache.Get(cacheKey); ok {
ctxLogger.Debug("Expiration check has been cached, no need to refresh")
return nil
}
_, err, _ = s.singleflightGroup.Do(cacheKey, func() (interface{}, error) {
ctxLogger.Debug("Singleflight request for OAuth token sync")
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
updateCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 15*time.Second)
defer cancel()
if refreshErr := s.service.TryTokenRefresh(updateCtx, id); refreshErr != nil {
token, refreshErr := s.service.TryTokenRefresh(updateCtx, id)
if refreshErr != nil {
if errors.Is(refreshErr, context.Canceled) {
return nil, nil
}
token, _, err := s.service.HasOAuthEntry(ctx, id)
if err != nil {
ctxLogger.Error("Failed to get OAuth entry for verifying if token has already been refreshed", "id", id.ID, "error", err)
return nil, err
}
// if the access token has already been refreshed by another request (for example in HA scenario)
tokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
if !tokenExpires.Before(time.Now()) {
return nil, nil
}
ctxLogger.Error("Failed to refresh OAuth access token", "id", id.ID, "error", refreshErr)
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil {
ctxLogger.Warn("Failed to invalidate OAuth tokens", "id", id.ID, "error", err)
}
// log the user out
if err := s.sessionService.RevokeToken(ctx, id.SessionToken, false); err != nil {
ctxLogger.Warn("Failed to revoke session token", "id", id.ID, "tokenId", id.SessionToken.Id, "error", err)
}
s.cache.Delete(cacheKey)
return nil, refreshErr
}
s.cache.Set(cacheKey, true, getOAuthTokenCacheTTL(token))
return nil, nil
})
@ -103,3 +110,27 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, id *authn.Ident
return nil
}
func getOAuthTokenCacheTTL(token *oauth2.Token) time.Duration {
ttl := maxOAuthTokenCacheTTL
if token == nil {
return ttl
}
if !token.Expiry.IsZero() {
d := time.Until(token.Expiry.Add(-oauthtoken.ExpiryDelta))
if d < ttl {
ttl = d
}
}
idTokenExpiry, err := oauthtoken.GetIDTokenExpiry(token)
if err == nil && !idTokenExpiry.IsZero() {
d := time.Until(idTokenExpiry.Add(-oauthtoken.ExpiryDelta))
if d < ttl {
ttl = d
}
}
return ttl
}

@ -8,9 +8,11 @@ import (
"github.com/grafana/authlib/claims"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/apimachinery/identity"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/login/social"
@ -28,14 +30,12 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
identity *authn.Identity
oauthInfo *social.OAuthInfo
expectedHasEntryToken *login.UserAuth
expectHasEntryCalled bool
expectToken *login.UserAuth
expectedTryRefreshErr error
expectTryRefreshTokenCalled bool
expectRevokeTokenCalled bool
expectInvalidateOauthTokensCalled bool
expectRevokeTokenCalled bool
expectedErr error
}
@ -52,33 +52,26 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
expectTryRefreshTokenCalled: false,
},
{
desc: "should invalidate access token and session token if token refresh fails",
identity: &authn.Identity{ID: "1", Type: claims.TypeUser, SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
expectHasEntryCalled: true,
expectedTryRefreshErr: errors.New("some err"),
expectTryRefreshTokenCalled: true,
expectInvalidateOauthTokensCalled: true,
expectRevokeTokenCalled: true,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
expectedErr: authn.ErrExpiredAccessToken,
desc: "should invalidate access token and session token if token refresh fails",
identity: &authn.Identity{ID: "1", Type: claims.TypeUser, SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
expectedTryRefreshErr: errors.New("some err"),
expectTryRefreshTokenCalled: true,
expectRevokeTokenCalled: true,
expectToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
expectedErr: authn.ErrExpiredAccessToken,
},
{
desc: "should refresh the token successfully",
identity: &authn.Identity{ID: "1", Type: claims.TypeUser, SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
expectHasEntryCalled: false,
expectTryRefreshTokenCalled: true,
expectInvalidateOauthTokensCalled: false,
expectRevokeTokenCalled: false,
desc: "should refresh the token successfully",
identity: &authn.Identity{ID: "1", Type: claims.TypeUser, SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
expectTryRefreshTokenCalled: true,
expectRevokeTokenCalled: false,
},
{
desc: "should not invalidate the token if the token has already been refreshed by another request (singleflight)",
identity: &authn.Identity{ID: "1", Type: claims.TypeUser, SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
expectHasEntryCalled: true,
expectTryRefreshTokenCalled: true,
expectInvalidateOauthTokensCalled: false,
expectRevokeTokenCalled: false,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
expectedTryRefreshErr: errors.New("some err"),
desc: "should not invalidate the token if the token has already been refreshed by another request (singleflight)",
identity: &authn.Identity{ID: "1", Type: claims.TypeUser, SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
expectTryRefreshTokenCalled: true,
expectRevokeTokenCalled: false,
expectToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
},
// TODO: address coverage of oauthtoken sync
@ -87,24 +80,14 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
var (
hasEntryCalled bool
tryRefreshCalled bool
invalidateTokensCalled bool
revokeTokenCalled bool
tryRefreshCalled bool
revokeTokenCalled bool
)
service := &oauthtokentest.MockOauthTokenService{
HasOAuthEntryFunc: func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
hasEntryCalled = true
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil
},
InvalidateOAuthTokensFunc: func(ctx context.Context, usr *login.UserAuth) error {
invalidateTokensCalled = true
return nil
},
TryTokenRefreshFunc: func(ctx context.Context, usr identity.Requester) error {
TryTokenRefreshFunc: func(ctx context.Context, usr identity.Requester) (*oauth2.Token, error) {
tryRefreshCalled = true
return tt.expectedTryRefreshErr
return nil, tt.expectedTryRefreshErr
},
}
@ -132,13 +115,12 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
socialService: socialService,
singleflightGroup: new(singleflight.Group),
tracer: tracing.InitializeTracerForTest(),
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
}
err := sync.SyncOauthTokenHook(context.Background(), tt.identity, nil)
assert.ErrorIs(t, err, tt.expectedErr)
assert.Equal(t, tt.expectHasEntryCalled, hasEntryCalled)
assert.Equal(t, tt.expectTryRefreshTokenCalled, tryRefreshCalled)
assert.Equal(t, tt.expectInvalidateOauthTokensCalled, invalidateTokensCalled)
assert.Equal(t, tt.expectRevokeTokenCalled, revokeTokenCalled)
})
}

@ -9,13 +9,15 @@ import (
"github.com/go-jose/go-jose/v3/jwt"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
"github.com/grafana/authlib/claims"
"github.com/grafana/grafana/pkg/apimachinery/identity"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/serverlock"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/login"
@ -30,16 +32,15 @@ var (
ExpiryDelta = 10 * time.Second
ErrNoRefreshTokenFound = errors.New("no refresh token found")
ErrNotAnOAuthProvider = errors.New("not an oauth provider")
ErrCouldntRefreshToken = errors.New("could not refresh token")
)
const maxOAuthTokenCacheTTL = 10 * time.Minute
type Service struct {
Cfg *setting.Cfg
SocialService social.Service
AuthInfoService login.AuthInfoService
singleFlightGroup *singleflight.Group
cache *localcache.CacheService
Cfg *setting.Cfg
SocialService social.Service
AuthInfoService login.AuthInfoService
serverLock *serverlock.ServerLockService
tracer tracing.Tracer
tokenRefreshDuration *prometheus.HistogramVec
}
@ -49,29 +50,46 @@ type OAuthTokenService interface {
GetCurrentOAuthToken(context.Context, identity.Requester) *oauth2.Token
IsOAuthPassThruEnabled(*datasources.DataSource) bool
HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error)
TryTokenRefresh(context.Context, identity.Requester) error
TryTokenRefresh(context.Context, identity.Requester) (*oauth2.Token, error)
InvalidateOAuthTokens(context.Context, *login.UserAuth) error
}
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer) *Service {
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer,
serverLockService *serverlock.ServerLockService, tracer tracing.Tracer) *Service {
return &Service{
AuthInfoService: authInfoService,
Cfg: cfg,
SocialService: socialService,
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
singleFlightGroup: new(singleflight.Group),
serverLock: serverLockService,
tokenRefreshDuration: newTokenRefreshDurationMetric(registerer),
tracer: tracer,
}
}
// GetCurrentOAuthToken returns the OAuth token, if any, for the authenticated user. Will try to refresh the token if it has expired.
func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
ctx, span := o.tracer.Start(ctx, "oauthtoken.GetCurrentOAuthToken")
defer span.End()
authInfo, ok, _ := o.HasOAuthEntry(ctx, usr)
if !ok {
return nil
}
token, err := o.tryGetOrRefreshOAuthToken(ctx, authInfo)
if err := checkOAuthRefreshToken(authInfo); err != nil {
if errors.Is(err, ErrNoRefreshTokenFound) {
return buildOAuthTokenFromAuthInfo(authInfo)
}
return nil
}
persistedToken, refreshNeeded := needTokenRefresh(authInfo)
if !refreshNeeded {
return persistedToken
}
token, err := o.TryTokenRefresh(ctx, usr)
if err != nil {
if errors.Is(err, ErrNoRefreshTokenFound) {
return buildOAuthTokenFromAuthInfo(authInfo)
@ -90,6 +108,9 @@ func (o *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
// HasOAuthEntry returns true and the UserAuth object when OAuth info exists for the specified User
func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
ctx, span := o.tracer.Start(ctx, "oauthtoken.HasOAuthEntry")
defer span.End()
if usr == nil || usr.IsNil() {
// No user, therefore no token
return nil, false, nil
@ -126,76 +147,105 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*l
}
// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful
// It uses a singleflight.Group to prevent getting the Refresh Token multiple times for a given User
func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester) error {
// It uses a server lock to prevent getting the Refresh Token multiple times for a given User
func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester) (*oauth2.Token, error) {
ctx, span := o.tracer.Start(ctx, "oauthtoken.TryTokenRefresh")
defer span.End()
ctxLogger := logger.FromContext(ctx)
if usr == nil || usr.IsNil() {
ctxLogger.Warn("Can only refresh OAuth tokens for existing users", "user", "nil")
// Not user, no token.
return nil
return nil, nil
}
if !usr.IsIdentityType(claims.TypeUser) {
ctxLogger.Warn("Can only refresh OAuth tokens for users", "id", usr.GetID())
return nil
return nil, nil
}
userID, err := usr.GetInternalID()
if err != nil {
ctxLogger.Warn("Failed to convert user id to int", "id", usr.GetID(), "error", err)
return nil
return nil, nil
}
ctxLogger = ctxLogger.New("userID", userID)
// get the token's auth provider (f.e. azuread)
currAuthenticator := usr.GetAuthenticatedBy()
if !strings.HasPrefix(currAuthenticator, "oauth") {
ctxLogger.Warn("The specified user's auth provider is not OAuth", "authmodule", currAuthenticator)
return nil, nil
}
provider := strings.TrimPrefix(currAuthenticator, "oauth_")
currentOAuthInfo := o.SocialService.GetOAuthInfoProvider(provider)
if currentOAuthInfo == nil {
ctxLogger.Warn("OAuth provider not found", "provider", provider)
return nil, nil
}
// if refresh token handling is disabled for this provider, we can skip the refresh
if !currentOAuthInfo.UseRefreshToken {
ctxLogger.Debug("Skipping token refresh", "provider", provider)
return nil, nil
}
lockKey := fmt.Sprintf("oauth-refresh-token-%d", userID)
if _, ok := o.cache.Get(lockKey); ok {
ctxLogger.Debug("Expiration check has been cached, no need to refresh")
return nil
lockTimeConfig := serverlock.LockTimeConfig{
MaxInterval: 30 * time.Second,
MinWait: time.Duration(o.Cfg.OAuthRefreshTokenServerLockMinWaitMs) * time.Millisecond,
MaxWait: time.Duration(o.Cfg.OAuthRefreshTokenServerLockMinWaitMs+500) * time.Millisecond,
}
_, err, _ = o.singleFlightGroup.Do(lockKey, func() (any, error) {
ctxLogger.Debug("Singleflight request for getting a new access token", "key", lockKey)
retryOpt := func(attempts int) error {
if attempts < 5 {
return nil
}
return ErrCouldntRefreshToken
}
var newToken *oauth2.Token
var cmdErr error
lockErr := o.serverLock.LockExecuteAndReleaseWithRetries(ctx, lockKey, lockTimeConfig, func(ctx context.Context) {
ctx, span := o.tracer.Start(ctx, "oauthtoken server lock",
trace.WithAttributes(attribute.Int64("userID", userID)))
defer span.End()
ctxLogger.Debug("Serverlock request for getting a new access token", "key", lockKey)
authInfo, exists, err := o.HasOAuthEntry(ctx, usr)
if !exists {
if err != nil {
ctxLogger.Debug("Failed to fetch oauth entry", "error", err)
} else {
// User is not logged in via OAuth no need to check
o.cache.Set(lockKey, struct{}{}, maxOAuthTokenCacheTTL)
}
return nil, nil
return
}
_, needRefresh, ttl := needTokenRefresh(authInfo)
storedToken, needRefresh := needTokenRefresh(authInfo)
if !needRefresh {
o.cache.Set(lockKey, struct{}{}, ttl)
return nil, nil
// Set the token which is returned by the outer function in case there's no need to refresh the token
newToken = storedToken
return
}
// get the token's auth provider (f.e. azuread)
provider := strings.TrimPrefix(authInfo.AuthModule, "oauth_")
currentOAuthInfo := o.SocialService.GetOAuthInfoProvider(provider)
if currentOAuthInfo == nil {
ctxLogger.Warn("OAuth provider not found", "provider", provider)
return nil, nil
}
// if refresh token handling is disabled for this provider, we can skip the refresh
if !currentOAuthInfo.UseRefreshToken {
ctxLogger.Debug("Skipping token refresh", "provider", provider)
return nil, nil
}
newToken, cmdErr = o.tryGetOrRefreshOAuthToken(ctx, authInfo)
}, retryOpt)
if lockErr != nil {
ctxLogger.Error("Failed to obtain token refresh lock", "error", lockErr)
return nil, lockErr
}
return o.tryGetOrRefreshOAuthToken(ctx, authInfo)
})
// Silence ErrNoRefreshTokenFound
if errors.Is(err, ErrNoRefreshTokenFound) {
return nil
if errors.Is(cmdErr, ErrNoRefreshTokenFound) {
return nil, nil
}
return err
return newToken, cmdErr
}
func buildOAuthTokenFromAuthInfo(authInfo *login.UserAuth) *oauth2.Token {
@ -230,11 +280,11 @@ func checkOAuthRefreshToken(authInfo *login.UserAuth) error {
}
// InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero
func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *login.UserAuth) error {
func (o *Service) InvalidateOAuthTokens(ctx context.Context, authInfo *login.UserAuth) error {
return o.AuthInfoService.UpdateAuthInfo(ctx, &login.UpdateAuthInfoCommand{
UserId: usr.UserId,
AuthModule: usr.AuthModule,
AuthId: usr.AuthId,
UserId: authInfo.UserId,
AuthModule: authInfo.AuthModule,
AuthId: authInfo.AuthId,
OAuthToken: &oauth2.Token{
AccessToken: "",
RefreshToken: "",
@ -243,36 +293,33 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *login.UserAuth
})
}
func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, usr *login.UserAuth) (*oauth2.Token, error) {
ctxLogger := logger.FromContext(ctx).New("userID", usr.UserId)
func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, authInfo *login.UserAuth) (*oauth2.Token, error) {
ctx, span := o.tracer.Start(ctx, "oauthtoken.tryGetOrRefreshOAuthToken",
trace.WithAttributes(attribute.Int64("userID", authInfo.UserId)))
defer span.End()
key := getCheckCacheKey(usr.UserId)
if _, ok := o.cache.Get(key); ok {
ctxLogger.Debug("Expiration check has been cached", "userID", usr.UserId)
return buildOAuthTokenFromAuthInfo(usr), nil
}
ctxLogger := logger.FromContext(ctx).New("userID", authInfo.UserId)
if err := checkOAuthRefreshToken(usr); err != nil {
if err := checkOAuthRefreshToken(authInfo); err != nil {
return nil, err
}
persistedToken, refreshNeeded, ttl := needTokenRefresh(usr)
persistedToken, refreshNeeded := needTokenRefresh(authInfo)
if !refreshNeeded {
o.cache.Set(key, struct{}{}, ttl)
return persistedToken, nil
}
authProvider := usr.AuthModule
authProvider := authInfo.AuthModule
connect, err := o.SocialService.GetConnector(authProvider)
if err != nil {
ctxLogger.Error("Failed to get oauth connector", "provider", authProvider, "error", err)
return nil, err
return persistedToken, err
}
client, err := o.SocialService.GetOAuthHttpClient(authProvider)
if err != nil {
ctxLogger.Error("Failed to get oauth http client", "provider", authProvider, "error", err)
return nil, err
return persistedToken, err
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
@ -284,22 +331,28 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, usr *login.User
if err != nil {
ctxLogger.Error("Failed to retrieve oauth access token",
"provider", usr.AuthModule, "userId", usr.UserId, "error", err)
"provider", authInfo.AuthModule, "userId", authInfo.UserId, "error", err)
// token refresh failed, invalidate the old token
if err := o.InvalidateOAuthTokens(ctx, authInfo); err != nil {
ctxLogger.Warn("Failed to invalidate OAuth tokens", "id", authInfo.Id, "error", err)
}
return nil, err
}
// If the tokens are not the same, update the entry in the DB
if !tokensEq(persistedToken, token) {
updateAuthCommand := &login.UpdateAuthInfoCommand{
UserId: usr.UserId,
AuthModule: usr.AuthModule,
AuthId: usr.AuthId,
UserId: authInfo.UserId,
AuthModule: authInfo.AuthModule,
AuthId: authInfo.AuthId,
OAuthToken: token,
}
if o.Cfg.Env == setting.Dev {
ctxLogger.Debug("Oauth got token",
"auth_module", usr.AuthModule,
"auth_module", authInfo.AuthModule,
"expiry", fmt.Sprintf("%v", token.Expiry),
"access_token", fmt.Sprintf("%v", token.AccessToken),
"refresh_token", fmt.Sprintf("%v", token.RefreshToken),
@ -307,8 +360,8 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, usr *login.User
}
if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil {
ctxLogger.Error("Failed to update auth info during token refresh", "userId", usr.UserId, "error", err)
return nil, err
ctxLogger.Error("Failed to update auth info during token refresh", "userId", authInfo.UserId, "error", err)
return token, err
}
ctxLogger.Debug("Updated oauth info for user")
}
@ -348,63 +401,44 @@ func tokensEq(t1, t2 *oauth2.Token) bool {
t1IdToken == t2IdToken
}
func needTokenRefresh(usr *login.UserAuth) (*oauth2.Token, bool, time.Duration) {
var accessTokenExpires, idTokenExpires time.Time
func needTokenRefresh(authInfo *login.UserAuth) (*oauth2.Token, bool) {
var hasAccessTokenExpired, hasIdTokenExpired bool
persistedToken := buildOAuthTokenFromAuthInfo(usr)
idTokenExp, err := getIDTokenExpiry(usr)
persistedToken := buildOAuthTokenFromAuthInfo(authInfo)
idTokenExp, err := GetIDTokenExpiry(persistedToken)
if err != nil {
logger.Warn("Could not get ID Token expiry", "error", err)
}
if !persistedToken.Expiry.IsZero() {
accessTokenExpires, hasAccessTokenExpired = getExpiryWithSkew(persistedToken.Expiry)
_, hasAccessTokenExpired = getExpiryWithSkew(persistedToken.Expiry)
}
if !idTokenExp.IsZero() {
idTokenExpires, hasIdTokenExpired = getExpiryWithSkew(idTokenExp)
_, hasIdTokenExpired = getExpiryWithSkew(idTokenExp)
}
if !hasAccessTokenExpired && !hasIdTokenExpired {
logger.Debug("Neither access nor id token have expired yet", "userID", usr.UserId)
return persistedToken, false, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires)
logger.Debug("Neither access nor id token have expired yet", "userID", authInfo.UserId)
return persistedToken, false
}
if hasIdTokenExpired {
// Force refreshing token when id token is expired
persistedToken.AccessToken = ""
}
return persistedToken, true, time.Second
return persistedToken, true
}
func getCheckCacheKey(usrID int64) string {
return fmt.Sprintf("token-check-%d", usrID)
}
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
min := maxOAuthTokenCacheTTL
if !accessTokenExpiry.IsZero() {
d := time.Until(accessTokenExpiry)
if d < min {
min = d
}
}
if !idTokenExpiry.IsZero() {
d := time.Until(idTokenExpiry)
if d < min {
min = d
}
}
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
return maxOAuthTokenCacheTTL
// GetIDTokenExpiry extracts the expiry time from the ID token
func GetIDTokenExpiry(token *oauth2.Token) (time.Time, error) {
idToken, ok := token.Extra("id_token").(string)
if !ok {
return time.Time{}, nil
}
return min
}
// getIDTokenExpiry extracts the expiry time from the ID token
func getIDTokenExpiry(usr *login.UserAuth) (time.Time, error) {
if usr.OAuthIdToken == "" {
if idToken == "" {
return time.Time{}, nil
}
parsedToken, err := jwt.ParseSigned(usr.OAuthIdToken)
parsedToken, err := jwt.ParseSigned(idToken)
if err != nil {
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
}

@ -12,11 +12,12 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/apimachinery/identity"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/infra/serverlock"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/socialtest"
"github.com/grafana/grafana/pkg/services/authn"
@ -27,10 +28,15 @@ import (
secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tests/testsuite"
)
var EXPIRED_JWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U"
func TestMain(m *testing.M) {
testsuite.Run(m)
}
func TestService_HasOAuthEntry(t *testing.T) {
testCases := []struct {
name string
@ -116,13 +122,16 @@ func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *social
authInfoStore := &FakeAuthInfoStore{ExpectedOAuth: &login.UserAuth{}}
authInfoService := authinfoimpl.ProvideService(authInfoStore, remotecache.NewFakeCacheStorage(), secretsManager.SetupTestService(t, fakes.NewFakeSecretsStore()))
store := db.InitTestDB(t)
return &Service{
Cfg: setting.NewCfg(),
SocialService: socialService,
AuthInfoService: authInfoService,
singleFlightGroup: &singleflight.Group{},
serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()),
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
tracer: tracing.InitializeTracerForTest(),
}, authInfoStore, socialConnector
}
@ -155,7 +164,7 @@ func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.Delet
func TestService_TryTokenRefresh(t *testing.T) {
type environment struct {
authInfoService *authinfotest.FakeService
cache *localcache.CacheService
serverLock *serverlock.ServerLockService
identity identity.Requester
socialConnector *socialtest.MockSocialConnector
socialService *socialtest.FakeSocialService
@ -209,13 +218,6 @@ func TestService_TryTokenRefresh(t *testing.T) {
}
},
},
{
desc: "should skip token refresh if the expiration check has already been cached",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
env.cache.Set("oauth-refresh-token-1234", true, 1*time.Minute)
},
},
{
desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned",
setup: func(env *environment) {
@ -287,7 +289,7 @@ func TestService_TryTokenRefresh(t *testing.T) {
Expiry: time.Now().Add(-time.Hour),
TokenType: "Bearer",
}
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true,
}
@ -312,7 +314,7 @@ func TestService_TryTokenRefresh(t *testing.T) {
Expiry: time.Now().Add(time.Hour),
TokenType: "Bearer",
}
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser}
env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true,
}
@ -334,9 +336,11 @@ func TestService_TryTokenRefresh(t *testing.T) {
t.Run(tt.desc, func(t *testing.T) {
socialConnector := &socialtest.MockSocialConnector{}
store := db.InitTestDB(t)
env := environment{
authInfoService: &authinfotest.FakeService{},
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()),
socialConnector: socialConnector,
socialService: &socialtest.FakeSocialService{
ExpectedConnector: socialConnector,
@ -347,17 +351,17 @@ func TestService_TryTokenRefresh(t *testing.T) {
tt.setup(&env)
}
env.service = &Service{
AuthInfoService: env.authInfoService,
Cfg: setting.NewCfg(),
cache: env.cache,
singleFlightGroup: &singleflight.Group{},
SocialService: env.socialService,
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
}
env.service = ProvideService(
env.socialService,
env.authInfoService,
setting.NewCfg(),
prometheus.NewRegistry(),
env.serverLock,
tracing.InitializeTracerForTest(),
)
// token refresh
err := env.service.TryTokenRefresh(context.Background(), env.identity)
_, err := env.service.TryTokenRefresh(context.Background(), env.identity)
// test and validations
assert.ErrorIs(t, err, tt.expectedErr)
@ -366,74 +370,6 @@ func TestService_TryTokenRefresh(t *testing.T) {
}
}
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
defaultTime := time.Now()
tests := []struct {
name string
accessTokenExpiry time.Time
idTokenExpiry time.Time
want time.Duration
}{
{
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
accessTokenExpiry: time.Time{},
idTokenExpiry: time.Time{},
want: maxOAuthTokenCacheTTL,
},
{
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
accessTokenExpiry: time.Time{},
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: maxOAuthTokenCacheTTL,
},
{
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
accessTokenExpiry: time.Time{},
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: time.Time{},
want: maxOAuthTokenCacheTTL,
},
{
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: time.Time{},
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: maxOAuthTokenCacheTTL,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
})
}
}
func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
tests := []struct {
name string
@ -445,7 +381,6 @@ func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
name: "should not need token refresh when token has no expiration date",
usr: &login.UserAuth{},
expectedTokenRefreshFlag: false,
expectedTokenDuration: maxOAuthTokenCacheTTL,
},
{
name: "should not need token refresh with an invalid jwt token that might result in an error when parsing",
@ -453,7 +388,6 @@ func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
OAuthIdToken: "invalid_jwt_format",
},
expectedTokenRefreshFlag: false,
expectedTokenDuration: maxOAuthTokenCacheTTL,
},
{
name: "should flag token refresh with id token is expired",
@ -474,11 +408,10 @@ func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, needsTokenRefresh, tokenDuration := needTokenRefresh(tt.usr)
token, needsTokenRefresh := needTokenRefresh(tt.usr)
assert.NotNil(t, token)
assert.Equal(t, tt.expectedTokenRefreshFlag, needsTokenRefresh)
assert.Equal(t, tt.expectedTokenDuration, tokenDuration)
})
}
}
@ -493,7 +426,7 @@ func TestOAuthTokenSync_tryGetOrRefreshOAuthToken(t *testing.T) {
}
type environment struct {
authInfoService *authinfotest.FakeService
cache *localcache.CacheService
serverLock *serverlock.ServerLockService
socialConnector *socialtest.MockSocialConnector
socialService *socialtest.FakeSocialService
@ -506,21 +439,6 @@ func TestOAuthTokenSync_tryGetOrRefreshOAuthToken(t *testing.T) {
usr *login.UserAuth
setup func(env *environment)
}{
{
desc: "should find and retrieve token from cache",
usr: &login.UserAuth{
UserId: int64(1234),
OAuthAccessToken: "new_access_token",
OAuthExpiry: timeNow,
},
setup: func(env *environment) {
env.cache.Set("token-check-1234", token, 1*time.Minute)
},
expectedToken: &oauth2.Token{
AccessToken: "new_access_token",
Expiry: timeNow,
},
},
{
desc: "should return ErrNotAnOAuthProvider error when the user is not an oauth provider",
usr: &login.UserAuth{
@ -578,9 +496,11 @@ func TestOAuthTokenSync_tryGetOrRefreshOAuthToken(t *testing.T) {
t.Run(tt.desc, func(t *testing.T) {
socialConnector := &socialtest.MockSocialConnector{}
store := db.InitTestDB(t)
env := environment{
authInfoService: &authinfotest.FakeService{},
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()),
socialConnector: socialConnector,
socialService: &socialtest.FakeSocialService{
ExpectedConnector: socialConnector,
@ -591,14 +511,14 @@ func TestOAuthTokenSync_tryGetOrRefreshOAuthToken(t *testing.T) {
tt.setup(&env)
}
env.service = &Service{
AuthInfoService: env.authInfoService,
Cfg: setting.NewCfg(),
cache: env.cache,
singleFlightGroup: &singleflight.Group{},
SocialService: env.socialService,
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
}
env.service = ProvideService(
env.socialService,
env.authInfoService,
setting.NewCfg(),
prometheus.NewRegistry(),
env.serverLock,
tracing.InitializeTracerForTest(),
)
token, err := env.service.tryGetOrRefreshOAuthToken(context.Background(), tt.usr)

@ -15,7 +15,7 @@ type MockOauthTokenService struct {
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
HasOAuthEntryFunc func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error)
InvalidateOAuthTokensFunc func(ctx context.Context, usr *login.UserAuth) error
TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester) error
TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester) (*oauth2.Token, error)
}
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
@ -46,9 +46,9 @@ func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *
return nil
}
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr identity.Requester) error {
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr identity.Requester) (*oauth2.Token, error) {
if m.TryTokenRefreshFunc != nil {
return m.TryTokenRefreshFunc(ctx, usr)
}
return nil
return nil, nil
}

@ -33,8 +33,8 @@ func (s *Service) HasOAuthEntry(context.Context, identity.Requester) (*login.Use
return nil, false, nil
}
func (s *Service) TryTokenRefresh(context.Context, identity.Requester) error {
return nil
func (s *Service) TryTokenRefresh(context.Context, identity.Requester) (*oauth2.Token, error) {
return s.Token, nil
}
func (s *Service) InvalidateOAuthTokens(context.Context, *login.UserAuth) error {

@ -258,9 +258,10 @@ type Cfg struct {
AuthProxy AuthProxySettings
// OAuth
OAuthAutoLogin bool
OAuthCookieMaxAge int
OAuthAllowInsecureEmailLookup bool
OAuthAutoLogin bool
OAuthCookieMaxAge int
OAuthAllowInsecureEmailLookup bool
OAuthRefreshTokenServerLockMinWaitMs int64
JWTAuth AuthJWTSettings
ExtJWTAuth ExtJWTSettings
@ -1603,6 +1604,7 @@ func readAuthSettings(iniFile *ini.File, cfg *Cfg) (err error) {
}
cfg.OAuthCookieMaxAge = auth.Key("oauth_state_cookie_max_age").MustInt(600)
cfg.OAuthRefreshTokenServerLockMinWaitMs = auth.Key("oauth_refresh_token_server_lock_min_wait_ms").MustInt64(1000)
cfg.SignoutRedirectUrl = valueAsString(auth, "signout_redirect_url", "")
// Deprecated

Loading…
Cancel
Save