|
|
|
|
@ -7,6 +7,7 @@ import ( |
|
|
|
|
"strings" |
|
|
|
|
"time" |
|
|
|
|
|
|
|
|
|
"github.com/prometheus/client_golang/prometheus" |
|
|
|
|
"golang.org/x/oauth2" |
|
|
|
|
"golang.org/x/sync/singleflight" |
|
|
|
|
|
|
|
|
|
@ -33,6 +34,8 @@ type Service struct { |
|
|
|
|
SocialService social.Service |
|
|
|
|
AuthInfoService login.AuthInfoService |
|
|
|
|
singleFlightGroup *singleflight.Group |
|
|
|
|
|
|
|
|
|
tokenRefreshDuration *prometheus.HistogramVec |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type OAuthTokenService interface { |
|
|
|
|
@ -43,12 +46,13 @@ type OAuthTokenService interface { |
|
|
|
|
InvalidateOAuthTokens(context.Context, *login.UserAuth) error |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg) *Service { |
|
|
|
|
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer) *Service { |
|
|
|
|
return &Service{ |
|
|
|
|
Cfg: cfg, |
|
|
|
|
SocialService: socialService, |
|
|
|
|
AuthInfoService: authInfoService, |
|
|
|
|
singleFlightGroup: new(singleflight.Group), |
|
|
|
|
Cfg: cfg, |
|
|
|
|
SocialService: socialService, |
|
|
|
|
AuthInfoService: authInfoService, |
|
|
|
|
singleFlightGroup: new(singleflight.Group), |
|
|
|
|
tokenRefreshDuration: newTokenRefreshDurationMetric(registerer), |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@ -212,8 +216,12 @@ func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *login.Use |
|
|
|
|
|
|
|
|
|
persistedToken := buildOAuthTokenFromAuthInfo(usr) |
|
|
|
|
|
|
|
|
|
start := time.Now() |
|
|
|
|
// TokenSource handles refreshing the token if it has expired
|
|
|
|
|
token, err := connect.TokenSource(ctx, persistedToken).Token() |
|
|
|
|
duration := time.Since(start) |
|
|
|
|
o.tokenRefreshDuration.WithLabelValues(authProvider, fmt.Sprintf("%t", err == nil)).Observe(duration.Seconds()) |
|
|
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
logger.Error("Failed to retrieve oauth access token", |
|
|
|
|
"provider", usr.AuthModule, "userId", usr.UserId, "error", err) |
|
|
|
|
@ -254,6 +262,20 @@ func IsOAuthPassThruEnabled(ds *datasources.DataSource) bool { |
|
|
|
|
return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func newTokenRefreshDurationMetric(registerer prometheus.Registerer) *prometheus.HistogramVec { |
|
|
|
|
tokenRefreshDuration := prometheus.NewHistogramVec(prometheus.HistogramOpts{ |
|
|
|
|
Namespace: "grafana", |
|
|
|
|
Subsystem: "oauth", |
|
|
|
|
Name: "token_refresh_fetch_duration_seconds", |
|
|
|
|
Help: "Time taken to fetch access token using refresh token", |
|
|
|
|
}, |
|
|
|
|
[]string{"auth_provider", "success"}) |
|
|
|
|
if registerer != nil { |
|
|
|
|
registerer.MustRegister(tokenRefreshDuration) |
|
|
|
|
} |
|
|
|
|
return tokenRefreshDuration |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in
|
|
|
|
|
func tokensEq(t1, t2 *oauth2.Token) bool { |
|
|
|
|
return t1.AccessToken == t2.AccessToken && |
|
|
|
|
|