@ -31,7 +31,9 @@ import (
"github.com/grafana/grafana/pkg/tests/testsuite"
)
var EXPIRED_JWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U"
const EXPIRED_ID_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6InlvdXItY2xpZW50LWlkIiwiZXhwIjoxNjAwMDAwMDAwLCJpYXQiOjE2MDAwMDAwMDAsIm5hbWUiOiJKb2huIERvZSIsImVtYWlsIjoiam9obkBleGFtcGxlLmNvbSJ9.c2lnbmF0dXJl" // #nosec G101 not a hardcoded credential
const UNEXPIRED_ID_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6InlvdXItY2xpZW50LWlkIiwiZXhwIjo0ODg1NjA4MDAwLCJpYXQiOjE2ODU2MDgwMDAsIm5hbWUiOiJKb2huIERvZSIsImVtYWlsIjoiam9obkBleGFtcGxlLmNvbSJ9.c2lnbmF0dXJl" // #nosec G101 not a hardcoded credential
func TestMain ( m * testing . M ) {
testsuite . Run ( m )
@ -162,19 +164,44 @@ func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.Delet
}
func TestService_TryTokenRefresh ( t * testing . T ) {
unexpiredToken := & oauth2 . Token {
AccessToken : "testaccess" ,
RefreshToken : "testrefresh" ,
Expiry : time . Now ( ) . Add ( time . Hour ) ,
TokenType : "Bearer" ,
}
unexpiredTokenWithIDToken := unexpiredToken . WithExtra ( map [ string ] interface { } {
"id_token" : UNEXPIRED_ID_TOKEN ,
} )
expiredToken := & oauth2 . Token {
AccessToken : "testaccess" ,
RefreshToken : "testrefresh" ,
Expiry : time . Now ( ) . Add ( - time . Hour ) ,
TokenType : "Bearer" ,
}
type environment struct {
authInfoService * authinfotest . FakeService
serverLock * serverlock . ServerLockService
identity identity . Requester
socialConnector * socialtest . MockSocialConnector
socialService * socialtest . FakeSocialService
service * Service
}
type testCase struct {
desc string
expectedErr error
setup func ( env * environment )
desc string
identity identity . Requester
setup func ( env * environment )
expectedToken * oauth2 . Token
expectedErr error
}
userIdentity := & authn . Identity {
AuthenticatedBy : login . GenericOAuthModule ,
ID : "1234" ,
Type : claims . TypeUser ,
}
tests := [ ] testCase {
@ -182,114 +209,111 @@ func TestService_TryTokenRefresh(t *testing.T) {
desc : "should skip sync when identity is nil" ,
} ,
{
desc : "should skip sync when identity is not a user" ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "1" , Type : claims . TypeServiceAccount }
} ,
desc : "should skip sync when identity is not a user" ,
identity : & authn . Identity { ID : "1" , Type : claims . TypeServiceAccount } ,
} ,
{
desc : "should skip token refresh and return nil if namespace and id cannot be converted to user ID" ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "invalid" , Type : claims . TypeUser }
} ,
desc : "should skip token refresh and return nil if namespace and id cannot be converted to user ID" ,
identity : & authn . Identity { ID : "invalid" , Type : claims . TypeUser } ,
} ,
{
desc : "should skip token refresh since the token is still valid" ,
desc : "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned" ,
identity : userIdentity ,
setup : func ( env * environment ) {
token := & oauth2 . Token {
AccessToken : "testaccess" ,
RefreshToken : "testrefresh" ,
Expiry : time . Now ( ) . Add ( time . Hour ) ,
TokenType : "Bearer" ,
}
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . GenericOAuthModule ,
OAuthAccessToken : token . AccessToken ,
OAuthRefreshToken : token . RefreshToken ,
OAuthExpiry : token . Expiry ,
OAuthTokenType : token . TokenType ,
}
env . identity = & authn . Identity {
AuthenticatedBy : login . GenericOAuthModule ,
ID : "1234" ,
Type : claims . TypeUser ,
}
env . authInfoService . ExpectedError = errors . New ( "some error" )
} ,
} ,
{
desc : "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned" ,
desc : "should skip token refresh if the user doesn't have an oauth entry" ,
identity : userIdentity ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "1234" , Type : claims . TypeUser }
env . authInfoService . ExpectedError = errors . New ( "some error" )
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . SAMLAuthModule ,
}
} ,
} ,
{
desc : "should skip token refresh if the user doesn't have an oauth entry" ,
desc : "should skip token refresh when no oauth provider was found" ,
identity : userIdentity ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "1234" , Type : claims . TypeUser }
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . SAML AuthModule,
AuthModule : login . GenericOAuthModule ,
}
} ,
} ,
{
desc : "should do token refresh if access token or id token have not expired yet" ,
desc : "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)" ,
identity : userIdentity ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "1234" , Type : claims . TypeUser }
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . GenericOAuthModule ,
}
env . socialService . ExpectedAuthInfoProvider = & social . OAuthInfo {
UseRefreshToken : false ,
}
} ,
} ,
{
desc : "should skip token refresh when no oauth provider was found" ,
desc : "should skip token refresh when the token is still valid and no id token is present" ,
identity : userIdentity ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "1234" , Type : claims . TypeUser }
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . GenericOAuthModule ,
OAuthIdToken : EXPIRED_JWT ,
AuthModule : login . GenericOAuthModule ,
OAuthAccessToken : unexpiredTokenWithIDToken . AccessToken ,
OAuthRefreshToken : unexpiredTokenWithIDToken . RefreshToken ,
OAuthExpiry : unexpiredTokenWithIDToken . Expiry ,
OAuthTokenType : unexpiredTokenWithIDToken . TokenType ,
}
env . socialService . ExpectedAuthInfoProvider = & social . OAuthInfo {
UseRefreshToken : true ,
}
} ,
expectedToken : unexpiredToken ,
} ,
{
desc : "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)" ,
desc : "should not refresh the tokens if access token or id token have not expired yet" ,
identity : userIdentity ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "1234" , Type : claims . TypeUser }
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . GenericOAuthModule ,
OAuthIdToken : EXPIRED_JWT ,
AuthModule : login . GenericOAuthModule ,
OAuthIdToken : UNEXPIRED_ID_TOKEN ,
OAuthAccessToken : unexpiredTokenWithIDToken . AccessToken ,
OAuthRefreshToken : unexpiredTokenWithIDToken . RefreshToken ,
OAuthExpiry : unexpiredTokenWithIDToken . Expiry ,
OAuthTokenType : unexpiredTokenWithIDToken . TokenType ,
}
env . socialService . ExpectedAuthInfoProvider = & social . OAuthInfo {
UseRefreshToken : false ,
UseRefreshToken : tru e,
}
} ,
expectedToken : unexpiredTokenWithIDToken ,
} ,
{
desc : "should skip token refresh when there is no refresh token" ,
desc : "should skip token refresh when there is no refresh token" ,
identity : userIdentity ,
setup : func ( env * environment ) {
env . identity = & authn . Identity { ID : "1234" , Type : claims . TypeUser }
env . authInfoService . ExpectedUserAuth = & login . UserAuth {
AuthModule : login . GenericOAuthModule ,
OAuthIdToken : EXPIRED_JWT ,
OAuthAccessToken : unexpiredTokenWithIDToken . AccessToken ,
OAuthRefreshToken : "" ,
OAuthExpiry : unexpiredTokenWithIDToken . Expiry ,
}
env . socialService . ExpectedAuthInfoProvider = & social . OAuthInfo {
UseRefreshToken : true ,
}
} ,
expectedToken : & oauth2 . Token {
AccessToken : unexpiredTokenWithIDToken . AccessToken ,
RefreshToken : "" ,
Expiry : unexpiredTokenWithIDToken . Expiry ,
} ,
} ,
{
desc : "should do token refresh when the token is expired" ,
desc : "should do token refresh when the token is expired" ,
identity : userIdentity ,
setup : func ( env * environment ) {
token := & oauth2 . Token {
AccessToken : "testaccess" ,
RefreshToken : "testrefresh" ,
Expiry : time . Now ( ) . Add ( - time . Hour ) ,
TokenType : "Bearer" ,
}
env . identity = & authn . Identity { ID : "1234" , Type : claims . TypeUser , AuthenticatedBy : login . GenericOAuthModule }
env . socialService . ExpectedAuthInfoProvider = & social . OAuthInfo {
UseRefreshToken : true ,
}
@ -297,24 +321,20 @@ func TestService_TryTokenRefresh(t *testing.T) {
AuthModule : login . GenericOAuthModule ,
AuthId : "subject" ,
UserId : 1 ,
OAuthAccessToken : token . AccessToken ,
OAuthRefreshToken : token . RefreshToken ,
OAuthExpiry : token . Expiry ,
OAuthTokenType : token . TokenType ,
OAuthAccessToken : expiredToken . AccessToken ,
OAuthRefreshToken : expiredToken . RefreshToken ,
OAuthExpiry : expiredToken . Expiry ,
OAuthTokenType : expiredToken . TokenType ,
OAuthIdToken : EXPIRED_ID_TOKEN ,
}
env . socialConnector . On ( "TokenSource" , mock . Anything , mock . Anything ) . Return ( oauth2 . StaticTokenSource ( token ) ) . Once ( )
env . socialConnector . On ( "TokenSource" , mock . Anything , mock . Anything ) . Return ( oauth2 . StaticTokenSource ( unexpiredTokenWi thIDT oken) ) . Once ( )
} ,
expectedToken : unexpiredTokenWithIDToken ,
} ,
{
desc : "should refresh token when the id token is expired" ,
desc : "should refresh token when the id token is expired" ,
identity : & authn . Identity { ID : "1234" , Type : claims . TypeUser , AuthenticatedBy : login . GenericOAuthModule } ,
setup : func ( env * environment ) {
token := & oauth2 . Token {
AccessToken : "testaccess" ,
RefreshToken : "testrefresh" ,
Expiry : time . Now ( ) . Add ( time . Hour ) ,
TokenType : "Bearer" ,
}
env . identity = & authn . Identity { ID : "1234" , Type : claims . TypeUser , AuthenticatedBy : login . GenericOAuthModule }
env . socialService . ExpectedAuthInfoProvider = & social . OAuthInfo {
UseRefreshToken : true ,
}
@ -322,19 +342,20 @@ func TestService_TryTokenRefresh(t *testing.T) {
AuthModule : login . GenericOAuthModule ,
AuthId : "subject" ,
UserId : 1 ,
OAuthAccessToken : token . AccessToken ,
OAuthRefreshToken : token . RefreshToken ,
OAuthExpiry : token . Expiry ,
OAuthTokenType : token . TokenType ,
OAuthIdToken : EXPIRED_JWT ,
OAuthAccessToken : unexpiredTokenWi thIDT oken. AccessToken ,
OAuthRefreshToken : unexpiredTokenWi thIDT oken. RefreshToken ,
OAuthExpiry : unexpiredTokenWi thIDT oken. Expiry ,
OAuthTokenType : unexpiredTokenWi thIDT oken. TokenType ,
OAuthIdToken : EXPIRED_ID_TOKEN ,
}
env . socialConnector . On ( "TokenSource" , mock . Anything , mock . Anything ) . Return ( oauth2 . StaticTokenSource ( token ) ) . Once ( )
env . socialConnector . On ( "TokenSource" , mock . Anything , mock . Anything ) . Return ( oauth2 . StaticTokenSource ( unexpiredTokenWi thIDT oken) ) . Once ( )
} ,
expectedToken : unexpiredTokenWithIDToken ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . desc , func ( t * testing . T ) {
socialConnector := & socialtest . MockSocialConnector { }
socialConnector := socialtest . NewMockSocialConnector ( t )
store := db . InitTestDB ( t )
@ -361,11 +382,27 @@ func TestService_TryTokenRefresh(t *testing.T) {
)
// token refresh
_ , err := env . service . TryTokenRefresh ( context . Background ( ) , env . identity )
actualToken , err := env . service . TryTokenRefresh ( context . Background ( ) , tt . identity )
// test and validations
assert . ErrorIs ( t , err , tt . expectedErr )
socialConnector . AssertExpectations ( t )
if tt . expectedErr != nil {
assert . ErrorIs ( t , err , tt . expectedErr )
return
}
if tt . expectedToken == nil {
assert . Nil ( t , actualToken )
return
}
assert . Equal ( t , tt . expectedToken . AccessToken , actualToken . AccessToken )
assert . Equal ( t , tt . expectedToken . RefreshToken , actualToken . RefreshToken )
assert . Equal ( t , tt . expectedToken . Expiry , actualToken . Expiry )
assert . Equal ( t , tt . expectedToken . TokenType , actualToken . TokenType )
if tt . expectedToken . Extra ( "id_token" ) != nil {
assert . Equal ( t , tt . expectedToken . Extra ( "id_token" ) . ( string ) , actualToken . Extra ( "id_token" ) . ( string ) )
} else {
assert . Nil ( t , actualToken . Extra ( "id_token" ) )
}
} )
}
}
@ -392,7 +429,7 @@ func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
{
name : "should flag token refresh with id token is expired" ,
usr : & login . UserAuth {
OAuthIdToken : EXPIRED_JWT ,
OAuthIdToken : EXPIRED_ID_TOKEN ,
} ,
expectedTokenRefreshFlag : true ,
expectedTokenDuration : time . Second ,
@ -408,125 +445,10 @@ func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
token , needsTokenRefresh := needTokenRefresh ( tt . usr )
token := buildOAuthTokenFromAuthInfo ( tt . usr )
needsTokenRefresh := needTokenRefresh ( context . Background ( ) , token )
assert . NotNil ( t , token )
assert . Equal ( t , tt . expectedTokenRefreshFlag , needsTokenRefresh )
} )
}
}
func TestOAuthTokenSync_tryGetOrRefreshOAuthToken ( t * testing . T ) {
timeNow := time . Now ( )
token := & oauth2 . Token {
AccessToken : "oauth_access_token" ,
RefreshToken : "refresh_token_found" ,
Expiry : timeNow ,
TokenType : "Bearer" ,
}
type environment struct {
authInfoService * authinfotest . FakeService
serverLock * serverlock . ServerLockService
socialConnector * socialtest . MockSocialConnector
socialService * socialtest . FakeSocialService
service * Service
}
tests := [ ] struct {
desc string
expectedErr error
expectedToken * oauth2 . Token
usr * login . UserAuth
setup func ( env * environment )
} {
{
desc : "should return ErrNotAnOAuthProvider error when the user is not an oauth provider" ,
usr : & login . UserAuth {
UserId : int64 ( 1234 ) ,
AuthModule : login . SAMLAuthModule ,
} ,
expectedErr : ErrNotAnOAuthProvider ,
} ,
{
desc : "should return ErrNoRefreshTokenFound error when the no refresh token was found" ,
usr : & login . UserAuth {
UserId : int64 ( 1234 ) ,
AuthModule : login . GenericOAuthModule ,
} ,
expectedErr : ErrNoRefreshTokenFound ,
} ,
{
desc : "should not refresh token if the token is not expired" ,
usr : & login . UserAuth {
UserId : int64 ( 1234 ) ,
AuthModule : login . GenericOAuthModule ,
OAuthAccessToken : token . AccessToken ,
OAuthRefreshToken : token . RefreshToken ,
OAuthExpiry : timeNow . Add ( time . Hour ) ,
OAuthTokenType : "Bearer" ,
} ,
expectedToken : & oauth2 . Token {
AccessToken : token . AccessToken ,
RefreshToken : token . RefreshToken ,
Expiry : timeNow . Add ( time . Hour ) ,
TokenType : "Bearer" ,
} ,
} ,
{
desc : "should update saved token if the user auth has new access/refresh tokens" ,
usr : & login . UserAuth {
UserId : int64 ( 1234 ) ,
AuthModule : login . GenericOAuthModule ,
OAuthAccessToken : "new_oauth_access_token" ,
OAuthRefreshToken : "new_refresh_token_found" ,
OAuthExpiry : timeNow ,
} ,
expectedToken : & oauth2 . Token {
AccessToken : "oauth_access_token" ,
RefreshToken : "refresh_token_found" ,
Expiry : timeNow ,
TokenType : "Bearer" ,
} ,
setup : func ( env * environment ) {
env . socialConnector . On ( "TokenSource" , mock . Anything , mock . Anything ) . Return ( oauth2 . StaticTokenSource ( token ) ) . Once ( )
} ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . desc , func ( t * testing . T ) {
socialConnector := & socialtest . MockSocialConnector { }
store := db . InitTestDB ( t )
env := environment {
authInfoService : & authinfotest . FakeService { } ,
serverLock : serverlock . ProvideService ( store , tracing . InitializeTracerForTest ( ) ) ,
socialConnector : socialConnector ,
socialService : & socialtest . FakeSocialService {
ExpectedConnector : socialConnector ,
} ,
}
if tt . setup != nil {
tt . setup ( & env )
}
env . service = ProvideService (
env . socialService ,
env . authInfoService ,
setting . NewCfg ( ) ,
prometheus . NewRegistry ( ) ,
env . serverLock ,
tracing . InitializeTracerForTest ( ) ,
)
token , err := env . service . tryGetOrRefreshOAuthToken ( context . Background ( ) , tt . usr )
if tt . expectedToken != nil {
assert . Equal ( t , tt . expectedToken , token )
}
assert . ErrorIs ( t , tt . expectedErr , err )
socialConnector . AssertExpectations ( t )
} )
}
}