From a95a40d87d80fe9fb8a025b7c96e50cae8af34f1 Mon Sep 17 00:00:00 2001 From: Alexander Emelin Date: Fri, 27 Aug 2021 13:26:28 +0300 Subject: [PATCH] Live: allow connections with request host matching origin host (#38538) --- pkg/services/live/live.go | 19 +++++++++++-------- pkg/services/live/live_test.go | 9 +++++++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/pkg/services/live/live.go b/pkg/services/live/live.go index 7e28ec0d99d..4294b1c6d28 100644 --- a/pkg/services/live/live.go +++ b/pkg/services/live/live.go @@ -373,25 +373,28 @@ func getCheckOriginFunc(appURL *url.URL, originPatterns []string, originGlobs [] // fast path for *. return true } - ok, err := checkAllowedOrigin(strings.ToLower(origin), appURL, originGlobs) + originURL, err := url.Parse(strings.ToLower(origin)) + if err != nil { + logger.Warn("Failed to parse request origin", "error", err, "origin", origin) + return false + } + if strings.EqualFold(originURL.Host, r.Host) { + return true + } + ok, err := checkAllowedOrigin(origin, originURL, appURL, originGlobs) if err != nil { logger.Warn("Error parsing request origin", "error", err, "origin", origin) return false } if !ok { - logger.Warn("Request Origin is not authorized", "origin", origin, "appUrl", appURL.String(), "allowedOrigins", strings.Join(originPatterns, ",")) + logger.Warn("Request Origin is not authorized", "origin", origin, "host", r.Host, "appUrl", appURL.String(), "allowedOrigins", strings.Join(originPatterns, ",")) return false } return true } } -func checkAllowedOrigin(origin string, appURL *url.URL, originGlobs []glob.Glob) (bool, error) { - originURL, err := url.Parse(origin) - if err != nil { - logger.Warn("Failed to parse request origin", "error", err, "origin", origin) - return false, err - } +func checkAllowedOrigin(origin string, originURL *url.URL, appURL *url.URL, originGlobs []glob.Glob) (bool, error) { // Try to match over configured [server] root_url first. if originURL.Port() == "" { if strings.EqualFold(originURL.Scheme, appURL.Scheme) && strings.EqualFold(originURL.Host, appURL.Hostname()) { diff --git a/pkg/services/live/live_test.go b/pkg/services/live/live_test.go index a1850d58f11..9c738c0aa30 100644 --- a/pkg/services/live/live_test.go +++ b/pkg/services/live/live_test.go @@ -62,6 +62,7 @@ func TestCheckOrigin(t *testing.T) { appURL string allowedOrigins []string success bool + host string }{ { name: "empty_origin", @@ -126,6 +127,13 @@ func TestCheckOrigin(t *testing.T) { allowedOrigins: []string{"*"}, success: true, }, + { + name: "request_host_matches_origin_host", + origin: "http://example.com", + appURL: "https://example.com", + success: true, + host: "example.com", + }, } for _, tc := range testCases { @@ -141,6 +149,7 @@ func TestCheckOrigin(t *testing.T) { checkOrigin := getCheckOriginFunc(appURL, tc.allowedOrigins, originGlobs) r := httptest.NewRequest("GET", tc.appURL, nil) + r.Host = tc.host r.Header.Set("Origin", tc.origin) require.Equal(t, tc.success, checkOrigin(r), "origin %s, appURL: %s", tc.origin, tc.appURL,