The open and composable observability and data visualization platform. Visualize metrics, logs, and traces from multiple sources like Prometheus, Loki, Elasticsearch, InfluxDB, Postgres and many more.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
grafana/pkg/services/oauthtoken/oauth_token_test.go

610 lines
20 KiB

package oauthtoken
import (
"context"
"errors"
"reflect"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/socialtest"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfoimpl"
"github.com/grafana/grafana/pkg/services/login/authinfotest"
"github.com/grafana/grafana/pkg/services/secrets/fakes"
secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
)
var EXPIRED_JWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U"
func TestService_HasOAuthEntry(t *testing.T) {
testCases := []struct {
name string
user *user.SignedInUser
want *login.UserAuth
wantExist bool
wantErr bool
err error
getAuthInfoErr error
getAuthInfoUser login.UserAuth
}{
{
name: "returns false without an error in case user is nil",
user: nil,
want: nil,
wantExist: false,
wantErr: false,
},
{
name: "returns false and an error in case GetAuthInfo returns an error",
user: &user.SignedInUser{UserID: 1},
want: nil,
wantExist: false,
wantErr: true,
getAuthInfoErr: errors.New("error"),
},
{
name: "returns false without an error in case auth entry is not found",
user: &user.SignedInUser{UserID: 1},
want: nil,
wantExist: false,
wantErr: false,
getAuthInfoErr: user.ErrUserNotFound,
},
{
name: "returns false without an error in case the auth entry is not oauth",
user: &user.SignedInUser{UserID: 1},
want: nil,
wantExist: false,
wantErr: false,
getAuthInfoUser: login.UserAuth{AuthModule: "auth_saml"},
},
{
name: "returns true when the auth entry is found",
user: &user.SignedInUser{UserID: 1},
want: &login.UserAuth{AuthModule: login.GenericOAuthModule},
wantExist: true,
wantErr: false,
getAuthInfoUser: login.UserAuth{AuthModule: login.GenericOAuthModule},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
srv, authInfoStore, _ := setupOAuthTokenService(t)
authInfoStore.ExpectedOAuth = &tc.getAuthInfoUser
authInfoStore.ExpectedError = tc.getAuthInfoErr
entry, exists, err := srv.HasOAuthEntry(context.Background(), tc.user)
if tc.wantErr {
assert.Error(t, err)
}
if tc.want != nil {
assert.True(t, reflect.DeepEqual(tc.want, entry))
}
assert.Equal(t, tc.wantExist, exists)
})
}
}
func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *socialtest.MockSocialConnector) {
t.Helper()
socialConnector := &socialtest.MockSocialConnector{}
socialService := &socialtest.FakeSocialService{
ExpectedConnector: socialConnector,
ExpectedAuthInfoProvider: &social.OAuthInfo{
UseRefreshToken: true,
},
}
authInfoStore := &FakeAuthInfoStore{ExpectedOAuth: &login.UserAuth{}}
authInfoService := authinfoimpl.ProvideService(authInfoStore, remotecache.NewFakeCacheStorage(), secretsManager.SetupTestService(t, fakes.NewFakeSecretsStore()))
return &Service{
Cfg: setting.NewCfg(),
SocialService: socialService,
AuthInfoService: authInfoService,
singleFlightGroup: &singleflight.Group{},
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
}, authInfoStore, socialConnector
}
type FakeAuthInfoStore struct {
login.Store
ExpectedError error
ExpectedOAuth *login.UserAuth
}
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
return f.ExpectedOAuth, f.ExpectedError
}
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error {
f.ExpectedOAuth.OAuthAccessToken = cmd.OAuthToken.AccessToken
f.ExpectedOAuth.OAuthExpiry = cmd.OAuthToken.Expiry
f.ExpectedOAuth.OAuthTokenType = cmd.OAuthToken.TokenType
f.ExpectedOAuth.OAuthRefreshToken = cmd.OAuthToken.RefreshToken
return f.ExpectedError
}
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error {
return f.ExpectedError
}
func TestService_TryTokenRefresh(t *testing.T) {
type environment struct {
authInfoService *authinfotest.FakeService
cache *localcache.CacheService
identity identity.Requester
socialConnector *socialtest.MockSocialConnector
socialService *socialtest.FakeSocialService
service *Service
}
type testCase struct {
desc string
expectedErr error
setup func(env *environment)
}
tests := []testCase{
{
desc: "should skip sync when identity is nil",
},
{
desc: "should skip sync when identity is not a user",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "service-account:1"}
},
},
{
desc: "should skip token refresh and return nil if namespace and id cannot be converted to user ID",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:invalidIdentifierFormat"}
},
},
{
desc: "should skip token refresh since the token is still valid",
setup: func(env *environment) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(time.Hour),
TokenType: "Bearer",
}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
}
env.identity = &authn.Identity{
AuthenticatedBy: login.GenericOAuthModule,
ID: "user:1234",
}
},
},
{
desc: "should skip token refresh if the expiration check has already been cached",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.cache.Set("oauth-refresh-token-1234", true, 1*time.Minute)
},
},
{
desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedError = errors.New("some error")
},
},
{
desc: "should skip token refresh if the user doesn't have an oauth entry",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.SAMLAuthModule,
}
},
},
{
desc: "should do token refresh if access token or id token have not expired yet",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
}
},
},
{
desc: "should skip token refresh when no oauth provider was found",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthIdToken: EXPIRED_JWT,
}
},
},
{
desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthIdToken: EXPIRED_JWT,
}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: false,
}
},
},
{
desc: "should skip token refresh when there is no refresh token",
setup: func(env *environment) {
env.identity = &authn.Identity{ID: "user:1234"}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
OAuthIdToken: EXPIRED_JWT,
OAuthRefreshToken: "",
}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true,
}
},
},
{
desc: "should do token refresh when the token is expired",
setup: func(env *environment) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(-time.Hour),
TokenType: "Bearer",
}
env.identity = &authn.Identity{ID: "user:1234"}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true,
}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
AuthId: "subject",
UserId: 1,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
}
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
},
},
{
desc: "should refresh token when the id token is expired",
setup: func(env *environment) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(time.Hour),
TokenType: "Bearer",
}
env.identity = &authn.Identity{ID: "user:1234"}
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
UseRefreshToken: true,
}
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
AuthModule: login.GenericOAuthModule,
AuthId: "subject",
UserId: 1,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
OAuthIdToken: EXPIRED_JWT,
}
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
socialConnector := &socialtest.MockSocialConnector{}
env := environment{
authInfoService: &authinfotest.FakeService{},
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
socialConnector: socialConnector,
socialService: &socialtest.FakeSocialService{
ExpectedConnector: socialConnector,
},
}
if tt.setup != nil {
tt.setup(&env)
}
env.service = &Service{
AuthInfoService: env.authInfoService,
Cfg: setting.NewCfg(),
cache: env.cache,
singleFlightGroup: &singleflight.Group{},
SocialService: env.socialService,
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
}
// token refresh
err := env.service.TryTokenRefresh(context.Background(), env.identity)
// test and validations
assert.ErrorIs(t, err, tt.expectedErr)
socialConnector.AssertExpectations(t)
})
}
}
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
defaultTime := time.Now()
tests := []struct {
name string
accessTokenExpiry time.Time
idTokenExpiry time.Time
want time.Duration
}{
{
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
accessTokenExpiry: time.Time{},
idTokenExpiry: time.Time{},
want: maxOAuthTokenCacheTTL,
},
{
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
accessTokenExpiry: time.Time{},
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: maxOAuthTokenCacheTTL,
},
{
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
accessTokenExpiry: time.Time{},
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: time.Time{},
want: maxOAuthTokenCacheTTL,
},
{
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: time.Time{},
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: maxOAuthTokenCacheTTL,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
})
}
}
func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
tests := []struct {
name string
usr *login.UserAuth
expectedTokenRefreshFlag bool
expectedTokenDuration time.Duration
}{
{
name: "should not need token refresh when token has no expiration date",
usr: &login.UserAuth{},
expectedTokenRefreshFlag: false,
expectedTokenDuration: maxOAuthTokenCacheTTL,
},
{
name: "should not need token refresh with an invalid jwt token that might result in an error when parsing",
usr: &login.UserAuth{
OAuthIdToken: "invalid_jwt_format",
},
expectedTokenRefreshFlag: false,
expectedTokenDuration: maxOAuthTokenCacheTTL,
},
{
name: "should flag token refresh with id token is expired",
usr: &login.UserAuth{
OAuthIdToken: EXPIRED_JWT,
},
expectedTokenRefreshFlag: true,
expectedTokenDuration: time.Second,
},
{
name: "should flag token refresh when expiry date is zero",
usr: &login.UserAuth{
OAuthExpiry: time.Unix(0, 0),
},
expectedTokenRefreshFlag: true,
expectedTokenDuration: time.Second,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, needsTokenRefresh, tokenDuration := needTokenRefresh(tt.usr)
assert.NotNil(t, token)
assert.Equal(t, tt.expectedTokenRefreshFlag, needsTokenRefresh)
assert.Equal(t, tt.expectedTokenDuration, tokenDuration)
})
}
}
func TestOAuthTokenSync_tryGetOrRefreshOAuthToken(t *testing.T) {
timeNow := time.Now()
token := &oauth2.Token{
AccessToken: "oauth_access_token",
RefreshToken: "refresh_token_found",
Expiry: timeNow,
TokenType: "Bearer",
}
type environment struct {
authInfoService *authinfotest.FakeService
cache *localcache.CacheService
socialConnector *socialtest.MockSocialConnector
socialService *socialtest.FakeSocialService
service *Service
}
tests := []struct {
desc string
expectedErr error
expectedToken *oauth2.Token
usr *login.UserAuth
setup func(env *environment)
}{
{
desc: "should find and retrieve token from cache",
usr: &login.UserAuth{
UserId: int64(1234),
OAuthAccessToken: "new_access_token",
OAuthExpiry: timeNow,
},
setup: func(env *environment) {
env.cache.Set("token-check-1234", token, 1*time.Minute)
},
expectedToken: &oauth2.Token{
AccessToken: "new_access_token",
Expiry: timeNow,
},
},
{
desc: "should return ErrNotAnOAuthProvider error when the user is not an oauth provider",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.SAMLAuthModule,
},
expectedErr: ErrNotAnOAuthProvider,
},
{
desc: "should return ErrNoRefreshTokenFound error when the no refresh token was found",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.GenericOAuthModule,
},
expectedErr: ErrNoRefreshTokenFound,
},
{
desc: "should not refresh token if the token is not expired",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.GenericOAuthModule,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: timeNow.Add(time.Hour),
OAuthTokenType: "Bearer",
},
expectedToken: &oauth2.Token{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
Expiry: timeNow.Add(time.Hour),
TokenType: "Bearer",
},
},
{
desc: "should update saved token if the user auth has new access/refresh tokens",
usr: &login.UserAuth{
UserId: int64(1234),
AuthModule: login.GenericOAuthModule,
OAuthAccessToken: "new_oauth_access_token",
OAuthRefreshToken: "new_refresh_token_found",
OAuthExpiry: timeNow,
},
expectedToken: &oauth2.Token{
AccessToken: "oauth_access_token",
RefreshToken: "refresh_token_found",
Expiry: timeNow,
TokenType: "Bearer",
},
setup: func(env *environment) {
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)).Once()
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
socialConnector := &socialtest.MockSocialConnector{}
env := environment{
authInfoService: &authinfotest.FakeService{},
cache: localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
socialConnector: socialConnector,
socialService: &socialtest.FakeSocialService{
ExpectedConnector: socialConnector,
},
}
if tt.setup != nil {
tt.setup(&env)
}
env.service = &Service{
AuthInfoService: env.authInfoService,
Cfg: setting.NewCfg(),
cache: env.cache,
singleFlightGroup: &singleflight.Group{},
SocialService: env.socialService,
tokenRefreshDuration: newTokenRefreshDurationMetric(prometheus.NewRegistry()),
}
token, err := env.service.tryGetOrRefreshOAuthToken(context.Background(), tt.usr)
if tt.expectedToken != nil {
assert.Equal(t, tt.expectedToken, token)
}
assert.ErrorIs(t, tt.expectedErr, err)
socialConnector.AssertExpectations(t)
})
}
}