@ -10,20 +10,26 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/socialtest"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfoimpl"
"github.com/grafana/grafana/pkg/services/login/authinfotest"
"github.com/grafana/grafana/pkg/services/secrets/fakes"
secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
)
var EXPIRED_JWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U"
func TestService_HasOAuthEntry ( t * testing . T ) {
testCases := [ ] struct {
name string
@ -69,10 +75,10 @@ func TestService_HasOAuthEntry(t *testing.T) {
{
name : "returns true when the auth entry is found" ,
user : & user . SignedInUser { UserID : 1 } ,
want : & login . UserAuth { AuthModule : "oauth_generic_oauth" } ,
want : & login . UserAuth { AuthModule : login . GenericOAuthModule } ,
wantExist : true ,
wantErr : false ,
getAuthInfoUser : login . UserAuth { AuthModule : "oauth_generic_oauth" } ,
getAuthInfoUser : login . UserAuth { AuthModule : login . GenericOAuthModule } ,
} ,
}
for _ , tc := range testCases {
@ -96,152 +102,26 @@ func TestService_HasOAuthEntry(t *testing.T) {
}
}
func TestService_TryTokenRefresh_ValidToken ( t * testing . T ) {
srv , authInfoStore , socialConnector := setupOAuthTokenService ( t )
ctx := context . Background ( )
token := & oauth2 . Token {
AccessToken : "testaccess" ,
RefreshToken : "testrefresh" ,
Expiry : time . Now ( ) ,
TokenType : "Bearer" ,
}
usr := & login . UserAuth {
AuthModule : "oauth_generic_oauth" ,
OAuthAccessToken : token . AccessToken ,
OAuthRefreshToken : token . RefreshToken ,
OAuthExpiry : token . Expiry ,
OAuthTokenType : token . TokenType ,
}
authInfoStore . ExpectedOAuth = usr
socialConnector . On ( "TokenSource" , mock . Anything , mock . Anything ) . Return ( oauth2 . StaticTokenSource ( token ) )
err := srv . TryTokenRefresh ( ctx , usr )
require . Nil ( t , err )
socialConnector . AssertNumberOfCalls ( t , "TokenSource" , 1 )
authInfoQuery := & login . GetAuthInfoQuery { UserId : 1 }
resultUsr , err := srv . AuthInfoService . GetAuthInfo ( ctx , authInfoQuery )
require . Nil ( t , err )
// User's token data had not been updated
assert . Equal ( t , resultUsr . OAuthAccessToken , token . AccessToken )
assert . Equal ( t , resultUsr . OAuthExpiry , token . Expiry )
assert . Equal ( t , resultUsr . OAuthRefreshToken , token . RefreshToken )
assert . Equal ( t , resultUsr . OAuthTokenType , token . TokenType )
}
func TestService_TryTokenRefresh_NoRefreshToken ( t * testing . T ) {
srv , _ , socialConnector := setupOAuthTokenService ( t )
ctx := context . Background ( )
token := & oauth2 . Token {
AccessToken : "testaccess" ,
RefreshToken : "" ,
Expiry : time . Now ( ) . Add ( - time . Hour ) ,
TokenType : "Bearer" ,
}
usr := & login . UserAuth {
AuthModule : "oauth_generic_oauth" ,
OAuthAccessToken : token . AccessToken ,
OAuthRefreshToken : token . RefreshToken ,
OAuthExpiry : token . Expiry ,
OAuthTokenType : token . TokenType ,
}
socialConnector . On ( "TokenSource" , mock . Anything , mock . Anything ) . Return ( oauth2 . StaticTokenSource ( token ) )
err := srv . TryTokenRefresh ( ctx , usr )
assert . NotNil ( t , err )
assert . ErrorIs ( t , err , ErrNoRefreshTokenFound )
socialConnector . AssertNotCalled ( t , "TokenSource" )
}
func TestService_TryTokenRefresh_ExpiredToken ( t * testing . T ) {
srv , authInfoStore , socialConnector := setupOAuthTokenService ( t )
ctx := context . Background ( )
token := & oauth2 . Token {
AccessToken : "testaccess" ,
RefreshToken : "testrefresh" ,
Expiry : time . Now ( ) . Add ( - time . Hour ) ,
TokenType : "Bearer" ,
}
newToken := & oauth2 . Token {
AccessToken : "testaccess_new" ,
RefreshToken : "testrefresh_new" ,
Expiry : time . Now ( ) . Add ( time . Hour ) ,
TokenType : "Bearer" ,
}
usr := & login . UserAuth {
AuthModule : "oauth_generic_oauth" ,
UserId : 1 ,
AuthId : "test" ,
OAuthAccessToken : token . AccessToken ,
OAuthRefreshToken : token . RefreshToken ,
OAuthExpiry : token . Expiry ,
OAuthTokenType : token . TokenType ,
}
authInfoStore . ExpectedOAuth = usr
socialConnector . On ( "TokenSource" , mock . Anything , mock . Anything ) . Return ( oauth2 . ReuseTokenSource ( token , oauth2 . StaticTokenSource ( newToken ) ) , nil )
err := srv . TryTokenRefresh ( ctx , usr )
require . Nil ( t , err )
socialConnector . AssertNumberOfCalls ( t , "TokenSource" , 1 )
authInfoQuery := & login . GetAuthInfoQuery { UserId : 1 }
authInfo , err := srv . AuthInfoService . GetAuthInfo ( ctx , authInfoQuery )
require . Nil ( t , err )
// newToken should be returned after the .Token() call, therefore the User had to be updated
assert . Equal ( t , authInfo . OAuthAccessToken , newToken . AccessToken )
assert . Equal ( t , authInfo . OAuthExpiry , newToken . Expiry )
assert . Equal ( t , authInfo . OAuthRefreshToken , newToken . RefreshToken )
assert . Equal ( t , authInfo . OAuthTokenType , newToken . TokenType )
}
func TestService_TryTokenRefresh_DifferentAuthModuleForUser ( t * testing . T ) {
srv , _ , socialConnector := setupOAuthTokenService ( t )
ctx := context . Background ( )
token := & oauth2 . Token { }
usr := & login . UserAuth {
AuthModule : "auth.saml" ,
}
socialConnector . On ( "TokenSource" , mock . Anything , mock . Anything ) . Return ( oauth2 . StaticTokenSource ( token ) )
err := srv . TryTokenRefresh ( ctx , usr )
assert . NotNil ( t , err )
assert . ErrorIs ( t , err , ErrNotAnOAuthProvider )
socialConnector . AssertNotCalled ( t , "TokenSource" )
}
func setupOAuthTokenService ( t * testing . T ) ( * Service , * FakeAuthInfoStore , * socialtest . MockSocialConnector ) {
t . Helper ( )
socialConnector := & socialtest . MockSocialConnector { }
socialService := & socialtest . FakeSocialService {
ExpectedConnector : socialConnector ,
ExpectedAuthInfoProvider : & social . OAuthInfo {
UseRefreshToken : true ,
} ,
}
authInfoStore := & FakeAuthInfoStore { }
authInfoService := authinfoimpl . ProvideService ( authInfoStore , remotecache . NewFakeCacheStorage ( ) ,
secretsManager . SetupTestService ( t , fakes . NewFakeSecretsStore ( ) ) )
authInfoStore := & FakeAuthInfoStore { ExpectedOAuth : & login . UserAuth { } }
authInfoService := authinfoimpl . ProvideService ( authInfoStore , remotecache . NewFakeCacheStorage ( ) , secretsManager . SetupTestService ( t , fakes . NewFakeSecretsStore ( ) ) )
return & Service {
Cfg : setting . NewCfg ( ) ,
SocialService : socialService ,
AuthInfoService : authInfoService ,
singleFlightGroup : & singleflight . Group { } ,
tokenRefreshDuration : newTokenRefreshDurationMetric ( prometheus . NewRegistry ( ) ) ,
cache : localcache . New ( maxOAuthTokenCacheTTL , 15 * time . Minute ) ,
} , authInfoStore , socialConnector
}
@ -270,3 +150,461 @@ func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.Updat
func ( f * FakeAuthInfoStore ) DeleteAuthInfo ( ctx context . Context , cmd * login . DeleteAuthInfoCommand ) error {
return f . ExpectedError
}
func TestService_TryTokenRefresh ( t * testing . T ) {
type environment struct {
authInfoService * authinfotest . FakeService
cache * localcache . CacheService
identity identity . Requester
socialConnector * socialtest . MockSocialConnector
socialService * socialtest . FakeSocialService
service * Service
}
type testCase struct {
desc string
expectedErr error
setup func ( env * environment )
}
tests := [ ] testCase {
{
desc : "should skip sync when identity is nil" ,
} ,
{
desc : "should skip sync when identity is not a user" ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "service-account:1" }
} ,
} ,
{
desc : "should skip token refresh and return nil if namespace and id cannot be converted to user ID" ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "user:invalidIdentifierFormat" }
} ,
} ,
{
desc : "should skip token refresh since the token is still valid" ,
setup : func ( env * environment ) {
token := & oauth2 . Token {
AccessToken : "testaccess" ,
RefreshToken : "testrefresh" ,
Expiry : time . Now ( ) . Add ( time . Hour ) ,
TokenType : "Bearer" ,
}
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . GenericOAuthModule ,
OAuthAccessToken : token . AccessToken ,
OAuthRefreshToken : token . RefreshToken ,
OAuthExpiry : token . Expiry ,
OAuthTokenType : token . TokenType ,
}
env . identity = & authn . Identity {
AuthenticatedBy : login . GenericOAuthModule ,
ID : "user:1234" ,
}
} ,
} ,
{
desc : "should skip token refresh if the expiration check has already been cached" ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "user:1234" }
env . cache . Set ( "oauth-refresh-token-1234" , true , 1 * time . Minute )
} ,
} ,
{
desc : "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned" ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "user:1234" }
env . authInfoService . ExpectedError = errors . New ( "some error" )
} ,
} ,
{
desc : "should skip token refresh if the user doesn't have an oauth entry" ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "user:1234" }
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . SAMLAuthModule ,
}
} ,
} ,
{
desc : "should do token refresh if access token or id token have not expired yet" ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "user:1234" }
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . GenericOAuthModule ,
}
} ,
} ,
{
desc : "should skip token refresh when no oauth provider was found" ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "user:1234" }
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . GenericOAuthModule ,
OAuthIdToken : EXPIRED_JWT ,
}
} ,
} ,
{
desc : "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)" ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "user:1234" }
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . GenericOAuthModule ,
OAuthIdToken : EXPIRED_JWT ,
}
env . socialService . ExpectedAuthInfoProvider = & social . OAuthInfo {
UseRefreshToken : false ,
}
} ,
} ,
{
desc : "should skip token refresh when there is no refresh token" ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "user:1234" }
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . GenericOAuthModule ,
OAuthIdToken : EXPIRED_JWT ,
OAuthRefreshToken : "" ,
}
env . socialService . ExpectedAuthInfoProvider = & social . OAuthInfo {
UseRefreshToken : true ,
}
} ,
} ,
{
desc : "should do token refresh when the token is expired" ,
setup : func ( env * environment ) {
token := & oauth2 . Token {
AccessToken : "testaccess" ,
RefreshToken : "testrefresh" ,
Expiry : time . Now ( ) . Add ( - time . Hour ) ,
TokenType : "Bearer" ,
}
env . identity = & authn . Identity { ID : "user:1234" }
env . socialService . ExpectedAuthInfoProvider = & social . OAuthInfo {
UseRefreshToken : true ,
}
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . GenericOAuthModule ,
AuthId : "subject" ,
UserId : 1 ,
OAuthAccessToken : token . AccessToken ,
OAuthRefreshToken : token . RefreshToken ,
OAuthExpiry : token . Expiry ,
OAuthTokenType : token . TokenType ,
}
env . socialConnector . On ( "TokenSource" , mock . Anything , mock . Anything ) . Return ( oauth2 . StaticTokenSource ( token ) ) . Once ( )
} ,
} ,
{
desc : "should refresh token when the id token is expired" ,
setup : func ( env * environment ) {
token := & oauth2 . Token {
AccessToken : "testaccess" ,
RefreshToken : "testrefresh" ,
Expiry : time . Now ( ) . Add ( time . Hour ) ,
TokenType : "Bearer" ,
}
env . identity = & authn . Identity { ID : "user:1234" }
env . socialService . ExpectedAuthInfoProvider = & social . OAuthInfo {
UseRefreshToken : true ,
}
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . GenericOAuthModule ,
AuthId : "subject" ,
UserId : 1 ,
OAuthAccessToken : token . AccessToken ,
OAuthRefreshToken : token . RefreshToken ,
OAuthExpiry : token . Expiry ,
OAuthTokenType : token . TokenType ,
OAuthIdToken : EXPIRED_JWT ,
}
env . socialConnector . On ( "TokenSource" , mock . Anything , mock . Anything ) . Return ( oauth2 . StaticTokenSource ( token ) ) . Once ( )
} ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . desc , func ( t * testing . T ) {
socialConnector := & socialtest . MockSocialConnector { }
env := environment {
authInfoService : & authinfotest . FakeService { } ,
cache : localcache . New ( maxOAuthTokenCacheTTL , 15 * time . Minute ) ,
socialConnector : socialConnector ,
socialService : & socialtest . FakeSocialService {
ExpectedConnector : socialConnector ,
} ,
}
if tt . setup != nil {
tt . setup ( & env )
}
env . service = & Service {
AuthInfoService : env . authInfoService ,
Cfg : setting . NewCfg ( ) ,
cache : env . cache ,
singleFlightGroup : & singleflight . Group { } ,
SocialService : env . socialService ,
tokenRefreshDuration : newTokenRefreshDurationMetric ( prometheus . NewRegistry ( ) ) ,
}
// token refresh
err := env . service . TryTokenRefresh ( context . Background ( ) , env . identity )
// test and validations
assert . ErrorIs ( t , err , tt . expectedErr )
socialConnector . AssertExpectations ( t )
} )
}
}
func TestOAuthTokenSync_getOAuthTokenCacheTTL ( t * testing . T ) {
defaultTime := time . Now ( )
tests := [ ] struct {
name string
accessTokenExpiry time . Time
idTokenExpiry time . Time
want time . Duration
} {
{
name : "should return maxOAuthTokenCacheTTL when no expiry is given" ,
accessTokenExpiry : time . Time { } ,
idTokenExpiry : time . Time { } ,
want : maxOAuthTokenCacheTTL ,
} ,
{
name : "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl" ,
accessTokenExpiry : time . Time { } ,
idTokenExpiry : defaultTime . Add ( 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
want : maxOAuthTokenCacheTTL ,
} ,
{
name : "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl" ,
accessTokenExpiry : time . Time { } ,
idTokenExpiry : defaultTime . Add ( - 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
want : time . Until ( defaultTime . Add ( - 5 * time . Minute + maxOAuthTokenCacheTTL ) ) ,
} ,
{
name : "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given" ,
accessTokenExpiry : defaultTime . Add ( 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
idTokenExpiry : time . Time { } ,
want : maxOAuthTokenCacheTTL ,
} ,
{
name : "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given" ,
accessTokenExpiry : defaultTime . Add ( - 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
idTokenExpiry : time . Time { } ,
want : time . Until ( defaultTime . Add ( - 5 * time . Minute + maxOAuthTokenCacheTTL ) ) ,
} ,
{
name : "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry" ,
accessTokenExpiry : defaultTime . Add ( - 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
idTokenExpiry : defaultTime . Add ( 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
want : time . Until ( defaultTime . Add ( - 5 * time . Minute + maxOAuthTokenCacheTTL ) ) ,
} ,
{
name : "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry" ,
accessTokenExpiry : defaultTime . Add ( 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
idTokenExpiry : defaultTime . Add ( - 3 * time . Minute + maxOAuthTokenCacheTTL ) ,
want : time . Until ( defaultTime . Add ( - 3 * time . Minute + maxOAuthTokenCacheTTL ) ) ,
} ,
{
name : "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl" ,
accessTokenExpiry : defaultTime . Add ( 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
idTokenExpiry : defaultTime . Add ( 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
want : maxOAuthTokenCacheTTL ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
got := getOAuthTokenCacheTTL ( tt . accessTokenExpiry , tt . idTokenExpiry )
assert . Equal ( t , tt . want . Round ( time . Second ) , got . Round ( time . Second ) )
} )
}
}
func TestOAuthTokenSync_needTokenRefresh ( t * testing . T ) {
tests := [ ] struct {
name string
usr * login . UserAuth
expectedTokenRefreshFlag bool
expectedTokenDuration time . Duration
} {
{
name : "should not need token refresh when token has no expiration date" ,
usr : & login . UserAuth { } ,
expectedTokenRefreshFlag : false ,
expectedTokenDuration : maxOAuthTokenCacheTTL ,
} ,
{
name : "should not need token refresh with an invalid jwt token that might result in an error when parsing" ,
usr : & login . UserAuth {
OAuthIdToken : "invalid_jwt_format" ,
} ,
expectedTokenRefreshFlag : false ,
expectedTokenDuration : maxOAuthTokenCacheTTL ,
} ,
{
name : "should flag token refresh with id token is expired" ,
usr : & login . UserAuth {
OAuthIdToken : EXPIRED_JWT ,
} ,
expectedTokenRefreshFlag : true ,
expectedTokenDuration : time . Second ,
} ,
{
name : "should flag token refresh when expiry date is zero" ,
usr : & login . UserAuth {
OAuthExpiry : time . Unix ( 0 , 0 ) ,
} ,
expectedTokenRefreshFlag : true ,
expectedTokenDuration : time . Second ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
token , needsTokenRefresh , tokenDuration := needTokenRefresh ( tt . usr )
assert . NotNil ( t , token )
assert . Equal ( t , tt . expectedTokenRefreshFlag , needsTokenRefresh )
assert . Equal ( t , tt . expectedTokenDuration , tokenDuration )
} )
}
}
func TestOAuthTokenSync_tryGetOrRefreshOAuthToken ( t * testing . T ) {
timeNow := time . Now ( )
token := & oauth2 . Token {
AccessToken : "oauth_access_token" ,
RefreshToken : "refresh_token_found" ,
Expiry : timeNow ,
TokenType : "Bearer" ,
}
type environment struct {
authInfoService * authinfotest . FakeService
cache * localcache . CacheService
socialConnector * socialtest . MockSocialConnector
socialService * socialtest . FakeSocialService
service * Service
}
tests := [ ] struct {
desc string
expectedErr error
expectedToken * oauth2 . Token
usr * login . UserAuth
setup func ( env * environment )
} {
{
desc : "should find and retrieve token from cache" ,
usr : & login . UserAuth {
UserId : int64 ( 1234 ) ,
OAuthAccessToken : "new_access_token" ,
OAuthExpiry : timeNow ,
} ,
setup : func ( env * environment ) {
env . cache . Set ( "token-check-1234" , token , 1 * time . Minute )
} ,
expectedToken : & oauth2 . Token {
AccessToken : "new_access_token" ,
Expiry : timeNow ,
} ,
} ,
{
desc : "should return ErrNotAnOAuthProvider error when the user is not an oauth provider" ,
usr : & login . UserAuth {
UserId : int64 ( 1234 ) ,
AuthModule : login . SAMLAuthModule ,
} ,
expectedErr : ErrNotAnOAuthProvider ,
} ,
{
desc : "should return ErrNoRefreshTokenFound error when the no refresh token was found" ,
usr : & login . UserAuth {
UserId : int64 ( 1234 ) ,
AuthModule : login . GenericOAuthModule ,
} ,
expectedErr : ErrNoRefreshTokenFound ,
} ,
{
desc : "should not refresh token if the token is not expired" ,
usr : & login . UserAuth {
UserId : int64 ( 1234 ) ,
AuthModule : login . GenericOAuthModule ,
OAuthAccessToken : token . AccessToken ,
OAuthRefreshToken : token . RefreshToken ,
OAuthExpiry : timeNow . Add ( time . Hour ) ,
OAuthTokenType : "Bearer" ,
} ,
expectedToken : & oauth2 . Token {
AccessToken : token . AccessToken ,
RefreshToken : token . RefreshToken ,
Expiry : timeNow . Add ( time . Hour ) ,
TokenType : "Bearer" ,
} ,
} ,
{
desc : "should update saved token if the user auth has new access/refresh tokens" ,
usr : & login . UserAuth {
UserId : int64 ( 1234 ) ,
AuthModule : login . GenericOAuthModule ,
OAuthAccessToken : "new_oauth_access_token" ,
OAuthRefreshToken : "new_refresh_token_found" ,
OAuthExpiry : timeNow ,
} ,
expectedToken : & oauth2 . Token {
AccessToken : "oauth_access_token" ,
RefreshToken : "refresh_token_found" ,
Expiry : timeNow ,
TokenType : "Bearer" ,
} ,
setup : func ( env * environment ) {
env . socialConnector . On ( "TokenSource" , mock . Anything , mock . Anything ) . Return ( oauth2 . StaticTokenSource ( token ) ) . Once ( )
} ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . desc , func ( t * testing . T ) {
socialConnector := & socialtest . MockSocialConnector { }
env := environment {
authInfoService : & authinfotest . FakeService { } ,
cache : localcache . New ( maxOAuthTokenCacheTTL , 15 * time . Minute ) ,
socialConnector : socialConnector ,
socialService : & socialtest . FakeSocialService {
ExpectedConnector : socialConnector ,
} ,
}
if tt . setup != nil {
tt . setup ( & env )
}
env . service = & Service {
AuthInfoService : env . authInfoService ,
Cfg : setting . NewCfg ( ) ,
cache : env . cache ,
singleFlightGroup : & singleflight . Group { } ,
SocialService : env . socialService ,
tokenRefreshDuration : newTokenRefreshDurationMetric ( prometheus . NewRegistry ( ) ) ,
}
token , err := env . service . tryGetOrRefreshOAuthToken ( context . Background ( ) , tt . usr )
if tt . expectedToken != nil {
assert . Equal ( t , tt . expectedToken , token )
}
assert . ErrorIs ( t , tt . expectedErr , err )
socialConnector . AssertExpectations ( t )
} )
}
}