From 45ffebd4d81303269bee9cc4e4e39bb82375e983 Mon Sep 17 00:00:00 2001 From: Karsten Jeschkies Date: Tue, 12 Apr 2022 14:38:16 +0200 Subject: [PATCH] Return HTTP 400 when multiple tenants are present in push. (#5800) --- pkg/distributor/http.go | 19 ++++++++++++------- pkg/loki/modules.go | 6 ++++++ pkg/querier/http_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 7 deletions(-) create mode 100644 pkg/querier/http_test.go diff --git a/pkg/distributor/http.go b/pkg/distributor/http.go index f9224d8c43..d8b778c087 100644 --- a/pkg/distributor/http.go +++ b/pkg/distributor/http.go @@ -19,10 +19,15 @@ import ( // PushHandler reads a snappy-compressed proto from the HTTP body. func (d *Distributor) PushHandler(w http.ResponseWriter, r *http.Request) { logger := util_log.WithContext(r.Context(), util_log.Logger) - userID, _ := tenant.TenantID(r.Context()) - req, err := push.ParseRequest(logger, userID, r, d.tenantsRetention) + tenantID, err := tenant.TenantID(r.Context()) if err != nil { - if d.tenantConfigs.LogPushRequest(userID) { + level.Error(logger).Log("msg", "error getting tenant id", "err", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + req, err := push.ParseRequest(logger, tenantID, r, d.tenantsRetention) + if err != nil { + if d.tenantConfigs.LogPushRequest(tenantID) { level.Debug(logger).Log( "msg", "push request failed", "code", http.StatusBadRequest, @@ -33,7 +38,7 @@ func (d *Distributor) PushHandler(w http.ResponseWriter, r *http.Request) { return } - if d.tenantConfigs.LogPushRequestStreams(userID) { + if d.tenantConfigs.LogPushRequestStreams(tenantID) { var sb strings.Builder for _, s := range req.Streams { sb.WriteString(s.Labels) @@ -46,7 +51,7 @@ func (d *Distributor) PushHandler(w http.ResponseWriter, r *http.Request) { _, err = d.Push(r.Context(), req) if err == nil { - if d.tenantConfigs.LogPushRequest(userID) { + if d.tenantConfigs.LogPushRequest(tenantID) { level.Debug(logger).Log( "msg", "push request successful", ) @@ -58,7 +63,7 @@ func (d *Distributor) PushHandler(w http.ResponseWriter, r *http.Request) { resp, ok := httpgrpc.HTTPResponseFromError(err) if ok { body := string(resp.Body) - if d.tenantConfigs.LogPushRequest(userID) { + if d.tenantConfigs.LogPushRequest(tenantID) { level.Debug(logger).Log( "msg", "push request failed", "code", resp.Code, @@ -67,7 +72,7 @@ func (d *Distributor) PushHandler(w http.ResponseWriter, r *http.Request) { } http.Error(w, body, int(resp.Code)) } else { - if d.tenantConfigs.LogPushRequest(userID) { + if d.tenantConfigs.LogPushRequest(tenantID) { level.Debug(logger).Log( "msg", "push request failed", "code", http.StatusInternalServerError, diff --git a/pkg/loki/modules.go b/pkg/loki/modules.go index 98b8499c77..90f4faebae 100644 --- a/pkg/loki/modules.go +++ b/pkg/loki/modules.go @@ -205,6 +205,12 @@ func (t *Loki) initDistributor() (services.Service, error) { logproto.RegisterPusherServer(t.Server.GRPC, t.distributor) } + // If the querier module is not part of this process we need to check if multi-tenant queries are enabled. + // If the querier module is part of this process the querier module will configure everything. + if !t.Cfg.isModuleEnabled(Querier) && t.Cfg.Querier.MultiTenantQueriesEnabled { + tenant.WithDefaultResolver(tenant.NewMultiResolver()) + } + pushHandler := middleware.Merge( serverutil.RecoveryHTTPMiddleware, t.HTTPAuthMiddleware, diff --git a/pkg/querier/http_test.go b/pkg/querier/http_test.go new file mode 100644 index 0000000000..32315351d4 --- /dev/null +++ b/pkg/querier/http_test.go @@ -0,0 +1,36 @@ +package querier + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-kit/log" + "github.com/grafana/dskit/tenant" + "github.com/stretchr/testify/require" + "github.com/weaveworks/common/user" + + "github.com/grafana/loki/pkg/validation" +) + +func TestTailHandler(t *testing.T) { + tenant.WithDefaultResolver(tenant.NewMultiResolver()) + + defaultLimits := defaultLimitsTestConfig() + limits, err := validation.NewOverrides(defaultLimits, nil) + require.NoError(t, err) + + api := NewQuerierAPI(mockQuerierConfig(), nil, limits, log.NewNopLogger()) + + req, err := http.NewRequest("GET", "/", nil) + ctx := user.InjectOrgID(req.Context(), "1|2") + req = req.WithContext(ctx) + require.NoError(t, err) + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(api.TailHandler) + + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusBadRequest, rr.Code) + require.Equal(t, "multiple org IDs present\n", rr.Body.String()) +}