diff --git a/pkg/api/dashboard_test.go b/pkg/api/dashboard_test.go index f07d37ade1c..d665d5bc4dc 100644 --- a/pkg/api/dashboard_test.go +++ b/pkg/api/dashboard_test.go @@ -84,6 +84,7 @@ type testState struct { func newTestLive(t *testing.T) *live.GrafanaLive { gLive := live.NewGrafanaLive() gLive.RouteRegister = routing.NewRouteRegister() + gLive.Cfg = &setting.Cfg{AppURL: "http://localhost:3000/"} err := gLive.Init() require.NoError(t, err) return gLive diff --git a/pkg/services/live/live.go b/pkg/services/live/live.go index f156db04899..9d1b08d652b 100644 --- a/pkg/services/live/live.go +++ b/pkg/services/live/live.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "net/http" + "net/url" "strconv" + "strings" "sync" "time" @@ -230,15 +232,26 @@ func (g *GrafanaLive) Init() error { return err } + appURL, err := url.Parse(g.Cfg.AppURL) + if err != nil { + return fmt.Errorf("error parsing AppURL %s: %w", g.Cfg.AppURL, err) + } + // Use a pure websocket transport. wsHandler := centrifuge.NewWebsocketHandler(node, centrifuge.WebsocketConfig{ ReadBufferSize: 1024, WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return checkOrigin(r, appURL) + }, }) pushWSHandler := pushws.NewHandler(g.ManagedStreamRunner, pushws.Config{ ReadBufferSize: 1024, WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return checkOrigin(r, appURL) + }, }) g.websocketHandler = func(ctx *models.ReqContext) { @@ -277,6 +290,23 @@ func (g *GrafanaLive) Init() error { return nil } +func checkOrigin(r *http.Request, appURL *url.URL) bool { + origin := r.Header.Get("Origin") + if origin == "" { + return true + } + originURL, err := url.Parse(origin) + if err != nil { + logger.Warn("Failed to parse request origin", "error", err, "origin", origin) + return false + } + if !strings.EqualFold(originURL.Scheme, appURL.Scheme) || !strings.EqualFold(originURL.Host, appURL.Host) { + logger.Warn("Request Origin is not authorized", "origin", origin, "appUrl", appURL.String()) + return false + } + return true +} + func runConcurrentlyIfNeeded(ctx context.Context, semaphore chan struct{}, fn func()) error { if cap(semaphore) > 1 { select { diff --git a/pkg/services/live/live_test.go b/pkg/services/live/live_test.go index e790f04081b..06a501bc197 100644 --- a/pkg/services/live/live_test.go +++ b/pkg/services/live/live_test.go @@ -2,6 +2,8 @@ package live import ( "context" + "net/http/httptest" + "net/url" "testing" "time" @@ -50,3 +52,63 @@ func Test_runConcurrentlyIfNeeded_DeadlineExceeded(t *testing.T) { err := runConcurrentlyIfNeeded(ctx, semaphore, f) require.ErrorIs(t, err, context.DeadlineExceeded) } + +func TestCheckOrigin(t *testing.T) { + testCases := []struct { + name string + origin string + appURL string + success bool + }{ + { + name: "empty_origin", + origin: "", + appURL: "http://localhost:3000/", + success: true, + }, + { + name: "valid_origin", + origin: "http://localhost:3000", + appURL: "http://localhost:3000/", + success: true, + }, + { + name: "unauthorized_origin", + origin: "http://localhost:8000", + appURL: "http://localhost:3000/", + success: false, + }, + { + name: "bad_origin", + origin: ":::http://localhost:8000", + appURL: "http://localhost:3000/", + success: false, + }, + { + name: "different_scheme", + origin: "http://example.com", + appURL: "https://example.com", + success: false, + }, + { + name: "authorized_case_insensitive", + origin: "https://examplE.com", + appURL: "https://example.com", + success: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + appURL, err := url.Parse(tc.appURL) + require.NoError(t, err) + r := httptest.NewRequest("GET", tc.appURL, nil) + r.Header.Set("Origin", tc.origin) + require.Equal(t, tc.success, checkOrigin(r, appURL), + "origin %s, appURL: %s", tc.origin, tc.appURL, + ) + }) + } +}