feat: ability to send query context for limit enforcement (#19900)

pull/19871/head^2
Trevor Whitney 2 months ago committed by GitHub
parent 06da42a8ac
commit 1a66d2ddab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      Makefile
  2. 10
      pkg/querier/queryrange/codec.go
  3. 37
      pkg/querier/queryrange/limits.go
  4. 140
      pkg/querier/queryrange/limits_test.go
  5. 23
      pkg/querier/queryrange/marshal.go
  6. 60
      pkg/util/querylimits/grpc.go
  7. 27
      pkg/util/querylimits/grpc_test.go
  8. 16
      pkg/util/querylimits/limiter.go
  9. 10
      pkg/util/querylimits/limiter_test.go
  10. 12
      pkg/util/querylimits/middleware.go
  11. 63
      pkg/util/querylimits/middleware_test.go
  12. 82
      pkg/util/querylimits/propagation.go
  13. 66
      pkg/util/querylimits/propagation_test.go

@ -223,8 +223,6 @@ cmd/loki/loki:
cmd/loki/loki-debug:
CGO_ENABLED=0 go build $(DEBUG_GO_FLAGS) -o $@ ./$(@D)
ui-assets:
make -C pkg/ui/frontend build
###############
# Loki-Canary #
###############

@ -717,13 +717,21 @@ func (c Codec) EncodeRequest(ctx context.Context, r queryrangebase.Request) (*ht
}
// Add limits
if limits := querylimits.ExtractQueryLimitsContext(ctx); limits != nil {
if limits := querylimits.ExtractQueryLimitsFromContext(ctx); limits != nil {
err := querylimits.InjectQueryLimitsHeader(&header, limits)
if err != nil {
return nil, err
}
}
// Add limits context
if limitsCtx := querylimits.ExtractQueryLimitsContextFromContext(ctx); limitsCtx != nil {
err := querylimits.InjectQueryLimitsContextHeader(&header, limitsCtx)
if err != nil {
return nil, err
}
}
// Add org id
orgID, err := user.ExtractOrgID(ctx)
if err != nil {

@ -37,6 +37,7 @@ import (
"github.com/grafana/loki/v3/pkg/util/constants"
"github.com/grafana/loki/v3/pkg/util/httpreq"
util_log "github.com/grafana/loki/v3/pkg/util/log"
"github.com/grafana/loki/v3/pkg/util/querylimits"
"github.com/grafana/loki/v3/pkg/util/spanlogger"
"github.com/grafana/loki/v3/pkg/util/validation"
)
@ -281,7 +282,30 @@ func (q *querySizeLimiter) getBytesReadForRequest(ctx context.Context, r queryra
ctx, sp := tracer.Start(ctx, "querySizeLimiter.getBytesReadForRequest")
defer sp.End()
expr, err := syntax.ParseExpr(r.GetQuery())
queryLimitCtx := querylimits.ExtractQueryLimitsContextFromContext(ctx)
fullCtxBytes := uint64(0)
if queryLimitCtx != nil && queryLimitCtx.Expr != "" && !queryLimitCtx.From.IsZero() && !queryLimitCtx.To.IsZero() {
var err error
fullCtxBytes, err = q.getBytesForQueryAndRange(ctx, queryLimitCtx.Expr, queryLimitCtx.From, queryLimitCtx.To)
if err != nil {
return 0, nil
}
}
queryBytes, err := q.getBytesForQueryAndRange(ctx, r.GetQuery(), r.GetStart(), r.GetEnd())
if err != nil {
return 0, nil
}
if fullCtxBytes > queryBytes {
return fullCtxBytes, nil
}
return queryBytes, nil
}
func (q *querySizeLimiter) getBytesForQueryAndRange(ctx context.Context, query string, from, to time.Time) (uint64, error) {
expr, err := syntax.ParseExpr(query)
if err != nil {
return 0, err
}
@ -294,7 +318,16 @@ func (q *querySizeLimiter) getBytesReadForRequest(ctx context.Context, r queryra
// TODO: Set concurrency dynamically as in shardResolverForConf?
start := time.Now()
const maxConcurrentIndexReq = 10
matcherStats, err := getStatsForMatchers(ctx, q.logger, q.statsHandler, model.Time(r.GetStart().UnixMilli()), model.Time(r.GetEnd().UnixMilli()), matcherGroups, maxConcurrentIndexReq, q.maxLookBackPeriod)
matcherStats, err := getStatsForMatchers(
ctx,
q.logger,
q.statsHandler,
model.Time(from.UnixMilli()),
model.Time(to.UnixMilli()),
matcherGroups,
maxConcurrentIndexReq,
q.maxLookBackPeriod,
)
if err != nil {
return 0, err
}

@ -29,6 +29,7 @@ import (
"github.com/grafana/loki/v3/pkg/util/constants"
"github.com/grafana/loki/v3/pkg/util/httpreq"
util_log "github.com/grafana/loki/v3/pkg/util/log"
"github.com/grafana/loki/v3/pkg/util/querylimits"
)
func TestLimits(t *testing.T) {
@ -1051,6 +1052,145 @@ func Test_MaxQuerySize(t *testing.T) {
}
}
func Test_MaxQuerySize_WithQueryLimitsContext(t *testing.T) {
// a sentinal query value to control when our mock stats handler returns context stats
ctxSentinal := `{context="true"}`
schemas := []config.PeriodConfig{
{
From: config.DayTime{Time: model.TimeFromUnix(testTime.Add(-48 * time.Hour).Unix())},
IndexType: types.TSDBType,
},
}
for _, tc := range []struct {
desc string
query string
queryStart time.Time
queryEnd time.Time
queryBytes uint64
contextStart time.Time
contextEnd time.Time
contextBytes uint64
limit int
shouldErr bool
expectedStatsHits int
}{
{
desc: "No context, query under limit",
query: `{app="foo"} |= "foo"`,
queryStart: testTime.Add(-1 * time.Hour),
queryEnd: testTime,
queryBytes: 500,
limit: 1000,
shouldErr: false,
expectedStatsHits: 1,
},
{
desc: "Context range larger, both under limit",
query: `{app="foo"} |= "foo"`,
queryStart: testTime.Add(-1 * time.Hour),
queryEnd: testTime,
queryBytes: 200,
contextStart: testTime.Add(-24 * time.Hour),
contextEnd: testTime,
contextBytes: 800,
limit: 1000,
shouldErr: false,
expectedStatsHits: 2,
},
{
desc: "Context range larger, context exceeds limit",
query: `{app="foo"} |= "foo"`,
queryStart: testTime.Add(-1 * time.Hour),
queryEnd: testTime,
queryBytes: 200,
contextStart: testTime.Add(-24 * time.Hour),
contextEnd: testTime,
contextBytes: 1200,
limit: 1000,
shouldErr: true,
expectedStatsHits: 2,
},
{
desc: "Query range larger, query exceeds limit",
query: `{app="foo"} |= "foo"`,
queryStart: testTime.Add(-24 * time.Hour),
queryEnd: testTime,
queryBytes: 1200,
contextStart: testTime.Add(-1 * time.Hour),
contextEnd: testTime,
contextBytes: 200,
limit: 1000,
shouldErr: true,
expectedStatsHits: 2,
},
} {
t.Run(tc.desc, func(t *testing.T) {
statsHits := atomic.NewInt32(0)
statsHandler := queryrangebase.HandlerFunc(func(_ context.Context, req queryrangebase.Request) (queryrangebase.Response, error) {
statsHits.Inc()
bytes := tc.queryBytes
if req.GetQuery() == ctxSentinal {
bytes = tc.contextBytes
}
return &IndexStatsResponse{
Response: &logproto.IndexStatsResponse{
Bytes: bytes,
},
}, nil
})
promHandler := queryrangebase.HandlerFunc(func(_ context.Context, _ queryrangebase.Request) (queryrangebase.Response, error) {
return &LokiPromResponse{
Response: &queryrangebase.PrometheusResponse{
Status: "success",
},
}, nil
})
lokiReq := &LokiRequest{
Query: tc.query,
StartTs: tc.queryStart,
EndTs: tc.queryEnd,
Direction: logproto.FORWARD,
Path: "/query_range",
Plan: &plan.QueryPlan{
AST: syntax.MustParseExpr(tc.query),
},
}
ctx := user.InjectOrgID(context.Background(), "foo")
if !tc.contextStart.IsZero() && !tc.contextEnd.IsZero() {
ctx = querylimits.InjectQueryLimitsContextIntoContext(ctx, querylimits.Context{
Expr: ctxSentinal, // a hack to make mocking the stats handler easier, irl this should be the same query as in the request
From: tc.contextStart,
To: tc.contextEnd,
})
}
middlewares := []queryrangebase.Middleware{
NewQuerySizeLimiterMiddleware(schemas, testEngineOpts, util_log.Logger, fakeLimits{
maxQueryBytesRead: tc.limit,
}, statsHandler),
}
_, err := queryrangebase.MergeMiddlewares(middlewares...).Wrap(promHandler).Do(ctx, lokiReq)
if tc.shouldErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
require.Equal(t, tc.expectedStatsHits, int(statsHits.Load()))
})
}
}
func Test_MaxQuerySize_MaxLookBackPeriod(t *testing.T) {
engineOpts := testEngineOpts
engineOpts.MaxLookBackPeriod = 1 * time.Hour

@ -342,7 +342,16 @@ func (Codec) QueryRequestUnwrap(ctx context.Context, req *QueryRequest) (queryra
if err != nil {
return nil, ctx, err
}
ctx = querylimits.InjectQueryLimitsContext(ctx, *limits)
ctx = querylimits.InjectQueryLimitsIntoContext(ctx, *limits)
}
// Add limits context
if encodedLimitsCtx, ok := req.Metadata[querylimits.HTTPHeaderQueryLimitsContextKey]; ok {
limitsCtx, err := querylimits.UnmarshalQueryLimitsContext([]byte(encodedLimitsCtx))
if err != nil {
return nil, ctx, err
}
ctx = querylimits.InjectQueryLimitsContextIntoContext(ctx, *limitsCtx)
}
// Add query time
@ -454,7 +463,7 @@ func (Codec) QueryRequestWrap(ctx context.Context, r queryrangebase.Request) (*Q
}
// Add limits
limits := querylimits.ExtractQueryLimitsContext(ctx)
limits := querylimits.ExtractQueryLimitsFromContext(ctx)
if limits != nil {
encodedLimits, err := querylimits.MarshalQueryLimits(limits)
if err != nil {
@ -463,6 +472,16 @@ func (Codec) QueryRequestWrap(ctx context.Context, r queryrangebase.Request) (*Q
result.Metadata[querylimits.HTTPHeaderQueryLimitsKey] = string(encodedLimits)
}
// Add limits context
limitsCtx := querylimits.ExtractQueryLimitsContextFromContext(ctx)
if limitsCtx != nil {
encodedLimitsCtx, err := querylimits.MarshalQueryLimitsContext(limitsCtx)
if err != nil {
return nil, err
}
result.Metadata[querylimits.HTTPHeaderQueryLimitsContextKey] = string(encodedLimitsCtx)
}
// Add org ID
orgID, err := user.ExtractOrgID(ctx)
if err != nil {

@ -2,35 +2,43 @@ package querylimits
import (
"context"
"fmt"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
const (
lowerQueryLimitsHeaderName = "x-loki-query-limits"
lowerQueryLimitsHeaderName = "x-loki-query-limits"
lowerQueryLimitsContextHeaderName = "x-loki-query-limits-context"
)
func injectIntoGRPCRequest(ctx context.Context) (context.Context, error) {
limits := ExtractQueryLimitsContext(ctx)
fmt.Printf("extract limits grpc: %v", limits)
if limits == nil {
return ctx, nil
}
// inject into GRPC metadata
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
md = metadata.New(map[string]string{})
}
md = md.Copy()
headerValue, err := MarshalQueryLimits(limits)
if err != nil {
return nil, err
limits := ExtractQueryLimitsFromContext(ctx)
if limits != nil {
headerValue, err := MarshalQueryLimits(limits)
if err != nil {
return nil, err
}
md.Set(lowerQueryLimitsHeaderName, string(headerValue))
}
limitsCtx := ExtractQueryLimitsContextFromContext(ctx)
if limitsCtx != nil {
headerValue, err := MarshalQueryLimitsContext(limitsCtx)
if err != nil {
return nil, err
}
md.Set(lowerQueryLimitsContextHeaderName, string(headerValue))
}
md.Set(lowerQueryLimitsHeaderName, string(headerValue))
newCtx := metadata.NewOutgoingContext(ctx, md)
newCtx := metadata.NewOutgoingContext(ctx, md)
return newCtx, nil
}
@ -60,21 +68,27 @@ func extractFromGRPCRequest(ctx context.Context) (context.Context, error) {
}
headerValues, ok := md[lowerQueryLimitsHeaderName]
if !ok {
// No QueryLimits header in metadata, just return context
return ctx, nil
if ok && len(headerValues) > 0 {
// Pick first header
limits, err := UnmarshalQueryLimits([]byte(headerValues[0]))
if err != nil {
return ctx, err
}
ctx = InjectQueryLimitsIntoContext(ctx, *limits)
}
if len(headerValues) == 0 {
return ctx, nil
// Extract QueryLimitsContext if present
headerContextValues, ok := md[lowerQueryLimitsContextHeaderName]
if ok && len(headerContextValues) > 0 {
// Pick first header
limitsCtx, err := UnmarshalQueryLimitsContext([]byte(headerContextValues[0]))
if err != nil {
return ctx, err
}
ctx = InjectQueryLimitsContextIntoContext(ctx, *limitsCtx)
}
// Pick first header
limits, err := UnmarshalQueryLimits([]byte(headerValues[0]))
if err != nil {
return ctx, err
}
return InjectQueryLimitsContext(ctx, *limits), nil
return ctx, nil
}
func ServerQueryLimitsInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {

@ -16,16 +16,37 @@ func TestGRPCQueryLimits(t *testing.T) {
MaxQueryLookback: model.Duration(14 * 24 * time.Hour),
MaxEntriesLimitPerQuery: 100,
}
c1 := InjectQueryLimitsContext(context.Background(), limits)
c1 := InjectQueryLimitsIntoContext(context.Background(), limits)
c1, err = injectIntoGRPCRequest(c1)
require.NoError(t, err)
c2, err := extractFromGRPCRequest(c1)
require.NoError(t, err)
require.Equal(t, limits, *(c2.Value(queryLimitsContextKey).(*QueryLimits)))
require.Equal(t, limits, *(c2.Value(queryLimitsCtxKey).(*QueryLimits)))
c3, err := extractFromGRPCRequest(context.Background())
require.NoError(t, err)
require.Nil(t, c3.Value(queryLimitsContextKey))
require.Nil(t, c3.Value(queryLimitsCtxKey))
}
func TestGRPCQueryLimitsContext(t *testing.T) {
var err error
limitsCtx := Context{
Expr: "{app=\"test\"}",
From: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
To: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC),
}
c1 := InjectQueryLimitsContextIntoContext(context.Background(), limitsCtx)
c1, err = injectIntoGRPCRequest(c1)
require.NoError(t, err)
c2, err := extractFromGRPCRequest(c1)
require.NoError(t, err)
require.Equal(t, limitsCtx, *(c2.Value(queryLimitsContextCtxKey).(*Context)))
c3, err := extractFromGRPCRequest(context.Background())
require.NoError(t, err)
require.Nil(t, c3.Value(queryLimitsContextCtxKey))
}

@ -27,7 +27,7 @@ func NewLimiter(log log.Logger, original limiter.CombinedLimits) *Limiter {
// MaxQueryLength returns the limit of the length (in time) of a query.
func (l *Limiter) MaxQueryLength(ctx context.Context, userID string) time.Duration {
original := l.CombinedLimits.MaxQueryLength(ctx, userID)
requestLimits := ExtractQueryLimitsContext(ctx)
requestLimits := ExtractQueryLimitsFromContext(ctx)
if requestLimits == nil || requestLimits.MaxQueryLength == 0 ||
(time.Duration(requestLimits.MaxQueryLength) > original && original != 0) {
return original
@ -39,7 +39,7 @@ func (l *Limiter) MaxQueryLength(ctx context.Context, userID string) time.Durati
// MaxQueryLookback returns the max lookback period of queries.
func (l *Limiter) MaxQueryLookback(ctx context.Context, userID string) time.Duration {
original := l.CombinedLimits.MaxQueryLookback(ctx, userID)
requestLimits := ExtractQueryLimitsContext(ctx)
requestLimits := ExtractQueryLimitsFromContext(ctx)
if requestLimits == nil || requestLimits.MaxQueryLookback == 0 ||
(time.Duration(requestLimits.MaxQueryLookback) > original && original != 0) {
return original
@ -51,7 +51,7 @@ func (l *Limiter) MaxQueryLookback(ctx context.Context, userID string) time.Dura
// MaxQueryRange retruns the max query range/interval of a query.
func (l *Limiter) MaxQueryRange(ctx context.Context, userID string) time.Duration {
original := l.CombinedLimits.MaxQueryRange(ctx, userID)
requestLimits := ExtractQueryLimitsContext(ctx)
requestLimits := ExtractQueryLimitsFromContext(ctx)
if requestLimits == nil || requestLimits.MaxQueryRange == 0 ||
(time.Duration(requestLimits.MaxQueryRange) > original && original != 0) {
return original
@ -63,7 +63,7 @@ func (l *Limiter) MaxQueryRange(ctx context.Context, userID string) time.Duratio
// MaxEntriesLimitPerQuery returns the limit to number of entries the querier should return per query.
func (l *Limiter) MaxEntriesLimitPerQuery(ctx context.Context, userID string) int {
original := l.CombinedLimits.MaxEntriesLimitPerQuery(ctx, userID)
requestLimits := ExtractQueryLimitsContext(ctx)
requestLimits := ExtractQueryLimitsFromContext(ctx)
if requestLimits == nil || requestLimits.MaxEntriesLimitPerQuery == 0 ||
(requestLimits.MaxEntriesLimitPerQuery > original && original != 0) {
return original
@ -75,7 +75,7 @@ func (l *Limiter) MaxEntriesLimitPerQuery(ctx context.Context, userID string) in
func (l *Limiter) QueryTimeout(ctx context.Context, userID string) time.Duration {
original := l.CombinedLimits.QueryTimeout(ctx, userID)
// in theory this error should never happen
requestLimits := ExtractQueryLimitsContext(ctx)
requestLimits := ExtractQueryLimitsFromContext(ctx)
if requestLimits == nil || requestLimits.QueryTimeout == 0 ||
(time.Duration(requestLimits.QueryTimeout) > original && original != 0) {
return original
@ -86,7 +86,7 @@ func (l *Limiter) QueryTimeout(ctx context.Context, userID string) time.Duration
func (l *Limiter) RequiredLabels(ctx context.Context, userID string) []string {
original := l.CombinedLimits.RequiredLabels(ctx, userID)
requestLimits := ExtractQueryLimitsContext(ctx)
requestLimits := ExtractQueryLimitsFromContext(ctx)
if requestLimits == nil {
return original
@ -112,7 +112,7 @@ func (l *Limiter) RequiredLabels(ctx context.Context, userID string) []string {
func (l *Limiter) RequiredNumberLabels(ctx context.Context, userID string) int {
original := l.CombinedLimits.RequiredNumberLabels(ctx, userID)
requestLimits := ExtractQueryLimitsContext(ctx)
requestLimits := ExtractQueryLimitsFromContext(ctx)
if requestLimits == nil || requestLimits.RequiredNumberLabels == 0 || requestLimits.RequiredNumberLabels < original {
return original
}
@ -122,7 +122,7 @@ func (l *Limiter) RequiredNumberLabels(ctx context.Context, userID string) int {
func (l *Limiter) MaxQueryBytesRead(ctx context.Context, userID string) int {
original := l.CombinedLimits.MaxQueryBytesRead(ctx, userID)
requestLimits := ExtractQueryLimitsContext(ctx)
requestLimits := ExtractQueryLimitsFromContext(ctx)
if requestLimits == nil || requestLimits.MaxQueryBytesRead.Val() == 0 ||
(requestLimits.MaxQueryBytesRead.Val() > original && original != 0) {
return original

@ -89,7 +89,7 @@ func TestLimiter_Defaults(t *testing.T) {
MaxQueryBytesRead: 10,
}
{
ctx2 := InjectQueryLimitsContext(context.Background(), *limits)
ctx2 := InjectQueryLimitsIntoContext(context.Background(), *limits)
queryLookback := l.MaxQueryLookback(ctx2, "fake")
require.Equal(t, time.Duration(expectedLimits2.MaxQueryLookback), queryLookback)
queryLength := l.MaxQueryLength(ctx2, "fake")
@ -143,7 +143,7 @@ func TestLimiter_RejectHighLimits(t *testing.T) {
RequiredNumberLabels: 100,
}
ctx := InjectQueryLimitsContext(context.Background(), limits)
ctx := InjectQueryLimitsIntoContext(context.Background(), limits)
require.Equal(t, time.Duration(expectedLimits.MaxQueryLookback), l.MaxQueryLookback(ctx, "fake"))
require.Equal(t, time.Duration(expectedLimits.MaxQueryLength), l.MaxQueryLength(ctx, "fake"))
require.Equal(t, expectedLimits.MaxEntriesLimitPerQuery, l.MaxEntriesLimitPerQuery(ctx, "fake"))
@ -180,7 +180,7 @@ func TestLimiter_AcceptLowerLimits(t *testing.T) {
RequiredNumberLabels: 10,
}
ctx := InjectQueryLimitsContext(context.Background(), limits)
ctx := InjectQueryLimitsIntoContext(context.Background(), limits)
require.Equal(t, time.Duration(limits.MaxQueryLookback), l.MaxQueryLookback(ctx, "fake"))
require.Equal(t, time.Duration(limits.MaxQueryLength), l.MaxQueryLength(ctx, "fake"))
require.Equal(t, limits.MaxEntriesLimitPerQuery, l.MaxEntriesLimitPerQuery(ctx, "fake"))
@ -218,7 +218,7 @@ func TestLimiter_AcceptRequestLimitsOverNotInitializedLimits(t *testing.T) {
RequiredNumberLabels: 10,
}
ctx := InjectQueryLimitsContext(context.Background(), limits)
ctx := InjectQueryLimitsIntoContext(context.Background(), limits)
require.Equal(t, time.Duration(limits.MaxQueryLookback), l.MaxQueryLookback(ctx, "fake"))
require.Equal(t, time.Duration(limits.MaxQueryLength), l.MaxQueryLength(ctx, "fake"))
require.Equal(t, limits.MaxEntriesLimitPerQuery, l.MaxEntriesLimitPerQuery(ctx, "fake"))
@ -243,7 +243,7 @@ func TestLimiter_MergeLimits(t *testing.T) {
require.ElementsMatch(t, []string{"one", "two"}, l.RequiredLabels(context.Background(), "fake"))
ctx := InjectQueryLimitsContext(context.Background(), limits)
ctx := InjectQueryLimitsIntoContext(context.Background(), limits)
require.ElementsMatch(t, []string{"one", "two", "three"}, l.RequiredLabels(ctx, "fake"))
}

@ -32,7 +32,17 @@ func (l *queryLimitsMiddleware) Wrap(next http.Handler) http.Handler {
}
if limits != nil {
r = r.Clone(InjectQueryLimitsContext(r.Context(), *limits))
r = r.Clone(InjectQueryLimitsIntoContext(r.Context(), *limits))
}
limitsCtx, err := ExtractQueryLimitsContextHTTP(r)
if err != nil {
level.Warn(util_log.Logger).Log("msg", "could not extract query limits context from header", "err", err)
limitsCtx = nil
}
if limitsCtx != nil {
r = r.Clone(InjectQueryLimitsContextIntoContext(r.Context(), *limitsCtx))
}
next.ServeHTTP(w, r)

@ -13,7 +13,7 @@ import (
func Test_MiddlewareWithoutHeader(t *testing.T) {
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
limits := ExtractQueryLimitsContext(r.Context())
limits := ExtractQueryLimitsFromContext(r.Context())
require.Nil(t, limits)
})
m := NewQueryLimitsMiddleware(log.NewNopLogger())
@ -29,7 +29,7 @@ func Test_MiddlewareWithoutHeader(t *testing.T) {
func Test_MiddlewareWithBrokenHeader(t *testing.T) {
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
limits := ExtractQueryLimitsContext(r.Context())
limits := ExtractQueryLimitsFromContext(r.Context())
require.Nil(t, limits)
})
m := NewQueryLimitsMiddleware(log.NewNopLogger())
@ -57,7 +57,7 @@ func Test_MiddlewareWithHeader(t *testing.T) {
}
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
actual := ExtractQueryLimitsContext(r.Context())
actual := ExtractQueryLimitsFromContext(r.Context())
require.Equal(t, limits, *actual)
})
m := NewQueryLimitsMiddleware(log.NewNopLogger())
@ -72,3 +72,60 @@ func Test_MiddlewareWithHeader(t *testing.T) {
response := rr.Result()
require.Equal(t, http.StatusOK, response.StatusCode)
}
func Test_MiddlewareWithoutContextHeader(t *testing.T) {
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
limitsCtx := ExtractQueryLimitsContextFromContext(r.Context())
require.Nil(t, limitsCtx)
})
m := NewQueryLimitsMiddleware(log.NewNopLogger())
wrapped := m.Wrap(nextHandler)
rr := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/example", nil)
require.NoError(t, err)
wrapped.ServeHTTP(rr, r)
response := rr.Result()
require.Equal(t, http.StatusOK, response.StatusCode)
}
func Test_MiddlewareWithBrokenContextHeader(t *testing.T) {
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
limitsCtx := ExtractQueryLimitsContextFromContext(r.Context())
require.Nil(t, limitsCtx)
})
m := NewQueryLimitsMiddleware(log.NewNopLogger())
wrapped := m.Wrap(nextHandler)
rr := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/example", nil)
require.NoError(t, err)
r.Header.Add(HTTPHeaderQueryLimitsContextKey, "{broken}")
wrapped.ServeHTTP(rr, r)
response := rr.Result()
require.Equal(t, http.StatusOK, response.StatusCode)
}
func Test_MiddlewareWithContextHeader(t *testing.T) {
limitsCtx := Context{
Expr: "{app=\"test\"}",
From: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
To: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC),
}
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
actual := ExtractQueryLimitsContextFromContext(r.Context())
require.Equal(t, limitsCtx, *actual)
})
m := NewQueryLimitsMiddleware(log.NewNopLogger())
wrapped := m.Wrap(nextHandler)
rr := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/example", nil)
require.NoError(t, err)
err = InjectQueryLimitsContextHTTP(r, &limitsCtx)
require.NoError(t, err)
wrapped.ServeHTTP(rr, r)
response := rr.Result()
require.Equal(t, http.StatusOK, response.StatusCode)
}

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"net/http"
"time"
"github.com/prometheus/common/model"
@ -14,9 +15,11 @@ import (
type key int
const (
queryLimitsContextKey key = 1
queryLimitsCtxKey key = 1
queryLimitsContextCtxKey key = 2
HTTPHeaderQueryLimitsKey = "X-Loki-Query-Limits"
HTTPHeaderQueryLimitsKey = "X-Loki-Query-Limits"
HTTPHeaderQueryLimitsContextKey = "X-Loki-Query-Limits-Context"
)
// NOTE: we use custom `model.Duration` instead of standard `time.Duration` because,
@ -73,9 +76,9 @@ func ExtractQueryLimitsHTTP(r *http.Request) (*QueryLimits, error) {
return nil, nil
}
// ExtractQueryLimitsContext gets the embedded limits from the context
func ExtractQueryLimitsContext(ctx context.Context) *QueryLimits {
source, ok := ctx.Value(queryLimitsContextKey).(*QueryLimits)
// ExtractQueryLimitsFromContext gets the embedded limits from the context
func ExtractQueryLimitsFromContext(ctx context.Context) *QueryLimits {
source, ok := ctx.Value(queryLimitsCtxKey).(*QueryLimits)
if !ok {
return nil
@ -84,7 +87,70 @@ func ExtractQueryLimitsContext(ctx context.Context) *QueryLimits {
return source
}
// InjectQueryLimitsContext returns a derived context containing the provided query limits
func InjectQueryLimitsContext(ctx context.Context, limits QueryLimits) context.Context {
return context.WithValue(ctx, interface{}(queryLimitsContextKey), &limits)
// InjectQueryLimitsIntoContext returns a derived context containing the provided query limits
func InjectQueryLimitsIntoContext(ctx context.Context, limits QueryLimits) context.Context {
return context.WithValue(ctx, interface{}(queryLimitsCtxKey), &limits)
}
type Context struct {
Expr string `json:"expr"`
From time.Time `json:"from"`
To time.Time `json:"to"`
}
func UnmarshalQueryLimitsContext(data []byte) (*Context, error) {
limitsCtx := &Context{}
err := json.Unmarshal(data, limitsCtx)
return limitsCtx, err
}
func MarshalQueryLimitsContext(limits *Context) ([]byte, error) {
return json.Marshal(limits)
}
// InjectQueryLimitsContextHTTP adds the query limits context to the request headers.
func InjectQueryLimitsContextHTTP(r *http.Request, limitsCtx *Context) error {
return InjectQueryLimitsContextHeader(&r.Header, limitsCtx)
}
// InjectQueryLimitsContextHeader adds the query limits context to the headers.
func InjectQueryLimitsContextHeader(h *http.Header, limitsCtx *Context) error {
// Ensure any existing policy sets are erased
h.Del(HTTPHeaderQueryLimitsContextKey)
encodedLimits, err := MarshalQueryLimitsContext(limitsCtx)
if err != nil {
return err
}
h.Add(HTTPHeaderQueryLimitsContextKey, string(encodedLimits))
return nil
}
// ExtractQueryLimitsContextHTTP retrieves the query limits context from the HTTP header and returns it.
func ExtractQueryLimitsContextHTTP(r *http.Request) (*Context, error) {
headerValues := r.Header.Values(HTTPHeaderQueryLimitsContextKey)
// Iterate through each set header value
for _, headerValue := range headerValues {
return UnmarshalQueryLimitsContext([]byte(headerValue))
}
return nil, nil
}
// ExtractQueryLimitsContextFromContext gets the embedded query limits context from the context
func ExtractQueryLimitsContextFromContext(ctx context.Context) *Context {
source, ok := ctx.Value(queryLimitsContextCtxKey).(*Context)
if !ok {
return nil
}
return source
}
// InjectQueryLimitsContextIntoContext returns a derived context containing the provided query limits context
func InjectQueryLimitsContextIntoContext(ctx context.Context, limitsCtx Context) context.Context {
return context.WithValue(ctx, any(queryLimitsContextCtxKey), &limitsCtx)
}

@ -18,8 +18,8 @@ func TestInjectAndExtractQueryLimits(t *testing.T) {
QueryTimeout: model.Duration(5 * time.Second),
}
ctx = InjectQueryLimitsContext(ctx, limits)
res := ExtractQueryLimitsContext(ctx)
ctx = InjectQueryLimitsIntoContext(ctx, limits)
res := ExtractQueryLimitsFromContext(ctx)
require.Equal(t, limits, *res)
}
@ -70,3 +70,65 @@ func TestSerializingQueryLimits(t *testing.T) {
expected = `{"maxEntriesLimitPerQuery": 100, "maxQueryLength": "2d", "maxQueryLookback": "2w"}`
require.JSONEq(t, expected, string(actual))
}
func TestInjectAndExtractQueryLimitsContext(t *testing.T) {
ctx := context.Background()
baseTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
limitsCtx := Context{
Expr: `{job="app"}`,
From: baseTime,
To: baseTime.Add(1 * time.Hour),
}
ctx = InjectQueryLimitsContextIntoContext(ctx, limitsCtx)
res := ExtractQueryLimitsContextFromContext(ctx)
require.Equal(t, limitsCtx, *res)
}
func TestDeserializingQueryLimitsContext(t *testing.T) {
baseTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
endTime := baseTime.Add(1 * time.Hour)
// full context
payload := `{"expr": "{job=\"app\"}", "from": "2024-01-15T10:30:00Z", "to": "2024-01-15T11:30:00Z"}`
limitsCtx, err := UnmarshalQueryLimitsContext([]byte(payload))
require.NoError(t, err)
require.Equal(t, `{job="app"}`, limitsCtx.Expr)
require.Equal(t, baseTime.Unix(), limitsCtx.From.Unix())
require.Equal(t, endTime.Unix(), limitsCtx.To.Unix())
// some fields are empty
payload = `{"expr": "rate({job=\"app\"}[5m])"}`
limitsCtx, err = UnmarshalQueryLimitsContext([]byte(payload))
require.NoError(t, err)
require.Equal(t, `rate({job="app"}[5m])`, limitsCtx.Expr)
require.True(t, limitsCtx.From.IsZero())
require.True(t, limitsCtx.To.IsZero())
}
func TestSerializingQueryLimitsContext(t *testing.T) {
baseTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
endTime := baseTime.Add(1 * time.Hour)
// full struct
limitsCtx := Context{
Expr: `{job="app"}`,
From: baseTime,
To: endTime,
}
actual, err := MarshalQueryLimitsContext(&limitsCtx)
require.NoError(t, err)
expected := `{"expr": "{job=\"app\"}", "from": "2024-01-15T10:30:00Z", "to": "2024-01-15T11:30:00Z"}`
require.JSONEq(t, expected, string(actual))
// some fields are empty
limitsCtx = Context{
Expr: `rate({job="app"}[5m])`,
}
actual, err = MarshalQueryLimitsContext(&limitsCtx)
require.NoError(t, err)
expected = `{"expr": "rate({job=\"app\"}[5m])", "from": "0001-01-01T00:00:00Z", "to": "0001-01-01T00:00:00Z"}`
require.JSONEq(t, expected, string(actual))
}

Loading…
Cancel
Save