The open and composable observability and data visualization platform. Visualize metrics, logs, and traces from multiple sources like Prometheus, Loki, Elasticsearch, InfluxDB, Postgres and many more.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
grafana/pkg/services/authn/authnimpl/sync/oauth_token_sync.go

208 lines
6.6 KiB

package sync
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/go-jose/go-jose/v3/jwt"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken"
)
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
return &OAuthTokenSync{
log.New("oauth_token.sync"),
localcache.New(maxOAuthTokenCacheTTL, 15*time.Minute),
service,
sessionService,
socialService,
new(singleflight.Group),
}
}
type OAuthTokenSync struct {
log log.Logger
cache *localcache.CacheService
service oauthtoken.OAuthTokenService
sessionService auth.UserTokenService
socialService social.Service
sf *singleflight.Group
}
func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
namespace, _ := identity.NamespacedID()
// only perform oauth token check if identity is a user
if namespace != authn.NamespaceUser {
return nil
}
// not authenticated through session tokens, so we can skip this hook
if identity.SessionToken == nil {
return nil
}
// if we recently have performed this it would be cached, so we can skip the hook
if _, ok := s.cache.Get(identity.ID); ok {
s.log.FromContext(ctx).Debug("OAuth token check is cached", "id", identity.ID)
return nil
}
token, exists, err := s.service.HasOAuthEntry(ctx, identity)
// user is not authenticated through oauth so skip further checks
if !exists {
if err != nil {
s.log.FromContext(ctx).Error("Failed to fetch oauth entry", "id", identity.ID, "error", err)
}
return nil
}
idTokenExpiry, err := getIDTokenExpiry(token)
if err != nil {
s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err)
}
// token has no expire time configured, so we don't have to refresh it
if token.OAuthExpiry.IsZero() {
s.log.FromContext(ctx).Debug("Access token without expiry", "id", identity.ID)
// cache the token check, so we don't perform it on every request
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry))
return nil
}
// get the token's auth provider (f.e. azuread)
provider := strings.TrimPrefix(token.AuthModule, "oauth_")
currentOAuthInfo := s.socialService.GetOAuthInfoProvider(provider)
if currentOAuthInfo == nil {
s.log.Warn("OAuth provider not found", "provider", provider)
return nil
}
// if refresh token handling is disabled for this provider, we can skip the hook
if !currentOAuthInfo.UseRefreshToken {
return nil
}
accessTokenExpires, hasAccessTokenExpired := getExpiryWithSkew(token.OAuthExpiry)
hasIdTokenExpired := false
idTokenExpires := time.Time{}
if !idTokenExpiry.IsZero() {
idTokenExpires, hasIdTokenExpired = getExpiryWithSkew(idTokenExpiry)
}
// token has not expired, so we don't have to refresh it
if !hasAccessTokenExpired && !hasIdTokenExpired {
s.log.FromContext(ctx).Debug("Access and id token has not expired yet", "id", identity.ID)
// cache the token check, so we don't perform it on every request
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
return nil
}
_, err, _ = s.sf.Do(identity.ID, func() (interface{}, error) {
s.log.Debug("Singleflight request for OAuth token sync", "key", identity.ID)
// FIXME: Consider using context.WithoutCancel instead of context.Background after Go 1.21 update
updateCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if refreshErr := s.service.TryTokenRefresh(updateCtx, token); refreshErr != nil {
if errors.Is(refreshErr, context.Canceled) {
return nil, nil
}
token, _, err := s.service.HasOAuthEntry(ctx, identity)
if err != nil {
s.log.Error("Failed to get OAuth entry for verifying if token has already been refreshed", "id", identity.ID, "error", err)
return nil, err
}
// if the access token has already been refreshed by another request (for example in HA scenario)
tokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
if !tokenExpires.Before(time.Now()) {
return nil, nil
}
s.log.Error("Failed to refresh OAuth access token", "id", identity.ID, "error", refreshErr)
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil {
s.log.Warn("Failed to invalidate OAuth tokens", "id", identity.ID, "error", err)
}
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
s.log.Warn("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
}
return nil, refreshErr
}
return nil, nil
})
if err != nil {
return authn.ErrExpiredAccessToken.Errorf("OAuth access token could not be refreshed: %w", err)
}
return nil
}
const maxOAuthTokenCacheTTL = 10 * time.Minute
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
return maxOAuthTokenCacheTTL
}
min := func(a, b time.Duration) time.Duration {
if a <= b {
return a
}
return b
}
if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() {
return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL)
}
if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL)
}
return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL)
}
// getIDTokenExpiry extracts the expiry time from the ID token
func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) {
if token.OAuthIdToken == "" {
return time.Time{}, nil
}
parsedToken, err := jwt.ParseSigned(token.OAuthIdToken)
if err != nil {
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
}
type Claims struct {
Exp int64 `json:"exp"`
}
var claims Claims
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
}
return time.Unix(claims.Exp, 0), nil
}
func getExpiryWithSkew(expiry time.Time) (adjustedExpiry time.Time, hasTokenExpired bool) {
adjustedExpiry = expiry.Round(0).Add(-oauthtoken.ExpiryDelta)
hasTokenExpired = adjustedExpiry.Before(time.Now())
return
}