@ -9,13 +9,15 @@ import (
"github.com/go-jose/go-jose/v3/jwt"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
"github.com/grafana/authlib/claims"
"github.com/grafana/grafana/pkg/apimachinery/identity"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/serverlock"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/login"
@ -30,16 +32,15 @@ var (
ExpiryDelta = 10 * time . Second
ErrNoRefreshTokenFound = errors . New ( "no refresh token found" )
ErrNotAnOAuthProvider = errors . New ( "not an oauth provider" )
ErrCouldntRefreshToken = errors . New ( "could not refresh token" )
)
const maxOAuthTokenCacheTTL = 10 * time . Minute
type Service struct {
Cfg * setting . Cfg
SocialService social . Service
AuthInfoService login . AuthInfoService
singleFlightGroup * singleflight . Group
cache * localcache . CacheService
Cfg * setting . Cfg
SocialService social . Service
AuthInfoService login . AuthInfoService
serverLock * serverlock . ServerLockService
tracer tracing . Tracer
tokenRefreshDuration * prometheus . HistogramVec
}
@ -49,29 +50,46 @@ type OAuthTokenService interface {
GetCurrentOAuthToken ( context . Context , identity . Requester ) * oauth2 . Token
IsOAuthPassThruEnabled ( * datasources . DataSource ) bool
HasOAuthEntry ( context . Context , identity . Requester ) ( * login . UserAuth , bool , error )
TryTokenRefresh ( context . Context , identity . Requester ) error
TryTokenRefresh ( context . Context , identity . Requester ) ( * oauth2 . Token , error )
InvalidateOAuthTokens ( context . Context , * login . UserAuth ) error
}
func ProvideService ( socialService social . Service , authInfoService login . AuthInfoService , cfg * setting . Cfg , registerer prometheus . Registerer ) * Service {
func ProvideService ( socialService social . Service , authInfoService login . AuthInfoService , cfg * setting . Cfg , registerer prometheus . Registerer ,
serverLockService * serverlock . ServerLockService , tracer tracing . Tracer ) * Service {
return & Service {
AuthInfoService : authInfoService ,
Cfg : cfg ,
SocialService : socialService ,
cache : localcache . New ( maxOAuthTokenCacheTTL , 15 * time . Minute ) ,
singleFlightGroup : new ( singleflight . Group ) ,
serverLock : serverLockService ,
tokenRefreshDuration : newTokenRefreshDurationMetric ( registerer ) ,
tracer : tracer ,
}
}
// GetCurrentOAuthToken returns the OAuth token, if any, for the authenticated user. Will try to refresh the token if it has expired.
func ( o * Service ) GetCurrentOAuthToken ( ctx context . Context , usr identity . Requester ) * oauth2 . Token {
ctx , span := o . tracer . Start ( ctx , "oauthtoken.GetCurrentOAuthToken" )
defer span . End ( )
authInfo , ok , _ := o . HasOAuthEntry ( ctx , usr )
if ! ok {
return nil
}
token , err := o . tryGetOrRefreshOAuthToken ( ctx , authInfo )
if err := checkOAuthRefreshToken ( authInfo ) ; err != nil {
if errors . Is ( err , ErrNoRefreshTokenFound ) {
return buildOAuthTokenFromAuthInfo ( authInfo )
}
return nil
}
persistedToken , refreshNeeded := needTokenRefresh ( authInfo )
if ! refreshNeeded {
return persistedToken
}
token , err := o . TryTokenRefresh ( ctx , usr )
if err != nil {
if errors . Is ( err , ErrNoRefreshTokenFound ) {
return buildOAuthTokenFromAuthInfo ( authInfo )
@ -90,6 +108,9 @@ func (o *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
// HasOAuthEntry returns true and the UserAuth object when OAuth info exists for the specified User
func ( o * Service ) HasOAuthEntry ( ctx context . Context , usr identity . Requester ) ( * login . UserAuth , bool , error ) {
ctx , span := o . tracer . Start ( ctx , "oauthtoken.HasOAuthEntry" )
defer span . End ( )
if usr == nil || usr . IsNil ( ) {
// No user, therefore no token
return nil , false , nil
@ -126,76 +147,105 @@ func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*l
}
// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful
// It uses a singleflight.Group to prevent getting the Refresh Token multiple times for a given User
func ( o * Service ) TryTokenRefresh ( ctx context . Context , usr identity . Requester ) error {
// It uses a server lock to prevent getting the Refresh Token multiple times for a given User
func ( o * Service ) TryTokenRefresh ( ctx context . Context , usr identity . Requester ) ( * oauth2 . Token , error ) {
ctx , span := o . tracer . Start ( ctx , "oauthtoken.TryTokenRefresh" )
defer span . End ( )
ctxLogger := logger . FromContext ( ctx )
if usr == nil || usr . IsNil ( ) {
ctxLogger . Warn ( "Can only refresh OAuth tokens for existing users" , "user" , "nil" )
// Not user, no token.
return nil
return nil , nil
}
if ! usr . IsIdentityType ( claims . TypeUser ) {
ctxLogger . Warn ( "Can only refresh OAuth tokens for users" , "id" , usr . GetID ( ) )
return nil
return nil , nil
}
userID , err := usr . GetInternalID ( )
if err != nil {
ctxLogger . Warn ( "Failed to convert user id to int" , "id" , usr . GetID ( ) , "error" , err )
return nil
return nil , nil
}
ctxLogger = ctxLogger . New ( "userID" , userID )
// get the token's auth provider (f.e. azuread)
currAuthenticator := usr . GetAuthenticatedBy ( )
if ! strings . HasPrefix ( currAuthenticator , "oauth" ) {
ctxLogger . Warn ( "The specified user's auth provider is not OAuth" , "authmodule" , currAuthenticator )
return nil , nil
}
provider := strings . TrimPrefix ( currAuthenticator , "oauth_" )
currentOAuthInfo := o . SocialService . GetOAuthInfoProvider ( provider )
if currentOAuthInfo == nil {
ctxLogger . Warn ( "OAuth provider not found" , "provider" , provider )
return nil , nil
}
// if refresh token handling is disabled for this provider, we can skip the refresh
if ! currentOAuthInfo . UseRefreshToken {
ctxLogger . Debug ( "Skipping token refresh" , "provider" , provider )
return nil , nil
}
lockKey := fmt . Sprintf ( "oauth-refresh-token-%d" , userID )
if _ , ok := o . cache . Get ( lockKey ) ; ok {
ctxLogger . Debug ( "Expiration check has been cached, no need to refresh" )
return nil
lockTimeConfig := serverlock . LockTimeConfig {
MaxInterval : 30 * time . Second ,
MinWait : time . Duration ( o . Cfg . OAuthRefreshTokenServerLockMinWaitMs ) * time . Millisecond ,
MaxWait : time . Duration ( o . Cfg . OAuthRefreshTokenServerLockMinWaitMs + 500 ) * time . Millisecond ,
}
_ , err , _ = o . singleFlightGroup . Do ( lockKey , func ( ) ( any , error ) {
ctxLogger . Debug ( "Singleflight request for getting a new access token" , "key" , lockKey )
retryOpt := func ( attempts int ) error {
if attempts < 5 {
return nil
}
return ErrCouldntRefreshToken
}
var newToken * oauth2 . Token
var cmdErr error
lockErr := o . serverLock . LockExecuteAndReleaseWithRetries ( ctx , lockKey , lockTimeConfig , func ( ctx context . Context ) {
ctx , span := o . tracer . Start ( ctx , "oauthtoken server lock" ,
trace . WithAttributes ( attribute . Int64 ( "userID" , userID ) ) )
defer span . End ( )
ctxLogger . Debug ( "Serverlock request for getting a new access token" , "key" , lockKey )
authInfo , exists , err := o . HasOAuthEntry ( ctx , usr )
if ! exists {
if err != nil {
ctxLogger . Debug ( "Failed to fetch oauth entry" , "error" , err )
} else {
// User is not logged in via OAuth no need to check
o . cache . Set ( lockKey , struct { } { } , maxOAuthTokenCacheTTL )
}
return nil , nil
return
}
_ , needRefresh , ttl := needTokenRefresh ( authInfo )
storedToken , needRefresh := needTokenRefresh ( authInfo )
if ! needRefresh {
o . cache . Set ( lockKey , struct { } { } , ttl )
return nil , nil
// Set the token which is returned by the outer function in case there's no need to refresh the token
newToken = storedToken
return
}
// get the token's auth provider (f.e. azuread)
provider := strings . TrimPrefix ( authInfo . AuthModule , "oauth_" )
currentOAuthInfo := o . SocialService . GetOAuthInfoProvider ( provider )
if currentOAuthInfo == nil {
ctxLogger . Warn ( "OAuth provider not found" , "provider" , provider )
return nil , nil
}
// if refresh token handling is disabled for this provider, we can skip the refresh
if ! currentOAuthInfo . UseRefreshToken {
ctxLogger . Debug ( "Skipping token refresh" , "provider" , provider )
return nil , nil
}
newToken , cmdErr = o . tryGetOrRefreshOAuthToken ( ctx , authInfo )
} , retryOpt )
if lockErr != nil {
ctxLogger . Error ( "Failed to obtain token refresh lock" , "error" , lockErr )
return nil , lockErr
}
return o . tryGetOrRefreshOAuthToken ( ctx , authInfo )
} )
// Silence ErrNoRefreshTokenFound
if errors . Is ( e rr, ErrNoRefreshTokenFound ) {
return nil
if errors . Is ( cmdErr , ErrNoRefreshTokenFound ) {
return nil , nil
}
return err
return newToken , cmdErr
}
func buildOAuthTokenFromAuthInfo ( authInfo * login . UserAuth ) * oauth2 . Token {
@ -230,11 +280,11 @@ func checkOAuthRefreshToken(authInfo *login.UserAuth) error {
}
// InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero
func ( o * Service ) InvalidateOAuthTokens ( ctx context . Context , usr * login . UserAuth ) error {
func ( o * Service ) InvalidateOAuthTokens ( ctx context . Context , authInfo * login . UserAuth ) error {
return o . AuthInfoService . UpdateAuthInfo ( ctx , & login . UpdateAuthInfoCommand {
UserId : usr . UserId ,
AuthModule : usr . AuthModule ,
AuthId : usr . AuthId ,
UserId : authInfo . UserId ,
AuthModule : authInfo . AuthModule ,
AuthId : authInfo . AuthId ,
OAuthToken : & oauth2 . Token {
AccessToken : "" ,
RefreshToken : "" ,
@ -243,36 +293,33 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *login.UserAuth
} )
}
func ( o * Service ) tryGetOrRefreshOAuthToken ( ctx context . Context , usr * login . UserAuth ) ( * oauth2 . Token , error ) {
ctxLogger := logger . FromContext ( ctx ) . New ( "userID" , usr . UserId )
func ( o * Service ) tryGetOrRefreshOAuthToken ( ctx context . Context , authInfo * login . UserAuth ) ( * oauth2 . Token , error ) {
ctx , span := o . tracer . Start ( ctx , "oauthtoken.tryGetOrRefreshOAuthToken" ,
trace . WithAttributes ( attribute . Int64 ( "userID" , authInfo . UserId ) ) )
defer span . End ( )
key := getCheckCacheKey ( usr . UserId )
if _ , ok := o . cache . Get ( key ) ; ok {
ctxLogger . Debug ( "Expiration check has been cached" , "userID" , usr . UserId )
return buildOAuthTokenFromAuthInfo ( usr ) , nil
}
ctxLogger := logger . FromContext ( ctx ) . New ( "userID" , authInfo . UserId )
if err := checkOAuthRefreshToken ( usr ) ; err != nil {
if err := checkOAuthRefreshToken ( authInfo ) ; err != nil {
return nil , err
}
persistedToken , refreshNeeded , ttl := needTokenRefresh ( usr )
persistedToken , refreshNeeded := needTokenRefresh ( authInfo )
if ! refreshNeeded {
o . cache . Set ( key , struct { } { } , ttl )
return persistedToken , nil
}
authProvider := usr . AuthModule
authProvider := authInfo . AuthModule
connect , err := o . SocialService . GetConnector ( authProvider )
if err != nil {
ctxLogger . Error ( "Failed to get oauth connector" , "provider" , authProvider , "error" , err )
return nil , err
return persistedToken , err
}
client , err := o . SocialService . GetOAuthHttpClient ( authProvider )
if err != nil {
ctxLogger . Error ( "Failed to get oauth http client" , "provider" , authProvider , "error" , err )
return nil , err
return persistedToken , err
}
ctx = context . WithValue ( ctx , oauth2 . HTTPClient , client )
@ -284,22 +331,28 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, usr *login.User
if err != nil {
ctxLogger . Error ( "Failed to retrieve oauth access token" ,
"provider" , usr . AuthModule , "userId" , usr . UserId , "error" , err )
"provider" , authInfo . AuthModule , "userId" , authInfo . UserId , "error" , err )
// token refresh failed, invalidate the old token
if err := o . InvalidateOAuthTokens ( ctx , authInfo ) ; err != nil {
ctxLogger . Warn ( "Failed to invalidate OAuth tokens" , "id" , authInfo . Id , "error" , err )
}
return nil , err
}
// If the tokens are not the same, update the entry in the DB
if ! tokensEq ( persistedToken , token ) {
updateAuthCommand := & login . UpdateAuthInfoCommand {
UserId : usr . UserId ,
AuthModule : usr . AuthModule ,
AuthId : usr . AuthId ,
UserId : authInfo . UserId ,
AuthModule : authInfo . AuthModule ,
AuthId : authInfo . AuthId ,
OAuthToken : token ,
}
if o . Cfg . Env == setting . Dev {
ctxLogger . Debug ( "Oauth got token" ,
"auth_module" , usr . AuthModule ,
"auth_module" , authInfo . AuthModule ,
"expiry" , fmt . Sprintf ( "%v" , token . Expiry ) ,
"access_token" , fmt . Sprintf ( "%v" , token . AccessToken ) ,
"refresh_token" , fmt . Sprintf ( "%v" , token . RefreshToken ) ,
@ -307,8 +360,8 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, usr *login.User
}
if err := o . AuthInfoService . UpdateAuthInfo ( ctx , updateAuthCommand ) ; err != nil {
ctxLogger . Error ( "Failed to update auth info during token refresh" , "userId" , usr . UserId , "error" , err )
return nil , err
ctxLogger . Error ( "Failed to update auth info during token refresh" , "userId" , authInfo . UserId , "error" , err )
return token , err
}
ctxLogger . Debug ( "Updated oauth info for user" )
}
@ -348,63 +401,44 @@ func tokensEq(t1, t2 *oauth2.Token) bool {
t1IdToken == t2IdToken
}
func needTokenRefresh ( usr * login . UserAuth ) ( * oauth2 . Token , bool , time . Duration ) {
var accessTokenExpires , idTokenExpires time . Time
func needTokenRefresh ( authInfo * login . UserAuth ) ( * oauth2 . Token , bool ) {
var hasAccessTokenExpired , hasIdTokenExpired bool
persistedToken := buildOAuthTokenFromAuthInfo ( usr )
idTokenExp , err := getIDTokenExpiry ( usr )
persistedToken := buildOAuthTokenFromAuthInfo ( authInfo )
idTokenExp , err := GetIDTokenExpiry ( persistedToken )
if err != nil {
logger . Warn ( "Could not get ID Token expiry" , "error" , err )
}
if ! persistedToken . Expiry . IsZero ( ) {
accessTokenExpires , hasAccessTokenExpired = getExpiryWithSkew ( persistedToken . Expiry )
_ , hasAccessTokenExpired = getExpiryWithSkew ( persistedToken . Expiry )
}
if ! idTokenExp . IsZero ( ) {
idTokenExpires , hasIdTokenExpired = getExpiryWithSkew ( idTokenExp )
_ , hasIdTokenExpired = getExpiryWithSkew ( idTokenExp )
}
if ! hasAccessTokenExpired && ! hasIdTokenExpired {
logger . Debug ( "Neither access nor id token have expired yet" , "userID" , usr . UserId )
return persistedToken , false , getOAuthTokenCacheTTL ( accessTokenExpires , idTokenExpires )
logger . Debug ( "Neither access nor id token have expired yet" , "userID" , authInfo . UserId )
return persistedToken , false
}
if hasIdTokenExpired {
// Force refreshing token when id token is expired
persistedToken . AccessToken = ""
}
return persistedToken , true , time . Second
return persistedToken , true
}
func getCheckCacheKey ( usrID int64 ) string {
return fmt . Sprintf ( "token-check-%d" , usrID )
}
func getOAuthTokenCacheTTL ( accessTokenExpiry , idTokenExpiry time . Time ) time . Duration {
min := maxOAuthTokenCacheTTL
if ! accessTokenExpiry . IsZero ( ) {
d := time . Until ( accessTokenExpiry )
if d < min {
min = d
}
}
if ! idTokenExpiry . IsZero ( ) {
d := time . Until ( idTokenExpiry )
if d < min {
min = d
}
}
if accessTokenExpiry . IsZero ( ) && idTokenExpiry . IsZero ( ) {
return maxOAuthTokenCacheTTL
// GetIDTokenExpiry extracts the expiry time from the ID token
func GetIDTokenExpiry ( token * oauth2 . Token ) ( time . Time , error ) {
idToken , ok := token . Extra ( "id_token" ) . ( string )
if ! ok {
return time . Time { } , nil
}
return min
}
// getIDTokenExpiry extracts the expiry time from the ID token
func getIDTokenExpiry ( usr * login . UserAuth ) ( time . Time , error ) {
if usr . OAuthIdToken == "" {
if idToken == "" {
return time . Time { } , nil
}
parsedToken , err := jwt . ParseSigned ( usr . OAuthI dToken)
parsedToken , err := jwt . ParseSigned ( idToken )
if err != nil {
return time . Time { } , fmt . Errorf ( "error parsing id token: %w" , err )
}