diff --git a/pkg/api/http_server.go b/pkg/api/http_server.go index ac80801de41..85699c01490 100644 --- a/pkg/api/http_server.go +++ b/pkg/api/http_server.go @@ -519,7 +519,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() { } m.Use(middleware.Recovery(hs.Cfg)) - m.UseMiddleware(hs.Csrf.Middleware(hs.log)) + m.UseMiddleware(hs.Csrf.Middleware()) hs.mapStatic(m, hs.Cfg.StaticRootPath, "build", "public/build") hs.mapStatic(m, hs.Cfg.StaticRootPath, "", "public", "/public/views/swagger.html") diff --git a/pkg/middleware/csrf/csrf.go b/pkg/middleware/csrf/csrf.go index a9694665099..85ca6069837 100644 --- a/pkg/middleware/csrf/csrf.go +++ b/pkg/middleware/csrf/csrf.go @@ -2,112 +2,62 @@ package csrf import ( "errors" + "fmt" "net/http" "net/url" + "reflect" - "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" ) type Service interface { - Middleware(logger log.Logger) func(http.Handler) http.Handler + Middleware() func(http.Handler) http.Handler TrustOrigin(origin string) - AddOriginHeader(headerName string) + AddAdditionalHeaders(headerName string) AddSafeEndpoint(endpoint string) } -type Implementation struct { +type CSRF struct { cfg *setting.Cfg trustedOrigins map[string]struct{} - originHeaders map[string]struct{} + headers map[string]struct{} safeEndpoints map[string]struct{} } func ProvideCSRFFilter(cfg *setting.Cfg) Service { - i := &Implementation{ + c := &CSRF{ cfg: cfg, trustedOrigins: map[string]struct{}{}, - originHeaders: map[string]struct{}{ - "Origin": {}, - }, - safeEndpoints: map[string]struct{}{}, + headers: map[string]struct{}{}, + safeEndpoints: map[string]struct{}{}, } additionalHeaders := cfg.SectionWithEnvOverrides("security").Key("csrf_additional_headers").Strings(" ") trustedOrigins := cfg.SectionWithEnvOverrides("security").Key("csrf_trusted_origins").Strings(" ") for _, header := range additionalHeaders { - i.originHeaders[header] = struct{}{} + c.headers[header] = struct{}{} } for _, origin := range trustedOrigins { - i.trustedOrigins[origin] = struct{}{} + c.trustedOrigins[origin] = struct{}{} } - return i + return c } -func (i *Implementation) Middleware(logger log.Logger) func(http.Handler) http.Handler { - // As per RFC 7231/4.2.2 these methods are idempotent: - // (GET is excluded because it may have side effects in some APIs) - safeMethods := []string{"HEAD", "OPTIONS", "TRACE"} - +func (c *CSRF) Middleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // If request has no login cookie - skip CSRF checks - if _, err := r.Cookie(i.cfg.LoginCookieName); errors.Is(err, http.ErrNoCookie) { - next.ServeHTTP(w, r) - return - } - // Skip CSRF checks for "safe" methods - for _, method := range safeMethods { - if r.Method == method { - next.ServeHTTP(w, r) - return - } - } - // Skip CSRF checks for "safe" endpoints - for safeEndpoint := range i.safeEndpoints { - if r.URL.Path == safeEndpoint { - next.ServeHTTP(w, r) - return - } - } - // Otherwise - verify that Origin matches the server origin - netAddr, err := util.SplitHostPortDefault(r.Host, "", "0") // we ignore the port - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - origins := map[string]struct{}{} - for header := range i.originHeaders { - origin, err := url.Parse(r.Header.Get(header)) - if err != nil { - logger.Error("error parsing Origin header", "header", header, "err", err) - } - if origin.String() != "" { - origins[origin.Hostname()] = struct{}{} - } - } - - // No Origin header sent, skip CSRF check. - if len(origins) == 0 { - next.ServeHTTP(w, r) - return - } + e := &errorWithStatus{} - trustedOrigin := false - for o := range i.trustedOrigins { - if _, ok := origins[o]; ok { - trustedOrigin = true - break + err := c.check(r) + if err != nil { + if !errors.As(err, &e) { + http.Error(w, fmt.Sprintf("internal server error: expected error type errorWithStatus, got %s. Error: %v", reflect.TypeOf(err), err), http.StatusInternalServerError) } - } - - _, hostnameMatches := origins[netAddr.Host] - if netAddr.Host == "" || !trustedOrigin && !hostnameMatches { - http.Error(w, "origin not allowed", http.StatusForbidden) + http.Error(w, err.Error(), e.HTTPStatus) return } @@ -116,15 +66,96 @@ func (i *Implementation) Middleware(logger log.Logger) func(http.Handler) http.H } } -func (i *Implementation) TrustOrigin(origin string) { - i.trustedOrigins[origin] = struct{}{} +func (c *CSRF) check(r *http.Request) error { + // As per RFC 7231/4.2.2 these methods are idempotent: + // (GET is excluded because it may have side effects in some APIs) + safeMethods := []string{"HEAD", "OPTIONS", "TRACE"} + + // If request has no login cookie - skip CSRF checks + if _, err := r.Cookie(c.cfg.LoginCookieName); errors.Is(err, http.ErrNoCookie) { + return nil + } + // Skip CSRF checks for "safe" methods + for _, method := range safeMethods { + if r.Method == method { + return nil + } + } + // Skip CSRF checks for "safe" endpoints + for safeEndpoint := range c.safeEndpoints { + if r.URL.Path == safeEndpoint { + return nil + } + } + // Otherwise - verify that Origin matches the server origin + netAddr, err := util.SplitHostPortDefault(r.Host, "", "0") // we ignore the port + if err != nil { + return &errorWithStatus{Underlying: err, HTTPStatus: http.StatusBadRequest} + } + + o := r.Header.Get("Origin") + + // No Origin header sent, skip CSRF check. + if o == "" { + return nil + } + + originURL, err := url.Parse(o) + if err != nil { + return &errorWithStatus{Underlying: err, HTTPStatus: http.StatusBadRequest} + } + origin := originURL.Hostname() + + trustedOrigin := false + for h := range c.headers { + customHost := r.Header.Get(h) + addr, err := util.SplitHostPortDefault(customHost, "", "0") // we ignore the port + if err != nil { + return &errorWithStatus{Underlying: err, HTTPStatus: http.StatusBadRequest} + } + if addr.Host == origin { + trustedOrigin = true + break + } + } + + for o := range c.trustedOrigins { + if o == origin { + trustedOrigin = true + break + } + } + + hostnameMatches := origin == netAddr.Host + if netAddr.Host == "" || !trustedOrigin && !hostnameMatches { + return &errorWithStatus{Underlying: errors.New("origin not allowed"), HTTPStatus: http.StatusForbidden} + } + + return nil +} + +func (c *CSRF) TrustOrigin(origin string) { + c.trustedOrigins[origin] = struct{}{} } -func (i *Implementation) AddOriginHeader(headerName string) { - i.originHeaders[headerName] = struct{}{} +func (c *CSRF) AddAdditionalHeaders(headerName string) { + c.headers[headerName] = struct{}{} } // AddSafeEndpoint is used for endpoints requests to skip CSRF check -func (i *Implementation) AddSafeEndpoint(endpoint string) { - i.safeEndpoints[endpoint] = struct{}{} +func (c *CSRF) AddSafeEndpoint(endpoint string) { + c.safeEndpoints[endpoint] = struct{}{} +} + +type errorWithStatus struct { + Underlying error + HTTPStatus int +} + +func (e errorWithStatus) Error() string { + return e.Underlying.Error() +} + +func (e errorWithStatus) Unwrap() error { + return e.Underlying } diff --git a/pkg/middleware/csrf/csrf_test.go b/pkg/middleware/csrf/csrf_test.go index f1e132ef593..470f20613ae 100644 --- a/pkg/middleware/csrf/csrf_test.go +++ b/pkg/middleware/csrf/csrf_test.go @@ -1,13 +1,15 @@ package csrf import ( + "errors" "net/http" "net/http/httptest" + "strings" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/setting" ) @@ -100,6 +102,117 @@ func TestMiddlewareCSRF(t *testing.T) { } } +func TestCSRF_Check(t *testing.T) { + tests := []struct { + name string + request *http.Request + addtHeader map[string]struct{} + trustedOrigins map[string]struct{} + safeEndpoints map[string]struct{} + expectedOK bool + expectedStatus int + }{ + { + name: "base case", + request: postRequest(t, "", nil), + expectedOK: true, + }, + { + name: "base with null origin header", + request: postRequest(t, "", map[string]string{"Origin": "null"}), + expectedStatus: http.StatusForbidden, + }, + { + name: "grafana.org", + request: postRequest(t, "grafana.org", map[string]string{"Origin": "https://grafana.org"}), + expectedOK: true, + }, + { + name: "grafana.org with X-Forwarded-Host", + request: postRequest(t, "grafana.localhost", map[string]string{"X-Forwarded-Host": "grafana.org", "Origin": "https://grafana.org"}), + expectedStatus: http.StatusForbidden, + }, + { + name: "grafana.org with X-Forwarded-Host and header trusted", + request: postRequest(t, "grafana.localhost", map[string]string{"X-Forwarded-Host": "grafana.org", "Origin": "https://grafana.org"}), + addtHeader: map[string]struct{}{"X-Forwarded-Host": {}}, + expectedOK: true, + }, + { + name: "grafana.org from grafana.com", + request: postRequest(t, "grafana.org", map[string]string{"Origin": "https://grafana.com"}), + expectedStatus: http.StatusForbidden, + }, + { + name: "grafana.org from grafana.com explicit trust for grafana.com", + request: postRequest(t, "grafana.org", map[string]string{"Origin": "https://grafana.com"}), + trustedOrigins: map[string]struct{}{"grafana.com": {}}, + expectedOK: true, + }, + { + name: "grafana.org from grafana.com with X-Forwarded-Host and header trusted", + request: postRequest(t, "grafana.localhost", map[string]string{"X-Forwarded-Host": "grafana.org", "Origin": "https://grafana.com"}), + addtHeader: map[string]struct{}{"X-Forwarded-Host": {}}, + trustedOrigins: map[string]struct{}{"grafana.com": {}}, + expectedOK: true, + }, + { + name: "safe endpoint", + request: postRequest(t, "example.org/foo/bar", map[string]string{"Origin": "null"}), + safeEndpoints: map[string]struct{}{"foo/bar": {}}, + expectedOK: true, + }, + } + + for _, tc := range tests { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + c := CSRF{ + cfg: setting.NewCfg(), + trustedOrigins: tc.trustedOrigins, + headers: tc.addtHeader, + safeEndpoints: tc.safeEndpoints, + } + c.cfg.LoginCookieName = "LoginCookie" + + err := c.check(tc.request) + if tc.expectedOK { + require.NoError(t, err) + } else { + require.Error(t, err) + var actual *errorWithStatus + require.True(t, errors.As(err, &actual)) + assert.EqualValues(t, tc.expectedStatus, actual.HTTPStatus) + } + }) + } +} + +func postRequest(t testing.TB, hostname string, headers map[string]string) *http.Request { + t.Helper() + urlParts := strings.SplitN(hostname, "/", 2) + + path := "/" + if len(urlParts) == 2 { + path = urlParts[1] + } + r, err := http.NewRequest(http.MethodPost, path, nil) + require.NoError(t, err) + + r.Host = urlParts[0] + + r.AddCookie(&http.Cookie{ + Name: "LoginCookie", + Value: "this should not be important", + }) + + for k, v := range headers { + r.Header.Set(k, v) + } + return r +} + func csrfScenario(t *testing.T, cookieName, method, origin, host string) *httptest.ResponseRecorder { req, err := http.NewRequest(method, "/", nil) if err != nil { @@ -123,7 +236,7 @@ func csrfScenario(t *testing.T, cookieName, method, origin, host string) *httpte cfg := setting.NewCfg() cfg.LoginCookieName = cookieName service := ProvideCSRFFilter(cfg) - handler := service.Middleware(log.New())(testHandler) + handler := service.Middleware()(testHandler) handler.ServeHTTP(rr, req) return rr }