guard against divide by 0 when splitting parallelism (#7831)

**What this PR does / why we need it**:

**Which issue(s) this PR fixes**:
We saw a spike in divide by zero panics in the code introduced in #7769.
I was able to reproduce this error via a test that calculates
`WeightedParallelism` with a start that's after the end. Not sure if
this is possible, but we definitely saw this happening in our ops
environment, so something is causing it, and the fix should guard
against it in any case.

**Special notes for your reviewer**:

**Checklist**
- [X] Tests updated

Co-authored-by: Sandeep Sukhani <sandeep.d.sukhani@gmail.com>
pull/7833/head
Trevor Whitney 3 years ago committed by GitHub
parent ea9ad336dc
commit 37b1c0fce0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 36
      pkg/querier/queryrange/limits.go
  2. 58
      pkg/querier/queryrange/limits_test.go
  3. 6
      pkg/querier/queryrange/queryrangebase/results_cache.go
  4. 12
      pkg/querier/queryrange/queryrangebase/results_cache_test.go
  5. 4
      pkg/querier/queryrange/querysharding.go
  6. 4
      pkg/querier/queryrange/roundtrip.go
  7. 2
      pkg/querier/queryrange/split_by_interval.go

@ -9,19 +9,19 @@ import (
"time" "time"
"github.com/go-kit/log/level" "github.com/go-kit/log/level"
"github.com/grafana/dskit/tenant"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
"github.com/prometheus/common/model" "github.com/prometheus/common/model"
"github.com/prometheus/prometheus/model/timestamp" "github.com/prometheus/prometheus/model/timestamp"
"github.com/weaveworks/common/httpgrpc" "github.com/weaveworks/common/httpgrpc"
"github.com/weaveworks/common/user" "github.com/weaveworks/common/user"
"github.com/grafana/dskit/tenant"
"github.com/grafana/loki/pkg/logproto" "github.com/grafana/loki/pkg/logproto"
"github.com/grafana/loki/pkg/logql" "github.com/grafana/loki/pkg/logql"
"github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase"
"github.com/grafana/loki/pkg/storage/config" "github.com/grafana/loki/pkg/storage/config"
"github.com/grafana/loki/pkg/util" "github.com/grafana/loki/pkg/util"
util_log "github.com/grafana/loki/pkg/util/log"
"github.com/grafana/loki/pkg/util/spanlogger" "github.com/grafana/loki/pkg/util/spanlogger"
"github.com/grafana/loki/pkg/util/validation" "github.com/grafana/loki/pkg/util/validation"
) )
@ -281,6 +281,7 @@ func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error)
} }
parallelism := MinWeightedParallelism( parallelism := MinWeightedParallelism(
ctx,
tenantIDs, tenantIDs,
rt.configs, rt.configs,
rt.limits, rt.limits,
@ -358,11 +359,27 @@ func (rt limitedRoundTripper) do(ctx context.Context, r queryrangebase.Request)
// the resulting parallelism will be // the resulting parallelism will be
// 0.5 * 10 + 0.5 * 100 = 60 // 0.5 * 10 + 0.5 * 100 = 60
func WeightedParallelism( func WeightedParallelism(
ctx context.Context,
configs []config.PeriodConfig, configs []config.PeriodConfig,
user string, user string,
l Limits, l Limits,
start, end model.Time, start, end model.Time,
) int { ) int {
logger := util_log.WithContext(ctx, util_log.Logger)
tsdbMaxQueryParallelism := l.TSDBMaxQueryParallelism(user)
regMaxQueryParallelism := l.MaxQueryParallelism(user)
if tsdbMaxQueryParallelism+regMaxQueryParallelism == 0 {
level.Info(logger).Log("msg", "querying disabled for tenant")
return 0
}
// query end before start would anyways error out so just short circuit and return 1
if end < start {
level.Warn(logger).Log("msg", "query end time before start, letting downstream code handle it gracefully", "start", start, "end", end)
return 1
}
// Return first index of desired period configs // Return first index of desired period configs
i := sort.Search(len(configs), func(i int) bool { i := sort.Search(len(configs), func(i int) bool {
// return true when there is no overlap with query & current // return true when there is no overlap with query & current
@ -419,8 +436,14 @@ func WeightedParallelism(
} }
totalDur := int(tsdbDur + otherDur) totalDur := int(tsdbDur + otherDur)
tsdbMaxQueryParallelism := l.TSDBMaxQueryParallelism(user) // If totalDur is 0, the query likely does not overlap any of the schema configs so just use parallelism of 1 and
regMaxQueryParallelism := l.MaxQueryParallelism(user) // let the downstream code handle it.
if totalDur == 0 {
level.Warn(logger).Log("msg", "could not determine query overlaps on tsdb vs non-tsdb schemas, likely due to query not overlapping any of the schema configs,"+
"letting downstream code handle it gracefully", "start", start, "end", end)
return 1
}
tsdbPart := int(tsdbDur) * tsdbMaxQueryParallelism / totalDur tsdbPart := int(tsdbDur) * tsdbMaxQueryParallelism / totalDur
regPart := int(otherDur) * regMaxQueryParallelism / totalDur regPart := int(otherDur) * regMaxQueryParallelism / totalDur
@ -435,8 +458,8 @@ func WeightedParallelism(
if (tsdbMaxQueryParallelism > 0 && tsdbDur > 0) || (regMaxQueryParallelism > 0 && otherDur > 0) { if (tsdbMaxQueryParallelism > 0 && tsdbDur > 0) || (regMaxQueryParallelism > 0 && otherDur > 0) {
return 1 return 1
} }
return 0
return 0
} }
func minMaxModelTime(a, b model.Time) (min, max model.Time) { func minMaxModelTime(a, b model.Time) (min, max model.Time) {
@ -446,9 +469,10 @@ func minMaxModelTime(a, b model.Time) (min, max model.Time) {
return b, a return b, a
} }
func MinWeightedParallelism(tenantIDs []string, configs []config.PeriodConfig, l Limits, start, end model.Time) int { func MinWeightedParallelism(ctx context.Context, tenantIDs []string, configs []config.PeriodConfig, l Limits, start, end model.Time) int {
return validation.SmallestPositiveIntPerTenant(tenantIDs, func(user string) int { return validation.SmallestPositiveIntPerTenant(tenantIDs, func(user string) int {
return WeightedParallelism( return WeightedParallelism(
ctx,
configs, configs,
user, user,
l, l,

@ -370,9 +370,65 @@ func Test_WeightedParallelism(t *testing.T) {
}, },
} { } {
t.Run(cfgs.desc+tc.desc, func(t *testing.T) { t.Run(cfgs.desc+tc.desc, func(t *testing.T) {
require.Equal(t, tc.exp, WeightedParallelism(confs, "fake", limits, tc.start, tc.end)) require.Equal(t, tc.exp, WeightedParallelism(context.Background(), confs, "fake", limits, tc.start, tc.end))
}) })
} }
} }
} }
func Test_WeightedParallelism_DivideByZeroError(t *testing.T) {
t.Run("query end before start", func(t *testing.T) {
parsed, err := time.Parse("2006-01-02", "2022-01-02")
require.NoError(t, err)
borderTime := model.TimeFromUnix(parsed.Unix())
confs := []config.PeriodConfig{
{
From: config.DayTime{
Time: borderTime.Add(-1 * time.Hour),
},
IndexType: config.TSDBType,
},
}
result := WeightedParallelism(context.Background(), confs, "fake", &fakeLimits{tsdbMaxQueryParallelism: 50}, borderTime, borderTime.Add(-1*time.Hour))
require.Equal(t, 1, result)
})
t.Run("negative start and end time", func(t *testing.T) {
parsed, err := time.Parse("2006-01-02", "2022-01-02")
require.NoError(t, err)
borderTime := model.TimeFromUnix(parsed.Unix())
confs := []config.PeriodConfig{
{
From: config.DayTime{
Time: borderTime.Add(-1 * time.Hour),
},
IndexType: config.TSDBType,
},
}
result := WeightedParallelism(context.Background(), confs, "fake", &fakeLimits{maxQueryParallelism: 50}, -100, -50)
require.Equal(t, 1, result)
})
t.Run("query start and end time before config start", func(t *testing.T) {
parsed, err := time.Parse("2006-01-02", "2022-01-02")
require.NoError(t, err)
borderTime := model.TimeFromUnix(parsed.Unix())
confs := []config.PeriodConfig{
{
From: config.DayTime{
Time: borderTime.Add(-1 * time.Hour),
},
IndexType: config.TSDBType,
},
}
result := WeightedParallelism(context.Background(), confs, "fake", &fakeLimits{maxQueryParallelism: 50}, confs[0].From.Add(-24*time.Hour), confs[0].From.Add(-12*time.Hour))
require.Equal(t, 1, result)
})
}

@ -164,7 +164,7 @@ type resultsCache struct {
merger Merger merger Merger
cacheGenNumberLoader CacheGenNumberLoader cacheGenNumberLoader CacheGenNumberLoader
shouldCache ShouldCacheFn shouldCache ShouldCacheFn
parallelismForReq func(tenantIDs []string, r Request) int parallelismForReq func(ctx context.Context, tenantIDs []string, r Request) int
retentionEnabled bool retentionEnabled bool
metrics *ResultsCacheMetrics metrics *ResultsCacheMetrics
} }
@ -184,7 +184,7 @@ func NewResultsCacheMiddleware(
extractor Extractor, extractor Extractor,
cacheGenNumberLoader CacheGenNumberLoader, cacheGenNumberLoader CacheGenNumberLoader,
shouldCache ShouldCacheFn, shouldCache ShouldCacheFn,
parallelismForReq func(tenantIDs []string, r Request) int, parallelismForReq func(ctx context.Context, tenantIDs []string, r Request) int,
retentionEnabled bool, retentionEnabled bool,
metrics *ResultsCacheMetrics, metrics *ResultsCacheMetrics,
) (Middleware, error) { ) (Middleware, error) {
@ -410,7 +410,7 @@ func (s resultsCache) handleHit(ctx context.Context, r Request, extents []Extent
if err != nil { if err != nil {
return nil, nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) return nil, nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error())
} }
reqResps, err = DoRequests(ctx, s.next, requests, s.parallelismForReq(tenantIDs, r)) reqResps, err = DoRequests(ctx, s.next, requests, s.parallelismForReq(ctx, tenantIDs, r))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

@ -731,7 +731,7 @@ func TestHandleHit(t *testing.T) {
minCacheExtent: 10, minCacheExtent: 10,
limits: mockLimits{}, limits: mockLimits{},
merger: PrometheusCodec, merger: PrometheusCodec,
parallelismForReq: func(tenantIDs []string, r Request) int { return 1 }, parallelismForReq: func(_ context.Context, tenantIDs []string, r Request) int { return 1 },
next: HandlerFunc(func(_ context.Context, req Request) (Response, error) { next: HandlerFunc(func(_ context.Context, req Request) (Response, error) {
return mkAPIResponse(req.GetStart(), req.GetEnd(), req.GetStep()), nil return mkAPIResponse(req.GetStart(), req.GetEnd(), req.GetStep()), nil
}), }),
@ -766,7 +766,7 @@ func TestResultsCache(t *testing.T) {
PrometheusResponseExtractor{}, PrometheusResponseExtractor{},
nil, nil,
nil, nil,
func(tenantIDs []string, r Request) int { func(_ context.Context, tenantIDs []string, r Request) int {
return mockLimits{}.MaxQueryParallelism("fake") return mockLimits{}.MaxQueryParallelism("fake")
}, },
false, false,
@ -812,7 +812,7 @@ func TestResultsCacheRecent(t *testing.T) {
PrometheusResponseExtractor{}, PrometheusResponseExtractor{},
nil, nil,
nil, nil,
func(tenantIDs []string, r Request) int { func(_ context.Context, tenantIDs []string, r Request) int {
return mockLimits{}.MaxQueryParallelism("fake") return mockLimits{}.MaxQueryParallelism("fake")
}, },
false, false,
@ -880,7 +880,7 @@ func TestResultsCacheMaxFreshness(t *testing.T) {
PrometheusResponseExtractor{}, PrometheusResponseExtractor{},
nil, nil,
nil, nil,
func(tenantIDs []string, r Request) int { func(_ context.Context, tenantIDs []string, r Request) int {
return tc.fakeLimits.MaxQueryParallelism("fake") return tc.fakeLimits.MaxQueryParallelism("fake")
}, },
false, false,
@ -923,7 +923,7 @@ func Test_resultsCache_MissingData(t *testing.T) {
PrometheusResponseExtractor{}, PrometheusResponseExtractor{},
nil, nil,
nil, nil,
func(tenantIDs []string, r Request) int { func(_ context.Context, tenantIDs []string, r Request) int {
return mockLimits{}.MaxQueryParallelism("fake") return mockLimits{}.MaxQueryParallelism("fake")
}, },
false, false,
@ -1038,7 +1038,7 @@ func TestResultsCacheShouldCacheFunc(t *testing.T) {
PrometheusResponseExtractor{}, PrometheusResponseExtractor{},
nil, nil,
tc.shouldCache, tc.shouldCache,
func(tenantIDs []string, r Request) int { func(_ context.Context, tenantIDs []string, r Request) int {
return mockLimits{}.MaxQueryParallelism("fake") return mockLimits{}.MaxQueryParallelism("fake")
}, },
false, false,

@ -110,7 +110,7 @@ func (ast *astMapperware) Do(ctx context.Context, r queryrangebase.Request) (que
conf, conf,
ast.ng.Opts().MaxLookBackPeriod, ast.ng.Opts().MaxLookBackPeriod,
ast.logger, ast.logger,
MinWeightedParallelism(tenants, ast.confs, ast.limits, model.Time(r.GetStart()), model.Time(r.GetEnd())), MinWeightedParallelism(ctx, tenants, ast.confs, ast.limits, model.Time(r.GetStart()), model.Time(r.GetEnd())),
r, r,
ast.next, ast.next,
) )
@ -362,7 +362,7 @@ func (ss *seriesShardingHandler) Do(ctx context.Context, r queryrangebase.Reques
ctx, ctx,
ss.next, ss.next,
requests, requests,
MinWeightedParallelism(tenantIDs, ss.confs, ss.limits, model.Time(req.GetStart()), model.Time(req.GetEnd())), MinWeightedParallelism(ctx, tenantIDs, ss.confs, ss.limits, model.Time(req.GetStart()), model.Time(req.GetEnd())),
) )
if err != nil { if err != nil {
return nil, err return nil, err

@ -1,6 +1,7 @@
package queryrange package queryrange
import ( import (
"context"
"flag" "flag"
"net/http" "net/http"
"strings" "strings"
@ -431,8 +432,9 @@ func NewMetricTripperware(
func(r queryrangebase.Request) bool { func(r queryrangebase.Request) bool {
return !r.GetCachingOptions().Disabled return !r.GetCachingOptions().Disabled
}, },
func(tenantIDs []string, r queryrangebase.Request) int { func(ctx context.Context, tenantIDs []string, r queryrangebase.Request) int {
return MinWeightedParallelism( return MinWeightedParallelism(
ctx,
tenantIDs, tenantIDs,
schema.Configs, schema.Configs,
limits, limits,

@ -221,7 +221,7 @@ func (h *splitByInterval) Do(ctx context.Context, r queryrangebase.Request) (que
} }
maxSeries := validation.SmallestPositiveIntPerTenant(tenantIDs, h.limits.MaxQuerySeries) maxSeries := validation.SmallestPositiveIntPerTenant(tenantIDs, h.limits.MaxQuerySeries)
maxParallelism := MinWeightedParallelism(tenantIDs, h.configs, h.limits, model.Time(r.GetStart()), model.Time(r.GetEnd())) maxParallelism := MinWeightedParallelism(ctx, tenantIDs, h.configs, h.limits, model.Time(r.GetStart()), model.Time(r.GetEnd()))
resps, err := h.Process(ctx, maxParallelism, limit, input, maxSeries) resps, err := h.Process(ctx, maxParallelism, limit, input, maxSeries)
if err != nil { if err != nil {
return nil, err return nil, err

Loading…
Cancel
Save