From 76ea754dbb0bfb3e157981ec73ba4ef1efae4207 Mon Sep 17 00:00:00 2001 From: Serge Zaitsev Date: Wed, 2 Jul 2025 17:42:20 +0200 Subject: [PATCH] apply 444-202506261140 manually --- pkg/api/login.go | 14 +++++ pkg/api/login_oauth_test.go | 16 ++--- pkg/api/user_token_test.go | 91 +++++++++++++++++++++++++++++ pkg/middleware/org_redirect.go | 15 +++++ pkg/middleware/org_redirect_test.go | 17 ++++++ pkg/web/webtest/webtest.go | 5 +- 6 files changed, 147 insertions(+), 11 deletions(-) diff --git a/pkg/api/login.go b/pkg/api/login.go index b6d5a950706..8539b3a08bf 100644 --- a/pkg/api/login.go +++ b/pkg/api/login.go @@ -7,6 +7,8 @@ import ( "fmt" "net/http" "net/url" + "path" + "regexp" "strings" "github.com/grafana/grafana/pkg/api/response" @@ -39,6 +41,9 @@ var getViewIndex = func() string { return viewIndex } +// Only allow redirects that start with an alphanumerical character, a dash or an underscore. +var redirectRe = regexp.MustCompile(`^/[a-zA-Z0-9-_].*`) + var ( errAbsoluteRedirectTo = errors.New("absolute URLs are not allowed for redirect_to cookie value") errInvalidRedirectTo = errors.New("invalid redirect_to cookie value") @@ -68,6 +73,15 @@ func (hs *HTTPServer) ValidateRedirectTo(redirectTo string) error { return errForbiddenRedirectTo } + cleanPath := path.Clean(to.Path) + // "." is what path.Clean returns for empty paths + if cleanPath == "." { + return errForbiddenRedirectTo + } + if to.Path != "/" && !redirectRe.MatchString(cleanPath) { + return errForbiddenRedirectTo + } + // when using a subUrl, the redirect_to should start with the subUrl (which contains the leading slash), otherwise the redirect // will send the user to the wrong location if hs.Cfg.AppSubURL != "" && !strings.HasPrefix(to.Path, hs.Cfg.AppSubURL+"/") { diff --git a/pkg/api/login_oauth_test.go b/pkg/api/login_oauth_test.go index fefff77b260..210d03ad0f7 100644 --- a/pkg/api/login_oauth_test.go +++ b/pkg/api/login_oauth_test.go @@ -14,20 +14,16 @@ import ( "github.com/grafana/grafana/pkg/services/authn/authntest" "github.com/grafana/grafana/pkg/services/secrets/fakes" "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/web/webtest" ) -func setClientWithoutRedirectFollow(t *testing.T) { +func setClientWithoutRedirectFollow(t *testing.T, s *webtest.Server) { t.Helper() - old := http.DefaultClient - http.DefaultClient = &http.Client{ + s.HttpClient = &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, } - - t.Cleanup(func() { - http.DefaultClient = old - }) } func TestOAuthLogin_Redirect(t *testing.T) { @@ -79,7 +75,7 @@ func TestOAuthLogin_Redirect(t *testing.T) { }) // we need to prevent the http.Client from following redirects - setClientWithoutRedirectFollow(t) + setClientWithoutRedirectFollow(t, server) res, err := server.Send(server.NewGetRequest("/login/generic_oauth")) require.NoError(t, err) @@ -155,7 +151,7 @@ func TestOAuthLogin_AuthorizationCode(t *testing.T) { }) // we need to prevent the http.Client from following redirects - setClientWithoutRedirectFollow(t) + setClientWithoutRedirectFollow(t, server) res, err := server.Send(server.NewGetRequest("/login/generic_oauth?code=code")) require.NoError(t, err) @@ -199,7 +195,7 @@ func TestOAuthLogin_Error(t *testing.T) { hs.SecretsService = fakes.NewFakeSecretsService() }) - setClientWithoutRedirectFollow(t) + setClientWithoutRedirectFollow(t, server) res, err := server.Send(server.NewGetRequest("/login/azuread?error=someerror")) require.NoError(t, err) diff --git a/pkg/api/user_token_test.go b/pkg/api/user_token_test.go index 62fa0c2b0db..204ab6af6e6 100644 --- a/pkg/api/user_token_test.go +++ b/pkg/api/user_token_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "net/url" "testing" "time" @@ -20,6 +21,7 @@ import ( "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/services/user/usertest" "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/web/webtest" ) func TestUserTokenAPIEndpoint(t *testing.T) { @@ -150,6 +152,95 @@ func TestUserTokenAPIEndpoint(t *testing.T) { }) } +func TestHTTPServer_RotateUserAuthTokenRedirect(t *testing.T) { + redirectTestCases := []struct { + name string + redirectUrl string + expectedUrl string + }{ + // Valid redirects should be preserved + {"valid root path", "/", "/"}, + {"valid simple path", "/hello", "/hello"}, + {"valid single char path", "/a", "/a"}, + {"valid nested path", "/asd/hello", "/asd/hello"}, + + // Invalid redirects should be converted to root + {"backslash domain", `/\grafana.com`, "/"}, + {"traversal backslash domain", `/a/../\grafana.com`, "/"}, + {"double slash", "//grafana", "/"}, + {"missing initial slash", "missingInitialSlash", "/"}, + {"parent directory", "/../", "/"}, + } + + sessionTestCases := []struct { + name string + useSessionStorageRedirect bool + }{ + {"when useSessionStorageRedirect is enabled", true}, + {"when useSessionStorageRedirect is disabled", false}, + } + + for _, sessionCase := range sessionTestCases { + t.Run(sessionCase.name, func(t *testing.T) { + for _, redirectCase := range redirectTestCases { + t.Run(redirectCase.name, func(t *testing.T) { + server := SetupAPITestServer(t, func(hs *HTTPServer) { + cfg := setting.NewCfg() + cfg.LoginCookieName = "grafana_session" + cfg.LoginMaxLifetime = 10 * time.Hour + hs.Cfg = cfg + hs.log = log.New() + hs.AuthTokenService = &authtest.FakeUserAuthTokenService{ + RotateTokenProvider: func(ctx context.Context, cmd auth.RotateCommand) (*auth.UserToken, error) { + return &auth.UserToken{UnhashedToken: "new"}, nil + }, + } + }) + + redirectToQuery := url.QueryEscape(redirectCase.redirectUrl) + urlString := "/user/auth-tokens/rotate" + + if sessionCase.useSessionStorageRedirect { + urlString = urlString + "?redirectTo=" + redirectToQuery + } + + req := server.NewGetRequest(urlString) + req.AddCookie(&http.Cookie{Name: "grafana_session", Value: "123", Path: "/"}) + + if sessionCase.useSessionStorageRedirect { + req = webtest.RequestWithWebContext(req, &contextmodel.ReqContext{UseSessionStorageRedirect: true}) + } else { + req.AddCookie(&http.Cookie{Name: "redirect_to", Value: redirectToQuery, Path: "/"}) + } + + var redirectStatusCode int + var redirectLocation string + + server.HttpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) > 1 { + // Stop after first redirect + return http.ErrUseLastResponse + } + + if req.Response == nil { + return nil + } + redirectStatusCode = req.Response.StatusCode + redirectLocation = req.Response.Header.Get("Location") + return nil + } + res, err := server.Send(req) + require.NoError(t, err) + assert.Equal(t, 302, redirectStatusCode) + assert.Equal(t, redirectCase.expectedUrl, redirectLocation) + + require.NoError(t, res.Body.Close()) + }) + } + }) + } +} + func TestHTTPServer_RotateUserAuthToken(t *testing.T) { type testCase struct { desc string diff --git a/pkg/middleware/org_redirect.go b/pkg/middleware/org_redirect.go index b794f9acfdf..8ce90ac643d 100644 --- a/pkg/middleware/org_redirect.go +++ b/pkg/middleware/org_redirect.go @@ -3,6 +3,8 @@ package middleware import ( "fmt" "net/http" + "path" + "regexp" "strconv" "github.com/grafana/grafana/pkg/services/contexthandler" @@ -11,6 +13,9 @@ import ( "github.com/grafana/grafana/pkg/web" ) +// Only allow redirects that start with an alphanumerical character, a dash or an underscore. +var redirectRe = regexp.MustCompile(`^/?[a-zA-Z0-9-_].*`) + // OrgRedirect changes org and redirects users if the // querystring `orgId` doesn't match the active org. func OrgRedirect(cfg *setting.Cfg, userSvc user.Service) web.Handler { @@ -31,6 +36,11 @@ func OrgRedirect(cfg *setting.Cfg, userSvc user.Service) web.Handler { return } + if !validRedirectPath(c.Req.URL.Path) { + // Do not switch orgs or perform the redirect because the new path is not valid + return + } + if err := userSvc.Update(ctx.Req.Context(), &user.UpdateUserCommand{UserID: ctx.UserID, OrgID: &orgId}); err != nil { if ctx.IsApiRequest() { ctx.JsonApiErr(404, "Not found", nil) @@ -54,3 +64,8 @@ func OrgRedirect(cfg *setting.Cfg, userSvc user.Service) web.Handler { c.Redirect(newURL, 302) } } + +func validRedirectPath(p string) bool { + cleanPath := path.Clean(p) + return cleanPath == "." || cleanPath == "/" || redirectRe.MatchString(cleanPath) +} diff --git a/pkg/middleware/org_redirect_test.go b/pkg/middleware/org_redirect_test.go index 4087eeb91c6..06800c5ac69 100644 --- a/pkg/middleware/org_redirect_test.go +++ b/pkg/middleware/org_redirect_test.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "net/url" "testing" "github.com/stretchr/testify/require" @@ -23,6 +24,12 @@ func TestOrgRedirectMiddleware(t *testing.T) { expStatus: 302, expLocation: "/?orgId=3", }, + { + desc: "when setting a correct org for the user with an empty path", + input: "?orgId=3", + expStatus: 302, + expLocation: "/?orgId=3", + }, { desc: "when setting a correct org for the user with '&kiosk'", input: "/?orgId=3&kiosk", @@ -64,6 +71,16 @@ func TestOrgRedirectMiddleware(t *testing.T) { require.Equal(t, 404, sc.resp.Code) }) + middlewareScenario(t, "when redirecting to an invalid path", func(t *testing.T, sc *scenarioContext) { + sc.withIdentity(&authn.Identity{}) + + path := url.QueryEscape(`/\example.com`) + sc.m.Get(url.QueryEscape(path), sc.defaultHandler) + sc.fakeReq("GET", fmt.Sprintf("%s?orgId=3", path)).exec() + + require.Equal(t, 404, sc.resp.Code) + }) + middlewareScenario(t, "works correctly when grafana is served under a subpath", func(t *testing.T, sc *scenarioContext) { sc.withIdentity(&authn.Identity{}) diff --git a/pkg/web/webtest/webtest.go b/pkg/web/webtest/webtest.go index 8dd42eba7db..92cff85b2b5 100644 --- a/pkg/web/webtest/webtest.go +++ b/pkg/web/webtest/webtest.go @@ -24,6 +24,7 @@ type Server struct { Mux *web.Mux RouteRegister routing.RouteRegister TestServer *httptest.Server + HttpClient *http.Client } // NewServer starts and returns a new server. @@ -50,6 +51,7 @@ func NewServer(t testing.TB, routeRegister routing.RouteRegister) *Server { RouteRegister: routeRegister, Mux: m, TestServer: testServer, + HttpClient: &http.Client{}, } } @@ -81,7 +83,7 @@ func (s *Server) NewRequest(method string, target string, body io.Reader) *http. // Send sends a HTTP request to the test server and returns an HTTP response. func (s *Server) Send(req *http.Request) (*http.Response, error) { - return http.DefaultClient.Do(req) + return s.HttpClient.Do(req) } // SendJSON sets the Content-Type header to application/json and sends @@ -144,6 +146,7 @@ func requestContextMiddleware() web.Middleware { c.RequestNonce = ctx.RequestNonce c.PerfmonTimer = ctx.PerfmonTimer c.LookupTokenErr = ctx.LookupTokenErr + c.UseSessionStorageRedirect = ctx.UseSessionStorageRedirect } next.ServeHTTP(w, r)