From d24fe3e68be4be16b3654e6aedafaa71ac9d307c Mon Sep 17 00:00:00 2001 From: Salva Corts Date: Thu, 23 Mar 2023 21:18:54 +0100 Subject: [PATCH] Max bytes read limit (#8670) **What this PR does / why we need it**: This PR implements two new per-tenant limits that are enforced on log and metric queries (both range and instant) when TSDB is used: - `max_query_bytes_read`: Refuse queries that would read more than the configured bytes here. Overall limit regardless of splitting/sharding. The goal is to refuse queries that would take too long. The default value of 0 disables this limit. - `max_querier_bytes_read`: Refuse queries in which any of their subqueries after splitting and sharding would read more than the configured bytes here. The goal is to avoid a querier from running a query that would load too much data in memory and can potentially get OOMed. The default value of 0 disables this limit. These new limits can be configured per tenant and per query (see https://github.com/grafana/loki/pull/8727). The bytes a query would read are estimated through TSDB's index stats. Even though they are not exact, they are good enough to have a rough estimation of whether a query is too big to run or not. For more details on this refer to this discussion in the PR: https://github.com/grafana/loki/pull/8670#discussion_r1124858508. Both limits are implemented in the frontend. Even though we considered implementing `max_querier_bytes_read` in the querier, this way, the limits for pre and post splitting/sharding queries are enforced close to each other on the same component. Moreover, this way we can reduce the number of index stats requests issued to the index gateways by reusing the stats gathered while sharding the query. With regard to how index stats requests are issued: - We parallelize index stats requests by splitting them into queries that span up to 24h since our indices are sharded by 24h periods. On top of that, this prevents a single index gateway from processing a single huge request like `{app=~".+"} for 30d`. - If sharding is enabled and the query is shardable, for `max_querier_bytes_read`, we re-use the stats requests issued by the sharding ware. Specifically, we look at the [bytesPerShard][1] to enforce this limit. Note that once we merge this PR and enable these limits, the load of index stats requests will increase substantially and we may discover bottlenecks in our index gateways and TSDB. After speaking with @owen-d, we think it should be fine as, if needed, we can scale up our index gateways and support caching index stats requests. Here's a demo of this working: image image **Which issue(s) this PR fixes**: This PR addresses https://github.com/grafana/loki-private/issues/674. **Special notes for your reviewer**: - @jeschkies has reviewed the changes related to query-time limits. - I've done some refactoring in this PR: - Extracted logic to get stats for a set of matches into a new function [getStatsForMatchers][2]. - Extracted the _Handler_ interface implementation for [queryrangebase.roundTripper][3] into a new type [queryrangebase.roundTripperHandler][4]. This is used to create the handler that skips the rest of configured middlewares when sending an index stat quests ([example][5]). **Checklist** - [x] Reviewed the [`CONTRIBUTING.md`](https://github.com/grafana/loki/blob/main/CONTRIBUTING.md) guide (**required**) - [x] Documentation added - [x] Tests updated - [x] `CHANGELOG.md` updated - [ ] Changes that require user attention or interaction to upgrade are documented in `docs/sources/upgrading/_index.md` [1]: https://github.com/grafana/loki/blob/ff847305afaf7de5eb56436f3683773e88701075/pkg/querier/queryrange/shard_resolver.go#L179-L186 [2]: https://github.com/grafana/loki/blob/ff847305afaf7de5eb56436f3683773e88701075/pkg/querier/queryrange/shard_resolver.go#L72 [3]: https://github.com/grafana/loki/blob/3d2fff3a2d416a48a73346a53ba7499b0eeb67f7/pkg/querier/queryrange/queryrangebase/roundtrip.go#L124 [4]: https://github.com/grafana/loki/blob/3d2fff3a2d416a48a73346a53ba7499b0eeb67f7/pkg/querier/queryrange/queryrangebase/roundtrip.go#L163 [5]: https://github.com/grafana/loki/blob/f422e0a52b743a11209b8276510feb2ab8241486/pkg/querier/queryrange/roundtrip.go#L521 --- CHANGELOG.md | 1 + docs/sources/configuration/_index.md | 11 + pkg/logql/downstream_test.go | 4 +- pkg/logql/shardmapper.go | 133 +++++--- pkg/logql/shardmapper_test.go | 8 +- pkg/querier/queryrange/codec.go | 15 + pkg/querier/queryrange/limits.go | 194 +++++++++++ pkg/querier/queryrange/limits_test.go | 213 ++++++++++++ .../queryrange/queryrangebase/roundtrip.go | 25 +- pkg/querier/queryrange/querysharding.go | 36 +- pkg/querier/queryrange/querysharding_test.go | 137 +++++++- pkg/querier/queryrange/roundtrip.go | 314 +++++++++--------- pkg/querier/queryrange/roundtrip_test.go | 147 +++++++- pkg/querier/queryrange/shard_resolver.go | 129 ++++--- pkg/querier/queryrange/shard_resolver_test.go | 3 +- pkg/querier/queryrange/split_by_interval.go | 23 +- .../queryrange/split_by_interval_test.go | 15 + pkg/util/querylimits/limiter.go | 10 + pkg/util/querylimits/limiter_test.go | 17 + pkg/util/querylimits/middleware_test.go | 1 + pkg/util/querylimits/propagation.go | 2 + pkg/util/querylimits/propagation_test.go | 7 +- pkg/validation/limits.go | 19 +- 23 files changed, 1183 insertions(+), 281 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 86e7401c8c..381e5cad3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ * [6675](https://github.com/grafana/loki/pull/6675) **btaani**: Add logfmt expression parser for selective extraction of labels from logfmt formatted logs * [8474](https://github.com/grafana/loki/pull/8474) **farodin91**: Add support for short-lived S3 session tokens * [8774](https://github.com/grafana/loki/pull/8774) **slim-bean**: Add new logql template functions `bytes`, `duration`, `unixEpochMillis`, `unixEpochNanos`, `toDateInZone`, `b64Enc`, and `b64Dec` +* [8670](https://github.com/grafana/loki/pull/8670) **salvacorts** Introduce two new limits to refuse log and metric queries that would read too much data. ##### Fixes diff --git a/docs/sources/configuration/_index.md b/docs/sources/configuration/_index.md index f4fd7a7458..ededb02129 100644 --- a/docs/sources/configuration/_index.md +++ b/docs/sources/configuration/_index.md @@ -2331,6 +2331,17 @@ The `limits_config` block configures global and per-tenant limits in Loki. # CLI flag: -frontend.min-sharding-lookback [min_sharding_lookback: | default = 0s] +# Max number of bytes a query can fetch. Enforced in log and metric queries only +# when TSDB is used. The default value of 0 disables this limit. +# CLI flag: -frontend.max-query-bytes-read +[max_query_bytes_read: | default = 0B] + +# Max number of bytes a query can fetch after splitting and sharding. Enforced +# in log and metric queries only when TSDB is used. The default value of 0 +# disables this limit. +# CLI flag: -frontend.max-querier-bytes-read +[max_querier_bytes_read: | default = 0B] + # Duration to delay the evaluation of rules to ensure the underlying metrics # have been pushed to Cortex. # CLI flag: -ruler.evaluation-delay-duration diff --git a/pkg/logql/downstream_test.go b/pkg/logql/downstream_test.go index 0c67605f8d..67f54a4652 100644 --- a/pkg/logql/downstream_test.go +++ b/pkg/logql/downstream_test.go @@ -81,7 +81,7 @@ func TestMappingEquivalence(t *testing.T) { ctx := user.InjectOrgID(context.Background(), "fake") mapper := NewShardMapper(ConstantShards(shards), nilShardMetrics) - _, mapped, err := mapper.Parse(tc.query) + _, _, mapped, err := mapper.Parse(tc.query) require.Nil(t, err) shardedQry := sharded.Query(ctx, params, mapped) @@ -146,7 +146,7 @@ func TestShardCounter(t *testing.T) { ctx := user.InjectOrgID(context.Background(), "fake") mapper := NewShardMapper(ConstantShards(shards), nilShardMetrics) - noop, mapped, err := mapper.Parse(tc.query) + noop, _, mapped, err := mapper.Parse(tc.query) require.Nil(t, err) shardedQry := sharded.Query(ctx, params, mapped) diff --git a/pkg/logql/shardmapper.go b/pkg/logql/shardmapper.go index 9652f7a792..8422c1f7c0 100644 --- a/pkg/logql/shardmapper.go +++ b/pkg/logql/shardmapper.go @@ -4,21 +4,25 @@ import ( "fmt" "github.com/go-kit/log/level" + "github.com/grafana/loki/pkg/util/math" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/grafana/loki/pkg/logql/syntax" "github.com/grafana/loki/pkg/querier/astmapper" + "github.com/grafana/loki/pkg/storage/stores/index/stats" util_log "github.com/grafana/loki/pkg/util/log" ) type ShardResolver interface { - Shards(expr syntax.Expr) (int, error) + Shards(expr syntax.Expr) (int, uint64, error) + GetStats(e syntax.Expr) (stats.Stats, error) } type ConstantShards int -func (s ConstantShards) Shards(_ syntax.Expr) (int, error) { return int(s), nil } +func (s ConstantShards) Shards(_ syntax.Expr) (int, uint64, error) { return int(s), 0, nil } +func (s ConstantShards) GetStats(_ syntax.Expr) (stats.Stats, error) { return stats.Stats{}, nil } type ShardMapper struct { shards ShardResolver @@ -36,18 +40,18 @@ func NewShardMapperMetrics(registerer prometheus.Registerer) *MapperMetrics { return newMapperMetrics(registerer, "shard") } -func (m ShardMapper) Parse(query string) (noop bool, expr syntax.Expr, err error) { +func (m ShardMapper) Parse(query string) (noop bool, bytesPerShard uint64, expr syntax.Expr, err error) { parsed, err := syntax.ParseExpr(query) if err != nil { - return false, nil, err + return false, 0, nil, err } recorder := m.metrics.downstreamRecorder() - mapped, err := m.Map(parsed, recorder) + mapped, bytesPerShard, err := m.Map(parsed, recorder) if err != nil { m.metrics.ParsedQueries.WithLabelValues(FailureKey).Inc() - return false, nil, err + return false, 0, nil, err } originalStr := parsed.String() @@ -61,21 +65,21 @@ func (m ShardMapper) Parse(query string) (noop bool, expr syntax.Expr, err error recorder.Finish() // only record metrics for successful mappings - return noop, mapped, err + return noop, bytesPerShard, mapped, err } -func (m ShardMapper) Map(expr syntax.Expr, r *downstreamRecorder) (syntax.Expr, error) { +func (m ShardMapper) Map(expr syntax.Expr, r *downstreamRecorder) (syntax.Expr, uint64, error) { // immediately clone the passed expr to avoid mutating the original expr, err := syntax.Clone(expr) if err != nil { - return nil, err + return nil, 0, err } switch e := expr.(type) { case *syntax.LiteralExpr: - return e, nil + return e, 0, nil case *syntax.VectorExpr: - return e, nil + return e, 0, nil case *syntax.MatchersExpr, *syntax.PipelineExpr: return m.mapLogSelectorExpr(e.(syntax.LogSelectorExpr), r) case *syntax.VectorAggregationExpr: @@ -85,35 +89,39 @@ func (m ShardMapper) Map(expr syntax.Expr, r *downstreamRecorder) (syntax.Expr, case *syntax.RangeAggregationExpr: return m.mapRangeAggregationExpr(e, r) case *syntax.BinOpExpr: - lhsMapped, err := m.Map(e.SampleExpr, r) + lhsMapped, lhsBytesPerShard, err := m.Map(e.SampleExpr, r) if err != nil { - return nil, err + return nil, 0, err } - rhsMapped, err := m.Map(e.RHS, r) + rhsMapped, rhsBytesPerShard, err := m.Map(e.RHS, r) if err != nil { - return nil, err + return nil, 0, err } lhsSampleExpr, ok := lhsMapped.(syntax.SampleExpr) if !ok { - return nil, badASTMapping(lhsMapped) + return nil, 0, badASTMapping(lhsMapped) } rhsSampleExpr, ok := rhsMapped.(syntax.SampleExpr) if !ok { - return nil, badASTMapping(rhsMapped) + return nil, 0, badASTMapping(rhsMapped) } e.SampleExpr = lhsSampleExpr e.RHS = rhsSampleExpr - return e, nil + + // We take the maximum bytes per shard of both sides of the operation + bytesPerShard := uint64(math.Max(int(lhsBytesPerShard), int(rhsBytesPerShard))) + + return e, bytesPerShard, nil default: - return nil, errors.Errorf("unexpected expr type (%T) for ASTMapper type (%T) ", expr, m) + return nil, 0, errors.Errorf("unexpected expr type (%T) for ASTMapper type (%T) ", expr, m) } } -func (m ShardMapper) mapLogSelectorExpr(expr syntax.LogSelectorExpr, r *downstreamRecorder) (syntax.LogSelectorExpr, error) { +func (m ShardMapper) mapLogSelectorExpr(expr syntax.LogSelectorExpr, r *downstreamRecorder) (syntax.LogSelectorExpr, uint64, error) { var head *ConcatLogSelectorExpr - shards, err := m.shards.Shards(expr) + shards, bytesPerShard, err := m.shards.Shards(expr) if err != nil { - return nil, err + return nil, 0, err } if shards == 0 { return &ConcatLogSelectorExpr{ @@ -121,7 +129,7 @@ func (m ShardMapper) mapLogSelectorExpr(expr syntax.LogSelectorExpr, r *downstre shard: nil, LogSelectorExpr: expr, }, - }, nil + }, bytesPerShard, nil } for i := shards - 1; i >= 0; i-- { head = &ConcatLogSelectorExpr{ @@ -137,14 +145,14 @@ func (m ShardMapper) mapLogSelectorExpr(expr syntax.LogSelectorExpr, r *downstre } r.Add(shards, StreamsKey) - return head, nil + return head, bytesPerShard, nil } -func (m ShardMapper) mapSampleExpr(expr syntax.SampleExpr, r *downstreamRecorder) (syntax.SampleExpr, error) { +func (m ShardMapper) mapSampleExpr(expr syntax.SampleExpr, r *downstreamRecorder) (syntax.SampleExpr, uint64, error) { var head *ConcatSampleExpr - shards, err := m.shards.Shards(expr) + shards, bytesPerShard, err := m.shards.Shards(expr) if err != nil { - return nil, err + return nil, 0, err } if shards == 0 { return &ConcatSampleExpr{ @@ -152,7 +160,7 @@ func (m ShardMapper) mapSampleExpr(expr syntax.SampleExpr, r *downstreamRecorder shard: nil, SampleExpr: expr, }, - }, nil + }, bytesPerShard, nil } for i := shards - 1; i >= 0; i-- { head = &ConcatSampleExpr{ @@ -168,22 +176,22 @@ func (m ShardMapper) mapSampleExpr(expr syntax.SampleExpr, r *downstreamRecorder } r.Add(shards, MetricsKey) - return head, nil + return head, bytesPerShard, nil } // technically, std{dev,var} are also parallelizable if there is no cross-shard merging // in descendent nodes in the AST. This optimization is currently avoided for simplicity. -func (m ShardMapper) mapVectorAggregationExpr(expr *syntax.VectorAggregationExpr, r *downstreamRecorder) (syntax.SampleExpr, error) { +func (m ShardMapper) mapVectorAggregationExpr(expr *syntax.VectorAggregationExpr, r *downstreamRecorder) (syntax.SampleExpr, uint64, error) { // if this AST contains unshardable operations, don't shard this at this level, // but attempt to shard a child node. if !expr.Shardable() { - subMapped, err := m.Map(expr.Left, r) + subMapped, bytesPerShard, err := m.Map(expr.Left, r) if err != nil { - return nil, err + return nil, 0, err } sampleExpr, ok := subMapped.(syntax.SampleExpr) if !ok { - return nil, badASTMapping(subMapped) + return nil, 0, badASTMapping(subMapped) } return &syntax.VectorAggregationExpr{ @@ -191,60 +199,63 @@ func (m ShardMapper) mapVectorAggregationExpr(expr *syntax.VectorAggregationExpr Grouping: expr.Grouping, Params: expr.Params, Operation: expr.Operation, - }, nil + }, bytesPerShard, nil } switch expr.Operation { case syntax.OpTypeSum: // sum(x) -> sum(sum(x, shard=1) ++ sum(x, shard=2)...) - sharded, err := m.mapSampleExpr(expr, r) + sharded, bytesPerShard, err := m.mapSampleExpr(expr, r) if err != nil { - return nil, err + return nil, 0, err } return &syntax.VectorAggregationExpr{ Left: sharded, Grouping: expr.Grouping, Params: expr.Params, Operation: expr.Operation, - }, nil + }, bytesPerShard, nil case syntax.OpTypeAvg: // avg(x) -> sum(x)/count(x) - lhs, err := m.mapVectorAggregationExpr(&syntax.VectorAggregationExpr{ + lhs, lhsBytesPerShard, err := m.mapVectorAggregationExpr(&syntax.VectorAggregationExpr{ Left: expr.Left, Grouping: expr.Grouping, Operation: syntax.OpTypeSum, }, r) if err != nil { - return nil, err + return nil, 0, err } - rhs, err := m.mapVectorAggregationExpr(&syntax.VectorAggregationExpr{ + rhs, rhsBytesPerShard, err := m.mapVectorAggregationExpr(&syntax.VectorAggregationExpr{ Left: expr.Left, Grouping: expr.Grouping, Operation: syntax.OpTypeCount, }, r) if err != nil { - return nil, err + return nil, 0, err } + // We take the maximum bytes per shard of both sides of the operation + bytesPerShard := uint64(math.Max(int(lhsBytesPerShard), int(rhsBytesPerShard))) + return &syntax.BinOpExpr{ SampleExpr: lhs, RHS: rhs, Op: syntax.OpTypeDiv, - }, nil + }, bytesPerShard, nil case syntax.OpTypeCount: // count(x) -> sum(count(x, shard=1) ++ count(x, shard=2)...) - sharded, err := m.mapSampleExpr(expr, r) + sharded, bytesPerShard, err := m.mapSampleExpr(expr, r) if err != nil { - return nil, err + return nil, 0, err } return &syntax.VectorAggregationExpr{ Left: sharded, Grouping: expr.Grouping, Operation: syntax.OpTypeSum, - }, nil + }, bytesPerShard, nil default: // this should not be reachable. If an operation is shardable it should // have an optimization listed. @@ -252,28 +263,38 @@ func (m ShardMapper) mapVectorAggregationExpr(expr *syntax.VectorAggregationExpr "msg", "unexpected operation which appears shardable, ignoring", "operation", expr.Operation, ) - return expr, nil + exprStats, err := m.shards.GetStats(expr) + if err != nil { + return nil, 0, err + } + return expr, exprStats.Bytes, nil } } -func (m ShardMapper) mapLabelReplaceExpr(expr *syntax.LabelReplaceExpr, r *downstreamRecorder) (syntax.SampleExpr, error) { - subMapped, err := m.Map(expr.Left, r) +func (m ShardMapper) mapLabelReplaceExpr(expr *syntax.LabelReplaceExpr, r *downstreamRecorder) (syntax.SampleExpr, uint64, error) { + subMapped, bytesPerShard, err := m.Map(expr.Left, r) if err != nil { - return nil, err + return nil, 0, err } cpy := *expr cpy.Left = subMapped.(syntax.SampleExpr) - return &cpy, nil + return &cpy, bytesPerShard, nil } -func (m ShardMapper) mapRangeAggregationExpr(expr *syntax.RangeAggregationExpr, r *downstreamRecorder) (syntax.SampleExpr, error) { +func (m ShardMapper) mapRangeAggregationExpr(expr *syntax.RangeAggregationExpr, r *downstreamRecorder) (syntax.SampleExpr, uint64, error) { if hasLabelModifier(expr) { // if an expr can modify labels this means multiple shards can return the same labelset. // When this happens the merge strategy needs to be different from a simple concatenation. // For instance for rates we need to sum data from different shards but same series. // Since we currently support only concatenation as merge strategy, we skip those queries. - return expr, nil + exprStats, err := m.shards.GetStats(expr) + if err != nil { + return nil, 0, err + } + + return expr, exprStats.Bytes, nil } + switch expr.Operation { case syntax.OpRangeTypeCount, syntax.OpRangeTypeRate, syntax.OpRangeTypeBytesRate, syntax.OpRangeTypeBytes: // count_over_time(x) -> count_over_time(x, shard=1) ++ count_over_time(x, shard=2)... @@ -281,7 +302,13 @@ func (m ShardMapper) mapRangeAggregationExpr(expr *syntax.RangeAggregationExpr, // same goes for bytes_rate and bytes_over_time return m.mapSampleExpr(expr, r) default: - return expr, nil + // This part of the query is not shardable, so the bytesPerShard is the bytes for all the log matchers in expr + exprStats, err := m.shards.GetStats(expr) + if err != nil { + return nil, 0, err + } + + return expr, exprStats.Bytes, nil } } diff --git a/pkg/logql/shardmapper_test.go b/pkg/logql/shardmapper_test.go index 35d24ddf9b..4b26bf3a33 100644 --- a/pkg/logql/shardmapper_test.go +++ b/pkg/logql/shardmapper_test.go @@ -105,7 +105,7 @@ func TestMapSampleExpr(t *testing.T) { }, } { t.Run(tc.in.String(), func(t *testing.T) { - mapped, err := m.mapSampleExpr(tc.in, nilShardMetrics.downstreamRecorder()) + mapped, _, err := m.mapSampleExpr(tc.in, nilShardMetrics.downstreamRecorder()) require.Nil(t, err) require.Equal(t, tc.out, mapped) }) @@ -299,7 +299,7 @@ func TestMappingStrings(t *testing.T) { ast, err := syntax.ParseExpr(tc.in) require.Nil(t, err) - mapped, err := m.Map(ast, nilShardMetrics.downstreamRecorder()) + mapped, _, err := m.Map(ast, nilShardMetrics.downstreamRecorder()) require.Nil(t, err) require.Equal(t, removeWhiteSpace(tc.out), removeWhiteSpace(mapped.String())) @@ -1205,7 +1205,7 @@ func TestMapping(t *testing.T) { ast, err := syntax.ParseExpr(tc.in) require.Equal(t, tc.err, err) - mapped, err := m.Map(ast, nilShardMetrics.downstreamRecorder()) + mapped, _, err := m.Map(ast, nilShardMetrics.downstreamRecorder()) require.Equal(t, tc.err, err) require.Equal(t, tc.expr.String(), mapped.String()) @@ -1274,7 +1274,7 @@ func TestStringTrimming(t *testing.T) { } { t.Run(tc.expr, func(t *testing.T) { m := NewShardMapper(ConstantShards(tc.shards), nilShardMetrics) - _, mappedExpr, err := m.Parse(tc.expr) + _, _, mappedExpr, err := m.Parse(tc.expr) require.Nil(t, err) require.Equal(t, removeWhiteSpace(tc.expected), removeWhiteSpace(mappedExpr.String())) }) diff --git a/pkg/querier/queryrange/codec.go b/pkg/querier/queryrange/codec.go index 9f7519d104..255ec11e23 100644 --- a/pkg/querier/queryrange/codec.go +++ b/pkg/querier/queryrange/codec.go @@ -26,6 +26,7 @@ import ( "github.com/grafana/loki/pkg/logqlmodel" "github.com/grafana/loki/pkg/logqlmodel/stats" "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" + indexStats "github.com/grafana/loki/pkg/storage/stores/index/stats" "github.com/grafana/loki/pkg/util" "github.com/grafana/loki/pkg/util/httpreq" "github.com/grafana/loki/pkg/util/marshal" @@ -685,6 +686,20 @@ func (Codec) MergeResponse(responses ...queryrangebase.Response) (queryrangebase Data: names, Statistics: mergedStats, }, nil + case *IndexStatsResponse: + headers := responses[0].(*IndexStatsResponse).Headers + stats := make([]*indexStats.Stats, len(responses)) + for i, res := range responses { + stats[i] = res.(*IndexStatsResponse).Response + } + + mergedIndexStats := indexStats.MergeStats(stats...) + + return &IndexStatsResponse{ + Response: &mergedIndexStats, + Headers: headers, + }, nil + default: return nil, errors.New("unknown response in merging responses") } diff --git a/pkg/querier/queryrange/limits.go b/pkg/querier/queryrange/limits.go index f537f4027b..e6f1eb38b0 100644 --- a/pkg/querier/queryrange/limits.go +++ b/pkg/querier/queryrange/limits.go @@ -9,9 +9,12 @@ import ( "sync" "time" + "github.com/dustin/go-humanize" + "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/grafana/dskit/tenant" "github.com/opentracing/opentracing-go" + "github.com/pkg/errors" "github.com/prometheus/common/model" "github.com/prometheus/prometheus/model/labels" "github.com/prometheus/prometheus/model/timestamp" @@ -20,8 +23,10 @@ import ( "github.com/grafana/loki/pkg/logproto" "github.com/grafana/loki/pkg/logql" + "github.com/grafana/loki/pkg/logql/syntax" "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" "github.com/grafana/loki/pkg/storage/config" + "github.com/grafana/loki/pkg/storage/stores/index/stats" "github.com/grafana/loki/pkg/util" util_log "github.com/grafana/loki/pkg/util/log" "github.com/grafana/loki/pkg/util/spanlogger" @@ -32,6 +37,8 @@ const ( limitErrTmpl = "maximum of series (%d) reached for a single query" maxSeriesErrTmpl = "max entries limit per query exceeded, limit > max_entries_limit (%d > %d)" requiredLabelsErrTmpl = "stream selector is missing required matchers [%s], labels present in the query were [%s]" + limErrQueryTooManyBytesTmpl = "the query would read too many bytes (query: %s, limit: %s). Consider adding more specific stream selectors or reduce the time range of the query" + limErrQuerierTooManyBytesTmpl = "query too large to execute on a single querier, either because parallelization is not enabled, the query is unshardable, or a shard query is too big to execute: (query: %s, limit: %s). Consider adding more specific stream selectors or reduce the time range of the query" ) var ( @@ -50,6 +57,8 @@ type Limits interface { // frontend will process in parallel for TSDB queries. TSDBMaxQueryParallelism(context.Context, string) int RequiredLabels(context.Context, string) []string + MaxQueryBytesRead(context.Context, string) int + MaxQuerierBytesRead(context.Context, string) int } type limits struct { @@ -57,6 +66,7 @@ type limits struct { // Use pointers so nil value can indicate if the value was set. splitDuration *time.Duration maxQueryParallelism *int + maxQueryBytesRead *int } func (l limits) QuerySplitDuration(user string) time.Duration { @@ -184,6 +194,190 @@ func (l limitsMiddleware) Do(ctx context.Context, r queryrangebase.Request) (que return l.next.Do(ctx, r) } +type querySizeLimiter struct { + logger log.Logger + next queryrangebase.Handler + statsHandler queryrangebase.Handler + cfg []config.PeriodConfig + maxLookBackPeriod time.Duration + limitFunc func(context.Context, string) int + limitErrorTmpl string +} + +func newQuerySizeLimiter( + next queryrangebase.Handler, + cfg []config.PeriodConfig, + logger log.Logger, + limits Limits, + codec queryrangebase.Codec, + limitFunc func(context.Context, string) int, + limitErrorTmpl string, + statsHandler ...queryrangebase.Handler, +) *querySizeLimiter { + q := &querySizeLimiter{ + logger: logger, + next: next, + cfg: cfg, + limitFunc: limitFunc, + limitErrorTmpl: limitErrorTmpl, + } + + q.statsHandler = next + if len(statsHandler) > 0 { + q.statsHandler = statsHandler[0] + } + + // Parallelize the index stats requests, so it doesn't send a huge request to a single index-gw (i.e. {app=~".+"} for 30d). + // Indices are sharded by 24 hours, so we split the stats request in 24h intervals. + statsSplitTimeMiddleware := SplitByIntervalMiddleware(cfg, WithSplitByLimits(limits, 24*time.Hour), codec, splitByTime, nil) + q.statsHandler = statsSplitTimeMiddleware.Wrap(q.statsHandler) + + // Get MaxLookBackPeriod from downstream engine. This is needed for instant limited queries at getStatsForMatchers + ng := logql.NewDownstreamEngine(logql.EngineOpts{LogExecutingQuery: false}, DownstreamHandler{next: next, limits: limits}, limits, logger) + q.maxLookBackPeriod = ng.Opts().MaxLookBackPeriod + + return q +} + +// NewQuerierSizeLimiterMiddleware creates a new Middleware that enforces query size limits after sharding and splitting. +// The errorTemplate should format two strings: the bytes that would be read and the bytes limit. +func NewQuerierSizeLimiterMiddleware( + cfg []config.PeriodConfig, + logger log.Logger, + limits Limits, + codec queryrangebase.Codec, + statsHandler ...queryrangebase.Handler, +) queryrangebase.Middleware { + return queryrangebase.MiddlewareFunc(func(next queryrangebase.Handler) queryrangebase.Handler { + return newQuerySizeLimiter(next, cfg, logger, limits, codec, limits.MaxQuerierBytesRead, limErrQuerierTooManyBytesTmpl, statsHandler...) + }) +} + +// NewQuerySizeLimiterMiddleware creates a new Middleware that enforces query size limits. +// The errorTemplate should format two strings: the bytes that would be read and the bytes limit. +func NewQuerySizeLimiterMiddleware( + cfg []config.PeriodConfig, + logger log.Logger, + limits Limits, + codec queryrangebase.Codec, + statsHandler ...queryrangebase.Handler, +) queryrangebase.Middleware { + return queryrangebase.MiddlewareFunc(func(next queryrangebase.Handler) queryrangebase.Handler { + return newQuerySizeLimiter(next, cfg, logger, limits, codec, limits.MaxQueryBytesRead, limErrQueryTooManyBytesTmpl, statsHandler...) + }) +} + +// getBytesReadForRequest returns the number of bytes that would be read for the query in r. +// Since the query expression may contain multiple stream matchers, this function sums up the +// bytes that will be read for each stream. +// E.g. for the following query: +// +// count_over_time({job="foo"}[5m]) / count_over_time({job="bar"}[5m] offset 10m) +// +// this function will sum the bytes read for each of the following streams, taking into account +// individual intervals and offsets +// - {job="foo"} +// - {job="bar"} +func (q *querySizeLimiter) getBytesReadForRequest(ctx context.Context, r queryrangebase.Request) (uint64, error) { + sp, ctx := spanlogger.NewWithLogger(ctx, q.logger, "querySizeLimiter.getBytesReadForRequest") + defer sp.Finish() + + expr, err := syntax.ParseExpr(r.GetQuery()) + if err != nil { + return 0, err + } + + matcherGroups, err := syntax.MatcherGroups(expr) + if err != nil { + return 0, err + } + + // TODO: Set concurrency dynamically as in shardResolverForConf? + start := time.Now() + const maxConcurrentIndexReq = 10 + matcherStats, err := getStatsForMatchers(ctx, q.logger, q.statsHandler, model.Time(r.GetStart()), model.Time(r.GetEnd()), matcherGroups, maxConcurrentIndexReq, q.maxLookBackPeriod) + if err != nil { + return 0, err + } + + combinedStats := stats.MergeStats(matcherStats...) + + level.Debug(sp).Log( + append( + combinedStats.LoggingKeyValues(), + "msg", "queried index", + "type", "combined", + "len", len(matcherStats), + "max_parallelism", maxConcurrentIndexReq, + "duration", time.Since(start), + "total_bytes", strings.Replace(humanize.Bytes(combinedStats.Bytes), " ", "", 1), + )..., + ) + + return combinedStats.Bytes, nil +} + +func (q *querySizeLimiter) getSchemaCfg(r queryrangebase.Request) (config.PeriodConfig, error) { + maxRVDuration, maxOffset, err := maxRangeVectorAndOffsetDuration(r.GetQuery()) + if err != nil { + return config.PeriodConfig{}, errors.New("failed to get range-vector and offset duration: " + err.Error()) + } + + adjustedStart := int64(model.Time(r.GetStart()).Add(-maxRVDuration).Add(-maxOffset)) + adjustedEnd := int64(model.Time(r.GetEnd()).Add(-maxOffset)) + + return ShardingConfigs(q.cfg).ValidRange(adjustedStart, adjustedEnd) +} + +func (q *querySizeLimiter) guessLimitName() string { + if q.limitErrorTmpl == limErrQueryTooManyBytesTmpl { + return "MaxQueryBytesRead" + } + if q.limitErrorTmpl == limErrQuerierTooManyBytesTmpl { + return "MaxQuerierBytesRead" + } + return "unknown" +} + +func (q *querySizeLimiter) Do(ctx context.Context, r queryrangebase.Request) (queryrangebase.Response, error) { + log, ctx := spanlogger.New(ctx, "query_size_limits") + defer log.Finish() + + // Only support TSDB + schemaCfg, err := q.getSchemaCfg(r) + if err != nil { + return nil, httpgrpc.Errorf(http.StatusInternalServerError, "Failed to get schema config: %s", err.Error()) + } + if schemaCfg.IndexType != config.TSDBType { + return q.next.Do(ctx, r) + } + + tenantIDs, err := tenant.TenantIDs(ctx) + if err != nil { + return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) + } + + limitFuncCapture := func(id string) int { return q.limitFunc(ctx, id) } + if maxBytesRead := validation.SmallestPositiveNonZeroIntPerTenant(tenantIDs, limitFuncCapture); maxBytesRead > 0 { + bytesRead, err := q.getBytesReadForRequest(ctx, r) + if err != nil { + return nil, httpgrpc.Errorf(http.StatusInternalServerError, "Failed to get bytes read stats for query: %s", err.Error()) + } + + statsBytesStr := humanize.Bytes(bytesRead) + maxBytesReadStr := humanize.Bytes(uint64(maxBytesRead)) + + if bytesRead > uint64(maxBytesRead) { + level.Warn(log).Log("msg", "Query exceeds limits", "status", "rejected", "limit_name", q.guessLimitName(), "limit_bytes", maxBytesReadStr, "resolved_bytes", statsBytesStr) + return nil, httpgrpc.Errorf(http.StatusBadRequest, q.limitErrorTmpl, statsBytesStr, maxBytesReadStr) + } + + level.Debug(log).Log("msg", "Query is within limits", "status", "accepted", "limit_name", q.guessLimitName(), "limit_bytes", maxBytesReadStr, "resolved_bytes", statsBytesStr) + } + + return q.next.Do(ctx, r) +} + type seriesLimiter struct { hashes map[uint64]struct{} rw sync.RWMutex diff --git a/pkg/querier/queryrange/limits_test.go b/pkg/querier/queryrange/limits_test.go index fb960a4caf..28c26d2f2d 100644 --- a/pkg/querier/queryrange/limits_test.go +++ b/pkg/querier/queryrange/limits_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/grafana/loki/pkg/util/math" "github.com/prometheus/common/model" "github.com/prometheus/prometheus/model/labels" "github.com/prometheus/prometheus/promql" @@ -432,3 +433,215 @@ func Test_WeightedParallelism_DivideByZeroError(t *testing.T) { require.Equal(t, 1, result) }) } + +func getFakeStatsHandler(retBytes uint64) (queryrangebase.Handler, *int, error) { + fakeRT, err := newfakeRoundTripper() + if err != nil { + return nil, nil, err + } + + count, statsHandler := indexStatsResult(logproto.IndexStatsResponse{Bytes: retBytes}) + + fakeRT.setHandler(statsHandler) + + return queryrangebase.NewRoundTripperHandler(fakeRT, LokiCodec), count, nil +} + +func Test_MaxQuerySize(t *testing.T) { + const statsBytes = 1000 + + schemas := []config.PeriodConfig{ + { + // BoltDB -> Time -4 days + From: config.DayTime{Time: model.TimeFromUnix(testTime.Add(-96 * time.Hour).Unix())}, + IndexType: config.BoltDBShipperType, + }, + { + // TSDB -> Time -2 days + From: config.DayTime{Time: model.TimeFromUnix(testTime.Add(-48 * time.Hour).Unix())}, + IndexType: config.TSDBType, + }, + } + + for _, tc := range []struct { + desc string + schema string + query string + queryRange time.Duration + queryStart time.Time + queryEnd time.Time + limits Limits + + shouldErr bool + expectedQueryStatsHits int + expectedQuerierStatsHits int + }{ + { + desc: "No TSDB", + schema: config.BoltDBShipperType, + query: `{app="foo"} |= "foo"`, + queryRange: 1 * time.Hour, + + queryStart: testTime.Add(-96 * time.Hour), + queryEnd: testTime.Add(-90 * time.Hour), + limits: fakeLimits{ + maxQueryBytesRead: 1, + maxQuerierBytesRead: 1, + }, + + shouldErr: false, + expectedQueryStatsHits: 0, + expectedQuerierStatsHits: 0, + }, + { + desc: "Unlimited", + query: `{app="foo"} |= "foo"`, + queryStart: testTime.Add(-48 * time.Hour), + queryEnd: testTime, + limits: fakeLimits{ + maxQueryBytesRead: 0, + maxQuerierBytesRead: 0, + }, + + shouldErr: false, + expectedQueryStatsHits: 0, + expectedQuerierStatsHits: 0, + }, + { + desc: "1 hour range", + query: `{app="foo"} |= "foo"`, + queryStart: testTime.Add(-1 * time.Hour), + queryEnd: testTime, + limits: fakeLimits{ + maxQueryBytesRead: statsBytes, + maxQuerierBytesRead: statsBytes, + }, + + shouldErr: false, + // [testTime-1h, testTime) + expectedQueryStatsHits: 1, + expectedQuerierStatsHits: 1, + }, + { + desc: "24 hour range", + query: `{app="foo"} |= "foo"`, + queryStart: testTime.Add(-24 * time.Hour), + queryEnd: testTime, + limits: fakeLimits{ + maxQueryBytesRead: statsBytes, + maxQuerierBytesRead: statsBytes, + }, + + shouldErr: false, + // [testTime-24h, midnight) and [midnight, testTime] + expectedQueryStatsHits: 2, + expectedQuerierStatsHits: 2, + }, + { + desc: "48 hour range", + query: `{app="foo"} |= "foo"`, + queryStart: testTime.Add(-48 * time.Hour), + queryEnd: testTime, + limits: fakeLimits{ + maxQueryBytesRead: statsBytes, + maxQuerierBytesRead: statsBytes, + }, + + shouldErr: false, + // [testTime-48h, midnight-1d), [midnight-1d, midnight) and [midnight, testTime] + expectedQueryStatsHits: 3, + expectedQuerierStatsHits: 3, + }, + { + desc: "Query size too big", + query: `{app="foo"} |= "foo"`, + queryStart: testTime.Add(-1 * time.Hour), + queryEnd: testTime, + limits: fakeLimits{ + maxQueryBytesRead: statsBytes - 1, + maxQuerierBytesRead: statsBytes, + }, + + shouldErr: true, + expectedQueryStatsHits: 1, + expectedQuerierStatsHits: 0, + }, + { + desc: "Querier size too big", + query: `{app="foo"} |= "foo"`, + queryStart: testTime.Add(-1 * time.Hour), + queryEnd: testTime, + limits: fakeLimits{ + maxQueryBytesRead: statsBytes, + maxQuerierBytesRead: statsBytes - 1, + }, + + shouldErr: true, + expectedQueryStatsHits: 1, + expectedQuerierStatsHits: 1, + }, + { + desc: "Multi-matchers with offset", + query: `sum_over_time ({app="foo"} |= "foo" | unwrap foo [5m] ) / sum_over_time ({app="bar"} |= "bar" | unwrap bar [5m] offset 1h)`, + queryStart: testTime.Add(-1 * time.Hour), + queryEnd: testTime, + limits: fakeLimits{ + maxQueryBytesRead: statsBytes, + maxQuerierBytesRead: statsBytes, + }, + + shouldErr: false, + // *2 since we have two matcher groups + expectedQueryStatsHits: 1 * 2, + expectedQuerierStatsHits: 1 * 2, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + queryStatsHandler, queryStatsHits, err := getFakeStatsHandler(uint64(statsBytes / math.Max(tc.expectedQueryStatsHits, 1))) + require.NoError(t, err) + + querierStatsHandler, querierStatsHits, err := getFakeStatsHandler(uint64(statsBytes / math.Max(tc.expectedQuerierStatsHits, 1))) + require.NoError(t, err) + + fakeRT, err := newfakeRoundTripper() + require.NoError(t, err) + + _, promHandler := promqlResult(matrix) + fakeRT.setHandler(promHandler) + + lokiReq := &LokiRequest{ + Query: tc.query, + Limit: 1000, + StartTs: tc.queryStart, + EndTs: tc.queryEnd, + Direction: logproto.FORWARD, + Path: "/query_range", + } + + ctx := user.InjectOrgID(context.Background(), "foo") + req, err := LokiCodec.EncodeRequest(ctx, lokiReq) + require.NoError(t, err) + + req = req.WithContext(ctx) + err = user.InjectOrgIDIntoHTTPRequest(ctx, req) + require.NoError(t, err) + + middlewares := []queryrangebase.Middleware{ + NewQuerySizeLimiterMiddleware(schemas, util_log.Logger, tc.limits, LokiCodec, queryStatsHandler), + NewQuerierSizeLimiterMiddleware(schemas, util_log.Logger, tc.limits, LokiCodec, querierStatsHandler), + } + + _, err = queryrangebase.NewRoundTripper(fakeRT, LokiCodec, nil, middlewares...).RoundTrip(req) + + if tc.shouldErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + require.Equal(t, tc.expectedQueryStatsHits, *queryStatsHits) + require.Equal(t, tc.expectedQuerierStatsHits, *querierStatsHits) + }) + } + +} diff --git a/pkg/querier/queryrange/queryrangebase/roundtrip.go b/pkg/querier/queryrange/queryrangebase/roundtrip.go index c5c701346c..ad41b1676a 100644 --- a/pkg/querier/queryrange/queryrangebase/roundtrip.go +++ b/pkg/querier/queryrange/queryrangebase/roundtrip.go @@ -122,9 +122,8 @@ func (f RoundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { } type roundTripper struct { - next http.RoundTripper + roundTripperHandler handler Handler - codec Codec headers []string } @@ -132,8 +131,10 @@ type roundTripper struct { // using the codec to translate requests and responses. func NewRoundTripper(next http.RoundTripper, codec Codec, headers []string, middlewares ...Middleware) http.RoundTripper { transport := roundTripper{ - next: next, - codec: codec, + roundTripperHandler: roundTripperHandler{ + next: next, + codec: codec, + }, headers: headers, } transport.handler = MergeMiddlewares(middlewares...).Wrap(&transport) @@ -159,8 +160,22 @@ func (q roundTripper) RoundTrip(r *http.Request) (*http.Response, error) { return q.codec.EncodeResponse(r.Context(), response) } +type roundTripperHandler struct { + next http.RoundTripper + codec Codec +} + +// NewRoundTripperHandler returns a handler that translates Loki requests into http requests +// and passes down these to the next RoundTripper. +func NewRoundTripperHandler(next http.RoundTripper, codec Codec) Handler { + return roundTripperHandler{ + next: next, + codec: codec, + } +} + // Do implements Handler. -func (q roundTripper) Do(ctx context.Context, r Request) (Response, error) { +func (q roundTripperHandler) Do(ctx context.Context, r Request) (Response, error) { request, err := q.codec.EncodeRequest(ctx, r) if err != nil { return nil, err diff --git a/pkg/querier/queryrange/querysharding.go b/pkg/querier/queryrange/querysharding.go index 83b5b12a7d..8e2a98b924 100644 --- a/pkg/querier/queryrange/querysharding.go +++ b/pkg/querier/queryrange/querysharding.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "github.com/dustin/go-humanize" "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/grafana/dskit/tenant" @@ -33,6 +34,7 @@ var errInvalidShardingRange = errors.New("Query does not fit in a single shardin func NewQueryShardMiddleware( logger log.Logger, confs ShardingConfigs, + codec queryrangebase.Codec, middlewareMetrics *queryrangebase.InstrumentMiddlewareMetrics, shardingMetrics *logql.MapperMetrics, limits Limits, @@ -95,6 +97,28 @@ type astMapperware struct { maxShards int } +func (ast *astMapperware) checkQuerySizeLimit(ctx context.Context, bytesPerShard uint64) error { + tenantIDs, err := tenant.TenantIDs(ctx) + if err != nil { + return httpgrpc.Errorf(http.StatusBadRequest, err.Error()) + } + + maxQuerierBytesReadCapture := func(id string) int { return ast.limits.MaxQuerierBytesRead(ctx, id) } + if maxBytesRead := validation.SmallestPositiveNonZeroIntPerTenant(tenantIDs, maxQuerierBytesReadCapture); maxBytesRead > 0 { + statsBytesStr := humanize.Bytes(bytesPerShard) + maxBytesReadStr := humanize.Bytes(uint64(maxBytesRead)) + + if bytesPerShard > uint64(maxBytesRead) { + level.Warn(ast.logger).Log("msg", "Query exceeds limits", "status", "rejected", "limit_name", "MaxQuerierBytesRead", "limit_bytes", maxBytesReadStr, "resolved_bytes", statsBytesStr) + return httpgrpc.Errorf(http.StatusBadRequest, limErrQuerierTooManyBytesTmpl, statsBytesStr, maxBytesReadStr) + } + + level.Debug(ast.logger).Log("msg", "Query is within limits", "status", "accepted", "limit_name", "MaxQuerierBytesRead", "limit_bytes", maxBytesReadStr, "resolved_bytes", statsBytesStr) + } + + return nil +} + func (ast *astMapperware) Do(ctx context.Context, r queryrangebase.Request) (queryrangebase.Response, error) { logger := spanlogger.FromContextWithFallback( ctx, @@ -128,6 +152,7 @@ func (ast *astMapperware) Do(ctx context.Context, r queryrangebase.Request) (que ast.maxShards, r, ast.next, + ast.limits, ) if !ok { return ast.next.Do(ctx, r) @@ -135,16 +160,21 @@ func (ast *astMapperware) Do(ctx context.Context, r queryrangebase.Request) (que mapper := logql.NewShardMapper(resolver, ast.metrics) - noop, parsed, err := mapper.Parse(r.GetQuery()) + noop, bytesPerShard, parsed, err := mapper.Parse(r.GetQuery()) if err != nil { level.Warn(logger).Log("msg", "failed mapping AST", "err", err.Error(), "query", r.GetQuery()) return nil, err } level.Debug(logger).Log("no-op", noop, "mapped", parsed.String()) + // Note, even if noop, bytesPerShard contains the bytes that'd be read for the whole expr without sharding + if err = ast.checkQuerySizeLimit(ctx, bytesPerShard); err != nil { + return nil, err + } + + // If the ast can't be mapped to a sharded equivalent, + // we can bypass the sharding engine and forward the request downstream. if noop { - // the ast can't be mapped to a sharded equivalent - // so we can bypass the sharding engine. return ast.next.Do(ctx, r) } diff --git a/pkg/querier/queryrange/querysharding_test.go b/pkg/querier/queryrange/querysharding_test.go index 54d8c7664d..1f0b424945 100644 --- a/pkg/querier/queryrange/querysharding_test.go +++ b/pkg/querier/queryrange/querysharding_test.go @@ -6,6 +6,7 @@ import ( "fmt" "math" "sort" + "strings" "sync" "testing" "time" @@ -185,6 +186,140 @@ func Test_astMapper(t *testing.T) { require.Equal(t, expected.(*LokiResponse).Data, resp.(*LokiResponse).Data) } +func Test_astMapper_QuerySizeLimits(t *testing.T) { + noErr := "" + for _, tc := range []struct { + desc string + query string + maxQuerierBytesSize int + + err string + expectedStatsHandlerHits int + }{ + { + desc: "Non shardable query", + query: `sum_over_time({app="foo"} |= "foo" | unwrap foo [1h])`, + maxQuerierBytesSize: 100, + + err: noErr, + expectedStatsHandlerHits: 1, + }, + { + desc: "Non shardable query too big", + query: `sum_over_time({app="foo"} |= "foo" | unwrap foo [1h])`, + maxQuerierBytesSize: 10, + err: fmt.Sprintf(limErrQuerierTooManyBytesTmpl, "100 B", "10 B"), + expectedStatsHandlerHits: 1, + }, + { + desc: "Shardable query", + query: `count_over_time({app="foo"} |= "foo" [1h])`, + maxQuerierBytesSize: 100, + + err: noErr, + expectedStatsHandlerHits: 1, + }, + { + desc: "Shardable query too big", + query: `count_over_time({app="foo"} |= "foo" [1h])`, + maxQuerierBytesSize: 10, + + err: fmt.Sprintf(limErrQuerierTooManyBytesTmpl, "100 B", "10 B"), + expectedStatsHandlerHits: 1, + }, + { + desc: "Partially Shardable query fitting", + query: `count_over_time({app="foo"} |= "foo" [1h]) - sum_over_time({app="foo"} |= "foo" | unwrap foo [1h])`, + maxQuerierBytesSize: 100, + + err: noErr, + expectedStatsHandlerHits: 2, + }, + { + desc: "Partially Shardable LHS too big", + query: `count_over_time({app="bar"} |= "bar" [1h]) - sum_over_time({app="foo"} |= "foo" | unwrap foo [1h])`, + maxQuerierBytesSize: 100, + + err: fmt.Sprintf(limErrQuerierTooManyBytesTmpl, "500 B", "100 B"), + expectedStatsHandlerHits: 2, + }, + { + desc: "Partially Shardable RHS too big", + query: `count_over_time({app="foo"} |= "foo" [1h]) - sum_over_time({app="bar"} |= "bar" | unwrap foo [1h])`, + maxQuerierBytesSize: 100, + + err: fmt.Sprintf(limErrQuerierTooManyBytesTmpl, "500 B", "100 B"), + expectedStatsHandlerHits: 2, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + statsCalled := 0 + handler := queryrangebase.HandlerFunc(func(ctx context.Context, req queryrangebase.Request) (queryrangebase.Response, error) { + if casted, ok := req.(*logproto.IndexStatsRequest); ok { + statsCalled++ + + var bytes uint64 + if strings.Contains(casted.Matchers, `app="foo"`) { + bytes = 100 + } + if strings.Contains(casted.Matchers, `app="bar"`) { + bytes = 500 + } + + return &IndexStatsResponse{ + Response: &logproto.IndexStatsResponse{ + Bytes: bytes, + }, + }, nil + } + if _, ok := req.(*LokiRequest); ok { + return &LokiPromResponse{Response: &queryrangebase.PrometheusResponse{ + Data: queryrangebase.PrometheusData{ + ResultType: loghttp.ResultTypeVector, + Result: []queryrangebase.SampleStream{ + { + Labels: []logproto.LabelAdapter{{Name: "foo", Value: "bar"}}, + Samples: []logproto.LegacySample{{Value: 10, TimestampMs: 10}}, + }, + }, + }, + }}, nil + } + + return nil, nil + }) + + mware := newASTMapperware( + ShardingConfigs{ + config.PeriodConfig{ + RowShards: 2, + IndexType: config.TSDBType, + }, + }, + handler, + log.NewNopLogger(), + nilShardingMetrics, + fakeLimits{ + maxSeries: math.MaxInt32, + maxQueryParallelism: 1, + tsdbMaxQueryParallelism: 1, + queryTimeout: time.Minute, + maxQuerierBytesRead: tc.maxQuerierBytesSize, + }, + 0, + ) + + _, err := mware.Do(user.InjectOrgID(context.Background(), "1"), defaultReq().WithQuery(tc.query)) + if err != nil { + require.ErrorContains(t, err, tc.err) + } + + require.Equal(t, tc.expectedStatsHandlerHits, statsCalled) + + }) + } +} + func Test_ShardingByPass(t *testing.T) { called := 0 handler := queryrangebase.HandlerFunc(func(ctx context.Context, req queryrangebase.Request) (queryrangebase.Response, error) { @@ -269,7 +404,7 @@ func Test_InstantSharding(t *testing.T) { cpyPeriodConf.RowShards = 3 sharding := NewQueryShardMiddleware(log.NewNopLogger(), ShardingConfigs{ cpyPeriodConf, - }, queryrangebase.NewInstrumentMiddlewareMetrics(nil), + }, LokiCodec, queryrangebase.NewInstrumentMiddlewareMetrics(nil), nilShardingMetrics, fakeLimits{ maxSeries: math.MaxInt32, diff --git a/pkg/querier/queryrange/roundtrip.go b/pkg/querier/queryrange/roundtrip.go index 1c1a86c2d4..2cb45a00ab 100644 --- a/pkg/querier/queryrange/roundtrip.go +++ b/pkg/querier/queryrange/roundtrip.go @@ -287,52 +287,66 @@ func NewLogFilterTripperware( c cache.Cache, metrics *Metrics, ) (queryrangebase.Tripperware, error) { - queryRangeMiddleware := []queryrangebase.Middleware{ - StatsCollectorMiddleware(), - NewLimitsMiddleware(limits), - queryrangebase.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), - SplitByIntervalMiddleware(schema.Configs, limits, codec, splitByTime, metrics.SplitByMetrics), - } + return func(next http.RoundTripper) http.RoundTripper { + skipMiddleware := queryrangebase.NewRoundTripperHandler(next, codec) + if cfg.MaxRetries > 0 { + skipMiddleware = queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics).Wrap(skipMiddleware) + } - if cfg.CacheResults { - queryCacheMiddleware := NewLogResultCache( - log, - limits, - c, - func(r queryrangebase.Request) bool { - return !r.GetCachingOptions().Disabled - }, - cfg.Transformer, - metrics.LogResultCacheMetrics, - ) - queryRangeMiddleware = append( - queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("log_results_cache", metrics.InstrumentMiddlewareMetrics), - queryCacheMiddleware, - ) - } + queryRangeMiddleware := []queryrangebase.Middleware{ + StatsCollectorMiddleware(), + NewLimitsMiddleware(limits), + NewQuerySizeLimiterMiddleware(schema.Configs, log, limits, codec, skipMiddleware), + queryrangebase.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), + queryrangebase.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), + SplitByIntervalMiddleware(schema.Configs, limits, codec, splitByTime, metrics.SplitByMetrics), + } - if cfg.ShardedQueries { - queryRangeMiddleware = append(queryRangeMiddleware, - NewQueryShardMiddleware( + if cfg.CacheResults { + queryCacheMiddleware := NewLogResultCache( log, - schema.Configs, - metrics.InstrumentMiddlewareMetrics, // instrumentation is included in the sharding middleware - metrics.MiddlewareMapperMetrics.shardMapper, limits, - 0, // 0 is unlimited shards - ), - ) - } + c, + func(r queryrangebase.Request) bool { + return !r.GetCachingOptions().Disabled + }, + cfg.Transformer, + metrics.LogResultCacheMetrics, + ) + queryRangeMiddleware = append( + queryRangeMiddleware, + queryrangebase.InstrumentMiddleware("log_results_cache", metrics.InstrumentMiddlewareMetrics), + queryCacheMiddleware, + ) + } - if cfg.MaxRetries > 0 { - queryRangeMiddleware = append( - queryRangeMiddleware, queryrangebase.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), - queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), - ) - } + if cfg.ShardedQueries { + queryRangeMiddleware = append(queryRangeMiddleware, + NewQueryShardMiddleware( + log, + schema.Configs, + codec, + metrics.InstrumentMiddlewareMetrics, // instrumentation is included in the sharding middleware + metrics.MiddlewareMapperMetrics.shardMapper, + limits, + 0, // 0 is unlimited shards + ), + ) + } else { + // The sharding middleware takes care of enforcing this limit for both shardable and non-shardable queries. + // If we are not using sharding, we enforce the limit by adding this middleware after time splitting. + queryRangeMiddleware = append(queryRangeMiddleware, + NewQuerierSizeLimiterMiddleware(schema.Configs, log, limits, codec, skipMiddleware), + ) + } + + if cfg.MaxRetries > 0 { + queryRangeMiddleware = append( + queryRangeMiddleware, queryrangebase.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), + queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), + ) + } - return func(next http.RoundTripper) http.RoundTripper { if len(queryRangeMiddleware) > 0 { return NewLimitedRoundTripper(next, codec, limits, schema.Configs, queryRangeMiddleware...) } @@ -350,59 +364,26 @@ func NewLimitedTripperware( c cache.Cache, metrics *Metrics, ) (queryrangebase.Tripperware, error) { - queryRangeMiddleware := []queryrangebase.Middleware{ - StatsCollectorMiddleware(), - NewLimitsMiddleware(limits), - queryrangebase.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), - // Limited queries only need to fetch up to the requested line limit worth of logs, - // Our defaults for splitting and parallelism are much too aggressive for large customers and result in - // potentially GB of logs being returned by all the shards and splits which will overwhelm the frontend - // Therefore we force max parallelism to one so that these queries are executed sequentially. - // Below we also fix the number of shards to a static number. - SplitByIntervalMiddleware(schema.Configs, WithMaxParallelism(limits, 1), codec, splitByTime, metrics.SplitByMetrics), - } - - if cfg.CacheResults { - queryCacheMiddleware := NewLogResultCache( - log, - limits, - c, - func(r queryrangebase.Request) bool { - return !r.GetCachingOptions().Disabled - }, - cfg.Transformer, - metrics.LogResultCacheMetrics, - ) - queryRangeMiddleware = append( - queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("log_results_cache", metrics.InstrumentMiddlewareMetrics), - queryCacheMiddleware, - ) - } - - if cfg.ShardedQueries { - queryRangeMiddleware = append(queryRangeMiddleware, - NewQueryShardMiddleware( - log, - schema.Configs, - metrics.InstrumentMiddlewareMetrics, // instrumentation is included in the sharding middleware - metrics.MiddlewareMapperMetrics.shardMapper, - limits, - // Too many shards on limited queries results in slowing down this type of query - // and overwhelming the frontend, therefore we fix the number of shards to prevent this. - 32, - ), - ) - } + return func(next http.RoundTripper) http.RoundTripper { + skipMiddleware := queryrangebase.NewRoundTripperHandler(next, codec) + if cfg.MaxRetries > 0 { + skipMiddleware = queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics).Wrap(skipMiddleware) + } - if cfg.MaxRetries > 0 { - queryRangeMiddleware = append( - queryRangeMiddleware, queryrangebase.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), - queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), - ) - } + queryRangeMiddleware := []queryrangebase.Middleware{ + StatsCollectorMiddleware(), + NewLimitsMiddleware(limits), + NewQuerySizeLimiterMiddleware(schema.Configs, log, limits, codec, skipMiddleware), + queryrangebase.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), + // Limited queries only need to fetch up to the requested line limit worth of logs, + // Our defaults for splitting and parallelism are much too aggressive for large customers and result in + // potentially GB of logs being returned by all the shards and splits which will overwhelm the frontend + // Therefore we force max parallelism to one so that these queries are executed sequentially. + // Below we also fix the number of shards to a static number. + SplitByIntervalMiddleware(schema.Configs, WithMaxParallelism(limits, 1), codec, splitByTime, metrics.SplitByMetrics), + NewQuerierSizeLimiterMiddleware(schema.Configs, log, limits, codec, skipMiddleware), + } - return func(next http.RoundTripper) http.RoundTripper { if len(queryRangeMiddleware) > 0 { return NewLimitedRoundTripper(next, codec, limits, schema.Configs, queryRangeMiddleware...) } @@ -505,24 +486,12 @@ func NewMetricTripperware( metrics *Metrics, registerer prometheus.Registerer, ) (queryrangebase.Tripperware, error) { - queryRangeMiddleware := []queryrangebase.Middleware{StatsCollectorMiddleware(), NewLimitsMiddleware(limits)} - if cfg.AlignQueriesWithStep { - queryRangeMiddleware = append( - queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("step_align", metrics.InstrumentMiddlewareMetrics), - queryrangebase.StepAlignMiddleware, - ) - } - - queryRangeMiddleware = append( - queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), - SplitByIntervalMiddleware(schema.Configs, limits, codec, splitMetricByTime, metrics.SplitByMetrics), - ) cacheKey := cacheKeyLimits{limits, cfg.Transformer} + var queryCacheMiddleware queryrangebase.Middleware if cfg.CacheResults { - queryCacheMiddleware, err := queryrangebase.NewResultsCacheMiddleware( + var err error + queryCacheMiddleware, err = queryrangebase.NewResultsCacheMiddleware( log, c, cacheKey, @@ -549,35 +518,70 @@ func NewMetricTripperware( if err != nil { return nil, err } - queryRangeMiddleware = append( - queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("results_cache", metrics.InstrumentMiddlewareMetrics), - queryCacheMiddleware, - ) } - if cfg.ShardedQueries { - queryRangeMiddleware = append(queryRangeMiddleware, - NewQueryShardMiddleware( - log, - schema.Configs, - metrics.InstrumentMiddlewareMetrics, // instrumentation is included in the sharding middleware - metrics.MiddlewareMapperMetrics.shardMapper, - limits, - 0, // 0 is unlimited shards - ), - ) - } + return func(next http.RoundTripper) http.RoundTripper { + skipMiddleware := queryrangebase.NewRoundTripperHandler(next, codec) + if cfg.MaxRetries > 0 { + skipMiddleware = queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics).Wrap(skipMiddleware) + } + + queryRangeMiddleware := []queryrangebase.Middleware{ + StatsCollectorMiddleware(), + NewLimitsMiddleware(limits), + } + + if cfg.AlignQueriesWithStep { + queryRangeMiddleware = append( + queryRangeMiddleware, + queryrangebase.InstrumentMiddleware("step_align", metrics.InstrumentMiddlewareMetrics), + queryrangebase.StepAlignMiddleware, + ) + } - if cfg.MaxRetries > 0 { queryRangeMiddleware = append( queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), - queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), + NewQuerySizeLimiterMiddleware(schema.Configs, log, limits, codec, skipMiddleware), + queryrangebase.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), + SplitByIntervalMiddleware(schema.Configs, limits, codec, splitMetricByTime, metrics.SplitByMetrics), ) - } - return func(next http.RoundTripper) http.RoundTripper { + if cfg.CacheResults { + queryRangeMiddleware = append( + queryRangeMiddleware, + queryrangebase.InstrumentMiddleware("results_cache", metrics.InstrumentMiddlewareMetrics), + queryCacheMiddleware, + ) + } + + if cfg.ShardedQueries { + queryRangeMiddleware = append(queryRangeMiddleware, + NewQueryShardMiddleware( + log, + schema.Configs, + codec, + metrics.InstrumentMiddlewareMetrics, // instrumentation is included in the sharding middleware + metrics.MiddlewareMapperMetrics.shardMapper, + limits, + 0, // 0 is unlimited shards + ), + ) + } else { + // The sharding middleware takes care of enforcing this limit for both shardable and non-shardable queries. + // If we are not using sharding, we enforce the limit by adding this middleware after time splitting. + queryRangeMiddleware = append(queryRangeMiddleware, + NewQuerierSizeLimiterMiddleware(schema.Configs, log, limits, codec, skipMiddleware), + ) + } + + if cfg.MaxRetries > 0 { + queryRangeMiddleware = append( + queryRangeMiddleware, + queryrangebase.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), + queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), + ) + } + // Finally, if the user selected any query range middleware, stitch it in. if len(queryRangeMiddleware) > 0 { rt := NewLimitedRoundTripper(next, codec, limits, schema.Configs, queryRangeMiddleware...) @@ -601,31 +605,41 @@ func NewInstantMetricTripperware( codec queryrangebase.Codec, metrics *Metrics, ) (queryrangebase.Tripperware, error) { - queryRangeMiddleware := []queryrangebase.Middleware{StatsCollectorMiddleware(), NewLimitsMiddleware(limits)} + return func(next http.RoundTripper) http.RoundTripper { + skipMiddleware := queryrangebase.NewRoundTripperHandler(next, codec) + if cfg.MaxRetries > 0 { + skipMiddleware = queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics).Wrap(skipMiddleware) + } - if cfg.ShardedQueries { - queryRangeMiddleware = append(queryRangeMiddleware, - NewSplitByRangeMiddleware(log, limits, metrics.MiddlewareMapperMetrics.rangeMapper), - NewQueryShardMiddleware( - log, - schema.Configs, - metrics.InstrumentMiddlewareMetrics, // instrumentation is included in the sharding middleware - metrics.MiddlewareMapperMetrics.shardMapper, - limits, - 0, // 0 is unlimited shards - ), - ) - } + queryRangeMiddleware := []queryrangebase.Middleware{ + StatsCollectorMiddleware(), + NewLimitsMiddleware(limits), + NewQuerySizeLimiterMiddleware(schema.Configs, log, limits, codec, skipMiddleware), + } - if cfg.MaxRetries > 0 { - queryRangeMiddleware = append( - queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), - queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), - ) - } + if cfg.ShardedQueries { + queryRangeMiddleware = append(queryRangeMiddleware, + NewSplitByRangeMiddleware(log, limits, metrics.MiddlewareMapperMetrics.rangeMapper), + NewQueryShardMiddleware( + log, + schema.Configs, + codec, + metrics.InstrumentMiddlewareMetrics, // instrumentation is included in the sharding middleware + metrics.MiddlewareMapperMetrics.shardMapper, + limits, + 0, // 0 is unlimited shards + ), + ) + } + + if cfg.MaxRetries > 0 { + queryRangeMiddleware = append( + queryRangeMiddleware, + queryrangebase.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), + queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), + ) + } - return func(next http.RoundTripper) http.RoundTripper { if len(queryRangeMiddleware) > 0 { return NewLimitedRoundTripper(next, codec, limits, schema.Configs, queryRangeMiddleware...) } diff --git a/pkg/querier/queryrange/roundtrip_test.go b/pkg/querier/queryrange/roundtrip_test.go index ffd5c5f109..4597132734 100644 --- a/pkg/querier/queryrange/roundtrip_test.go +++ b/pkg/querier/queryrange/roundtrip_test.go @@ -108,11 +108,34 @@ var ( } ) +func getQueryAndStatsHandler(queryHandler, statsHandler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/loki/api/v1/index/stats" { + statsHandler.ServeHTTP(w, r) + return + } + + if r.URL.Path == "/loki/api/v1/query_range" || r.URL.Path == "/loki/api/v1/query" { + queryHandler.ServeHTTP(w, r) + return + } + + panic("Request not supported") + }) +} + // those tests are mostly for testing the glue between all component and make sure they activate correctly. func TestMetricsTripperware(t *testing.T) { - l := WithSplitByLimits(fakeLimits{maxSeries: math.MaxInt32, maxQueryParallelism: 1}, 4*time.Hour) + var l Limits = fakeLimits{ + maxSeries: math.MaxInt32, + maxQueryParallelism: 1, + tsdbMaxQueryParallelism: 1, + maxQueryBytesRead: 1000, + maxQuerierBytesRead: 100, + } + l = WithSplitByLimits(l, 4*time.Hour) tpw, stopper, err := NewTripperware(testConfig, util_log.Logger, l, config.SchemaConfig{ - Configs: testSchemas, + Configs: testSchemasTSDB, }, nil, false, nil) if stopper != nil { defer stopper.Stop() @@ -139,9 +162,28 @@ func TestMetricsTripperware(t *testing.T) { rt, err := newfakeRoundTripper() require.NoError(t, err) + // Test MaxQueryBytesRead limit + statsCount, statsHandler := indexStatsResult(logproto.IndexStatsResponse{Bytes: 2000}) + queryCount, queryHandler := counter() + rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) + _, err = tpw(rt).RoundTrip(req) + require.Error(t, err) + require.Equal(t, 1, *statsCount) + require.Equal(t, 0, *queryCount) + + // Test MaxQuerierBytesRead limit + statsCount, statsHandler = indexStatsResult(logproto.IndexStatsResponse{Bytes: 200}) + queryCount, queryHandler = counter() + rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) + _, err = tpw(rt).RoundTrip(req) + require.Error(t, err) + require.Equal(t, 2, *statsCount) + require.Equal(t, 0, *queryCount) + // testing retry - retries, h := counter() - rt.setHandler(h) + _, statsHandler = indexStatsResult(logproto.IndexStatsResponse{Bytes: 10}) + retries, queryHandler := counter() + rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) _, err = tpw(rt).RoundTrip(req) // 3 retries configured. require.GreaterOrEqual(t, *retries, 3) @@ -153,8 +195,9 @@ func TestMetricsTripperware(t *testing.T) { defer rt.Close() // testing split interval - count, h := promqlResult(matrix) - rt.setHandler(h) + _, statsHandler = indexStatsResult(logproto.IndexStatsResponse{Bytes: 10}) + count, queryHandler := promqlResult(matrix) + rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) resp, err := tpw(rt).RoundTrip(req) // 2 queries require.Equal(t, 2, *count) @@ -163,8 +206,8 @@ func TestMetricsTripperware(t *testing.T) { require.NoError(t, err) // testing cache - count, h = counter() - rt.setHandler(h) + count, queryHandler = counter() + rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) cacheResp, err := tpw(rt).RoundTrip(req) // 0 queries result are cached. require.Equal(t, 0, *count) @@ -176,7 +219,13 @@ func TestMetricsTripperware(t *testing.T) { } func TestLogFilterTripperware(t *testing.T) { - tpw, stopper, err := NewTripperware(testConfig, util_log.Logger, fakeLimits{maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) + var l Limits = fakeLimits{ + maxQueryParallelism: 1, + tsdbMaxQueryParallelism: 1, + maxQueryBytesRead: 1000, + maxQuerierBytesRead: 100, + } + tpw, stopper, err := NewTripperware(testConfig, util_log.Logger, l, config.SchemaConfig{Configs: testSchemasTSDB}, nil, false, nil) if stopper != nil { defer stopper.Stop() } @@ -215,17 +264,44 @@ func TestLogFilterTripperware(t *testing.T) { require.NoError(t, err) // testing retry - retries, h := counter() - rt.setHandler(h) + _, statsHandler := indexStatsResult(logproto.IndexStatsResponse{Bytes: 10}) + retries, queryHandler := counter() + rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) _, err = tpw(rt).RoundTrip(req) require.GreaterOrEqual(t, *retries, 3) require.Error(t, err) + + // Test MaxQueryBytesRead limit + statsCount, statsHandler := indexStatsResult(logproto.IndexStatsResponse{Bytes: 2000}) + queryCount, queryHandler := counter() + rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) + _, err = tpw(rt).RoundTrip(req) + require.Error(t, err) + require.Equal(t, 1, *statsCount) + require.Equal(t, 0, *queryCount) + + // Test MaxQuerierBytesRead limit + statsCount, statsHandler = indexStatsResult(logproto.IndexStatsResponse{Bytes: 200}) + queryCount, queryHandler = counter() + rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) + _, err = tpw(rt).RoundTrip(req) + require.Error(t, err) + require.Equal(t, 2, *statsCount) + require.Equal(t, 0, *queryCount) } func TestInstantQueryTripperware(t *testing.T) { testShardingConfig := testConfig testShardingConfig.ShardedQueries = true - tpw, stopper, err := NewTripperware(testShardingConfig, util_log.Logger, fakeLimits{maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) + var l Limits = fakeLimits{ + maxQueryParallelism: 1, + tsdbMaxQueryParallelism: 1, + maxQueryBytesRead: 1000, + maxQuerierBytesRead: 100, + queryTimeout: 1 * time.Minute, + maxSeries: 1, + } + tpw, stopper, err := NewTripperware(testShardingConfig, util_log.Logger, l, config.SchemaConfig{Configs: testSchemasTSDB}, nil, false, nil) if stopper != nil { defer stopper.Stop() } @@ -237,6 +313,7 @@ func TestInstantQueryTripperware(t *testing.T) { lreq := &LokiInstantRequest{ Query: `sum by (job) (bytes_rate({cluster="dev-us-central-0"}[15m]))`, Limit: 1000, + TimeTs: testTime, Direction: logproto.FORWARD, Path: "/loki/api/v1/query", } @@ -249,8 +326,27 @@ func TestInstantQueryTripperware(t *testing.T) { err = user.InjectOrgIDIntoHTTPRequest(ctx, req) require.NoError(t, err) - count, h := promqlResult(vector) - rt.setHandler(h) + // Test MaxQueryBytesRead limit + statsCount, statsHandler := indexStatsResult(logproto.IndexStatsResponse{Bytes: 2000}) + queryCount, queryHandler := counter() + rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) + _, err = tpw(rt).RoundTrip(req) + require.Error(t, err) + require.Equal(t, 1, *statsCount) + require.Equal(t, 0, *queryCount) + + // Test MaxQuerierBytesRead limit + statsCount, statsHandler = indexStatsResult(logproto.IndexStatsResponse{Bytes: 200}) + queryCount, queryHandler = counter() + rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) + _, err = tpw(rt).RoundTrip(req) + require.Error(t, err) + require.Equal(t, 2, *statsCount) + require.Equal(t, 0, *queryCount) + + count, queryHandler := promqlResult(vector) + _, statsHandler = indexStatsResult(logproto.IndexStatsResponse{Bytes: 10}) + rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) resp, err := tpw(rt).RoundTrip(req) require.Equal(t, 1, *count) require.NoError(t, err) @@ -639,6 +735,8 @@ type fakeLimits struct { minShardingLookback time.Duration queryTimeout time.Duration requiredLabels []string + maxQueryBytesRead int + maxQuerierBytesRead int } func (f fakeLimits) QuerySplitDuration(key string) time.Duration { @@ -683,6 +781,14 @@ func (f fakeLimits) MinShardingLookback(string) time.Duration { return f.minShardingLookback } +func (f fakeLimits) MaxQueryBytesRead(context.Context, string) int { + return f.maxQueryBytesRead +} + +func (f fakeLimits) MaxQuerierBytesRead(context.Context, string) int { + return f.maxQuerierBytesRead +} + func (f fakeLimits) QueryTimeout(context.Context, string) time.Duration { return f.queryTimeout } @@ -731,6 +837,19 @@ func seriesResult(v logproto.SeriesResponse) (*int, http.Handler) { }) } +func indexStatsResult(v logproto.IndexStatsResponse) (*int, http.Handler) { + count := 0 + var lock sync.Mutex + return &count, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + lock.Lock() + defer lock.Unlock() + if err := marshal.WriteIndexStatsResponseJSON(&v, w); err != nil { + panic(err) + } + count++ + }) +} + type fakeHandler struct { count int lock sync.Mutex diff --git a/pkg/querier/queryrange/shard_resolver.go b/pkg/querier/queryrange/shard_resolver.go index 5f196550a5..a40c6d0c9b 100644 --- a/pkg/querier/queryrange/shard_resolver.go +++ b/pkg/querier/queryrange/shard_resolver.go @@ -11,8 +11,6 @@ import ( "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/grafana/dskit/concurrency" - "github.com/prometheus/common/model" - "github.com/grafana/loki/pkg/logproto" "github.com/grafana/loki/pkg/logql" "github.com/grafana/loki/pkg/logql/syntax" @@ -20,6 +18,7 @@ import ( "github.com/grafana/loki/pkg/storage/config" "github.com/grafana/loki/pkg/storage/stores/index/stats" "github.com/grafana/loki/pkg/util/spanlogger" + "github.com/prometheus/common/model" ) func shardResolverForConf( @@ -31,12 +30,14 @@ func shardResolverForConf( maxShards int, r queryrangebase.Request, handler queryrangebase.Handler, + limits Limits, ) (logql.ShardResolver, bool) { if conf.IndexType == config.TSDBType { return &dynamicShardResolver{ ctx: ctx, logger: logger, handler: handler, + limits: limits, from: model.Time(r.GetStart()), through: model.Time(r.GetEnd()), maxParallelism: maxParallelism, @@ -54,6 +55,7 @@ type dynamicShardResolver struct { ctx context.Context handler queryrangebase.Handler logger log.Logger + limits Limits from, through model.Time maxParallelism int @@ -61,37 +63,34 @@ type dynamicShardResolver struct { defaultLookback time.Duration } -func (r *dynamicShardResolver) Shards(e syntax.Expr) (int, error) { - sp, ctx := spanlogger.NewWithLogger(r.ctx, r.logger, "dynamicShardResolver.Shards") - defer sp.Finish() - // We try to shard subtrees in the AST independently if possible, although - // nested binary expressions can make this difficult. In this case, - // we query the index stats for all matcher groups then sum the results. - grps, err := syntax.MatcherGroups(e) - if err != nil { - return 0, err - } - - // If there are zero matchers groups, we'll inject one to query everything - if len(grps) == 0 { - grps = append(grps, syntax.MatcherRange{}) - } - - results := make([]*stats.Stats, 0, len(grps)) - - start := time.Now() - if err := concurrency.ForEachJob(ctx, len(grps), r.maxParallelism, func(ctx context.Context, i int) error { - matchers := syntax.MatchersString(grps[i].Matchers) - diff := grps[i].Interval + grps[i].Offset - adjustedFrom := r.from.Add(-diff) - if grps[i].Interval == 0 { - adjustedFrom = adjustedFrom.Add(-r.defaultLookback) +// getStatsForMatchers returns the index stats for all the groups in matcherGroups. +func getStatsForMatchers( + ctx context.Context, + logger log.Logger, + statsHandler queryrangebase.Handler, + start, end model.Time, + matcherGroups []syntax.MatcherRange, + parallelism int, + defaultLookback ...time.Duration, +) ([]*stats.Stats, error) { + startTime := time.Now() + + results := make([]*stats.Stats, len(matcherGroups)) + if err := concurrency.ForEachJob(ctx, len(matcherGroups), parallelism, func(ctx context.Context, i int) error { + matchers := syntax.MatchersString(matcherGroups[i].Matchers) + diff := matcherGroups[i].Interval + matcherGroups[i].Offset + adjustedFrom := start.Add(-diff) + if matcherGroups[i].Interval == 0 && len(defaultLookback) > 0 { + // For limited instant queries, when start == end, the queries would return + // zero results. Prometheus has a concept of "look back amount of time for instant queries" + // since metric data is sampled at some configurable scrape_interval (commonly 15s, 30s, or 1m). + // We copy that idea and say "find me logs from the past when start=end". + adjustedFrom = adjustedFrom.Add(-defaultLookback[0]) } - adjustedThrough := r.through.Add(-grps[i].Offset) + adjustedThrough := end.Add(-matcherGroups[i].Offset) - start := time.Now() - resp, err := r.handler.Do(r.ctx, &logproto.IndexStatsRequest{ + resp, err := statsHandler.Do(ctx, &logproto.IndexStatsRequest{ From: adjustedFrom, Through: adjustedThrough, Matchers: matchers, @@ -105,31 +104,55 @@ func (r *dynamicShardResolver) Shards(e syntax.Expr) (int, error) { return fmt.Errorf("expected *IndexStatsResponse while querying index, got %T", resp) } - results = append(results, casted.Response) - level.Debug(sp).Log( + results[i] = casted.Response + + level.Debug(logger).Log( append( casted.Response.LoggingKeyValues(), "msg", "queried index", "type", "single", "matchers", matchers, - "duration", time.Since(start), + "duration", time.Since(startTime), "from", adjustedFrom.Time(), "through", adjustedThrough.Time(), "length", adjustedThrough.Sub(adjustedFrom), )..., ) + return nil }); err != nil { - return 0, err + return nil, err } - combined := stats.MergeStats(results...) - factor := guessShardFactor(combined, r.maxShards) + return results, nil +} - var bytesPerShard = combined.Bytes - if factor > 0 { - bytesPerShard = combined.Bytes / uint64(factor) +func (r *dynamicShardResolver) GetStats(e syntax.Expr) (stats.Stats, error) { + sp, ctx := spanlogger.NewWithLogger(r.ctx, r.logger, "dynamicShardResolver.GetStats") + defer sp.Finish() + + start := time.Now() + + // We try to shard subtrees in the AST independently if possible, although + // nested binary expressions can make this difficult. In this case, + // we query the index stats for all matcher groups then sum the results. + grps, err := syntax.MatcherGroups(e) + if err != nil { + return stats.Stats{}, err } + + // If there are zero matchers groups, we'll inject one to query everything + if len(grps) == 0 { + grps = append(grps, syntax.MatcherRange{}) + } + + results, err := getStatsForMatchers(ctx, sp, r.handler, r.from, r.through, grps, r.maxParallelism, r.defaultLookback) + if err != nil { + return stats.Stats{}, err + } + + combined := stats.MergeStats(results...) + level.Debug(sp).Log( append( combined.LoggingKeyValues(), @@ -138,11 +161,37 @@ func (r *dynamicShardResolver) Shards(e syntax.Expr) (int, error) { "len", len(results), "max_parallelism", r.maxParallelism, "duration", time.Since(start), + )..., + ) + + return combined, nil +} + +func (r *dynamicShardResolver) Shards(e syntax.Expr) (int, uint64, error) { + sp, _ := spanlogger.NewWithLogger(r.ctx, r.logger, "dynamicShardResolver.Shards") + defer sp.Finish() + + combined, err := r.GetStats(e) + if err != nil { + return 0, 0, err + } + + factor := guessShardFactor(combined, r.maxShards) + + var bytesPerShard = combined.Bytes + if factor > 0 { + bytesPerShard = combined.Bytes / uint64(factor) + } + + level.Debug(sp).Log( + append( + combined.LoggingKeyValues(), + "msg", "Got shard factor", "factor", factor, "bytes_per_shard", strings.Replace(humanize.Bytes(bytesPerShard), " ", "", 1), )..., ) - return factor, nil + return factor, bytesPerShard, nil } const ( diff --git a/pkg/querier/queryrange/shard_resolver_test.go b/pkg/querier/queryrange/shard_resolver_test.go index 92fea236d5..148a0f3093 100644 --- a/pkg/querier/queryrange/shard_resolver_test.go +++ b/pkg/querier/queryrange/shard_resolver_test.go @@ -4,9 +4,8 @@ import ( "fmt" "testing" - "github.com/stretchr/testify/require" - "github.com/grafana/loki/pkg/storage/stores/index/stats" + "github.com/stretchr/testify/require" ) func TestGuessShardFactor(t *testing.T) { diff --git a/pkg/querier/queryrange/split_by_interval.go b/pkg/querier/queryrange/split_by_interval.go index aaae5ac842..0ce46bad8c 100644 --- a/pkg/querier/queryrange/split_by_interval.go +++ b/pkg/querier/queryrange/split_by_interval.go @@ -5,6 +5,7 @@ import ( "net/http" "time" + "github.com/grafana/loki/pkg/util/math" "github.com/opentracing/opentracing-go" otlog "github.com/opentracing/opentracing-go/log" "github.com/prometheus/client_golang/prometheus" @@ -60,6 +61,10 @@ type Splitter func(req queryrangebase.Request, interval time.Duration) ([]queryr // SplitByIntervalMiddleware creates a new Middleware that splits log requests by a given interval. func SplitByIntervalMiddleware(configs []config.PeriodConfig, limits Limits, merger queryrangebase.Merger, splitter Splitter, metrics *SplitByMetrics) queryrangebase.Middleware { + if metrics == nil { + metrics = NewSplitByMetrics(nil) + } + return queryrangebase.MiddlewareFunc(func(next queryrangebase.Handler) queryrangebase.Handler { return &splitByInterval{ configs: configs, @@ -109,8 +114,9 @@ func (h *splitByInterval) Process( unlimited = true } + // Parallelism will be at least 1 + p := math.Max(parallelism, 1) // don't spawn unnecessary goroutines - p := parallelism if len(input) < parallelism { p = len(input) } @@ -181,6 +187,7 @@ func (h *splitByInterval) Do(ctx context.Context, r queryrangebase.Request) (que if err != nil { return nil, err } + h.metrics.splits.Observe(float64(len(intervals))) // no interval should not be processed by the frontend. @@ -205,8 +212,8 @@ func (h *splitByInterval) Do(ctx context.Context, r queryrangebase.Request) (que intervals[i], intervals[j] = intervals[j], intervals[i] } } - case *LokiSeriesRequest, *LokiLabelNamesRequest: - // Set this to 0 since this is not used in Series/Labels Request. + case *LokiSeriesRequest, *LokiLabelNamesRequest, *logproto.IndexStatsRequest: + // Set this to 0 since this is not used in Series/Labels/Index Request. limit = 0 default: return nil, httpgrpc.Errorf(http.StatusBadRequest, "unknown request type") @@ -271,6 +278,16 @@ func splitByTime(req queryrangebase.Request, interval time.Duration) ([]queryran EndTs: end, }) }) + case *logproto.IndexStatsRequest: + startTS := model.Time(r.GetStart()).Time() + endTS := model.Time(r.GetEnd()).Time() + util.ForInterval(interval, startTS, endTS, true, func(start, end time.Time) { + reqs = append(reqs, &logproto.IndexStatsRequest{ + From: model.TimeFromUnix(start.Unix()), + Through: model.TimeFromUnix(end.Unix()), + Matchers: r.GetMatchers(), + }) + }) default: return nil, nil } diff --git a/pkg/querier/queryrange/split_by_interval_test.go b/pkg/querier/queryrange/split_by_interval_test.go index a6d7044822..5af1a3b50b 100644 --- a/pkg/querier/queryrange/split_by_interval_test.go +++ b/pkg/querier/queryrange/split_by_interval_test.go @@ -37,6 +37,21 @@ var testSchemas = func() []config.PeriodConfig { return confs }() +var testSchemasTSDB = func() []config.PeriodConfig { + confS := ` +- from: "1950-01-01" + store: tsdb + object_store: gcs + schema: v12 +` + + var confs []config.PeriodConfig + if err := yaml.Unmarshal([]byte(confS), &confs); err != nil { + panic(err) + } + return confs +}() + func Test_splitQuery(t *testing.T) { buildLokiRequest := func(start, end time.Time) queryrangebase.Request { return &LokiRequest{ diff --git a/pkg/util/querylimits/limiter.go b/pkg/util/querylimits/limiter.go index 82340ede06..9ce84a19fb 100644 --- a/pkg/util/querylimits/limiter.go +++ b/pkg/util/querylimits/limiter.go @@ -92,3 +92,13 @@ func (l *Limiter) RequiredLabels(ctx context.Context, userID string) []string { } return union } + +func (l *Limiter) MaxQueryBytesRead(ctx context.Context, userID string) int { + original := l.CombinedLimits.MaxQueryBytesRead(ctx, userID) + requestLimits := ExtractQueryLimitsContext(ctx) + if requestLimits == nil || requestLimits.MaxQueryBytesRead.Val() == 0 || requestLimits.MaxQueryBytesRead.Val() > original { + level.Debug(logutil.WithContext(ctx, l.logger)).Log("msg", "using original limit") + return original + } + return requestLimits.MaxQueryBytesRead.Val() +} diff --git a/pkg/util/querylimits/limiter_test.go b/pkg/util/querylimits/limiter_test.go index fcb825913b..77c9f93fcc 100644 --- a/pkg/util/querylimits/limiter_test.go +++ b/pkg/util/querylimits/limiter_test.go @@ -39,6 +39,8 @@ func TestLimiter_Defaults(t *testing.T) { MaxQueryLength: model.Duration(30 * time.Second), MaxEntriesLimitPerQuery: 10, RequiredLabels: []string{"foo", "bar"}, + MaxQueryBytesRead: 10, + MaxQuerierBytesRead: 10, } overrides, _ := validation.NewOverrides(validation.Limits{}, newMockTenantLimits(tLimits)) @@ -49,6 +51,7 @@ func TestLimiter_Defaults(t *testing.T) { MaxQueryLookback: model.Duration(30 * time.Second), MaxEntriesLimitPerQuery: 10, QueryTimeout: model.Duration(30 * time.Second), + MaxQueryBytesRead: 10, RequiredLabels: []string{"foo", "bar"}, } ctx := context.Background() @@ -60,6 +63,8 @@ func TestLimiter_Defaults(t *testing.T) { require.Equal(t, expectedLimits.MaxEntriesLimitPerQuery, maxEntries) queryTimeout := l.QueryTimeout(ctx, "fake") require.Equal(t, time.Duration(expectedLimits.QueryTimeout), queryTimeout) + maxQueryBytesRead := l.MaxQueryBytesRead(ctx, "fake") + require.Equal(t, expectedLimits.MaxQueryBytesRead.Val(), maxQueryBytesRead) var limits QueryLimits @@ -69,6 +74,7 @@ func TestLimiter_Defaults(t *testing.T) { MaxEntriesLimitPerQuery: 10, QueryTimeout: model.Duration(29 * time.Second), RequiredLabels: []string{"foo", "bar"}, + MaxQueryBytesRead: 10, } { ctx2 := InjectQueryLimitsContext(context.Background(), limits) @@ -80,6 +86,8 @@ func TestLimiter_Defaults(t *testing.T) { require.Equal(t, expectedLimits2.MaxEntriesLimitPerQuery, maxEntries) queryTimeout := l.QueryTimeout(ctx2, "fake") require.Equal(t, time.Duration(expectedLimits.QueryTimeout), queryTimeout) + maxQueryBytesRead := l.MaxQueryBytesRead(ctx2, "fake") + require.Equal(t, expectedLimits2.MaxQueryBytesRead.Val(), maxQueryBytesRead) } } @@ -92,6 +100,8 @@ func TestLimiter_RejectHighLimits(t *testing.T) { MaxQueryLength: model.Duration(30 * time.Second), MaxEntriesLimitPerQuery: 10, QueryTimeout: model.Duration(30 * time.Second), + MaxQueryBytesRead: 10, + MaxQuerierBytesRead: 10, } overrides, _ := validation.NewOverrides(validation.Limits{}, newMockTenantLimits(tLimits)) @@ -101,12 +111,14 @@ func TestLimiter_RejectHighLimits(t *testing.T) { MaxQueryLookback: model.Duration(14 * 24 * time.Hour), MaxEntriesLimitPerQuery: 100, QueryTimeout: model.Duration(100 * time.Second), + MaxQueryBytesRead: 100, } expectedLimits := QueryLimits{ MaxQueryLength: model.Duration(30 * time.Second), MaxQueryLookback: model.Duration(30 * time.Second), MaxEntriesLimitPerQuery: 10, QueryTimeout: model.Duration(30 * time.Second), + MaxQueryBytesRead: 10, } ctx := InjectQueryLimitsContext(context.Background(), limits) @@ -114,6 +126,7 @@ func TestLimiter_RejectHighLimits(t *testing.T) { require.Equal(t, time.Duration(expectedLimits.MaxQueryLength), l.MaxQueryLength(ctx, "fake")) require.Equal(t, expectedLimits.MaxEntriesLimitPerQuery, l.MaxEntriesLimitPerQuery(ctx, "fake")) require.Equal(t, time.Duration(expectedLimits.QueryTimeout), l.QueryTimeout(ctx, "fake")) + require.Equal(t, expectedLimits.MaxQueryBytesRead.Val(), l.MaxQueryBytesRead(ctx, "fake")) } func TestLimiter_AcceptLowerLimits(t *testing.T) { @@ -124,6 +137,8 @@ func TestLimiter_AcceptLowerLimits(t *testing.T) { MaxQueryLength: model.Duration(30 * time.Second), MaxEntriesLimitPerQuery: 10, QueryTimeout: model.Duration(30 * time.Second), + MaxQueryBytesRead: 10, + MaxQuerierBytesRead: 10, } overrides, _ := validation.NewOverrides(validation.Limits{}, newMockTenantLimits(tLimits)) @@ -133,6 +148,7 @@ func TestLimiter_AcceptLowerLimits(t *testing.T) { MaxQueryLookback: model.Duration(29 * time.Second), MaxEntriesLimitPerQuery: 9, QueryTimeout: model.Duration(29 * time.Second), + MaxQueryBytesRead: 9, } ctx := InjectQueryLimitsContext(context.Background(), limits) @@ -140,6 +156,7 @@ func TestLimiter_AcceptLowerLimits(t *testing.T) { require.Equal(t, time.Duration(limits.MaxQueryLength), l.MaxQueryLength(ctx, "fake")) require.Equal(t, limits.MaxEntriesLimitPerQuery, l.MaxEntriesLimitPerQuery(ctx, "fake")) require.Equal(t, time.Duration(limits.QueryTimeout), l.QueryTimeout(ctx, "fake")) + require.Equal(t, limits.MaxQueryBytesRead.Val(), l.MaxQueryBytesRead(ctx, "fake")) } func TestLimiter_MergeLimits(t *testing.T) { diff --git a/pkg/util/querylimits/middleware_test.go b/pkg/util/querylimits/middleware_test.go index 5b8293f9db..b2ead1b740 100644 --- a/pkg/util/querylimits/middleware_test.go +++ b/pkg/util/querylimits/middleware_test.go @@ -32,6 +32,7 @@ func Test_MiddlewareWithHeader(t *testing.T) { 1, model.Duration(1 * time.Second), []string{"foo", "bar"}, + 10, } nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/util/querylimits/propagation.go b/pkg/util/querylimits/propagation.go index 9eea2db94e..94f720c598 100644 --- a/pkg/util/querylimits/propagation.go +++ b/pkg/util/querylimits/propagation.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" + "github.com/grafana/loki/pkg/util/flagext" "github.com/prometheus/common/model" ) @@ -25,6 +26,7 @@ type QueryLimits struct { MaxEntriesLimitPerQuery int `json:"maxEntriesLimitPerQuery,omitempty"` QueryTimeout model.Duration `json:"queryTimeout,omitempty"` RequiredLabels []string `json:"requiredLabels,omitempty"` + MaxQueryBytesRead flagext.ByteSize `json:"maxQueryBytesRead,omitempty"` } func UnmarshalQueryLimits(data []byte) (*QueryLimits, error) { diff --git a/pkg/util/querylimits/propagation_test.go b/pkg/util/querylimits/propagation_test.go index a80faf9464..b5fbdca834 100644 --- a/pkg/util/querylimits/propagation_test.go +++ b/pkg/util/querylimits/propagation_test.go @@ -25,13 +25,14 @@ func TestInjectAndExtractQueryLimits(t *testing.T) { func TestDeserializingQueryLimits(t *testing.T) { // full limits - payload := `{"maxEntriesLimitPerQuery": 100, "maxQueryLength": "2d", "maxQueryLookback": "2w", "queryTimeout": "5s"}` + payload := `{"maxEntriesLimitPerQuery": 100, "maxQueryLength": "2d", "maxQueryLookback": "2w", "queryTimeout": "5s", "maxQueryBytesRead": "1MB", "maxQuerierBytesRead": "1MB"}` limits, err := UnmarshalQueryLimits([]byte(payload)) require.NoError(t, err) require.Equal(t, model.Duration(2*24*time.Hour), limits.MaxQueryLength) require.Equal(t, model.Duration(14*24*time.Hour), limits.MaxQueryLookback) require.Equal(t, model.Duration(5*time.Second), limits.QueryTimeout) require.Equal(t, 100, limits.MaxEntriesLimitPerQuery) + require.Equal(t, 1*1024*1024, limits.MaxQueryBytesRead.Val()) // some limits are empty payload = `{"maxQueryLength":"1h"}` limits, err = UnmarshalQueryLimits([]byte(payload)) @@ -39,6 +40,7 @@ func TestDeserializingQueryLimits(t *testing.T) { require.Equal(t, model.Duration(3600000000000), limits.MaxQueryLength) require.Equal(t, model.Duration(0), limits.MaxQueryLookback) require.Equal(t, 0, limits.MaxEntriesLimitPerQuery) + require.Equal(t, 0, limits.MaxQueryBytesRead.Val()) } func TestSerializingQueryLimits(t *testing.T) { @@ -48,11 +50,12 @@ func TestSerializingQueryLimits(t *testing.T) { MaxQueryLookback: model.Duration(14 * 24 * time.Hour), MaxEntriesLimitPerQuery: 100, QueryTimeout: model.Duration(5 * time.Second), + MaxQueryBytesRead: 1 * 1024 * 1024, } actual, err := MarshalQueryLimits(&limits) require.NoError(t, err) - expected := `{"maxEntriesLimitPerQuery": 100, "maxQueryLength": "2d", "maxQueryLookback": "2w", "queryTimeout": "5s"}` + expected := `{"maxEntriesLimitPerQuery": 100, "maxQueryLength": "2d", "maxQueryLookback": "2w", "queryTimeout": "5s", "maxQueryBytesRead": "1MB"}` require.JSONEq(t, expected, string(actual)) // some limits are empty diff --git a/pkg/validation/limits.go b/pkg/validation/limits.go index 17d1434304..fe795fba42 100644 --- a/pkg/validation/limits.go +++ b/pkg/validation/limits.go @@ -97,8 +97,10 @@ type Limits struct { QueryTimeout model.Duration `yaml:"query_timeout" json:"query_timeout"` // Query frontend enforced limits. The default is actually parameterized by the queryrange config. - QuerySplitDuration model.Duration `yaml:"split_queries_by_interval" json:"split_queries_by_interval"` - MinShardingLookback model.Duration `yaml:"min_sharding_lookback" json:"min_sharding_lookback"` + QuerySplitDuration model.Duration `yaml:"split_queries_by_interval" json:"split_queries_by_interval"` + MinShardingLookback model.Duration `yaml:"min_sharding_lookback" json:"min_sharding_lookback"` + MaxQueryBytesRead flagext.ByteSize `yaml:"max_query_bytes_read" json:"max_query_bytes_read"` + MaxQuerierBytesRead flagext.ByteSize `yaml:"max_querier_bytes_read" json:"max_querier_bytes_read"` // Ruler defaults and limits. @@ -232,6 +234,9 @@ func (l *Limits) RegisterFlags(f *flag.FlagSet) { _ = l.MinShardingLookback.Set("0s") f.Var(&l.MinShardingLookback, "frontend.min-sharding-lookback", "Limit queries that can be sharded. Queries within the time range of now and now minus this sharding lookback are not sharded. The default value of 0s disables the lookback, causing sharding of all queries at all times.") + f.Var(&l.MaxQueryBytesRead, "frontend.max-query-bytes-read", "Max number of bytes a query can fetch. Enforced in log and metric queries only when TSDB is used. The default value of 0 disables this limit.") + f.Var(&l.MaxQuerierBytesRead, "frontend.max-querier-bytes-read", "Max number of bytes a query can fetch after splitting and sharding. Enforced in log and metric queries only when TSDB is used. The default value of 0 disables this limit.") + _ = l.MaxCacheFreshness.Set("1m") f.Var(&l.MaxCacheFreshness, "frontend.max-cache-freshness", "Most recent allowed cacheable result per-tenant, to prevent caching very recent results that might still be in flux.") @@ -483,6 +488,16 @@ func (o *Overrides) QuerySplitDuration(userID string) time.Duration { return time.Duration(o.getOverridesForUser(userID).QuerySplitDuration) } +// MaxQueryBytesRead returns the maximum bytes a query can read. +func (o *Overrides) MaxQueryBytesRead(_ context.Context, userID string) int { + return o.getOverridesForUser(userID).MaxQueryBytesRead.Val() +} + +// MaxQuerierBytesRead returns the maximum bytes a sub query can read after splitting and sharding. +func (o *Overrides) MaxQuerierBytesRead(_ context.Context, userID string) int { + return o.getOverridesForUser(userID).MaxQuerierBytesRead.Val() +} + // MaxConcurrentTailRequests returns the limit to number of concurrent tail requests. func (o *Overrides) MaxConcurrentTailRequests(ctx context.Context, userID string) int { return o.getOverridesForUser(userID).MaxConcurrentTailRequests