mirror of https://github.com/grafana/grafana
API: Migrate CSRF to service and support additional options (#48120)
* API: Migrate CSRF to service and support additional options * minor * public Csrf service to use in tests * WIP * remove fmt * comment * WIP * remove fmt prints * todo add prefix slash * remove fmt prints * linting fix * remove trimPrefix Co-authored-by: Eric Leijonmarck <eric.leijonmarck@gmail.com> Co-authored-by: IevaVasiljeva <ieva.vasiljeva@grafana.com>pull/50099/head
parent
84860ffc96
commit
3e81fa0716
@ -1,50 +0,0 @@ |
|||||||
package middleware |
|
||||||
|
|
||||||
import ( |
|
||||||
"errors" |
|
||||||
"net/http" |
|
||||||
"net/url" |
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/infra/log" |
|
||||||
"github.com/grafana/grafana/pkg/util" |
|
||||||
) |
|
||||||
|
|
||||||
func CSRF(loginCookieName string, logger log.Logger) func(http.Handler) http.Handler { |
|
||||||
// As per RFC 7231/4.2.2 these methods are idempotent:
|
|
||||||
// (GET is excluded because it may have side effects in some APIs)
|
|
||||||
safeMethods := []string{"HEAD", "OPTIONS", "TRACE"} |
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler { |
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
||||||
// If request has no login cookie - skip CSRF checks
|
|
||||||
if _, err := r.Cookie(loginCookieName); errors.Is(err, http.ErrNoCookie) { |
|
||||||
next.ServeHTTP(w, r) |
|
||||||
return |
|
||||||
} |
|
||||||
// Skip CSRF checks for "safe" methods
|
|
||||||
for _, method := range safeMethods { |
|
||||||
if r.Method == method { |
|
||||||
next.ServeHTTP(w, r) |
|
||||||
return |
|
||||||
} |
|
||||||
} |
|
||||||
// Otherwise - verify that Origin matches the server origin
|
|
||||||
netAddr, err := util.SplitHostPortDefault(r.Host, "", "0") // we ignore the port
|
|
||||||
if err != nil { |
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest) |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
origin, err := url.Parse(r.Header.Get("Origin")) |
|
||||||
if err != nil { |
|
||||||
logger.Error("error parsing Origin header", "err", err) |
|
||||||
} |
|
||||||
if err != nil || netAddr.Host == "" || (origin.String() != "" && origin.Hostname() != netAddr.Host) { |
|
||||||
http.Error(w, "origin not allowed", http.StatusForbidden) |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
next.ServeHTTP(w, r) |
|
||||||
}) |
|
||||||
} |
|
||||||
} |
|
||||||
@ -0,0 +1,130 @@ |
|||||||
|
package csrf |
||||||
|
|
||||||
|
import ( |
||||||
|
"errors" |
||||||
|
"net/http" |
||||||
|
"net/url" |
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/infra/log" |
||||||
|
"github.com/grafana/grafana/pkg/setting" |
||||||
|
"github.com/grafana/grafana/pkg/util" |
||||||
|
) |
||||||
|
|
||||||
|
type Service interface { |
||||||
|
Middleware(logger log.Logger) func(http.Handler) http.Handler |
||||||
|
TrustOrigin(origin string) |
||||||
|
AddOriginHeader(headerName string) |
||||||
|
AddSafeEndpoint(endpoint string) |
||||||
|
} |
||||||
|
|
||||||
|
type Implementation struct { |
||||||
|
cfg *setting.Cfg |
||||||
|
|
||||||
|
trustedOrigins map[string]struct{} |
||||||
|
originHeaders map[string]struct{} |
||||||
|
safeEndpoints map[string]struct{} |
||||||
|
} |
||||||
|
|
||||||
|
func ProvideCSRFFilter(cfg *setting.Cfg) Service { |
||||||
|
i := &Implementation{ |
||||||
|
cfg: cfg, |
||||||
|
trustedOrigins: map[string]struct{}{}, |
||||||
|
originHeaders: map[string]struct{}{ |
||||||
|
"Origin": {}, |
||||||
|
}, |
||||||
|
safeEndpoints: map[string]struct{}{}, |
||||||
|
} |
||||||
|
|
||||||
|
additionalHeaders := cfg.SectionWithEnvOverrides("security").Key("csrf_additional_headers").Strings(" ") |
||||||
|
trustedOrigins := cfg.SectionWithEnvOverrides("security").Key("csrf_trusted_origins").Strings(" ") |
||||||
|
|
||||||
|
for _, header := range additionalHeaders { |
||||||
|
i.originHeaders[header] = struct{}{} |
||||||
|
} |
||||||
|
for _, origin := range trustedOrigins { |
||||||
|
i.trustedOrigins[origin] = struct{}{} |
||||||
|
} |
||||||
|
|
||||||
|
return i |
||||||
|
} |
||||||
|
|
||||||
|
func (i *Implementation) Middleware(logger log.Logger) func(http.Handler) http.Handler { |
||||||
|
// As per RFC 7231/4.2.2 these methods are idempotent:
|
||||||
|
// (GET is excluded because it may have side effects in some APIs)
|
||||||
|
safeMethods := []string{"HEAD", "OPTIONS", "TRACE"} |
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler { |
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||||
|
// If request has no login cookie - skip CSRF checks
|
||||||
|
if _, err := r.Cookie(i.cfg.LoginCookieName); errors.Is(err, http.ErrNoCookie) { |
||||||
|
next.ServeHTTP(w, r) |
||||||
|
return |
||||||
|
} |
||||||
|
// Skip CSRF checks for "safe" methods
|
||||||
|
for _, method := range safeMethods { |
||||||
|
if r.Method == method { |
||||||
|
next.ServeHTTP(w, r) |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
// Skip CSRF checks for "safe" endpoints
|
||||||
|
for safeEndpoint := range i.safeEndpoints { |
||||||
|
if r.URL.Path == safeEndpoint { |
||||||
|
next.ServeHTTP(w, r) |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
// Otherwise - verify that Origin matches the server origin
|
||||||
|
netAddr, err := util.SplitHostPortDefault(r.Host, "", "0") // we ignore the port
|
||||||
|
if err != nil { |
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest) |
||||||
|
return |
||||||
|
} |
||||||
|
origins := map[string]struct{}{} |
||||||
|
for header := range i.originHeaders { |
||||||
|
origin, err := url.Parse(r.Header.Get(header)) |
||||||
|
if err != nil { |
||||||
|
logger.Error("error parsing Origin header", "header", header, "err", err) |
||||||
|
} |
||||||
|
if origin.String() != "" { |
||||||
|
origins[origin.Hostname()] = struct{}{} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// No Origin header sent, skip CSRF check.
|
||||||
|
if len(origins) == 0 { |
||||||
|
next.ServeHTTP(w, r) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
trustedOrigin := false |
||||||
|
for o := range i.trustedOrigins { |
||||||
|
if _, ok := origins[o]; ok { |
||||||
|
trustedOrigin = true |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
_, hostnameMatches := origins[netAddr.Host] |
||||||
|
if netAddr.Host == "" || !trustedOrigin && !hostnameMatches { |
||||||
|
http.Error(w, "origin not allowed", http.StatusForbidden) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
next.ServeHTTP(w, r) |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (i *Implementation) TrustOrigin(origin string) { |
||||||
|
i.trustedOrigins[origin] = struct{}{} |
||||||
|
} |
||||||
|
|
||||||
|
func (i *Implementation) AddOriginHeader(headerName string) { |
||||||
|
i.originHeaders[headerName] = struct{}{} |
||||||
|
} |
||||||
|
|
||||||
|
// AddSafeEndpoint is used for endpoints requests to skip CSRF check
|
||||||
|
func (i *Implementation) AddSafeEndpoint(endpoint string) { |
||||||
|
i.safeEndpoints[endpoint] = struct{}{} |
||||||
|
} |
||||||
Loading…
Reference in new issue