mirror of https://github.com/grafana/grafana
AuthN: move oauth token hook into session client (#76688)
* Move rotate logic into its own function * Move oauth token sync to session client * Add user to the local cache if refresh tokens are not enabled for the provider so we can skip the check in other requestspull/76761/head
parent
8b16f2aca8
commit
455cede699
@ -1,174 +0,0 @@ |
||||
package sync |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"fmt" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/go-jose/go-jose/v3/jwt" |
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache" |
||||
"github.com/grafana/grafana/pkg/infra/log" |
||||
"github.com/grafana/grafana/pkg/login/social" |
||||
"github.com/grafana/grafana/pkg/services/auth" |
||||
"github.com/grafana/grafana/pkg/services/authn" |
||||
"github.com/grafana/grafana/pkg/services/login" |
||||
"github.com/grafana/grafana/pkg/services/oauthtoken" |
||||
"github.com/grafana/grafana/pkg/services/user" |
||||
) |
||||
|
||||
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync { |
||||
return &OAuthTokenSync{ |
||||
log.New("oauth_token.sync"), |
||||
localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute), |
||||
service, |
||||
sessionService, |
||||
socialService, |
||||
} |
||||
} |
||||
|
||||
type OAuthTokenSync struct { |
||||
log log.Logger |
||||
cache *localcache.CacheService |
||||
service oauthtoken.OAuthTokenService |
||||
sessionService auth.UserTokenService |
||||
socialService social.Service |
||||
} |
||||
|
||||
func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error { |
||||
namespace, id := identity.NamespacedID() |
||||
// only perform oauth token check if identity is a user
|
||||
if namespace != authn.NamespaceUser { |
||||
return nil |
||||
} |
||||
|
||||
// not authenticated through session tokens, so we can skip this hook
|
||||
if identity.SessionToken == nil { |
||||
return nil |
||||
} |
||||
|
||||
// if we recently have performed this it would be cached, so we can skip the hook
|
||||
if _, ok := s.cache.Get(identity.ID); ok { |
||||
return nil |
||||
} |
||||
|
||||
token, exists, _ := s.service.HasOAuthEntry(ctx, &user.SignedInUser{UserID: id}) |
||||
// user is not authenticated through oauth so skip further checks
|
||||
if !exists { |
||||
return nil |
||||
} |
||||
|
||||
idTokenExpiry, err := getIDTokenExpiry(token) |
||||
if err != nil { |
||||
s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err) |
||||
} |
||||
|
||||
// token has no expire time configured, so we don't have to refresh it
|
||||
if token.OAuthExpiry.IsZero() { |
||||
// cache the token check, so we don't perform it on every request
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry)) |
||||
return nil |
||||
} |
||||
|
||||
// get the token's auth provider (f.e. azuread)
|
||||
provider := strings.TrimPrefix(token.AuthModule, "oauth_") |
||||
currentOAuthInfo := s.socialService.GetOAuthInfoProvider(provider) |
||||
if currentOAuthInfo == nil { |
||||
s.log.Warn("OAuth provider not found", "provider", provider) |
||||
return nil |
||||
} |
||||
|
||||
// if refresh token handling is disabled for this provider, we can skip the hook
|
||||
if !currentOAuthInfo.UseRefreshToken { |
||||
return nil |
||||
} |
||||
|
||||
accessTokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta) |
||||
|
||||
hasIdTokenExpired := false |
||||
idTokenExpires := time.Time{} |
||||
|
||||
if !idTokenExpiry.IsZero() { |
||||
idTokenExpires = idTokenExpiry.Round(0).Add(-oauthtoken.ExpiryDelta) |
||||
hasIdTokenExpired = idTokenExpires.Before(time.Now()) |
||||
} |
||||
// token has not expired, so we don't have to refresh it
|
||||
if !accessTokenExpires.Before(time.Now()) && !hasIdTokenExpired { |
||||
// cache the token check, so we don't perform it on every request
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires)) |
||||
return nil |
||||
} |
||||
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second) |
||||
defer cancel() |
||||
|
||||
if err := s.service.TryTokenRefresh(updateCtx, token); err != nil { |
||||
if errors.Is(err, context.Canceled) { |
||||
return nil |
||||
} |
||||
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) { |
||||
s.log.Error("Failed to refresh OAuth access token", "id", identity.ID, "error", err) |
||||
} |
||||
|
||||
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil { |
||||
s.log.Warn("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err) |
||||
} |
||||
|
||||
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil { |
||||
s.log.Warn("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err) |
||||
} |
||||
|
||||
return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
const maxOAuthTokenCacheTTL = 10 * time.Minute |
||||
|
||||
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration { |
||||
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() { |
||||
return maxOAuthTokenCacheTTL |
||||
} |
||||
|
||||
min := func(a, b time.Duration) time.Duration { |
||||
if a <= b { |
||||
return a |
||||
} |
||||
return b |
||||
} |
||||
|
||||
if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() { |
||||
return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL) |
||||
} |
||||
|
||||
if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() { |
||||
return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL) |
||||
} |
||||
|
||||
return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL) |
||||
} |
||||
|
||||
// getIDTokenExpiry extracts the expiry time from the ID token
|
||||
func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) { |
||||
if token.OAuthIdToken == "" { |
||||
return time.Time{}, nil |
||||
} |
||||
|
||||
parsedToken, err := jwt.ParseSigned(token.OAuthIdToken) |
||||
if err != nil { |
||||
return time.Time{}, fmt.Errorf("error parsing id token: %w", err) |
||||
} |
||||
|
||||
type Claims struct { |
||||
Exp int64 `json:"exp"` |
||||
} |
||||
var claims Claims |
||||
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil { |
||||
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err) |
||||
} |
||||
|
||||
return time.Unix(claims.Exp, 0), nil |
||||
} |
||||
@ -1,258 +0,0 @@ |
||||
package sync |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/base64" |
||||
"encoding/json" |
||||
"errors" |
||||
"fmt" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache" |
||||
"github.com/grafana/grafana/pkg/infra/log" |
||||
"github.com/grafana/grafana/pkg/login/social" |
||||
"github.com/grafana/grafana/pkg/login/socialtest" |
||||
"github.com/grafana/grafana/pkg/services/auth" |
||||
"github.com/grafana/grafana/pkg/services/auth/authtest" |
||||
"github.com/grafana/grafana/pkg/services/auth/identity" |
||||
"github.com/grafana/grafana/pkg/services/authn" |
||||
"github.com/grafana/grafana/pkg/services/login" |
||||
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest" |
||||
) |
||||
|
||||
func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) { |
||||
type testCase struct { |
||||
desc string |
||||
identity *authn.Identity |
||||
oauthInfo *social.OAuthInfo |
||||
|
||||
expectedHasEntryToken *login.UserAuth |
||||
expectHasEntryCalled bool |
||||
|
||||
expectedTryRefreshErr error |
||||
expectTryRefreshTokenCalled bool |
||||
|
||||
expectRevokeTokenCalled bool |
||||
expectInvalidateOauthTokensCalled bool |
||||
|
||||
expectedErr error |
||||
} |
||||
|
||||
tests := []testCase{ |
||||
{ |
||||
desc: "should skip sync when identity is not a user", |
||||
identity: &authn.Identity{ID: "service-account:1"}, |
||||
}, |
||||
{ |
||||
desc: "should skip sync when identity is a user but is not authenticated with session token", |
||||
identity: &authn.Identity{ID: "user:1"}, |
||||
}, |
||||
{ |
||||
desc: "should skip sync when user has session but is not authenticated with oauth", |
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, |
||||
expectHasEntryCalled: true, |
||||
}, |
||||
{ |
||||
desc: "should skip sync for when access token don't have expire time", |
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, |
||||
expectHasEntryCalled: true, |
||||
expectedHasEntryToken: &login.UserAuth{}, |
||||
}, |
||||
{ |
||||
desc: "should skip sync when access token has no expired yet", |
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, |
||||
expectHasEntryCalled: true, |
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)}, |
||||
}, |
||||
{ |
||||
desc: "should skip sync when access token has no expired yet", |
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, |
||||
expectHasEntryCalled: true, |
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)}, |
||||
}, |
||||
{ |
||||
desc: "should refresh access token when is has expired", |
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, |
||||
expectHasEntryCalled: true, |
||||
expectTryRefreshTokenCalled: true, |
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)}, |
||||
}, |
||||
{ |
||||
desc: "should invalidate access token and session token if access token can't be refreshed", |
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, |
||||
expectHasEntryCalled: true, |
||||
expectedTryRefreshErr: errors.New("some err"), |
||||
expectTryRefreshTokenCalled: true, |
||||
expectInvalidateOauthTokensCalled: true, |
||||
expectRevokeTokenCalled: true, |
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)}, |
||||
expectedErr: authn.ErrExpiredAccessToken, |
||||
}, { |
||||
desc: "should skip sync when use_refresh_token is disabled", |
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule}, |
||||
expectHasEntryCalled: true, |
||||
expectTryRefreshTokenCalled: false, |
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)}, |
||||
oauthInfo: &social.OAuthInfo{UseRefreshToken: false}, |
||||
}, |
||||
{ |
||||
desc: "should refresh access token when ID token has expired", |
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, |
||||
expectHasEntryCalled: true, |
||||
expectTryRefreshTokenCalled: true, |
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))}, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.desc, func(t *testing.T) { |
||||
var ( |
||||
hasEntryCalled bool |
||||
tryRefreshCalled bool |
||||
invalidateTokensCalled bool |
||||
revokeTokenCalled bool |
||||
) |
||||
|
||||
service := &oauthtokentest.MockOauthTokenService{ |
||||
HasOAuthEntryFunc: func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) { |
||||
hasEntryCalled = true |
||||
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil |
||||
}, |
||||
InvalidateOAuthTokensFunc: func(ctx context.Context, usr *login.UserAuth) error { |
||||
invalidateTokensCalled = true |
||||
return nil |
||||
}, |
||||
TryTokenRefreshFunc: func(ctx context.Context, usr *login.UserAuth) error { |
||||
tryRefreshCalled = true |
||||
return tt.expectedTryRefreshErr |
||||
}, |
||||
} |
||||
|
||||
sessionService := &authtest.FakeUserAuthTokenService{ |
||||
RevokeTokenProvider: func(ctx context.Context, token *auth.UserToken, soft bool) error { |
||||
revokeTokenCalled = true |
||||
return nil |
||||
}, |
||||
} |
||||
|
||||
if tt.oauthInfo == nil { |
||||
tt.oauthInfo = &social.OAuthInfo{ |
||||
UseRefreshToken: true, |
||||
} |
||||
} |
||||
|
||||
socialService := &socialtest.FakeSocialService{ |
||||
ExpectedAuthInfoProvider: tt.oauthInfo, |
||||
} |
||||
|
||||
sync := &OAuthTokenSync{ |
||||
log: log.NewNopLogger(), |
||||
cache: localcache.New(0, 0), |
||||
service: service, |
||||
sessionService: sessionService, |
||||
socialService: socialService, |
||||
} |
||||
|
||||
err := sync.SyncOauthTokenHook(context.Background(), tt.identity, nil) |
||||
assert.ErrorIs(t, err, tt.expectedErr) |
||||
assert.Equal(t, tt.expectHasEntryCalled, hasEntryCalled) |
||||
assert.Equal(t, tt.expectTryRefreshTokenCalled, tryRefreshCalled) |
||||
assert.Equal(t, tt.expectInvalidateOauthTokensCalled, invalidateTokensCalled) |
||||
assert.Equal(t, tt.expectRevokeTokenCalled, revokeTokenCalled) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
// fakeIDToken is used to create a fake invalid token to verify expiry logic
|
||||
func fakeIDToken(t *testing.T, expiryDate time.Time) string { |
||||
type Header struct { |
||||
Kid string `json:"kid"` |
||||
Alg string `json:"alg"` |
||||
} |
||||
type Payload struct { |
||||
Iss string `json:"iss"` |
||||
Sub string `json:"sub"` |
||||
Exp int64 `json:"exp"` |
||||
} |
||||
|
||||
header, err := json.Marshal(Header{Kid: "123", Alg: "none"}) |
||||
require.NoError(t, err) |
||||
u := expiryDate.UTC().Unix() |
||||
payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u}) |
||||
require.NoError(t, err) |
||||
|
||||
fakeSignature := []byte("6ICJm") |
||||
return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature)) |
||||
} |
||||
|
||||
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) { |
||||
defaultTime := time.Now() |
||||
tests := []struct { |
||||
name string |
||||
accessTokenExpiry time.Time |
||||
idTokenExpiry time.Time |
||||
want time.Duration |
||||
}{ |
||||
{ |
||||
name: "should return maxOAuthTokenCacheTTL when no expiry is given", |
||||
accessTokenExpiry: time.Time{}, |
||||
idTokenExpiry: time.Time{}, |
||||
|
||||
want: maxOAuthTokenCacheTTL, |
||||
}, |
||||
{ |
||||
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl", |
||||
accessTokenExpiry: time.Time{}, |
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), |
||||
|
||||
want: maxOAuthTokenCacheTTL, |
||||
}, |
||||
{ |
||||
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl", |
||||
accessTokenExpiry: time.Time{}, |
||||
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL), |
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)), |
||||
}, |
||||
{ |
||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given", |
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), |
||||
idTokenExpiry: time.Time{}, |
||||
want: maxOAuthTokenCacheTTL, |
||||
}, |
||||
{ |
||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given", |
||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL), |
||||
idTokenExpiry: time.Time{}, |
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)), |
||||
}, |
||||
{ |
||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry", |
||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL), |
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), |
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)), |
||||
}, |
||||
{ |
||||
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry", |
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), |
||||
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL), |
||||
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)), |
||||
}, |
||||
{ |
||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl", |
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), |
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), |
||||
want: maxOAuthTokenCacheTTL, |
||||
}, |
||||
} |
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry) |
||||
|
||||
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second)) |
||||
}) |
||||
} |
||||
} |
||||
Loading…
Reference in new issue