From 1a66d2ddab11e7d0219040477ed2c0b95f87bfdb Mon Sep 17 00:00:00 2001 From: Trevor Whitney Date: Wed, 19 Nov 2025 17:05:30 -0700 Subject: [PATCH] feat: ability to send query context for limit enforcement (#19900) --- Makefile | 2 - pkg/querier/queryrange/codec.go | 10 +- pkg/querier/queryrange/limits.go | 37 +++++- pkg/querier/queryrange/limits_test.go | 140 +++++++++++++++++++++++ pkg/querier/queryrange/marshal.go | 23 +++- pkg/util/querylimits/grpc.go | 60 ++++++---- pkg/util/querylimits/grpc_test.go | 27 ++++- pkg/util/querylimits/limiter.go | 16 +-- pkg/util/querylimits/limiter_test.go | 10 +- pkg/util/querylimits/middleware.go | 12 +- pkg/util/querylimits/middleware_test.go | 63 +++++++++- pkg/util/querylimits/propagation.go | 82 +++++++++++-- pkg/util/querylimits/propagation_test.go | 66 ++++++++++- 13 files changed, 488 insertions(+), 60 deletions(-) diff --git a/Makefile b/Makefile index 945c9833e3..5c95c1a8c9 100644 --- a/Makefile +++ b/Makefile @@ -223,8 +223,6 @@ cmd/loki/loki: cmd/loki/loki-debug: CGO_ENABLED=0 go build $(DEBUG_GO_FLAGS) -o $@ ./$(@D) -ui-assets: - make -C pkg/ui/frontend build ############### # Loki-Canary # ############### diff --git a/pkg/querier/queryrange/codec.go b/pkg/querier/queryrange/codec.go index 12e4e74339..2ba6f4740b 100644 --- a/pkg/querier/queryrange/codec.go +++ b/pkg/querier/queryrange/codec.go @@ -717,13 +717,21 @@ func (c Codec) EncodeRequest(ctx context.Context, r queryrangebase.Request) (*ht } // Add limits - if limits := querylimits.ExtractQueryLimitsContext(ctx); limits != nil { + if limits := querylimits.ExtractQueryLimitsFromContext(ctx); limits != nil { err := querylimits.InjectQueryLimitsHeader(&header, limits) if err != nil { return nil, err } } + // Add limits context + if limitsCtx := querylimits.ExtractQueryLimitsContextFromContext(ctx); limitsCtx != nil { + err := querylimits.InjectQueryLimitsContextHeader(&header, limitsCtx) + if err != nil { + return nil, err + } + } + // Add org id orgID, err := user.ExtractOrgID(ctx) if err != nil { diff --git a/pkg/querier/queryrange/limits.go b/pkg/querier/queryrange/limits.go index 35753333de..e07f1d9882 100644 --- a/pkg/querier/queryrange/limits.go +++ b/pkg/querier/queryrange/limits.go @@ -37,6 +37,7 @@ import ( "github.com/grafana/loki/v3/pkg/util/constants" "github.com/grafana/loki/v3/pkg/util/httpreq" util_log "github.com/grafana/loki/v3/pkg/util/log" + "github.com/grafana/loki/v3/pkg/util/querylimits" "github.com/grafana/loki/v3/pkg/util/spanlogger" "github.com/grafana/loki/v3/pkg/util/validation" ) @@ -281,7 +282,30 @@ func (q *querySizeLimiter) getBytesReadForRequest(ctx context.Context, r queryra ctx, sp := tracer.Start(ctx, "querySizeLimiter.getBytesReadForRequest") defer sp.End() - expr, err := syntax.ParseExpr(r.GetQuery()) + queryLimitCtx := querylimits.ExtractQueryLimitsContextFromContext(ctx) + fullCtxBytes := uint64(0) + if queryLimitCtx != nil && queryLimitCtx.Expr != "" && !queryLimitCtx.From.IsZero() && !queryLimitCtx.To.IsZero() { + var err error + fullCtxBytes, err = q.getBytesForQueryAndRange(ctx, queryLimitCtx.Expr, queryLimitCtx.From, queryLimitCtx.To) + if err != nil { + return 0, nil + } + } + + queryBytes, err := q.getBytesForQueryAndRange(ctx, r.GetQuery(), r.GetStart(), r.GetEnd()) + if err != nil { + return 0, nil + } + + if fullCtxBytes > queryBytes { + return fullCtxBytes, nil + } + + return queryBytes, nil +} + +func (q *querySizeLimiter) getBytesForQueryAndRange(ctx context.Context, query string, from, to time.Time) (uint64, error) { + expr, err := syntax.ParseExpr(query) if err != nil { return 0, err } @@ -294,7 +318,16 @@ func (q *querySizeLimiter) getBytesReadForRequest(ctx context.Context, r queryra // TODO: Set concurrency dynamically as in shardResolverForConf? start := time.Now() const maxConcurrentIndexReq = 10 - matcherStats, err := getStatsForMatchers(ctx, q.logger, q.statsHandler, model.Time(r.GetStart().UnixMilli()), model.Time(r.GetEnd().UnixMilli()), matcherGroups, maxConcurrentIndexReq, q.maxLookBackPeriod) + matcherStats, err := getStatsForMatchers( + ctx, + q.logger, + q.statsHandler, + model.Time(from.UnixMilli()), + model.Time(to.UnixMilli()), + matcherGroups, + maxConcurrentIndexReq, + q.maxLookBackPeriod, + ) if err != nil { return 0, err } diff --git a/pkg/querier/queryrange/limits_test.go b/pkg/querier/queryrange/limits_test.go index f2eafd2ca1..d4ae99652c 100644 --- a/pkg/querier/queryrange/limits_test.go +++ b/pkg/querier/queryrange/limits_test.go @@ -29,6 +29,7 @@ import ( "github.com/grafana/loki/v3/pkg/util/constants" "github.com/grafana/loki/v3/pkg/util/httpreq" util_log "github.com/grafana/loki/v3/pkg/util/log" + "github.com/grafana/loki/v3/pkg/util/querylimits" ) func TestLimits(t *testing.T) { @@ -1051,6 +1052,145 @@ func Test_MaxQuerySize(t *testing.T) { } } +func Test_MaxQuerySize_WithQueryLimitsContext(t *testing.T) { + // a sentinal query value to control when our mock stats handler returns context stats + ctxSentinal := `{context="true"}` + schemas := []config.PeriodConfig{ + { + From: config.DayTime{Time: model.TimeFromUnix(testTime.Add(-48 * time.Hour).Unix())}, + IndexType: types.TSDBType, + }, + } + + for _, tc := range []struct { + desc string + query string + queryStart time.Time + queryEnd time.Time + queryBytes uint64 + contextStart time.Time + contextEnd time.Time + contextBytes uint64 + limit int + shouldErr bool + expectedStatsHits int + }{ + { + desc: "No context, query under limit", + query: `{app="foo"} |= "foo"`, + queryStart: testTime.Add(-1 * time.Hour), + queryEnd: testTime, + queryBytes: 500, + limit: 1000, + shouldErr: false, + expectedStatsHits: 1, + }, + { + desc: "Context range larger, both under limit", + query: `{app="foo"} |= "foo"`, + queryStart: testTime.Add(-1 * time.Hour), + queryEnd: testTime, + queryBytes: 200, + contextStart: testTime.Add(-24 * time.Hour), + contextEnd: testTime, + contextBytes: 800, + limit: 1000, + shouldErr: false, + expectedStatsHits: 2, + }, + { + desc: "Context range larger, context exceeds limit", + query: `{app="foo"} |= "foo"`, + queryStart: testTime.Add(-1 * time.Hour), + queryEnd: testTime, + queryBytes: 200, + contextStart: testTime.Add(-24 * time.Hour), + contextEnd: testTime, + contextBytes: 1200, + limit: 1000, + shouldErr: true, + expectedStatsHits: 2, + }, + { + desc: "Query range larger, query exceeds limit", + query: `{app="foo"} |= "foo"`, + queryStart: testTime.Add(-24 * time.Hour), + queryEnd: testTime, + queryBytes: 1200, + contextStart: testTime.Add(-1 * time.Hour), + contextEnd: testTime, + contextBytes: 200, + limit: 1000, + shouldErr: true, + expectedStatsHits: 2, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + statsHits := atomic.NewInt32(0) + + statsHandler := queryrangebase.HandlerFunc(func(_ context.Context, req queryrangebase.Request) (queryrangebase.Response, error) { + statsHits.Inc() + + bytes := tc.queryBytes + if req.GetQuery() == ctxSentinal { + bytes = tc.contextBytes + } + + return &IndexStatsResponse{ + Response: &logproto.IndexStatsResponse{ + Bytes: bytes, + }, + }, nil + }) + + promHandler := queryrangebase.HandlerFunc(func(_ context.Context, _ queryrangebase.Request) (queryrangebase.Response, error) { + return &LokiPromResponse{ + Response: &queryrangebase.PrometheusResponse{ + Status: "success", + }, + }, nil + }) + + lokiReq := &LokiRequest{ + Query: tc.query, + StartTs: tc.queryStart, + EndTs: tc.queryEnd, + Direction: logproto.FORWARD, + Path: "/query_range", + Plan: &plan.QueryPlan{ + AST: syntax.MustParseExpr(tc.query), + }, + } + + ctx := user.InjectOrgID(context.Background(), "foo") + + if !tc.contextStart.IsZero() && !tc.contextEnd.IsZero() { + ctx = querylimits.InjectQueryLimitsContextIntoContext(ctx, querylimits.Context{ + Expr: ctxSentinal, // a hack to make mocking the stats handler easier, irl this should be the same query as in the request + From: tc.contextStart, + To: tc.contextEnd, + }) + } + + middlewares := []queryrangebase.Middleware{ + NewQuerySizeLimiterMiddleware(schemas, testEngineOpts, util_log.Logger, fakeLimits{ + maxQueryBytesRead: tc.limit, + }, statsHandler), + } + + _, err := queryrangebase.MergeMiddlewares(middlewares...).Wrap(promHandler).Do(ctx, lokiReq) + + if tc.shouldErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + require.Equal(t, tc.expectedStatsHits, int(statsHits.Load())) + }) + } +} + func Test_MaxQuerySize_MaxLookBackPeriod(t *testing.T) { engineOpts := testEngineOpts engineOpts.MaxLookBackPeriod = 1 * time.Hour diff --git a/pkg/querier/queryrange/marshal.go b/pkg/querier/queryrange/marshal.go index 10fab0c1d2..4d2ae4de0a 100644 --- a/pkg/querier/queryrange/marshal.go +++ b/pkg/querier/queryrange/marshal.go @@ -342,7 +342,16 @@ func (Codec) QueryRequestUnwrap(ctx context.Context, req *QueryRequest) (queryra if err != nil { return nil, ctx, err } - ctx = querylimits.InjectQueryLimitsContext(ctx, *limits) + ctx = querylimits.InjectQueryLimitsIntoContext(ctx, *limits) + } + + // Add limits context + if encodedLimitsCtx, ok := req.Metadata[querylimits.HTTPHeaderQueryLimitsContextKey]; ok { + limitsCtx, err := querylimits.UnmarshalQueryLimitsContext([]byte(encodedLimitsCtx)) + if err != nil { + return nil, ctx, err + } + ctx = querylimits.InjectQueryLimitsContextIntoContext(ctx, *limitsCtx) } // Add query time @@ -454,7 +463,7 @@ func (Codec) QueryRequestWrap(ctx context.Context, r queryrangebase.Request) (*Q } // Add limits - limits := querylimits.ExtractQueryLimitsContext(ctx) + limits := querylimits.ExtractQueryLimitsFromContext(ctx) if limits != nil { encodedLimits, err := querylimits.MarshalQueryLimits(limits) if err != nil { @@ -463,6 +472,16 @@ func (Codec) QueryRequestWrap(ctx context.Context, r queryrangebase.Request) (*Q result.Metadata[querylimits.HTTPHeaderQueryLimitsKey] = string(encodedLimits) } + // Add limits context + limitsCtx := querylimits.ExtractQueryLimitsContextFromContext(ctx) + if limitsCtx != nil { + encodedLimitsCtx, err := querylimits.MarshalQueryLimitsContext(limitsCtx) + if err != nil { + return nil, err + } + result.Metadata[querylimits.HTTPHeaderQueryLimitsContextKey] = string(encodedLimitsCtx) + } + // Add org ID orgID, err := user.ExtractOrgID(ctx) if err != nil { diff --git a/pkg/util/querylimits/grpc.go b/pkg/util/querylimits/grpc.go index b9c2c63610..c07bcdd277 100644 --- a/pkg/util/querylimits/grpc.go +++ b/pkg/util/querylimits/grpc.go @@ -2,35 +2,43 @@ package querylimits import ( "context" - "fmt" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) const ( - lowerQueryLimitsHeaderName = "x-loki-query-limits" + lowerQueryLimitsHeaderName = "x-loki-query-limits" + lowerQueryLimitsContextHeaderName = "x-loki-query-limits-context" ) func injectIntoGRPCRequest(ctx context.Context) (context.Context, error) { - limits := ExtractQueryLimitsContext(ctx) - fmt.Printf("extract limits grpc: %v", limits) - if limits == nil { - return ctx, nil - } // inject into GRPC metadata md, ok := metadata.FromOutgoingContext(ctx) if !ok { md = metadata.New(map[string]string{}) } md = md.Copy() - headerValue, err := MarshalQueryLimits(limits) - if err != nil { - return nil, err + + limits := ExtractQueryLimitsFromContext(ctx) + if limits != nil { + headerValue, err := MarshalQueryLimits(limits) + if err != nil { + return nil, err + } + md.Set(lowerQueryLimitsHeaderName, string(headerValue)) + } + + limitsCtx := ExtractQueryLimitsContextFromContext(ctx) + if limitsCtx != nil { + headerValue, err := MarshalQueryLimitsContext(limitsCtx) + if err != nil { + return nil, err + } + md.Set(lowerQueryLimitsContextHeaderName, string(headerValue)) } - md.Set(lowerQueryLimitsHeaderName, string(headerValue)) - newCtx := metadata.NewOutgoingContext(ctx, md) + newCtx := metadata.NewOutgoingContext(ctx, md) return newCtx, nil } @@ -60,21 +68,27 @@ func extractFromGRPCRequest(ctx context.Context) (context.Context, error) { } headerValues, ok := md[lowerQueryLimitsHeaderName] - if !ok { - // No QueryLimits header in metadata, just return context - return ctx, nil + if ok && len(headerValues) > 0 { + // Pick first header + limits, err := UnmarshalQueryLimits([]byte(headerValues[0])) + if err != nil { + return ctx, err + } + ctx = InjectQueryLimitsIntoContext(ctx, *limits) } - if len(headerValues) == 0 { - return ctx, nil + // Extract QueryLimitsContext if present + headerContextValues, ok := md[lowerQueryLimitsContextHeaderName] + if ok && len(headerContextValues) > 0 { + // Pick first header + limitsCtx, err := UnmarshalQueryLimitsContext([]byte(headerContextValues[0])) + if err != nil { + return ctx, err + } + ctx = InjectQueryLimitsContextIntoContext(ctx, *limitsCtx) } - // Pick first header - limits, err := UnmarshalQueryLimits([]byte(headerValues[0])) - if err != nil { - return ctx, err - } - return InjectQueryLimitsContext(ctx, *limits), nil + return ctx, nil } func ServerQueryLimitsInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { diff --git a/pkg/util/querylimits/grpc_test.go b/pkg/util/querylimits/grpc_test.go index 07dd8ec083..af9c40c7e9 100644 --- a/pkg/util/querylimits/grpc_test.go +++ b/pkg/util/querylimits/grpc_test.go @@ -16,16 +16,37 @@ func TestGRPCQueryLimits(t *testing.T) { MaxQueryLookback: model.Duration(14 * 24 * time.Hour), MaxEntriesLimitPerQuery: 100, } - c1 := InjectQueryLimitsContext(context.Background(), limits) + c1 := InjectQueryLimitsIntoContext(context.Background(), limits) c1, err = injectIntoGRPCRequest(c1) require.NoError(t, err) c2, err := extractFromGRPCRequest(c1) require.NoError(t, err) - require.Equal(t, limits, *(c2.Value(queryLimitsContextKey).(*QueryLimits))) + require.Equal(t, limits, *(c2.Value(queryLimitsCtxKey).(*QueryLimits))) c3, err := extractFromGRPCRequest(context.Background()) require.NoError(t, err) - require.Nil(t, c3.Value(queryLimitsContextKey)) + require.Nil(t, c3.Value(queryLimitsCtxKey)) +} + +func TestGRPCQueryLimitsContext(t *testing.T) { + var err error + limitsCtx := Context{ + Expr: "{app=\"test\"}", + From: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + To: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC), + } + c1 := InjectQueryLimitsContextIntoContext(context.Background(), limitsCtx) + + c1, err = injectIntoGRPCRequest(c1) + require.NoError(t, err) + + c2, err := extractFromGRPCRequest(c1) + require.NoError(t, err) + require.Equal(t, limitsCtx, *(c2.Value(queryLimitsContextCtxKey).(*Context))) + + c3, err := extractFromGRPCRequest(context.Background()) + require.NoError(t, err) + require.Nil(t, c3.Value(queryLimitsContextCtxKey)) } diff --git a/pkg/util/querylimits/limiter.go b/pkg/util/querylimits/limiter.go index 439f1d927d..1a78df1959 100644 --- a/pkg/util/querylimits/limiter.go +++ b/pkg/util/querylimits/limiter.go @@ -27,7 +27,7 @@ func NewLimiter(log log.Logger, original limiter.CombinedLimits) *Limiter { // MaxQueryLength returns the limit of the length (in time) of a query. func (l *Limiter) MaxQueryLength(ctx context.Context, userID string) time.Duration { original := l.CombinedLimits.MaxQueryLength(ctx, userID) - requestLimits := ExtractQueryLimitsContext(ctx) + requestLimits := ExtractQueryLimitsFromContext(ctx) if requestLimits == nil || requestLimits.MaxQueryLength == 0 || (time.Duration(requestLimits.MaxQueryLength) > original && original != 0) { return original @@ -39,7 +39,7 @@ func (l *Limiter) MaxQueryLength(ctx context.Context, userID string) time.Durati // MaxQueryLookback returns the max lookback period of queries. func (l *Limiter) MaxQueryLookback(ctx context.Context, userID string) time.Duration { original := l.CombinedLimits.MaxQueryLookback(ctx, userID) - requestLimits := ExtractQueryLimitsContext(ctx) + requestLimits := ExtractQueryLimitsFromContext(ctx) if requestLimits == nil || requestLimits.MaxQueryLookback == 0 || (time.Duration(requestLimits.MaxQueryLookback) > original && original != 0) { return original @@ -51,7 +51,7 @@ func (l *Limiter) MaxQueryLookback(ctx context.Context, userID string) time.Dura // MaxQueryRange retruns the max query range/interval of a query. func (l *Limiter) MaxQueryRange(ctx context.Context, userID string) time.Duration { original := l.CombinedLimits.MaxQueryRange(ctx, userID) - requestLimits := ExtractQueryLimitsContext(ctx) + requestLimits := ExtractQueryLimitsFromContext(ctx) if requestLimits == nil || requestLimits.MaxQueryRange == 0 || (time.Duration(requestLimits.MaxQueryRange) > original && original != 0) { return original @@ -63,7 +63,7 @@ func (l *Limiter) MaxQueryRange(ctx context.Context, userID string) time.Duratio // MaxEntriesLimitPerQuery returns the limit to number of entries the querier should return per query. func (l *Limiter) MaxEntriesLimitPerQuery(ctx context.Context, userID string) int { original := l.CombinedLimits.MaxEntriesLimitPerQuery(ctx, userID) - requestLimits := ExtractQueryLimitsContext(ctx) + requestLimits := ExtractQueryLimitsFromContext(ctx) if requestLimits == nil || requestLimits.MaxEntriesLimitPerQuery == 0 || (requestLimits.MaxEntriesLimitPerQuery > original && original != 0) { return original @@ -75,7 +75,7 @@ func (l *Limiter) MaxEntriesLimitPerQuery(ctx context.Context, userID string) in func (l *Limiter) QueryTimeout(ctx context.Context, userID string) time.Duration { original := l.CombinedLimits.QueryTimeout(ctx, userID) // in theory this error should never happen - requestLimits := ExtractQueryLimitsContext(ctx) + requestLimits := ExtractQueryLimitsFromContext(ctx) if requestLimits == nil || requestLimits.QueryTimeout == 0 || (time.Duration(requestLimits.QueryTimeout) > original && original != 0) { return original @@ -86,7 +86,7 @@ func (l *Limiter) QueryTimeout(ctx context.Context, userID string) time.Duration func (l *Limiter) RequiredLabels(ctx context.Context, userID string) []string { original := l.CombinedLimits.RequiredLabels(ctx, userID) - requestLimits := ExtractQueryLimitsContext(ctx) + requestLimits := ExtractQueryLimitsFromContext(ctx) if requestLimits == nil { return original @@ -112,7 +112,7 @@ func (l *Limiter) RequiredLabels(ctx context.Context, userID string) []string { func (l *Limiter) RequiredNumberLabels(ctx context.Context, userID string) int { original := l.CombinedLimits.RequiredNumberLabels(ctx, userID) - requestLimits := ExtractQueryLimitsContext(ctx) + requestLimits := ExtractQueryLimitsFromContext(ctx) if requestLimits == nil || requestLimits.RequiredNumberLabels == 0 || requestLimits.RequiredNumberLabels < original { return original } @@ -122,7 +122,7 @@ func (l *Limiter) RequiredNumberLabels(ctx context.Context, userID string) int { func (l *Limiter) MaxQueryBytesRead(ctx context.Context, userID string) int { original := l.CombinedLimits.MaxQueryBytesRead(ctx, userID) - requestLimits := ExtractQueryLimitsContext(ctx) + requestLimits := ExtractQueryLimitsFromContext(ctx) if requestLimits == nil || requestLimits.MaxQueryBytesRead.Val() == 0 || (requestLimits.MaxQueryBytesRead.Val() > original && original != 0) { return original diff --git a/pkg/util/querylimits/limiter_test.go b/pkg/util/querylimits/limiter_test.go index f230d23189..b849b2f3a8 100644 --- a/pkg/util/querylimits/limiter_test.go +++ b/pkg/util/querylimits/limiter_test.go @@ -89,7 +89,7 @@ func TestLimiter_Defaults(t *testing.T) { MaxQueryBytesRead: 10, } { - ctx2 := InjectQueryLimitsContext(context.Background(), *limits) + ctx2 := InjectQueryLimitsIntoContext(context.Background(), *limits) queryLookback := l.MaxQueryLookback(ctx2, "fake") require.Equal(t, time.Duration(expectedLimits2.MaxQueryLookback), queryLookback) queryLength := l.MaxQueryLength(ctx2, "fake") @@ -143,7 +143,7 @@ func TestLimiter_RejectHighLimits(t *testing.T) { RequiredNumberLabels: 100, } - ctx := InjectQueryLimitsContext(context.Background(), limits) + ctx := InjectQueryLimitsIntoContext(context.Background(), limits) require.Equal(t, time.Duration(expectedLimits.MaxQueryLookback), l.MaxQueryLookback(ctx, "fake")) require.Equal(t, time.Duration(expectedLimits.MaxQueryLength), l.MaxQueryLength(ctx, "fake")) require.Equal(t, expectedLimits.MaxEntriesLimitPerQuery, l.MaxEntriesLimitPerQuery(ctx, "fake")) @@ -180,7 +180,7 @@ func TestLimiter_AcceptLowerLimits(t *testing.T) { RequiredNumberLabels: 10, } - ctx := InjectQueryLimitsContext(context.Background(), limits) + ctx := InjectQueryLimitsIntoContext(context.Background(), limits) require.Equal(t, time.Duration(limits.MaxQueryLookback), l.MaxQueryLookback(ctx, "fake")) require.Equal(t, time.Duration(limits.MaxQueryLength), l.MaxQueryLength(ctx, "fake")) require.Equal(t, limits.MaxEntriesLimitPerQuery, l.MaxEntriesLimitPerQuery(ctx, "fake")) @@ -218,7 +218,7 @@ func TestLimiter_AcceptRequestLimitsOverNotInitializedLimits(t *testing.T) { RequiredNumberLabels: 10, } - ctx := InjectQueryLimitsContext(context.Background(), limits) + ctx := InjectQueryLimitsIntoContext(context.Background(), limits) require.Equal(t, time.Duration(limits.MaxQueryLookback), l.MaxQueryLookback(ctx, "fake")) require.Equal(t, time.Duration(limits.MaxQueryLength), l.MaxQueryLength(ctx, "fake")) require.Equal(t, limits.MaxEntriesLimitPerQuery, l.MaxEntriesLimitPerQuery(ctx, "fake")) @@ -243,7 +243,7 @@ func TestLimiter_MergeLimits(t *testing.T) { require.ElementsMatch(t, []string{"one", "two"}, l.RequiredLabels(context.Background(), "fake")) - ctx := InjectQueryLimitsContext(context.Background(), limits) + ctx := InjectQueryLimitsIntoContext(context.Background(), limits) require.ElementsMatch(t, []string{"one", "two", "three"}, l.RequiredLabels(ctx, "fake")) } diff --git a/pkg/util/querylimits/middleware.go b/pkg/util/querylimits/middleware.go index a25d53949b..3062709f01 100644 --- a/pkg/util/querylimits/middleware.go +++ b/pkg/util/querylimits/middleware.go @@ -32,7 +32,17 @@ func (l *queryLimitsMiddleware) Wrap(next http.Handler) http.Handler { } if limits != nil { - r = r.Clone(InjectQueryLimitsContext(r.Context(), *limits)) + r = r.Clone(InjectQueryLimitsIntoContext(r.Context(), *limits)) + } + + limitsCtx, err := ExtractQueryLimitsContextHTTP(r) + if err != nil { + level.Warn(util_log.Logger).Log("msg", "could not extract query limits context from header", "err", err) + limitsCtx = nil + } + + if limitsCtx != nil { + r = r.Clone(InjectQueryLimitsContextIntoContext(r.Context(), *limitsCtx)) } next.ServeHTTP(w, r) diff --git a/pkg/util/querylimits/middleware_test.go b/pkg/util/querylimits/middleware_test.go index acea9fd5d3..7e498deb6e 100644 --- a/pkg/util/querylimits/middleware_test.go +++ b/pkg/util/querylimits/middleware_test.go @@ -13,7 +13,7 @@ import ( func Test_MiddlewareWithoutHeader(t *testing.T) { nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - limits := ExtractQueryLimitsContext(r.Context()) + limits := ExtractQueryLimitsFromContext(r.Context()) require.Nil(t, limits) }) m := NewQueryLimitsMiddleware(log.NewNopLogger()) @@ -29,7 +29,7 @@ func Test_MiddlewareWithoutHeader(t *testing.T) { func Test_MiddlewareWithBrokenHeader(t *testing.T) { nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - limits := ExtractQueryLimitsContext(r.Context()) + limits := ExtractQueryLimitsFromContext(r.Context()) require.Nil(t, limits) }) m := NewQueryLimitsMiddleware(log.NewNopLogger()) @@ -57,7 +57,7 @@ func Test_MiddlewareWithHeader(t *testing.T) { } nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - actual := ExtractQueryLimitsContext(r.Context()) + actual := ExtractQueryLimitsFromContext(r.Context()) require.Equal(t, limits, *actual) }) m := NewQueryLimitsMiddleware(log.NewNopLogger()) @@ -72,3 +72,60 @@ func Test_MiddlewareWithHeader(t *testing.T) { response := rr.Result() require.Equal(t, http.StatusOK, response.StatusCode) } + +func Test_MiddlewareWithoutContextHeader(t *testing.T) { + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + limitsCtx := ExtractQueryLimitsContextFromContext(r.Context()) + require.Nil(t, limitsCtx) + }) + m := NewQueryLimitsMiddleware(log.NewNopLogger()) + wrapped := m.Wrap(nextHandler) + + rr := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/example", nil) + require.NoError(t, err) + wrapped.ServeHTTP(rr, r) + response := rr.Result() + require.Equal(t, http.StatusOK, response.StatusCode) +} + +func Test_MiddlewareWithBrokenContextHeader(t *testing.T) { + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + limitsCtx := ExtractQueryLimitsContextFromContext(r.Context()) + require.Nil(t, limitsCtx) + }) + m := NewQueryLimitsMiddleware(log.NewNopLogger()) + wrapped := m.Wrap(nextHandler) + + rr := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/example", nil) + require.NoError(t, err) + r.Header.Add(HTTPHeaderQueryLimitsContextKey, "{broken}") + wrapped.ServeHTTP(rr, r) + response := rr.Result() + require.Equal(t, http.StatusOK, response.StatusCode) +} + +func Test_MiddlewareWithContextHeader(t *testing.T) { + limitsCtx := Context{ + Expr: "{app=\"test\"}", + From: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + To: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC), + } + + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + actual := ExtractQueryLimitsContextFromContext(r.Context()) + require.Equal(t, limitsCtx, *actual) + }) + m := NewQueryLimitsMiddleware(log.NewNopLogger()) + wrapped := m.Wrap(nextHandler) + + rr := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/example", nil) + require.NoError(t, err) + err = InjectQueryLimitsContextHTTP(r, &limitsCtx) + require.NoError(t, err) + wrapped.ServeHTTP(rr, r) + response := rr.Result() + require.Equal(t, http.StatusOK, response.StatusCode) +} diff --git a/pkg/util/querylimits/propagation.go b/pkg/util/querylimits/propagation.go index a9cb06e347..26fdba7354 100644 --- a/pkg/util/querylimits/propagation.go +++ b/pkg/util/querylimits/propagation.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "net/http" + "time" "github.com/prometheus/common/model" @@ -14,9 +15,11 @@ import ( type key int const ( - queryLimitsContextKey key = 1 + queryLimitsCtxKey key = 1 + queryLimitsContextCtxKey key = 2 - HTTPHeaderQueryLimitsKey = "X-Loki-Query-Limits" + HTTPHeaderQueryLimitsKey = "X-Loki-Query-Limits" + HTTPHeaderQueryLimitsContextKey = "X-Loki-Query-Limits-Context" ) // NOTE: we use custom `model.Duration` instead of standard `time.Duration` because, @@ -73,9 +76,9 @@ func ExtractQueryLimitsHTTP(r *http.Request) (*QueryLimits, error) { return nil, nil } -// ExtractQueryLimitsContext gets the embedded limits from the context -func ExtractQueryLimitsContext(ctx context.Context) *QueryLimits { - source, ok := ctx.Value(queryLimitsContextKey).(*QueryLimits) +// ExtractQueryLimitsFromContext gets the embedded limits from the context +func ExtractQueryLimitsFromContext(ctx context.Context) *QueryLimits { + source, ok := ctx.Value(queryLimitsCtxKey).(*QueryLimits) if !ok { return nil @@ -84,7 +87,70 @@ func ExtractQueryLimitsContext(ctx context.Context) *QueryLimits { return source } -// InjectQueryLimitsContext returns a derived context containing the provided query limits -func InjectQueryLimitsContext(ctx context.Context, limits QueryLimits) context.Context { - return context.WithValue(ctx, interface{}(queryLimitsContextKey), &limits) +// InjectQueryLimitsIntoContext returns a derived context containing the provided query limits +func InjectQueryLimitsIntoContext(ctx context.Context, limits QueryLimits) context.Context { + return context.WithValue(ctx, interface{}(queryLimitsCtxKey), &limits) +} + +type Context struct { + Expr string `json:"expr"` + From time.Time `json:"from"` + To time.Time `json:"to"` +} + +func UnmarshalQueryLimitsContext(data []byte) (*Context, error) { + limitsCtx := &Context{} + err := json.Unmarshal(data, limitsCtx) + return limitsCtx, err +} + +func MarshalQueryLimitsContext(limits *Context) ([]byte, error) { + return json.Marshal(limits) +} + +// InjectQueryLimitsContextHTTP adds the query limits context to the request headers. +func InjectQueryLimitsContextHTTP(r *http.Request, limitsCtx *Context) error { + return InjectQueryLimitsContextHeader(&r.Header, limitsCtx) +} + +// InjectQueryLimitsContextHeader adds the query limits context to the headers. +func InjectQueryLimitsContextHeader(h *http.Header, limitsCtx *Context) error { + // Ensure any existing policy sets are erased + h.Del(HTTPHeaderQueryLimitsContextKey) + + encodedLimits, err := MarshalQueryLimitsContext(limitsCtx) + if err != nil { + return err + } + h.Add(HTTPHeaderQueryLimitsContextKey, string(encodedLimits)) + return nil +} + +// ExtractQueryLimitsContextHTTP retrieves the query limits context from the HTTP header and returns it. +func ExtractQueryLimitsContextHTTP(r *http.Request) (*Context, error) { + headerValues := r.Header.Values(HTTPHeaderQueryLimitsContextKey) + + // Iterate through each set header value + for _, headerValue := range headerValues { + return UnmarshalQueryLimitsContext([]byte(headerValue)) + + } + + return nil, nil +} + +// ExtractQueryLimitsContextFromContext gets the embedded query limits context from the context +func ExtractQueryLimitsContextFromContext(ctx context.Context) *Context { + source, ok := ctx.Value(queryLimitsContextCtxKey).(*Context) + + if !ok { + return nil + } + + return source +} + +// InjectQueryLimitsContextIntoContext returns a derived context containing the provided query limits context +func InjectQueryLimitsContextIntoContext(ctx context.Context, limitsCtx Context) context.Context { + return context.WithValue(ctx, any(queryLimitsContextCtxKey), &limitsCtx) } diff --git a/pkg/util/querylimits/propagation_test.go b/pkg/util/querylimits/propagation_test.go index 31b7e7dc77..7569d3df0c 100644 --- a/pkg/util/querylimits/propagation_test.go +++ b/pkg/util/querylimits/propagation_test.go @@ -18,8 +18,8 @@ func TestInjectAndExtractQueryLimits(t *testing.T) { QueryTimeout: model.Duration(5 * time.Second), } - ctx = InjectQueryLimitsContext(ctx, limits) - res := ExtractQueryLimitsContext(ctx) + ctx = InjectQueryLimitsIntoContext(ctx, limits) + res := ExtractQueryLimitsFromContext(ctx) require.Equal(t, limits, *res) } @@ -70,3 +70,65 @@ func TestSerializingQueryLimits(t *testing.T) { expected = `{"maxEntriesLimitPerQuery": 100, "maxQueryLength": "2d", "maxQueryLookback": "2w"}` require.JSONEq(t, expected, string(actual)) } + +func TestInjectAndExtractQueryLimitsContext(t *testing.T) { + ctx := context.Background() + baseTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + limitsCtx := Context{ + Expr: `{job="app"}`, + From: baseTime, + To: baseTime.Add(1 * time.Hour), + } + + ctx = InjectQueryLimitsContextIntoContext(ctx, limitsCtx) + res := ExtractQueryLimitsContextFromContext(ctx) + require.Equal(t, limitsCtx, *res) +} + +func TestDeserializingQueryLimitsContext(t *testing.T) { + baseTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + endTime := baseTime.Add(1 * time.Hour) + + // full context + payload := `{"expr": "{job=\"app\"}", "from": "2024-01-15T10:30:00Z", "to": "2024-01-15T11:30:00Z"}` + limitsCtx, err := UnmarshalQueryLimitsContext([]byte(payload)) + require.NoError(t, err) + require.Equal(t, `{job="app"}`, limitsCtx.Expr) + require.Equal(t, baseTime.Unix(), limitsCtx.From.Unix()) + require.Equal(t, endTime.Unix(), limitsCtx.To.Unix()) + + // some fields are empty + payload = `{"expr": "rate({job=\"app\"}[5m])"}` + limitsCtx, err = UnmarshalQueryLimitsContext([]byte(payload)) + require.NoError(t, err) + require.Equal(t, `rate({job="app"}[5m])`, limitsCtx.Expr) + require.True(t, limitsCtx.From.IsZero()) + require.True(t, limitsCtx.To.IsZero()) +} + +func TestSerializingQueryLimitsContext(t *testing.T) { + baseTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + endTime := baseTime.Add(1 * time.Hour) + + // full struct + limitsCtx := Context{ + Expr: `{job="app"}`, + From: baseTime, + To: endTime, + } + + actual, err := MarshalQueryLimitsContext(&limitsCtx) + require.NoError(t, err) + expected := `{"expr": "{job=\"app\"}", "from": "2024-01-15T10:30:00Z", "to": "2024-01-15T11:30:00Z"}` + require.JSONEq(t, expected, string(actual)) + + // some fields are empty + limitsCtx = Context{ + Expr: `rate({job="app"}[5m])`, + } + + actual, err = MarshalQueryLimitsContext(&limitsCtx) + require.NoError(t, err) + expected = `{"expr": "rate({job=\"app\"}[5m])", "from": "0001-01-01T00:00:00Z", "to": "0001-01-01T00:00:00Z"}` + require.JSONEq(t, expected, string(actual)) +}