The open and composable observability and data visualization platform. Visualize metrics, logs, and traces from multiple sources like Prometheus, Loki, Elasticsearch, InfluxDB, Postgres and many more.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
grafana/pkg/util/proxyutil/reverse_proxy_test.go

233 lines
8.0 KiB

package proxyutil
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/middleware/requestmeta"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/setting"
)
func TestReverseProxy(t *testing.T) {
t.Run("When proxying a request should enforce request and response constraints", func(t *testing.T) {
var actualReq *http.Request
upstream := newUpstreamServer(t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
actualReq = req
http.SetCookie(w, &http.Cookie{Name: "test"})
w.Header().Set("Strict-Transport-Security", "max-age=31536000")
w.Header().Set("X-Custom-Hdr", "Ok!")
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(upstream.Close)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, upstream.URL, nil)
req.Header.Set("X-Forwarded-Host", "forwarded.host.com")
req.Header.Set("X-Forwarded-Port", "8080")
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("Origin", "test.com")
req.Header.Set("Referer", "https://test.com/api")
req.RemoteAddr = "10.0.0.1"
req = req.WithContext(contexthandler.WithAuthHTTPHeaders(req.Context(), setting.NewCfg()))
req.Header.Set("Authorization", "val")
rp := NewReverseProxy(log.New("test"), func(req *http.Request) {
req.Header.Set("X-KEY", "value")
})
require.NotNil(t, rp)
require.NotNil(t, rp.ModifyResponse)
rp.ServeHTTP(rec, req)
require.NotNil(t, actualReq)
require.Empty(t, actualReq.Header.Get("X-Forwarded-Host"))
require.Empty(t, actualReq.Header.Get("X-Forwarded-Port"))
require.Empty(t, actualReq.Header.Get("X-Forwarded-Proto"))
require.Equal(t, "10.0.0.1", actualReq.Header.Get("X-Forwarded-For"))
require.Empty(t, actualReq.Header.Get("Origin"))
require.Empty(t, actualReq.Header.Get("Referer"))
require.Equal(t, "https://test.com/api", actualReq.Header.Get("X-Grafana-Referer"))
require.Equal(t, "value", actualReq.Header.Get("X-KEY"))
require.Empty(t, actualReq.Header.Get("Authorization"))
resp := rec.Result()
require.Empty(t, resp.Cookies())
require.Equal(t, "sandbox", resp.Header.Get("Content-Security-Policy"))
require.Contains(t, resp.Header, "X-Custom-Hdr")
require.NotContains(t, resp.Header, "Strict-Transport-Security")
require.Contains(t, resp.Header.Get("Via"), "grafana")
require.NoError(t, resp.Body.Close())
})
t.Run("When proxying a request using WithModifyResponse should call it before default ModifyResponse func", func(t *testing.T) {
var actualReq *http.Request
upstream := newUpstreamServer(t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
actualReq = req
http.SetCookie(w, &http.Cookie{Name: "test"})
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(upstream.Close)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, upstream.URL, nil)
rp := NewReverseProxy(
log.New("test"),
func(req *http.Request) {
req.Header.Set("X-KEY", "value")
},
WithModifyResponse(func(r *http.Response) error {
r.Header.Set("X-KEY2", "value2")
return nil
}),
)
require.NotNil(t, rp)
require.NotNil(t, rp.ModifyResponse)
rp.ServeHTTP(rec, req)
require.NotNil(t, actualReq)
require.Equal(t, "value", actualReq.Header.Get("X-KEY"))
resp := rec.Result()
require.Empty(t, resp.Cookies())
require.Equal(t, "sandbox", resp.Header.Get("Content-Security-Policy"))
require.Equal(t, "value2", resp.Header.Get("X-KEY2"))
require.NoError(t, resp.Body.Close())
})
t.Run("Error handling should convert status codes depending on what kind of error it is and set downstream status source", func(t *testing.T) {
timedOutTransport := http.DefaultTransport.(*http.Transport)
timedOutTransport.ResponseHeaderTimeout = time.Millisecond
testCases := []struct {
desc string
transport http.RoundTripper
responseWaitTime time.Duration
expectedStatusCode int
}{
{
desc: "Cancelled request should return 499 Client closed request",
transport: &cancelledRoundTripper{},
expectedStatusCode: StatusClientClosedRequest,
},
{
desc: "Timed out request should return 504 Gateway timeout",
transport: timedOutTransport,
responseWaitTime: 100 * time.Millisecond,
expectedStatusCode: http.StatusGatewayTimeout,
},
{
desc: "Failed request should return 502 Bad gateway",
transport: &failingRoundTripper{},
expectedStatusCode: http.StatusBadGateway,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
upstream := newUpstreamServer(t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if tc.responseWaitTime > 0 {
time.Sleep(tc.responseWaitTime)
}
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(upstream.Close)
rec := httptest.NewRecorder()
ctx := requestmeta.SetRequestMetaData(context.Background(), requestmeta.RequestMetaData{
StatusSource: requestmeta.StatusSourceServer,
})
req := httptest.NewRequest(http.MethodGet, upstream.URL, nil)
req = req.WithContext(ctx)
rp := NewReverseProxy(
log.New("test"),
func(req *http.Request) {},
WithTransport(tc.transport),
)
require.NotNil(t, rp)
require.NotNil(t, rp.Transport)
require.Same(t, tc.transport, rp.Transport)
rp.ServeHTTP(rec, req)
resp := rec.Result()
require.Equal(t, tc.expectedStatusCode, resp.StatusCode)
require.NoError(t, resp.Body.Close())
rmd := requestmeta.GetRequestMetaData(ctx)
require.Equal(t, requestmeta.StatusSourceDownstream, rmd.StatusSource)
})
}
})
t.Run("5xx response status codes should set downstream status source", func(t *testing.T) {
testCases := []struct {
status int
expectedSource requestmeta.StatusSource
}{
{status: http.StatusOK, expectedSource: requestmeta.StatusSourceServer},
{status: http.StatusBadRequest, expectedSource: requestmeta.StatusSourceServer},
{status: http.StatusForbidden, expectedSource: requestmeta.StatusSourceServer},
{status: http.StatusUnauthorized, expectedSource: requestmeta.StatusSourceServer},
{status: http.StatusInternalServerError, expectedSource: requestmeta.StatusSourceDownstream},
{status: http.StatusBadGateway, expectedSource: requestmeta.StatusSourceDownstream},
{status: http.StatusGatewayTimeout, expectedSource: requestmeta.StatusSourceDownstream},
{status: 599, expectedSource: requestmeta.StatusSourceDownstream},
}
for _, testCase := range testCases {
tc := testCase
t.Run(fmt.Sprintf("status %d => source %s ", tc.status, tc.expectedSource), func(t *testing.T) {
upstream := newUpstreamServer(t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(tc.status)
}))
t.Cleanup(upstream.Close)
rec := httptest.NewRecorder()
ctx := requestmeta.SetRequestMetaData(context.Background(), requestmeta.RequestMetaData{
StatusSource: requestmeta.StatusSourceServer,
})
req := httptest.NewRequest(http.MethodGet, upstream.URL, nil)
req = req.WithContext(ctx)
rp := NewReverseProxy(
log.New("test"),
func(req *http.Request) {},
)
require.NotNil(t, rp)
rp.ServeHTTP(rec, req)
resp := rec.Result()
require.Equal(t, tc.status, resp.StatusCode)
require.NoError(t, resp.Body.Close())
rmd := requestmeta.GetRequestMetaData(ctx)
require.Equal(t, tc.expectedSource, rmd.StatusSource)
})
}
})
}
func newUpstreamServer(t *testing.T, handler http.Handler) *httptest.Server {
t.Helper()
upstream := httptest.NewServer(handler)
return upstream
}
type cancelledRoundTripper struct{}
func (cancelledRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
return nil, context.Canceled
}
type failingRoundTripper struct{}
func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
return nil, errors.New("some error")
}