diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index 30a1e787ab0..3176164fd34 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -70,61 +70,61 @@ func genPKCECode() (string, string, error) { return string(ascii), pkce, nil } -func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) { - name := web.Params(ctx.Req)[":name"] +func (hs *HTTPServer) OAuthLogin(reqCtx *contextmodel.ReqContext) { + name := web.Params(reqCtx.Req)[":name"] loginInfo := loginservice.LoginInfo{AuthModule: name} - if errorParam := ctx.Query("error"); errorParam != "" { - errorDesc := ctx.Query("error_description") + if errorParam := reqCtx.Query("error"); errorParam != "" { + errorDesc := reqCtx.Query("error_description") oauthLogger.Error("failed to login ", "error", errorParam, "errorDesc", errorDesc) - hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc) + hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc) return } - code := ctx.Query("code") + code := reqCtx.Query("code") if hs.Cfg.AuthBrokerEnabled { - req := &authn.Request{HTTPRequest: ctx.Req, Resp: ctx.Resp} + req := &authn.Request{HTTPRequest: reqCtx.Req, Resp: reqCtx.Resp} if code == "" { - redirect, err := hs.authnService.RedirectURL(ctx.Req.Context(), authn.ClientWithPrefix(name), req) + redirect, err := hs.authnService.RedirectURL(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req) if err != nil { - ctx.Redirect(hs.redirectURLWithErrorCookie(ctx, err)) + reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err)) return } if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" { - cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) + cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) } - cookies.WriteCookie(ctx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) - ctx.Redirect(redirect.URL) + cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) + reqCtx.Redirect(redirect.URL) return } - identity, err := hs.authnService.Login(ctx.Req.Context(), authn.ClientWithPrefix(name), req) + identity, err := hs.authnService.Login(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req) // NOTE: always delete these cookies, even if login failed - cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg) - cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg) + cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg) + cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg) if err != nil { - ctx.Redirect(hs.redirectURLWithErrorCookie(ctx, err)) + reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err)) return } metrics.MApiLoginOAuth.Inc() - authn.HandleLoginRedirect(ctx.Req, ctx.Resp, hs.Cfg, identity, hs.ValidateRedirectTo) + authn.HandleLoginRedirect(reqCtx.Req, reqCtx.Resp, hs.Cfg, identity, hs.ValidateRedirectTo) return } provider := hs.SocialService.GetOAuthInfoProvider(name) if provider == nil { - hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, errors.New("OAuth not enabled")) + hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, errors.New("OAuth not enabled")) return } connect, err := hs.SocialService.GetConnector(name) if err != nil { - hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, fmt.Errorf("no OAuth with name %s configured", name)) + hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, fmt.Errorf("no OAuth with name %s configured", name)) return } @@ -133,15 +133,15 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) { if provider.UsePKCE { ascii, pkce, err := genPKCECode() if err != nil { - ctx.Logger.Error("Generating PKCE failed", "error", err) - hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ + reqCtx.Logger.Error("Generating PKCE failed", "error", err) + hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{ HttpStatus: http.StatusInternalServerError, PublicMessage: "An internal error occurred", }) return } - cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) + cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) opts = append(opts, oauth2.SetAuthURLParam("code_challenge", pkce), @@ -151,8 +151,8 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) { state, err := GenStateString() if err != nil { - ctx.Logger.Error("Generating state string failed", "err", err) - hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ + reqCtx.Logger.Error("Generating state string failed", "err", err) + hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{ HttpStatus: http.StatusInternalServerError, PublicMessage: "An internal error occurred", }) @@ -160,32 +160,32 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) { } hashedState := hs.hashStatecode(state, provider.ClientSecret) - cookies.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) + cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) if provider.HostedDomain != "" { opts = append(opts, oauth2.SetAuthURLParam("hd", provider.HostedDomain)) } - ctx.Redirect(connect.AuthCodeURL(state, opts...)) + reqCtx.Redirect(connect.AuthCodeURL(state, opts...)) return } - cookieState := ctx.GetCookie(OauthStateCookieName) + cookieState := reqCtx.GetCookie(OauthStateCookieName) // delete cookie - cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg) + cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg) if cookieState == "" { - hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ + hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{ HttpStatus: http.StatusInternalServerError, PublicMessage: "login.OAuthLogin(missing saved state)", }) return } - queryState := hs.hashStatecode(ctx.Query("state"), provider.ClientSecret) + queryState := hs.hashStatecode(reqCtx.Query("state"), provider.ClientSecret) oauthLogger.Info("state check", "queryState", queryState, "cookieState", cookieState) if cookieState != queryState { - hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ + hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{ HttpStatus: http.StatusInternalServerError, PublicMessage: "login.OAuthLogin(state mismatch)", }) @@ -194,19 +194,20 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) { oauthClient, err := hs.SocialService.GetOAuthHttpClient(name) if err != nil { - ctx.Logger.Error("Failed to create OAuth http client", "error", err) - hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ + reqCtx.Logger.Error("Failed to create OAuth http client", "error", err) + hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{ HttpStatus: http.StatusInternalServerError, PublicMessage: "login.OAuthLogin(" + err.Error() + ")", }) return } - oauthCtx := context.WithValue(context.Background(), oauth2.HTTPClient, oauthClient) + ctx := reqCtx.Req.Context() + oauthCtx := context.WithValue(ctx, oauth2.HTTPClient, oauthClient) opts := []oauth2.AuthCodeOption{} - codeVerifier := ctx.GetCookie(OauthPKCECookieName) - cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg) + codeVerifier := reqCtx.GetCookie(OauthPKCECookieName) + cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg) if codeVerifier != "" { opts = append(opts, oauth2.SetAuthURLParam("code_verifier", codeVerifier), @@ -216,7 +217,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) { // get token from provider token, err := connect.Exchange(oauthCtx, code, opts...) if err != nil { - hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ + hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{ HttpStatus: http.StatusInternalServerError, PublicMessage: "login.OAuthLogin(NewTransportWithCode)", Err: err, @@ -245,13 +246,13 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) { client := connect.Client(oauthCtx, token) // get user info - userInfo, err := connect.UserInfo(client, token) + userInfo, err := connect.UserInfo(ctx, client, token) if err != nil { var sErr *social.Error if errors.As(err, &sErr) { - hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, sErr) + hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, sErr) } else { - hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ + hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{ HttpStatus: http.StatusInternalServerError, PublicMessage: fmt.Sprintf("login.OAuthLogin(get info from %s)", name), Err: err, @@ -264,34 +265,34 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) { // validate that we got at least an email address if userInfo.Email == "" { - hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrNoEmail) + hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrNoEmail) return } // validate that the email is allowed to login to grafana if !connect.IsEmailAllowed(userInfo.Email) { - hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrEmailNotAllowed) + hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrEmailNotAllowed) return } loginInfo.ExternalUser = *hs.buildExternalUserInfo(token, userInfo, name) - loginInfo.User, err = hs.SyncUser(ctx, &loginInfo.ExternalUser, connect) + loginInfo.User, err = hs.SyncUser(reqCtx, &loginInfo.ExternalUser, connect) if err != nil { - hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, err) + hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, err) return } // login - if err := hs.loginUserWithUser(loginInfo.User, ctx); err != nil { - hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, err) + if err := hs.loginUserWithUser(loginInfo.User, reqCtx); err != nil { + hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, err) return } loginInfo.HTTPStatus = http.StatusOK - hs.HooksService.RunLoginHook(&loginInfo, ctx) + hs.HooksService.RunLoginHook(&loginInfo, reqCtx) metrics.MApiLoginOAuth.Inc() - ctx.Redirect(hs.GetRedirectURL(ctx)) + reqCtx.Redirect(hs.GetRedirectURL(reqCtx)) } // buildExternalUserInfo returns a ExternalUserInfo struct from OAuth user profile diff --git a/pkg/login/social/azuread_oauth.go b/pkg/login/social/azuread_oauth.go index 32d65e56619..603eabaa5a8 100644 --- a/pkg/login/social/azuread_oauth.go +++ b/pkg/login/social/azuread_oauth.go @@ -2,6 +2,7 @@ package social import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -47,7 +48,7 @@ type azureAccessClaims struct { TenantID string `json:"tid"` } -func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { idToken := token.Extra("id_token") if idToken == nil { return nil, ErrIDTokenNotFound @@ -83,7 +84,7 @@ func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*Bas } logger.Debug("AzureAD OAuth: extracted role", "email", email, "role", role) - groups, err := s.extractGroups(client, claims, token) + groups, err := s.extractGroups(ctx, client, claims, token) if err != nil { return nil, fmt.Errorf("failed to extract groups: %w", err) } @@ -176,7 +177,7 @@ type getAzureGroupResponse struct { // Note: If user groups exceeds 200 no groups will be found in claims and URL to target the Graph API will be // given instead. // See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim -func (s *SocialAzureAD) extractGroups(client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) { +func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) { if !s.forceUseGraphAPI { logger.Debug("checking the claim for groups") if len(claims.Groups) > 0 { @@ -199,7 +200,13 @@ func (s *SocialAzureAD) extractGroups(client *http.Client, claims azureClaims, t return nil, err } - res, err := client.Post(endpoint, "application/json", bytes.NewBuffer(data)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + res, err := client.Do(req) if err != nil { return nil, err } diff --git a/pkg/login/social/azuread_oauth_test.go b/pkg/login/social/azuread_oauth_test.go index ecdda16e01d..08e29389554 100644 --- a/pkg/login/social/azuread_oauth_test.go +++ b/pkg/login/social/azuread_oauth_test.go @@ -473,7 +473,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) { tt.args.client = s.Client(context.Background(), token) } - got, err := s.UserInfo(tt.args.client, token) + got, err := s.UserInfo(context.Background(), tt.args.client, token) if (err != nil) != tt.wantErr { t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr) return @@ -617,7 +617,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) { tt.args.client = s.Client(context.Background(), token) } - got, err := s.UserInfo(tt.args.client, token) + got, err := s.UserInfo(context.Background(), tt.args.client, token) if (err != nil) != tt.wantErr { t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/login/social/common.go b/pkg/login/social/common.go index 1dfea668312..3947919ec84 100644 --- a/pkg/login/social/common.go +++ b/pkg/login/social/common.go @@ -1,6 +1,7 @@ package social import ( + "context" "encoding/json" "errors" "fmt" @@ -42,10 +43,15 @@ func isEmailAllowed(email string, allowedDomains []string) bool { return valid } -func (s *SocialBase) httpGet(client *http.Client, url string) (response httpGetResponse, err error) { - r, err := client.Get(url) - if err != nil { - return +func (s *SocialBase) httpGet(ctx context.Context, client *http.Client, url string) (*httpGetResponse, error) { + req, errReq := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if errReq != nil { + return nil, errReq + } + + r, errDo := client.Do(req) + if errDo != nil { + return nil, errDo } defer func() { @@ -54,21 +60,20 @@ func (s *SocialBase) httpGet(client *http.Client, url string) (response httpGetR } }() - body, err := io.ReadAll(r.Body) - if err != nil { - return + body, errRead := io.ReadAll(r.Body) + if errRead != nil { + return nil, errRead } - response = httpGetResponse{body, r.Header} + response := &httpGetResponse{body, r.Header} if r.StatusCode >= 300 { - err = fmt.Errorf(string(response.Body)) - return + return nil, fmt.Errorf("unsuccessful response status code %d: %s", r.StatusCode, string(response.Body)) } + s.log.Debug("HTTP GET", "url", url, "status", r.Status, "response_body", string(response.Body)) - err = nil - return + return response, nil } func (s *SocialBase) searchJSONForAttr(attributePath string, data []byte) (interface{}, error) { diff --git a/pkg/login/social/generic_oauth.go b/pkg/login/social/generic_oauth.go index e755b85e4eb..b3731e7e329 100644 --- a/pkg/login/social/generic_oauth.go +++ b/pkg/login/social/generic_oauth.go @@ -3,6 +3,7 @@ package social import ( "bytes" "compress/zlib" + "context" "encoding/base64" "encoding/json" "errors" @@ -50,12 +51,12 @@ func (s *SocialGenericOAuth) IsGroupMember(groups []string) bool { return false } -func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool { +func (s *SocialGenericOAuth) IsTeamMember(ctx context.Context, client *http.Client) bool { if len(s.teamIds) == 0 { return true } - teamMemberships, err := s.FetchTeamMemberships(client) + teamMemberships, err := s.FetchTeamMemberships(ctx, client) if err != nil { return false } @@ -71,12 +72,12 @@ func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool { return false } -func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool { +func (s *SocialGenericOAuth) IsOrganizationMember(ctx context.Context, client *http.Client) bool { if len(s.allowedOrganizations) == 0 { return true } - organizations, ok := s.FetchOrganizations(client) + organizations, ok := s.FetchOrganizations(ctx, client) if !ok { return false } @@ -111,14 +112,14 @@ func (info *UserInfoJson) String() string { info.Name, info.DisplayName, info.Login, info.Username, info.Email, info.Upn, info.Attributes) } -func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialGenericOAuth) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { s.log.Debug("Getting user info") toCheck := make([]*UserInfoJson, 0, 2) if tokenData := s.extractFromToken(token); tokenData != nil { toCheck = append(toCheck, tokenData) } - if apiData := s.extractFromAPI(client); apiData != nil { + if apiData := s.extractFromAPI(ctx, client); apiData != nil { toCheck = append(toCheck, apiData) } @@ -179,7 +180,7 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) if userInfo.Email == "" { var err error - userInfo.Email, err = s.FetchPrivateEmail(client) + userInfo.Email, err = s.FetchPrivateEmail(ctx, client) if err != nil { return nil, err } @@ -191,11 +192,11 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) userInfo.Login = userInfo.Email } - if !s.IsTeamMember(client) { + if !s.IsTeamMember(ctx, client) { return nil, errors.New("user not a member of one of the required teams") } - if !s.IsOrganizationMember(client) { + if !s.IsOrganizationMember(ctx, client) { return nil, errors.New("user not a member of one of the required organizations") } @@ -288,14 +289,14 @@ func (s *SocialGenericOAuth) extractFromToken(token *oauth2.Token) *UserInfoJson return &data } -func (s *SocialGenericOAuth) extractFromAPI(client *http.Client) *UserInfoJson { +func (s *SocialGenericOAuth) extractFromAPI(ctx context.Context, client *http.Client) *UserInfoJson { s.log.Debug("Getting user info from API") if s.apiUrl == "" { s.log.Debug("No api url configured") return nil } - rawUserInfoResponse, err := s.httpGet(client, s.apiUrl) + rawUserInfoResponse, err := s.httpGet(ctx, client, s.apiUrl) if err != nil { s.log.Debug("Error getting user info from API", "url", s.apiUrl, "error", err) return nil @@ -404,7 +405,7 @@ func (s *SocialGenericOAuth) extractGroups(data *UserInfoJson) ([]string, error) return s.searchJSONForStringArrayAttr(s.groupsAttributePath, data.rawJSON) } -func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, error) { +func (s *SocialGenericOAuth) FetchPrivateEmail(ctx context.Context, client *http.Client) (string, error) { type Record struct { Email string `json:"email"` Primary bool `json:"primary"` @@ -413,7 +414,7 @@ func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, err IsConfirmed bool `json:"is_confirmed"` } - response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/emails")) + response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/emails")) if err != nil { s.log.Error("Error getting email address", "url", s.apiUrl+"/emails", "error", err) return "", fmt.Errorf("%v: %w", "Error getting email address", err) @@ -451,14 +452,14 @@ func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, err return email, nil } -func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]string, error) { +func (s *SocialGenericOAuth) FetchTeamMemberships(ctx context.Context, client *http.Client) ([]string, error) { var err error var ids []string if s.teamsUrl == "" { - ids, err = s.fetchTeamMembershipsFromDeprecatedTeamsUrl(client) + ids, err = s.fetchTeamMembershipsFromDeprecatedTeamsUrl(ctx, client) } else { - ids, err = s.fetchTeamMembershipsFromTeamsUrl(client) + ids, err = s.fetchTeamMembershipsFromTeamsUrl(ctx, client) } if err == nil { @@ -468,16 +469,14 @@ func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]string return ids, err } -func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(client *http.Client) ([]string, error) { - var response httpGetResponse - var err error +func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(ctx context.Context, client *http.Client) ([]string, error) { var ids []string type Record struct { Id int `json:"id"` } - response, err = s.httpGet(client, fmt.Sprintf(s.apiUrl+"/teams")) + response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/teams")) if err != nil { s.log.Error("Error getting team memberships", "url", s.apiUrl+"/teams", "error", err) return []string{}, err @@ -499,15 +498,12 @@ func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(client * return ids, nil } -func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(client *http.Client) ([]string, error) { +func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(ctx context.Context, client *http.Client) ([]string, error) { if s.teamIdsAttributePath == "" { return []string{}, nil } - var response httpGetResponse - var err error - - response, err = s.httpGet(client, fmt.Sprintf(s.teamsUrl)) + response, err := s.httpGet(ctx, client, fmt.Sprintf(s.teamsUrl)) if err != nil { s.log.Error("Error getting team memberships", "url", s.teamsUrl, "error", err) return nil, err @@ -516,12 +512,12 @@ func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(client *http.Clien return s.searchJSONForStringArrayAttr(s.teamIdsAttributePath, response.Body) } -func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string, bool) { +func (s *SocialGenericOAuth) FetchOrganizations(ctx context.Context, client *http.Client) ([]string, bool) { type Record struct { Login string `json:"login"` } - response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/orgs")) + response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/orgs")) if err != nil { s.log.Error("Error getting organizations", "url", s.apiUrl+"/orgs", "error", err) return nil, false diff --git a/pkg/login/social/generic_oauth_test.go b/pkg/login/social/generic_oauth_test.go index a9e3a3537bc..d954b696c6b 100644 --- a/pkg/login/social/generic_oauth_test.go +++ b/pkg/login/social/generic_oauth_test.go @@ -1,6 +1,7 @@ package social import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -490,7 +491,7 @@ func TestUserInfoSearchesForEmailAndRole(t *testing.T) { } token := staticToken.WithExtra(test.OAuth2Extra) - actualResult, err := provider.UserInfo(ts.Client(), token) + actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token) require.NoError(t, err) require.Equal(t, test.ExpectedEmail, actualResult.Email) require.Equal(t, test.ExpectedEmail, actualResult.Login) @@ -588,7 +589,7 @@ func TestUserInfoSearchesForLogin(t *testing.T) { } token := staticToken.WithExtra(test.OAuth2Extra) - actualResult, err := provider.UserInfo(ts.Client(), token) + actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token) require.NoError(t, err) require.Equal(t, test.ExpectedLogin, actualResult.Login) }) @@ -686,7 +687,7 @@ func TestUserInfoSearchesForName(t *testing.T) { } token := staticToken.WithExtra(test.OAuth2Extra) - actualResult, err := provider.UserInfo(ts.Client(), token) + actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token) require.NoError(t, err) require.Equal(t, test.ExpectedName, actualResult.Name) }) @@ -755,7 +756,7 @@ func TestUserInfoSearchesForGroup(t *testing.T) { Expiry: time.Now(), } - userInfo, err := provider.UserInfo(ts.Client(), token) + userInfo, err := provider.UserInfo(context.Background(), ts.Client(), token) assert.NoError(t, err) assert.Equal(t, test.expectedResult, userInfo.Groups) }) diff --git a/pkg/login/social/github_oauth.go b/pkg/login/social/github_oauth.go index 4d04f412321..18584a3b8a4 100644 --- a/pkg/login/social/github_oauth.go +++ b/pkg/login/social/github_oauth.go @@ -1,6 +1,7 @@ package social import ( + "context" "encoding/json" "errors" "fmt" @@ -35,12 +36,12 @@ var ( ErrMissingOrganizationMembership = Error{"user not a member of one of the required organizations"} ) -func (s *SocialGithub) IsTeamMember(client *http.Client) bool { +func (s *SocialGithub) IsTeamMember(ctx context.Context, client *http.Client) bool { if len(s.teamIds) == 0 { return true } - teamMemberships, err := s.FetchTeamMemberships(client) + teamMemberships, err := s.FetchTeamMemberships(ctx, client) if err != nil { return false } @@ -56,12 +57,13 @@ func (s *SocialGithub) IsTeamMember(client *http.Client) bool { return false } -func (s *SocialGithub) IsOrganizationMember(client *http.Client, organizationsUrl string) bool { +func (s *SocialGithub) IsOrganizationMember(ctx context.Context, + client *http.Client, organizationsUrl string) bool { if len(s.allowedOrganizations) == 0 { return true } - organizations, err := s.FetchOrganizations(client, organizationsUrl) + organizations, err := s.FetchOrganizations(ctx, client, organizationsUrl) if err != nil { return false } @@ -77,14 +79,14 @@ func (s *SocialGithub) IsOrganizationMember(client *http.Client, organizationsUr return false } -func (s *SocialGithub) FetchPrivateEmail(client *http.Client) (string, error) { +func (s *SocialGithub) FetchPrivateEmail(ctx context.Context, client *http.Client) (string, error) { type Record struct { Email string `json:"email"` Primary bool `json:"primary"` Verified bool `json:"verified"` } - response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/emails")) + response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/emails")) if err != nil { return "", fmt.Errorf("Error getting email address: %s", err) } @@ -106,13 +108,13 @@ func (s *SocialGithub) FetchPrivateEmail(client *http.Client) (string, error) { return email, nil } -func (s *SocialGithub) FetchTeamMemberships(client *http.Client) ([]GithubTeam, error) { +func (s *SocialGithub) FetchTeamMemberships(ctx context.Context, client *http.Client) ([]GithubTeam, error) { url := fmt.Sprintf(s.apiUrl + "/teams?per_page=100") hasMore := true teams := make([]GithubTeam, 0) for hasMore { - response, err := s.httpGet(client, url) + response, err := s.httpGet(ctx, client, url) if err != nil { return nil, fmt.Errorf("Error getting team memberships: %s", err) } @@ -150,7 +152,7 @@ func (s *SocialGithub) HasMoreRecords(headers http.Header) (string, bool) { return url, true } -func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl string) ([]string, error) { +func (s *SocialGithub) FetchOrganizations(ctx context.Context, client *http.Client, organizationsUrl string) ([]string, error) { url := organizationsUrl hasMore := true logins := make([]string, 0) @@ -160,7 +162,7 @@ func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl } for hasMore { - response, err := s.httpGet(client, url) + response, err := s.httpGet(ctx, client, url) if err != nil { return nil, fmt.Errorf("error getting organizations: %s", err) } @@ -181,7 +183,7 @@ func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl return logins, nil } -func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialGithub) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { var data struct { Id int `json:"id"` Login string `json:"login"` @@ -189,7 +191,7 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi Name string `json:"name"` } - response, err := s.httpGet(client, s.apiUrl) + response, err := s.httpGet(ctx, client, s.apiUrl) if err != nil { return nil, fmt.Errorf("error getting user info: %s", err) } @@ -198,7 +200,7 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi return nil, fmt.Errorf("error unmarshalling user info: %s", err) } - teamMemberships, err := s.FetchTeamMemberships(client) + teamMemberships, err := s.FetchTeamMemberships(ctx, client) if err != nil { return nil, fmt.Errorf("error getting user teams: %s", err) } @@ -241,16 +243,16 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi organizationsUrl := fmt.Sprintf(s.apiUrl + "/orgs?per_page=100") - if !s.IsTeamMember(client) { + if !s.IsTeamMember(ctx, client) { return nil, ErrMissingTeamMembership } - if !s.IsOrganizationMember(client, organizationsUrl) { + if !s.IsOrganizationMember(ctx, client, organizationsUrl) { return nil, ErrMissingOrganizationMembership } if userInfo.Email == "" { - userInfo.Email, err = s.FetchPrivateEmail(client) + userInfo.Email, err = s.FetchPrivateEmail(ctx, client) if err != nil { return nil, err } diff --git a/pkg/login/social/github_oauth_test.go b/pkg/login/social/github_oauth_test.go index 9d92e21cd35..38cc4b0ed03 100644 --- a/pkg/login/social/github_oauth_test.go +++ b/pkg/login/social/github_oauth_test.go @@ -1,6 +1,7 @@ package social import ( + "context" "net/http" "net/http/httptest" "reflect" @@ -250,7 +251,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) { AccessToken: "fake_token", } - got, err := s.UserInfo(server.Client(), token) + got, err := s.UserInfo(context.Background(), server.Client(), token) if (err != nil) != tt.wantErr { t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/login/social/gitlab_oauth.go b/pkg/login/social/gitlab_oauth.go index 5f0f48ee960..96030227c86 100644 --- a/pkg/login/social/gitlab_oauth.go +++ b/pkg/login/social/gitlab_oauth.go @@ -1,6 +1,7 @@ package social import ( + "context" "encoding/json" "fmt" "net/http" @@ -34,10 +35,10 @@ func (s *SocialGitlab) IsGroupMember(groups []string) bool { return false } -func (s *SocialGitlab) GetGroups(client *http.Client) []string { +func (s *SocialGitlab) GetGroups(ctx context.Context, client *http.Client) []string { groups := make([]string, 0) - for page, url := s.GetGroupsPage(client, s.apiUrl+"/groups"); page != nil; page, url = s.GetGroupsPage(client, url) { + for page, url := s.GetGroupsPage(ctx, client, s.apiUrl+"/groups"); page != nil; page, url = s.GetGroupsPage(ctx, client, url) { groups = append(groups, page...) } @@ -45,7 +46,7 @@ func (s *SocialGitlab) GetGroups(client *http.Client) []string { } // GetGroupsPage returns groups and link to the next page if response is paginated -func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string, string) { +func (s *SocialGitlab) GetGroupsPage(ctx context.Context, client *http.Client, url string) ([]string, string) { type Group struct { FullPath string `json:"full_path"` } @@ -59,7 +60,7 @@ func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string, return nil, next } - response, err := s.httpGet(client, url) + response, err := s.httpGet(ctx, client, url) if err != nil { s.log.Error("Error getting groups from GitLab API", "err", err) return nil, next @@ -86,7 +87,7 @@ func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string, return fullPaths, next } -func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialGitlab) UserInfo(ctx context.Context, client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) { var data struct { Id int Username string @@ -95,7 +96,7 @@ func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUse State string } - response, err := s.httpGet(client, s.apiUrl+"/user") + response, err := s.httpGet(ctx, client, s.apiUrl+"/user") if err != nil { return nil, fmt.Errorf("Error getting user info: %s", err) } @@ -108,7 +109,7 @@ func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUse return nil, fmt.Errorf("user %s is inactive", data.Username) } - groups := s.GetGroups(client) + groups := s.GetGroups(ctx, client) var role roletype.RoleType var isGrafanaAdmin *bool = nil diff --git a/pkg/login/social/gitlab_oauth_test.go b/pkg/login/social/gitlab_oauth_test.go index a4eb8480447..54929623061 100644 --- a/pkg/login/social/gitlab_oauth_test.go +++ b/pkg/login/social/gitlab_oauth_test.go @@ -1,6 +1,7 @@ package social import ( + "context" "net/http" "net/http/httptest" "strings" @@ -159,7 +160,7 @@ func TestSocialGitlab_UserInfo(t *testing.T) { } })) provider.apiUrl = ts.URL + apiURI - actualResult, err := provider.UserInfo(ts.Client(), nil) + actualResult, err := provider.UserInfo(context.Background(), ts.Client(), nil) if test.ExpectedError != nil { require.Equal(t, err, test.ExpectedError) return diff --git a/pkg/login/social/google_oauth.go b/pkg/login/social/google_oauth.go index b499cc613be..9771a0195df 100644 --- a/pkg/login/social/google_oauth.go +++ b/pkg/login/social/google_oauth.go @@ -1,6 +1,7 @@ package social import ( + "context" "encoding/json" "fmt" "net/http" @@ -16,14 +17,14 @@ type SocialGoogle struct { apiUrl string } -func (s *SocialGoogle) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { var data struct { Id string `json:"id"` Name string `json:"name"` Email string `json:"email"` } - response, err := s.httpGet(client, s.apiUrl) + response, err := s.httpGet(ctx, client, s.apiUrl) if err != nil { return nil, fmt.Errorf("Error getting user info: %s", err) } diff --git a/pkg/login/social/grafana_com_oauth.go b/pkg/login/social/grafana_com_oauth.go index 7ac1568cd1e..77272b1aa3c 100644 --- a/pkg/login/social/grafana_com_oauth.go +++ b/pkg/login/social/grafana_com_oauth.go @@ -1,6 +1,7 @@ package social import ( + "context" "encoding/json" "fmt" "net/http" @@ -43,7 +44,7 @@ func (s *SocialGrafanaCom) IsOrganizationMember(organizations []OrgRecord) bool } // UserInfo is used for login credentials for the user -func (s *SocialGrafanaCom) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialGrafanaCom) UserInfo(ctx context.Context, client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) { var data struct { Id int `json:"id"` Name string `json:"name"` @@ -53,7 +54,8 @@ func (s *SocialGrafanaCom) UserInfo(client *http.Client, _ *oauth2.Token) (*Basi Orgs []OrgRecord `json:"orgs"` } - response, err := s.httpGet(client, s.url+"/api/oauth2/user") + response, err := s.httpGet(ctx, client, s.url+"/api/oauth2/user") + if err != nil { return nil, fmt.Errorf("Error getting user info: %s", err) } diff --git a/pkg/login/social/grafana_com_oauth_test.go b/pkg/login/social/grafana_com_oauth_test.go index 3f322951cb3..3739ec38ad0 100644 --- a/pkg/login/social/grafana_com_oauth_test.go +++ b/pkg/login/social/grafana_com_oauth_test.go @@ -1,6 +1,7 @@ package social import ( + "context" "net/http" "net/http/httptest" "testing" @@ -81,7 +82,7 @@ func TestSocialGrafanaCom_UserInfo(t *testing.T) { } })) provider.url = ts.URL - actualResult, err := provider.UserInfo(ts.Client(), nil) + actualResult, err := provider.UserInfo(context.Background(), ts.Client(), nil) if test.ExpectedError != nil { require.Equal(t, err, test.ExpectedError) return diff --git a/pkg/login/social/okta_oauth.go b/pkg/login/social/okta_oauth.go index 090b47b5feb..5ce3e416d71 100644 --- a/pkg/login/social/okta_oauth.go +++ b/pkg/login/social/okta_oauth.go @@ -1,6 +1,7 @@ package social import ( + "context" "encoding/json" "errors" "fmt" @@ -46,7 +47,7 @@ func (claims *OktaClaims) extractEmail() string { return claims.Email } -func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { +func (s *SocialOkta) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { idToken := token.Extra("id_token") if idToken == nil { return nil, fmt.Errorf("no id_token found") @@ -68,7 +69,7 @@ func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicU } var data OktaUserInfoJson - err = s.extractAPI(&data, client) + err = s.extractAPI(ctx, &data, client) if err != nil { return nil, err } @@ -105,8 +106,8 @@ func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicU }, nil } -func (s *SocialOkta) extractAPI(data *OktaUserInfoJson, client *http.Client) error { - rawUserInfoResponse, err := s.httpGet(client, s.apiUrl) +func (s *SocialOkta) extractAPI(ctx context.Context, data *OktaUserInfoJson, client *http.Client) error { + rawUserInfoResponse, err := s.httpGet(ctx, client, s.apiUrl) if err != nil { s.log.Debug("Error getting user info response", "url", s.apiUrl, "error", err) return fmt.Errorf("error getting user info response: %w", err) diff --git a/pkg/login/social/okta_oauth_test.go b/pkg/login/social/okta_oauth_test.go index fd955f081d1..29f4a23f608 100644 --- a/pkg/login/social/okta_oauth_test.go +++ b/pkg/login/social/okta_oauth_test.go @@ -1,6 +1,7 @@ package social import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -110,7 +111,7 @@ func TestSocialOkta_UserInfo(t *testing.T) { Expiry: time.Now(), } token := staticToken.WithExtra(tt.OAuth2Extra) - got, err := provider.UserInfo(server.Client(), token) + got, err := provider.UserInfo(context.Background(), server.Client(), token) if (err != nil) != tt.wantErr { t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/login/social/social.go b/pkg/login/social/social.go index a379d06ef8c..899e48919f2 100644 --- a/pkg/login/social/social.go +++ b/pkg/login/social/social.go @@ -7,9 +7,11 @@ import ( "crypto/x509" "encoding/json" "fmt" + "net" "net/http" "os" "strings" + "time" "golang.org/x/oauth2" "golang.org/x/text/cases" @@ -261,7 +263,7 @@ func (b *BasicUserInfo) String() string { } type SocialConnector interface { - UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) + UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) IsEmailAllowed(email string) bool IsSignupAllowed() bool @@ -450,9 +452,19 @@ func (ss *SocialService) GetOAuthHttpClient(name string) (*http.Client, error) { TLSClientConfig: &tls.Config{ InsecureSkipVerify: info.TlsSkipVerify, }, + DialContext: (&net.Dialer{ + Timeout: time.Second * 10, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 15 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, } + oauthClient := &http.Client{ Transport: tr, + Timeout: time.Second * 15, } if info.TlsClientCert != "" || info.TlsClientKey != "" { diff --git a/pkg/services/authn/clients/oauth.go b/pkg/services/authn/clients/oauth.go index e6f4eeb95c6..b459dc69b26 100644 --- a/pkg/services/authn/clients/oauth.go +++ b/pkg/services/authn/clients/oauth.go @@ -116,7 +116,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden } token.TokenType = "Bearer" - userInfo, err := c.connector.UserInfo(c.connector.Client(clientCtx, token), token) + userInfo, err := c.connector.UserInfo(ctx, c.connector.Client(clientCtx, token), token) if err != nil { var sErr *social.Error if errors.As(err, &sErr) { diff --git a/pkg/services/authn/clients/oauth_test.go b/pkg/services/authn/clients/oauth_test.go index 4999e5d0293..84329f87d56 100644 --- a/pkg/services/authn/clients/oauth_test.go +++ b/pkg/services/authn/clients/oauth_test.go @@ -8,13 +8,14 @@ import ( "golang.org/x/oauth2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/setting" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestOAuth_Authenticate(t *testing.T) { @@ -278,7 +279,7 @@ type fakeConnector struct { social.SocialConnector } -func (f fakeConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { +func (f fakeConnector) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { return f.ExpectedUserInfo, f.ExpectedUserInfoErr } diff --git a/pkg/services/oauthtoken/oauth_token_test.go b/pkg/services/oauthtoken/oauth_token_test.go index df53c1e95b3..b3ec0933e7a 100644 --- a/pkg/services/oauthtoken/oauth_token_test.go +++ b/pkg/services/oauthtoken/oauth_token_test.go @@ -273,7 +273,7 @@ func (m *MockSocialConnector) Type() int { return args.Int(0) } -func (m *MockSocialConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { +func (m *MockSocialConnector) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { args := m.Called(client, token) return args.Get(0).(*social.BasicUserInfo), args.Error(1) }