|
|
|
@ -2,36 +2,35 @@ package csrf |
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"errors" |
|
|
|
|
"fmt" |
|
|
|
|
"net/http" |
|
|
|
|
"net/url" |
|
|
|
|
"reflect" |
|
|
|
|
|
|
|
|
|
"github.com/grafana/grafana/pkg/infra/log" |
|
|
|
|
"github.com/grafana/grafana/pkg/setting" |
|
|
|
|
"github.com/grafana/grafana/pkg/util" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
type Service interface { |
|
|
|
|
Middleware(logger log.Logger) func(http.Handler) http.Handler |
|
|
|
|
Middleware() func(http.Handler) http.Handler |
|
|
|
|
TrustOrigin(origin string) |
|
|
|
|
AddOriginHeader(headerName string) |
|
|
|
|
AddAdditionalHeaders(headerName string) |
|
|
|
|
AddSafeEndpoint(endpoint string) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type Implementation struct { |
|
|
|
|
type CSRF struct { |
|
|
|
|
cfg *setting.Cfg |
|
|
|
|
|
|
|
|
|
trustedOrigins map[string]struct{} |
|
|
|
|
originHeaders map[string]struct{} |
|
|
|
|
headers map[string]struct{} |
|
|
|
|
safeEndpoints map[string]struct{} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func ProvideCSRFFilter(cfg *setting.Cfg) Service { |
|
|
|
|
i := &Implementation{ |
|
|
|
|
c := &CSRF{ |
|
|
|
|
cfg: cfg, |
|
|
|
|
trustedOrigins: map[string]struct{}{}, |
|
|
|
|
originHeaders: map[string]struct{}{ |
|
|
|
|
"Origin": {}, |
|
|
|
|
}, |
|
|
|
|
headers: map[string]struct{}{}, |
|
|
|
|
safeEndpoints: map[string]struct{}{}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -39,92 +38,124 @@ func ProvideCSRFFilter(cfg *setting.Cfg) Service { |
|
|
|
|
trustedOrigins := cfg.SectionWithEnvOverrides("security").Key("csrf_trusted_origins").Strings(" ") |
|
|
|
|
|
|
|
|
|
for _, header := range additionalHeaders { |
|
|
|
|
i.originHeaders[header] = struct{}{} |
|
|
|
|
c.headers[header] = struct{}{} |
|
|
|
|
} |
|
|
|
|
for _, origin := range trustedOrigins { |
|
|
|
|
i.trustedOrigins[origin] = struct{}{} |
|
|
|
|
c.trustedOrigins[origin] = struct{}{} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return i |
|
|
|
|
return c |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (i *Implementation) Middleware(logger log.Logger) func(http.Handler) http.Handler { |
|
|
|
|
func (c *CSRF) Middleware() func(http.Handler) http.Handler { |
|
|
|
|
return func(next http.Handler) http.Handler { |
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
e := &errorWithStatus{} |
|
|
|
|
|
|
|
|
|
err := c.check(r) |
|
|
|
|
if err != nil { |
|
|
|
|
if !errors.As(err, &e) { |
|
|
|
|
http.Error(w, fmt.Sprintf("internal server error: expected error type errorWithStatus, got %s. Error: %v", reflect.TypeOf(err), err), http.StatusInternalServerError) |
|
|
|
|
} |
|
|
|
|
http.Error(w, err.Error(), e.HTTPStatus) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
next.ServeHTTP(w, r) |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (c *CSRF) check(r *http.Request) error { |
|
|
|
|
// As per RFC 7231/4.2.2 these methods are idempotent:
|
|
|
|
|
// (GET is excluded because it may have side effects in some APIs)
|
|
|
|
|
safeMethods := []string{"HEAD", "OPTIONS", "TRACE"} |
|
|
|
|
|
|
|
|
|
return func(next http.Handler) http.Handler { |
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
// If request has no login cookie - skip CSRF checks
|
|
|
|
|
if _, err := r.Cookie(i.cfg.LoginCookieName); errors.Is(err, http.ErrNoCookie) { |
|
|
|
|
next.ServeHTTP(w, r) |
|
|
|
|
return |
|
|
|
|
if _, err := r.Cookie(c.cfg.LoginCookieName); errors.Is(err, http.ErrNoCookie) { |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
// Skip CSRF checks for "safe" methods
|
|
|
|
|
for _, method := range safeMethods { |
|
|
|
|
if r.Method == method { |
|
|
|
|
next.ServeHTTP(w, r) |
|
|
|
|
return |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
// Skip CSRF checks for "safe" endpoints
|
|
|
|
|
for safeEndpoint := range i.safeEndpoints { |
|
|
|
|
for safeEndpoint := range c.safeEndpoints { |
|
|
|
|
if r.URL.Path == safeEndpoint { |
|
|
|
|
next.ServeHTTP(w, r) |
|
|
|
|
return |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
// Otherwise - verify that Origin matches the server origin
|
|
|
|
|
netAddr, err := util.SplitHostPortDefault(r.Host, "", "0") // we ignore the port
|
|
|
|
|
if err != nil { |
|
|
|
|
http.Error(w, err.Error(), http.StatusBadRequest) |
|
|
|
|
return |
|
|
|
|
return &errorWithStatus{Underlying: err, HTTPStatus: http.StatusBadRequest} |
|
|
|
|
} |
|
|
|
|
origins := map[string]struct{}{} |
|
|
|
|
for header := range i.originHeaders { |
|
|
|
|
origin, err := url.Parse(r.Header.Get(header)) |
|
|
|
|
|
|
|
|
|
o := r.Header.Get("Origin") |
|
|
|
|
|
|
|
|
|
// No Origin header sent, skip CSRF check.
|
|
|
|
|
if o == "" { |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
originURL, err := url.Parse(o) |
|
|
|
|
if err != nil { |
|
|
|
|
logger.Error("error parsing Origin header", "header", header, "err", err) |
|
|
|
|
return &errorWithStatus{Underlying: err, HTTPStatus: http.StatusBadRequest} |
|
|
|
|
} |
|
|
|
|
if origin.String() != "" { |
|
|
|
|
origins[origin.Hostname()] = struct{}{} |
|
|
|
|
origin := originURL.Hostname() |
|
|
|
|
|
|
|
|
|
trustedOrigin := false |
|
|
|
|
for h := range c.headers { |
|
|
|
|
customHost := r.Header.Get(h) |
|
|
|
|
addr, err := util.SplitHostPortDefault(customHost, "", "0") // we ignore the port
|
|
|
|
|
if err != nil { |
|
|
|
|
return &errorWithStatus{Underlying: err, HTTPStatus: http.StatusBadRequest} |
|
|
|
|
} |
|
|
|
|
if addr.Host == origin { |
|
|
|
|
trustedOrigin = true |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// No Origin header sent, skip CSRF check.
|
|
|
|
|
if len(origins) == 0 { |
|
|
|
|
next.ServeHTTP(w, r) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
trustedOrigin := false |
|
|
|
|
for o := range i.trustedOrigins { |
|
|
|
|
if _, ok := origins[o]; ok { |
|
|
|
|
for o := range c.trustedOrigins { |
|
|
|
|
if o == origin { |
|
|
|
|
trustedOrigin = true |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
_, hostnameMatches := origins[netAddr.Host] |
|
|
|
|
hostnameMatches := origin == netAddr.Host |
|
|
|
|
if netAddr.Host == "" || !trustedOrigin && !hostnameMatches { |
|
|
|
|
http.Error(w, "origin not allowed", http.StatusForbidden) |
|
|
|
|
return |
|
|
|
|
return &errorWithStatus{Underlying: errors.New("origin not allowed"), HTTPStatus: http.StatusForbidden} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
next.ServeHTTP(w, r) |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (i *Implementation) TrustOrigin(origin string) { |
|
|
|
|
i.trustedOrigins[origin] = struct{}{} |
|
|
|
|
func (c *CSRF) TrustOrigin(origin string) { |
|
|
|
|
c.trustedOrigins[origin] = struct{}{} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (i *Implementation) AddOriginHeader(headerName string) { |
|
|
|
|
i.originHeaders[headerName] = struct{}{} |
|
|
|
|
func (c *CSRF) AddAdditionalHeaders(headerName string) { |
|
|
|
|
c.headers[headerName] = struct{}{} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// AddSafeEndpoint is used for endpoints requests to skip CSRF check
|
|
|
|
|
func (i *Implementation) AddSafeEndpoint(endpoint string) { |
|
|
|
|
i.safeEndpoints[endpoint] = struct{}{} |
|
|
|
|
func (c *CSRF) AddSafeEndpoint(endpoint string) { |
|
|
|
|
c.safeEndpoints[endpoint] = struct{}{} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type errorWithStatus struct { |
|
|
|
|
Underlying error |
|
|
|
|
HTTPStatus int |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (e errorWithStatus) Error() string { |
|
|
|
|
return e.Underlying.Error() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (e errorWithStatus) Unwrap() error { |
|
|
|
|
return e.Underlying |
|
|
|
|
} |
|
|
|
|