Live: allow connections with request host matching origin host (#38538)

pull/38631/head
Alexander Emelin 4 years ago committed by GitHub
parent b25eb0aa74
commit a95a40d87d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 19
      pkg/services/live/live.go
  2. 9
      pkg/services/live/live_test.go

@ -373,25 +373,28 @@ func getCheckOriginFunc(appURL *url.URL, originPatterns []string, originGlobs []
// fast path for *. // fast path for *.
return true return true
} }
ok, err := checkAllowedOrigin(strings.ToLower(origin), appURL, originGlobs) originURL, err := url.Parse(strings.ToLower(origin))
if err != nil {
logger.Warn("Failed to parse request origin", "error", err, "origin", origin)
return false
}
if strings.EqualFold(originURL.Host, r.Host) {
return true
}
ok, err := checkAllowedOrigin(origin, originURL, appURL, originGlobs)
if err != nil { if err != nil {
logger.Warn("Error parsing request origin", "error", err, "origin", origin) logger.Warn("Error parsing request origin", "error", err, "origin", origin)
return false return false
} }
if !ok { if !ok {
logger.Warn("Request Origin is not authorized", "origin", origin, "appUrl", appURL.String(), "allowedOrigins", strings.Join(originPatterns, ",")) logger.Warn("Request Origin is not authorized", "origin", origin, "host", r.Host, "appUrl", appURL.String(), "allowedOrigins", strings.Join(originPatterns, ","))
return false return false
} }
return true return true
} }
} }
func checkAllowedOrigin(origin string, appURL *url.URL, originGlobs []glob.Glob) (bool, error) { func checkAllowedOrigin(origin string, originURL *url.URL, appURL *url.URL, originGlobs []glob.Glob) (bool, error) {
originURL, err := url.Parse(origin)
if err != nil {
logger.Warn("Failed to parse request origin", "error", err, "origin", origin)
return false, err
}
// Try to match over configured [server] root_url first. // Try to match over configured [server] root_url first.
if originURL.Port() == "" { if originURL.Port() == "" {
if strings.EqualFold(originURL.Scheme, appURL.Scheme) && strings.EqualFold(originURL.Host, appURL.Hostname()) { if strings.EqualFold(originURL.Scheme, appURL.Scheme) && strings.EqualFold(originURL.Host, appURL.Hostname()) {

@ -62,6 +62,7 @@ func TestCheckOrigin(t *testing.T) {
appURL string appURL string
allowedOrigins []string allowedOrigins []string
success bool success bool
host string
}{ }{
{ {
name: "empty_origin", name: "empty_origin",
@ -126,6 +127,13 @@ func TestCheckOrigin(t *testing.T) {
allowedOrigins: []string{"*"}, allowedOrigins: []string{"*"},
success: true, success: true,
}, },
{
name: "request_host_matches_origin_host",
origin: "http://example.com",
appURL: "https://example.com",
success: true,
host: "example.com",
},
} }
for _, tc := range testCases { for _, tc := range testCases {
@ -141,6 +149,7 @@ func TestCheckOrigin(t *testing.T) {
checkOrigin := getCheckOriginFunc(appURL, tc.allowedOrigins, originGlobs) checkOrigin := getCheckOriginFunc(appURL, tc.allowedOrigins, originGlobs)
r := httptest.NewRequest("GET", tc.appURL, nil) r := httptest.NewRequest("GET", tc.appURL, nil)
r.Host = tc.host
r.Header.Set("Origin", tc.origin) r.Header.Set("Origin", tc.origin)
require.Equal(t, tc.success, checkOrigin(r), require.Equal(t, tc.success, checkOrigin(r),
"origin %s, appURL: %s", tc.origin, tc.appURL, "origin %s, appURL: %s", tc.origin, tc.appURL,

Loading…
Cancel
Save