From 1061e4712f1d2454e75edf862de259fad4bd8980 Mon Sep 17 00:00:00 2001 From: Misi Date: Thu, 21 Nov 2024 14:36:28 +0100 Subject: [PATCH] OAuth: Refactor OAuthToken service to make it easier to use the new external sessions (#96667) * Refactor OAuthToken service * introduce user.SessionAwareIdentityRequester * replace login.UserAuth parameters with user.SessionAwareIdentityRequester * Add nosec G101 to fake ID tokens * Opt 2, min changes * Revert a change to the current version --- pkg/services/authn/clients/oauth.go | 14 +- pkg/services/authn/clients/oauth_test.go | 60 ++-- pkg/services/oauthtoken/oauth_token.go | 104 ++++-- pkg/services/oauthtoken/oauth_token_test.go | 330 +++++++----------- .../oauthtoken/oauthtokentest/mock.go | 4 +- .../oauthtokentest/oauthtokentest.go | 2 +- 6 files changed, 239 insertions(+), 275 deletions(-) diff --git a/pkg/services/authn/clients/oauth.go b/pkg/services/authn/clients/oauth.go index 3c09fc46ee2..47601d4b548 100644 --- a/pkg/services/authn/clients/oauth.go +++ b/pkg/services/authn/clients/oauth.go @@ -266,23 +266,21 @@ func (c *OAuth) Logout(ctx context.Context, user identity.Requester) (*authn.Red return nil, false } - if err := c.oauthService.InvalidateOAuthTokens(ctx, &login.UserAuth{ - UserId: userID, - AuthId: user.GetAuthID(), - AuthModule: user.GetAuthenticatedBy(), - }); err != nil { - c.log.FromContext(ctx).Error("Failed to invalidate tokens", "id", user.GetID(), "error", err) + ctxLogger := c.log.FromContext(ctx).New("userID", userID) + + if err := c.oauthService.InvalidateOAuthTokens(ctx, user); err != nil { + ctxLogger.Error("Failed to invalidate tokens", "error", err) } oauthCfg := c.socialService.GetOAuthInfoProvider(c.providerName) if !oauthCfg.Enabled { - c.log.FromContext(ctx).Debug("OAuth client is disabled") + ctxLogger.Debug("OAuth client is disabled") return nil, false } redirectURL := getOAuthSignoutRedirectURL(c.cfg, oauthCfg) if redirectURL == "" { - c.log.FromContext(ctx).Debug("No signout redirect url configured") + ctxLogger.Debug("No signout redirect url configured") return nil, false } diff --git a/pkg/services/authn/clients/oauth_test.go b/pkg/services/authn/clients/oauth_test.go index 4f5347428e1..288dc31db91 100644 --- a/pkg/services/authn/clients/oauth_test.go +++ b/pkg/services/authn/clients/oauth_test.go @@ -71,10 +71,11 @@ func TestOAuth_Authenticate(t *testing.T) { }, { desc: "should return error when state from ipd does not match stored state", - req: &authn.Request{HTTPRequest: &http.Request{ - Header: map[string][]string{}, - URL: mustParseURL("http://grafana.com/?state=some-other-state"), - }, + req: &authn.Request{ + HTTPRequest: &http.Request{ + Header: map[string][]string{}, + URL: mustParseURL("http://grafana.com/?state=some-other-state"), + }, }, oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true}, addStateCookie: true, @@ -83,10 +84,11 @@ func TestOAuth_Authenticate(t *testing.T) { }, { desc: "should return error when pkce is configured but the cookie is not present", - req: &authn.Request{HTTPRequest: &http.Request{ - Header: map[string][]string{}, - URL: mustParseURL("http://grafana.com/?state=some-state"), - }, + req: &authn.Request{ + HTTPRequest: &http.Request{ + Header: map[string][]string{}, + URL: mustParseURL("http://grafana.com/?state=some-state"), + }, }, oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true}, addStateCookie: true, @@ -95,10 +97,11 @@ func TestOAuth_Authenticate(t *testing.T) { }, { desc: "should return error when email is empty", - req: &authn.Request{HTTPRequest: &http.Request{ - Header: map[string][]string{}, - URL: mustParseURL("http://grafana.com/?state=some-state"), - }, + req: &authn.Request{ + HTTPRequest: &http.Request{ + Header: map[string][]string{}, + URL: mustParseURL("http://grafana.com/?state=some-state"), + }, }, oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true}, addStateCookie: true, @@ -110,10 +113,11 @@ func TestOAuth_Authenticate(t *testing.T) { }, { desc: "should return error when email is not allowed", - req: &authn.Request{HTTPRequest: &http.Request{ - Header: map[string][]string{}, - URL: mustParseURL("http://grafana.com/?state=some-state"), - }, + req: &authn.Request{ + HTTPRequest: &http.Request{ + Header: map[string][]string{}, + URL: mustParseURL("http://grafana.com/?state=some-state"), + }, }, oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true}, addStateCookie: true, @@ -144,10 +148,11 @@ func TestOAuth_Authenticate(t *testing.T) { }, { desc: "should return identity for valid request", - req: &authn.Request{HTTPRequest: &http.Request{ - Header: map[string][]string{}, - URL: mustParseURL("http://grafana.com/?state=some-state"), - }, + req: &authn.Request{ + HTTPRequest: &http.Request{ + Header: map[string][]string{}, + URL: mustParseURL("http://grafana.com/?state=some-state"), + }, }, oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true}, addStateCookie: true, @@ -182,10 +187,11 @@ func TestOAuth_Authenticate(t *testing.T) { }, { desc: "should return identity for valid request - and lookup user by email", - req: &authn.Request{HTTPRequest: &http.Request{ - Header: map[string][]string{}, - URL: mustParseURL("http://grafana.com/?state=some-state"), - }, + req: &authn.Request{ + HTTPRequest: &http.Request{ + Header: map[string][]string{}, + URL: mustParseURL("http://grafana.com/?state=some-state"), + }, }, oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true}, allowInsecureTakeover: true, @@ -354,9 +360,7 @@ func TestOAuth_RedirectURL(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - var ( - authCodeUrlCalled = false - ) + authCodeUrlCalled := false fakeSocialSvc := &socialtest.FakeSocialService{ ExpectedAuthInfoProvider: tt.oauthCfg, @@ -475,7 +479,7 @@ func TestOAuth_Logout(t *testing.T) { "id_token": "some.id.token", }) }, - InvalidateOAuthTokensFunc: func(_ context.Context, _ *login.UserAuth) error { + InvalidateOAuthTokensFunc: func(_ context.Context, _ identity.Requester) error { invalidateTokenCalled = true return nil }, diff --git a/pkg/services/oauthtoken/oauth_token.go b/pkg/services/oauthtoken/oauth_token.go index f0046599ae7..44fe279f630 100644 --- a/pkg/services/oauthtoken/oauth_token.go +++ b/pkg/services/oauthtoken/oauth_token.go @@ -51,11 +51,12 @@ type OAuthTokenService interface { IsOAuthPassThruEnabled(*datasources.DataSource) bool HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error) TryTokenRefresh(context.Context, identity.Requester) (*oauth2.Token, error) - InvalidateOAuthTokens(context.Context, *login.UserAuth) error + InvalidateOAuthTokens(context.Context, identity.Requester) error } func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer, - serverLockService *serverlock.ServerLockService, tracer tracing.Tracer) *Service { + serverLockService *serverlock.ServerLockService, tracer tracing.Tracer, +) *Service { return &Service{ AuthInfoService: authInfoService, Cfg: cfg, @@ -71,6 +72,27 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Request ctx, span := o.tracer.Start(ctx, "oauthtoken.GetCurrentOAuthToken") defer span.End() + ctxLogger := logger.FromContext(ctx) + + if usr == nil || usr.IsNil() { + ctxLogger.Warn("Can only get OAuth tokens for existing users", "user", "nil") + // Not user, no token. + return nil + } + + if !usr.IsIdentityType(claims.TypeUser) { + ctxLogger.Warn("Can only get OAuth tokens for users", "id", usr.GetID()) + return nil + } + + userID, err := usr.GetInternalID() + if err != nil { + logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err) + return nil + } + + ctxLogger = ctxLogger.New("userID", userID) + authInfo, ok, _ := o.HasOAuthEntry(ctx, usr) if !ok { return nil @@ -84,7 +106,9 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Request return nil } - persistedToken, refreshNeeded := needTokenRefresh(authInfo) + persistedToken := buildOAuthTokenFromAuthInfo(authInfo) + + refreshNeeded := needTokenRefresh(ctx, persistedToken) if !refreshNeeded { return persistedToken } @@ -226,14 +250,16 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester) ( return } - storedToken, needRefresh := needTokenRefresh(authInfo) + persistedToken := buildOAuthTokenFromAuthInfo(authInfo) + + needRefresh := needTokenRefresh(ctx, persistedToken) if !needRefresh { // Set the token which is returned by the outer function in case there's no need to refresh the token - newToken = storedToken + newToken = persistedToken return } - newToken, cmdErr = o.tryGetOrRefreshOAuthToken(ctx, authInfo) + newToken, cmdErr = o.tryGetOrRefreshOAuthToken(ctx, persistedToken, usr) }, retryOpt) if lockErr != nil { ctxLogger.Error("Failed to obtain token refresh lock", "error", lockErr) @@ -280,11 +306,17 @@ func checkOAuthRefreshToken(authInfo *login.UserAuth) error { } // InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero -func (o *Service) InvalidateOAuthTokens(ctx context.Context, authInfo *login.UserAuth) error { +func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester) error { + userID, err := usr.GetInternalID() + if err != nil { + logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err) + return err + } + return o.AuthInfoService.UpdateAuthInfo(ctx, &login.UpdateAuthInfoCommand{ - UserId: authInfo.UserId, - AuthModule: authInfo.AuthModule, - AuthId: authInfo.AuthId, + UserId: userID, + AuthModule: usr.GetAuthenticatedBy(), + AuthId: usr.GetAuthID(), OAuthToken: &oauth2.Token{ AccessToken: "", RefreshToken: "", @@ -293,23 +325,31 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, authInfo *login.Use }) } -func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, authInfo *login.UserAuth) (*oauth2.Token, error) { - ctx, span := o.tracer.Start(ctx, "oauthtoken.tryGetOrRefreshOAuthToken", - trace.WithAttributes(attribute.Int64("userID", authInfo.UserId))) +func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken *oauth2.Token, usr identity.Requester) (*oauth2.Token, error) { + ctx, span := o.tracer.Start(ctx, "oauthtoken.tryGetOrRefreshOAuthToken") defer span.End() - ctxLogger := logger.FromContext(ctx).New("userID", authInfo.UserId) - - if err := checkOAuthRefreshToken(authInfo); err != nil { + userID, err := usr.GetInternalID() + if err != nil { + logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err) return nil, err } - persistedToken, refreshNeeded := needTokenRefresh(authInfo) + span.SetAttributes(attribute.Int64("userID", userID)) + + ctxLogger := logger.FromContext(ctx).New("userID", userID) + + if persistedToken.RefreshToken == "" { + ctxLogger.Warn("No refresh token available", "authmodule", usr.GetAuthenticatedBy()) + return nil, ErrNoRefreshTokenFound + } + + refreshNeeded := needTokenRefresh(ctx, persistedToken) if !refreshNeeded { return persistedToken, nil } - authProvider := authInfo.AuthModule + authProvider := usr.GetAuthenticatedBy() connect, err := o.SocialService.GetConnector(authProvider) if err != nil { ctxLogger.Error("Failed to get oauth connector", "provider", authProvider, "error", err) @@ -331,11 +371,11 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, authInfo *login if err != nil { ctxLogger.Error("Failed to retrieve oauth access token", - "provider", authInfo.AuthModule, "userId", authInfo.UserId, "error", err) + "provider", usr.GetAuthenticatedBy(), "error", err) // token refresh failed, invalidate the old token - if err := o.InvalidateOAuthTokens(ctx, authInfo); err != nil { - ctxLogger.Warn("Failed to invalidate OAuth tokens", "id", authInfo.Id, "error", err) + if err := o.InvalidateOAuthTokens(ctx, usr); err != nil { + ctxLogger.Warn("Failed to invalidate OAuth tokens", "authID", usr.GetAuthID(), "error", err) } return nil, err @@ -344,15 +384,15 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, authInfo *login // If the tokens are not the same, update the entry in the DB if !tokensEq(persistedToken, token) { updateAuthCommand := &login.UpdateAuthInfoCommand{ - UserId: authInfo.UserId, - AuthModule: authInfo.AuthModule, - AuthId: authInfo.AuthId, + UserId: userID, + AuthModule: usr.GetAuthenticatedBy(), + AuthId: usr.GetAuthID(), OAuthToken: token, } if o.Cfg.Env == setting.Dev { ctxLogger.Debug("Oauth got token", - "auth_module", authInfo.AuthModule, + "auth_module", usr.GetAuthID(), "expiry", fmt.Sprintf("%v", token.Expiry), "access_token", fmt.Sprintf("%v", token.AccessToken), "refresh_token", fmt.Sprintf("%v", token.RefreshToken), @@ -360,7 +400,7 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, authInfo *login } if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil { - ctxLogger.Error("Failed to update auth info during token refresh", "userId", authInfo.UserId, "error", err) + ctxLogger.Error("Failed to update auth info during token refresh", "authID", usr.GetAuthID(), "error", err) return token, err } ctxLogger.Debug("Updated oauth info for user") @@ -401,14 +441,14 @@ func tokensEq(t1, t2 *oauth2.Token) bool { t1IdToken == t2IdToken } -func needTokenRefresh(authInfo *login.UserAuth) (*oauth2.Token, bool) { +func needTokenRefresh(ctx context.Context, persistedToken *oauth2.Token) bool { var hasAccessTokenExpired, hasIdTokenExpired bool - persistedToken := buildOAuthTokenFromAuthInfo(authInfo) + ctxLogger := logger.FromContext(ctx) idTokenExp, err := GetIDTokenExpiry(persistedToken) if err != nil { - logger.Warn("Could not get ID Token expiry", "error", err) + ctxLogger.Warn("Could not get ID Token expiry", "error", err) } if !persistedToken.Expiry.IsZero() { _, hasAccessTokenExpired = getExpiryWithSkew(persistedToken.Expiry) @@ -417,14 +457,14 @@ func needTokenRefresh(authInfo *login.UserAuth) (*oauth2.Token, bool) { _, hasIdTokenExpired = getExpiryWithSkew(idTokenExp) } if !hasAccessTokenExpired && !hasIdTokenExpired { - logger.Debug("Neither access nor id token have expired yet", "userID", authInfo.UserId) - return persistedToken, false + ctxLogger.Debug("Neither access nor id token have expired yet") + return false } if hasIdTokenExpired { // Force refreshing token when id token is expired persistedToken.AccessToken = "" } - return persistedToken, true + return true } // GetIDTokenExpiry extracts the expiry time from the ID token diff --git a/pkg/services/oauthtoken/oauth_token_test.go b/pkg/services/oauthtoken/oauth_token_test.go index a872bc9e0e9..203784f503f 100644 --- a/pkg/services/oauthtoken/oauth_token_test.go +++ b/pkg/services/oauthtoken/oauth_token_test.go @@ -31,7 +31,9 @@ import ( "github.com/grafana/grafana/pkg/tests/testsuite" ) -var EXPIRED_JWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U" +const EXPIRED_ID_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6InlvdXItY2xpZW50LWlkIiwiZXhwIjoxNjAwMDAwMDAwLCJpYXQiOjE2MDAwMDAwMDAsIm5hbWUiOiJKb2huIERvZSIsImVtYWlsIjoiam9obkBleGFtcGxlLmNvbSJ9.c2lnbmF0dXJl" // #nosec G101 not a hardcoded credential + +const UNEXPIRED_ID_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6InlvdXItY2xpZW50LWlkIiwiZXhwIjo0ODg1NjA4MDAwLCJpYXQiOjE2ODU2MDgwMDAsIm5hbWUiOiJKb2huIERvZSIsImVtYWlsIjoiam9obkBleGFtcGxlLmNvbSJ9.c2lnbmF0dXJl" // #nosec G101 not a hardcoded credential func TestMain(m *testing.M) { testsuite.Run(m) @@ -162,19 +164,44 @@ func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.Delet } func TestService_TryTokenRefresh(t *testing.T) { + unexpiredToken := &oauth2.Token{ + AccessToken: "testaccess", + RefreshToken: "testrefresh", + Expiry: time.Now().Add(time.Hour), + TokenType: "Bearer", + } + unexpiredTokenWithIDToken := unexpiredToken.WithExtra(map[string]interface{}{ + "id_token": UNEXPIRED_ID_TOKEN, + }) + + expiredToken := &oauth2.Token{ + AccessToken: "testaccess", + RefreshToken: "testrefresh", + Expiry: time.Now().Add(-time.Hour), + TokenType: "Bearer", + } + type environment struct { authInfoService *authinfotest.FakeService serverLock *serverlock.ServerLockService - identity identity.Requester socialConnector *socialtest.MockSocialConnector socialService *socialtest.FakeSocialService service *Service } + type testCase struct { - desc string - expectedErr error - setup func(env *environment) + desc string + identity identity.Requester + setup func(env *environment) + expectedToken *oauth2.Token + expectedErr error + } + + userIdentity := &authn.Identity{ + AuthenticatedBy: login.GenericOAuthModule, + ID: "1234", + Type: claims.TypeUser, } tests := []testCase{ @@ -182,114 +209,111 @@ func TestService_TryTokenRefresh(t *testing.T) { desc: "should skip sync when identity is nil", }, { - desc: "should skip sync when identity is not a user", - setup: func(env *environment) { - env.identity = &authn.Identity{ID: "1", Type: claims.TypeServiceAccount} - }, + desc: "should skip sync when identity is not a user", + identity: &authn.Identity{ID: "1", Type: claims.TypeServiceAccount}, }, { - desc: "should skip token refresh and return nil if namespace and id cannot be converted to user ID", - setup: func(env *environment) { - env.identity = &authn.Identity{ID: "invalid", Type: claims.TypeUser} - }, + desc: "should skip token refresh and return nil if namespace and id cannot be converted to user ID", + identity: &authn.Identity{ID: "invalid", Type: claims.TypeUser}, }, { - desc: "should skip token refresh since the token is still valid", + desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned", + identity: userIdentity, setup: func(env *environment) { - token := &oauth2.Token{ - AccessToken: "testaccess", - RefreshToken: "testrefresh", - Expiry: time.Now().Add(time.Hour), - TokenType: "Bearer", - } - - env.authInfoService.ExpectedUserAuth = &login.UserAuth{ - AuthModule: login.GenericOAuthModule, - OAuthAccessToken: token.AccessToken, - OAuthRefreshToken: token.RefreshToken, - OAuthExpiry: token.Expiry, - OAuthTokenType: token.TokenType, - } - - env.identity = &authn.Identity{ - AuthenticatedBy: login.GenericOAuthModule, - ID: "1234", - Type: claims.TypeUser, - } + env.authInfoService.ExpectedError = errors.New("some error") }, }, { - desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned", + desc: "should skip token refresh if the user doesn't have an oauth entry", + identity: userIdentity, setup: func(env *environment) { - env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser} - env.authInfoService.ExpectedError = errors.New("some error") + env.authInfoService.ExpectedUserAuth = &login.UserAuth{ + AuthModule: login.SAMLAuthModule, + } }, }, { - desc: "should skip token refresh if the user doesn't have an oauth entry", + desc: "should skip token refresh when no oauth provider was found", + identity: userIdentity, setup: func(env *environment) { - env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser} env.authInfoService.ExpectedUserAuth = &login.UserAuth{ - AuthModule: login.SAMLAuthModule, + AuthModule: login.GenericOAuthModule, } }, }, { - desc: "should do token refresh if access token or id token have not expired yet", + desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)", + identity: userIdentity, setup: func(env *environment) { - env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser} env.authInfoService.ExpectedUserAuth = &login.UserAuth{ AuthModule: login.GenericOAuthModule, } + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: false, + } }, }, { - desc: "should skip token refresh when no oauth provider was found", + desc: "should skip token refresh when the token is still valid and no id token is present", + identity: userIdentity, setup: func(env *environment) { - env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser} env.authInfoService.ExpectedUserAuth = &login.UserAuth{ - AuthModule: login.GenericOAuthModule, - OAuthIdToken: EXPIRED_JWT, + AuthModule: login.GenericOAuthModule, + OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken, + OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken, + OAuthExpiry: unexpiredTokenWithIDToken.Expiry, + OAuthTokenType: unexpiredTokenWithIDToken.TokenType, + } + + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, } }, + expectedToken: unexpiredToken, }, { - desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)", + desc: "should not refresh the tokens if access token or id token have not expired yet", + identity: userIdentity, setup: func(env *environment) { - env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser} env.authInfoService.ExpectedUserAuth = &login.UserAuth{ - AuthModule: login.GenericOAuthModule, - OAuthIdToken: EXPIRED_JWT, + AuthModule: login.GenericOAuthModule, + OAuthIdToken: UNEXPIRED_ID_TOKEN, + OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken, + OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken, + OAuthExpiry: unexpiredTokenWithIDToken.Expiry, + OAuthTokenType: unexpiredTokenWithIDToken.TokenType, } + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ - UseRefreshToken: false, + UseRefreshToken: true, } }, + expectedToken: unexpiredTokenWithIDToken, }, { - desc: "should skip token refresh when there is no refresh token", + desc: "should skip token refresh when there is no refresh token", + identity: userIdentity, setup: func(env *environment) { - env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser} env.authInfoService.ExpectedUserAuth = &login.UserAuth{ AuthModule: login.GenericOAuthModule, - OAuthIdToken: EXPIRED_JWT, + OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken, OAuthRefreshToken: "", + OAuthExpiry: unexpiredTokenWithIDToken.Expiry, } env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ UseRefreshToken: true, } }, + expectedToken: &oauth2.Token{ + AccessToken: unexpiredTokenWithIDToken.AccessToken, + RefreshToken: "", + Expiry: unexpiredTokenWithIDToken.Expiry, + }, }, { - desc: "should do token refresh when the token is expired", + desc: "should do token refresh when the token is expired", + identity: userIdentity, setup: func(env *environment) { - token := &oauth2.Token{ - AccessToken: "testaccess", - RefreshToken: "testrefresh", - Expiry: time.Now().Add(-time.Hour), - TokenType: "Bearer", - } - env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule} env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ UseRefreshToken: true, } @@ -297,24 +321,20 @@ func TestService_TryTokenRefresh(t *testing.T) { AuthModule: login.GenericOAuthModule, AuthId: "subject", UserId: 1, - OAuthAccessToken: token.AccessToken, - OAuthRefreshToken: token.RefreshToken, - OAuthExpiry: token.Expiry, - OAuthTokenType: token.TokenType, + OAuthAccessToken: expiredToken.AccessToken, + OAuthRefreshToken: expiredToken.RefreshToken, + OAuthExpiry: expiredToken.Expiry, + OAuthTokenType: expiredToken.TokenType, + OAuthIdToken: EXPIRED_ID_TOKEN, } - env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once() + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() }, + expectedToken: unexpiredTokenWithIDToken, }, { - desc: "should refresh token when the id token is expired", + desc: "should refresh token when the id token is expired", + identity: &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule}, setup: func(env *environment) { - token := &oauth2.Token{ - AccessToken: "testaccess", - RefreshToken: "testrefresh", - Expiry: time.Now().Add(time.Hour), - TokenType: "Bearer", - } - env.identity = &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule} env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ UseRefreshToken: true, } @@ -322,19 +342,20 @@ func TestService_TryTokenRefresh(t *testing.T) { AuthModule: login.GenericOAuthModule, AuthId: "subject", UserId: 1, - OAuthAccessToken: token.AccessToken, - OAuthRefreshToken: token.RefreshToken, - OAuthExpiry: token.Expiry, - OAuthTokenType: token.TokenType, - OAuthIdToken: EXPIRED_JWT, + OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken, + OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken, + OAuthExpiry: unexpiredTokenWithIDToken.Expiry, + OAuthTokenType: unexpiredTokenWithIDToken.TokenType, + OAuthIdToken: EXPIRED_ID_TOKEN, } - env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once() + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() }, + expectedToken: unexpiredTokenWithIDToken, }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - socialConnector := &socialtest.MockSocialConnector{} + socialConnector := socialtest.NewMockSocialConnector(t) store := db.InitTestDB(t) @@ -361,11 +382,27 @@ func TestService_TryTokenRefresh(t *testing.T) { ) // token refresh - _, err := env.service.TryTokenRefresh(context.Background(), env.identity) + actualToken, err := env.service.TryTokenRefresh(context.Background(), tt.identity) - // test and validations - assert.ErrorIs(t, err, tt.expectedErr) - socialConnector.AssertExpectations(t) + if tt.expectedErr != nil { + assert.ErrorIs(t, err, tt.expectedErr) + return + } + + if tt.expectedToken == nil { + assert.Nil(t, actualToken) + return + } + + assert.Equal(t, tt.expectedToken.AccessToken, actualToken.AccessToken) + assert.Equal(t, tt.expectedToken.RefreshToken, actualToken.RefreshToken) + assert.Equal(t, tt.expectedToken.Expiry, actualToken.Expiry) + assert.Equal(t, tt.expectedToken.TokenType, actualToken.TokenType) + if tt.expectedToken.Extra("id_token") != nil { + assert.Equal(t, tt.expectedToken.Extra("id_token").(string), actualToken.Extra("id_token").(string)) + } else { + assert.Nil(t, actualToken.Extra("id_token")) + } }) } } @@ -392,7 +429,7 @@ func TestOAuthTokenSync_needTokenRefresh(t *testing.T) { { name: "should flag token refresh with id token is expired", usr: &login.UserAuth{ - OAuthIdToken: EXPIRED_JWT, + OAuthIdToken: EXPIRED_ID_TOKEN, }, expectedTokenRefreshFlag: true, expectedTokenDuration: time.Second, @@ -408,125 +445,10 @@ func TestOAuthTokenSync_needTokenRefresh(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - token, needsTokenRefresh := needTokenRefresh(tt.usr) + token := buildOAuthTokenFromAuthInfo(tt.usr) + needsTokenRefresh := needTokenRefresh(context.Background(), token) - assert.NotNil(t, token) assert.Equal(t, tt.expectedTokenRefreshFlag, needsTokenRefresh) }) } } - -func TestOAuthTokenSync_tryGetOrRefreshOAuthToken(t *testing.T) { - timeNow := time.Now() - token := &oauth2.Token{ - AccessToken: "oauth_access_token", - RefreshToken: "refresh_token_found", - Expiry: timeNow, - TokenType: "Bearer", - } - type environment struct { - authInfoService *authinfotest.FakeService - serverLock *serverlock.ServerLockService - socialConnector *socialtest.MockSocialConnector - socialService *socialtest.FakeSocialService - - service *Service - } - tests := []struct { - desc string - expectedErr error - expectedToken *oauth2.Token - usr *login.UserAuth - setup func(env *environment) - }{ - { - desc: "should return ErrNotAnOAuthProvider error when the user is not an oauth provider", - usr: &login.UserAuth{ - UserId: int64(1234), - AuthModule: login.SAMLAuthModule, - }, - expectedErr: ErrNotAnOAuthProvider, - }, - { - desc: "should return ErrNoRefreshTokenFound error when the no refresh token was found", - usr: &login.UserAuth{ - UserId: int64(1234), - AuthModule: login.GenericOAuthModule, - }, - expectedErr: ErrNoRefreshTokenFound, - }, - { - desc: "should not refresh token if the token is not expired", - usr: &login.UserAuth{ - UserId: int64(1234), - AuthModule: login.GenericOAuthModule, - OAuthAccessToken: token.AccessToken, - OAuthRefreshToken: token.RefreshToken, - OAuthExpiry: timeNow.Add(time.Hour), - OAuthTokenType: "Bearer", - }, - expectedToken: &oauth2.Token{ - AccessToken: token.AccessToken, - RefreshToken: token.RefreshToken, - Expiry: timeNow.Add(time.Hour), - TokenType: "Bearer", - }, - }, - { - desc: "should update saved token if the user auth has new access/refresh tokens", - usr: &login.UserAuth{ - UserId: int64(1234), - AuthModule: login.GenericOAuthModule, - OAuthAccessToken: "new_oauth_access_token", - OAuthRefreshToken: "new_refresh_token_found", - OAuthExpiry: timeNow, - }, - expectedToken: &oauth2.Token{ - AccessToken: "oauth_access_token", - RefreshToken: "refresh_token_found", - Expiry: timeNow, - TokenType: "Bearer", - }, - setup: func(env *environment) { - env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once() - }, - }, - } - for _, tt := range tests { - t.Run(tt.desc, func(t *testing.T) { - socialConnector := &socialtest.MockSocialConnector{} - - store := db.InitTestDB(t) - - env := environment{ - authInfoService: &authinfotest.FakeService{}, - serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()), - socialConnector: socialConnector, - socialService: &socialtest.FakeSocialService{ - ExpectedConnector: socialConnector, - }, - } - - if tt.setup != nil { - tt.setup(&env) - } - - env.service = ProvideService( - env.socialService, - env.authInfoService, - setting.NewCfg(), - prometheus.NewRegistry(), - env.serverLock, - tracing.InitializeTracerForTest(), - ) - - token, err := env.service.tryGetOrRefreshOAuthToken(context.Background(), tt.usr) - - if tt.expectedToken != nil { - assert.Equal(t, tt.expectedToken, token) - } - assert.ErrorIs(t, tt.expectedErr, err) - socialConnector.AssertExpectations(t) - }) - } -} diff --git a/pkg/services/oauthtoken/oauthtokentest/mock.go b/pkg/services/oauthtoken/oauthtokentest/mock.go index c930a42b79f..7c473ae11f0 100644 --- a/pkg/services/oauthtoken/oauthtokentest/mock.go +++ b/pkg/services/oauthtoken/oauthtokentest/mock.go @@ -14,7 +14,7 @@ type MockOauthTokenService struct { GetCurrentOauthTokenFunc func(ctx context.Context, usr identity.Requester) *oauth2.Token IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool HasOAuthEntryFunc func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) - InvalidateOAuthTokensFunc func(ctx context.Context, usr *login.UserAuth) error + InvalidateOAuthTokensFunc func(ctx context.Context, usr identity.Requester) error TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester) (*oauth2.Token, error) } @@ -39,7 +39,7 @@ func (m *MockOauthTokenService) HasOAuthEntry(ctx context.Context, usr identity. return nil, false, nil } -func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *login.UserAuth) error { +func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester) error { if m.InvalidateOAuthTokensFunc != nil { return m.InvalidateOAuthTokensFunc(ctx, usr) } diff --git a/pkg/services/oauthtoken/oauthtokentest/oauthtokentest.go b/pkg/services/oauthtoken/oauthtokentest/oauthtokentest.go index b2bd2be8427..f626683c5b3 100644 --- a/pkg/services/oauthtoken/oauthtokentest/oauthtokentest.go +++ b/pkg/services/oauthtoken/oauthtokentest/oauthtokentest.go @@ -37,6 +37,6 @@ func (s *Service) TryTokenRefresh(context.Context, identity.Requester) (*oauth2. return s.Token, nil } -func (s *Service) InvalidateOAuthTokens(context.Context, *login.UserAuth) error { +func (s *Service) InvalidateOAuthTokens(context.Context, identity.Requester) error { return nil }