mirror of https://github.com/grafana/loki
Extract results cache into new pkg (#11343)
**What this PR does / why we need it**: This extracts the results cache from `queryrangebase` into its own pkg so we can reuse it in other components such as the bloom-gateway without having to import `queryrangebase`. - Most of the logic inside `pkg/querier/queryrange/queryrangebase/results_cache.go` now lives in `pkg/storage/chunk/cache/results_cache/cache.go`. - Some of the tests in `pkg/querier/queryrange/queryrangebase/results_cache.go` are moved into pkg/storage/chunk/cache/results_cache/cache_test.go. - Note that here we don't have access to the types we use in `queryrangebase` so we created a new set of mock request/response types to test with.pull/11379/head
parent
0e433f304e
commit
489ac8d529
@ -0,0 +1,467 @@ |
||||
package resultscache |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"net/http" |
||||
"sort" |
||||
"time" |
||||
|
||||
"github.com/go-kit/log" |
||||
"github.com/go-kit/log/level" |
||||
"github.com/gogo/protobuf/proto" |
||||
"github.com/gogo/protobuf/types" |
||||
"github.com/grafana/dskit/httpgrpc" |
||||
"github.com/opentracing/opentracing-go" |
||||
otlog "github.com/opentracing/opentracing-go/log" |
||||
"github.com/prometheus/common/model" |
||||
"github.com/uber/jaeger-client-go" |
||||
|
||||
"github.com/grafana/dskit/tenant" |
||||
|
||||
"github.com/grafana/loki/pkg/storage/chunk/cache" |
||||
"github.com/grafana/loki/pkg/util/math" |
||||
"github.com/grafana/loki/pkg/util/spanlogger" |
||||
"github.com/grafana/loki/pkg/util/validation" |
||||
) |
||||
|
||||
// ConstSplitter is a utility for using a constant split interval when determining cache keys
|
||||
type ConstSplitter time.Duration |
||||
|
||||
// GenerateCacheKey generates a cache key based on the userID, Request and interval.
|
||||
func (t ConstSplitter) GenerateCacheKey(_ context.Context, userID string, r Request) string { |
||||
currentInterval := r.GetStart().UnixMilli() / int64(time.Duration(t)/time.Millisecond) |
||||
return fmt.Sprintf("%s:%s:%d:%d", userID, r.GetQuery(), r.GetStep(), currentInterval) |
||||
} |
||||
|
||||
// ShouldCacheReqFn checks whether the current request should go to cache or not.
|
||||
// If not, just send the request to next handler.
|
||||
type ShouldCacheReqFn func(ctx context.Context, r Request) bool |
||||
|
||||
// ShouldCacheResFn checks whether the current response should go to cache or not.
|
||||
type ShouldCacheResFn func(ctx context.Context, r Request, res Response, maxCacheTime int64) bool |
||||
|
||||
// ParallelismForReqFn returns the parallelism for a given request.
|
||||
type ParallelismForReqFn func(ctx context.Context, tenantIDs []string, r Request) int |
||||
|
||||
type ResultsCache struct { |
||||
logger log.Logger |
||||
next Handler |
||||
cache cache.Cache |
||||
limits Limits |
||||
splitter KeyGenerator |
||||
cacheGenNumberLoader CacheGenNumberLoader |
||||
retentionEnabled bool |
||||
extractor Extractor |
||||
minCacheExtent int64 // discard any cache extent smaller than this
|
||||
merger ResponseMerger |
||||
shouldCacheReq ShouldCacheReqFn |
||||
shouldCacheRes ShouldCacheResFn |
||||
parallelismForReq func(ctx context.Context, tenantIDs []string, r Request) int |
||||
} |
||||
|
||||
// NewResultsCache creates results cache from config.
|
||||
// The middleware cache result using a unique cache key for a given request (step,query,user) and interval.
|
||||
// The cache assumes that each request length (end-start) is below or equal the interval.
|
||||
// Each request starting from within the same interval will hit the same cache entry.
|
||||
// If the cache doesn't have the entire duration of the request cached, it will query the uncached parts and append them to the cache entries.
|
||||
// see `generateKey`.
|
||||
func NewResultsCache( |
||||
logger log.Logger, |
||||
c cache.Cache, |
||||
next Handler, |
||||
keyGen KeyGenerator, |
||||
limits Limits, |
||||
merger ResponseMerger, |
||||
extractor Extractor, |
||||
shouldCacheReq ShouldCacheReqFn, |
||||
shouldCacheRes ShouldCacheResFn, |
||||
parallelismForReq func(ctx context.Context, tenantIDs []string, r Request) int, |
||||
cacheGenNumberLoader CacheGenNumberLoader, |
||||
retentionEnabled bool, |
||||
) *ResultsCache { |
||||
return &ResultsCache{ |
||||
logger: logger, |
||||
next: next, |
||||
cache: c, |
||||
limits: limits, |
||||
splitter: keyGen, |
||||
cacheGenNumberLoader: cacheGenNumberLoader, |
||||
retentionEnabled: retentionEnabled, |
||||
extractor: extractor, |
||||
minCacheExtent: (5 * time.Minute).Milliseconds(), |
||||
merger: merger, |
||||
shouldCacheReq: shouldCacheReq, |
||||
shouldCacheRes: shouldCacheRes, |
||||
parallelismForReq: parallelismForReq, |
||||
} |
||||
} |
||||
|
||||
func (s ResultsCache) Do(ctx context.Context, r Request) (Response, error) { |
||||
sp, ctx := opentracing.StartSpanFromContext(ctx, "resultsCache.Do") |
||||
defer sp.Finish() |
||||
tenantIDs, err := tenant.TenantIDs(ctx) |
||||
if err != nil { |
||||
return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) |
||||
} |
||||
|
||||
if s.shouldCacheReq != nil && !s.shouldCacheReq(ctx, r) { |
||||
return s.next.Do(ctx, r) |
||||
} |
||||
|
||||
if s.cacheGenNumberLoader != nil && s.retentionEnabled { |
||||
ctx = cache.InjectCacheGenNumber(ctx, s.cacheGenNumberLoader.GetResultsCacheGenNumber(tenantIDs)) |
||||
} |
||||
|
||||
var ( |
||||
key = s.splitter.GenerateCacheKey(ctx, tenant.JoinTenantIDs(tenantIDs), r) |
||||
extents []Extent |
||||
response Response |
||||
) |
||||
|
||||
sp.LogKV( |
||||
"query", r.GetQuery(), |
||||
"step", time.UnixMilli(r.GetStep()), |
||||
"start", r.GetStart(), |
||||
"end", r.GetEnd(), |
||||
"key", key, |
||||
) |
||||
|
||||
cacheFreshnessCapture := func(id string) time.Duration { return s.limits.MaxCacheFreshness(ctx, id) } |
||||
maxCacheFreshness := validation.MaxDurationPerTenant(tenantIDs, cacheFreshnessCapture) |
||||
maxCacheTime := int64(model.Now().Add(-maxCacheFreshness)) |
||||
if r.GetStart().UnixMilli() > maxCacheTime { |
||||
return s.next.Do(ctx, r) |
||||
} |
||||
|
||||
cached, ok := s.get(ctx, key) |
||||
if ok { |
||||
response, extents, err = s.handleHit(ctx, r, cached, maxCacheTime) |
||||
} else { |
||||
response, extents, err = s.handleMiss(ctx, r, maxCacheTime) |
||||
} |
||||
|
||||
if err == nil && len(extents) > 0 { |
||||
extents, err := s.filterRecentExtents(r, maxCacheFreshness, extents) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
s.put(ctx, key, extents) |
||||
} |
||||
|
||||
return response, err |
||||
} |
||||
|
||||
func (s ResultsCache) handleMiss(ctx context.Context, r Request, maxCacheTime int64) (Response, []Extent, error) { |
||||
response, err := s.next.Do(ctx, r) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
if !s.shouldCacheRes(ctx, r, response, maxCacheTime) { |
||||
return response, []Extent{}, nil |
||||
} |
||||
|
||||
extent, err := toExtent(ctx, r, response) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
extents := []Extent{ |
||||
extent, |
||||
} |
||||
return response, extents, nil |
||||
} |
||||
|
||||
func (s ResultsCache) handleHit(ctx context.Context, r Request, extents []Extent, maxCacheTime int64) (Response, []Extent, error) { |
||||
var ( |
||||
reqResps []RequestResponse |
||||
err error |
||||
) |
||||
sp, ctx := opentracing.StartSpanFromContext(ctx, "handleHit") |
||||
defer sp.Finish() |
||||
log := spanlogger.FromContext(ctx) |
||||
defer log.Finish() |
||||
|
||||
requests, responses, err := s.partition(r, extents) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
if len(requests) == 0 { |
||||
response, err := s.merger.MergeResponse(responses...) |
||||
// No downstream requests so no need to write back to the cache.
|
||||
return response, nil, err |
||||
} |
||||
|
||||
tenantIDs, err := tenant.TenantIDs(ctx) |
||||
if err != nil { |
||||
return nil, nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) |
||||
} |
||||
reqResps, err = DoRequests(ctx, s.next, requests, s.parallelismForReq(ctx, tenantIDs, r)) |
||||
|
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
for _, reqResp := range reqResps { |
||||
responses = append(responses, reqResp.Response) |
||||
if s.shouldCacheRes != nil && !s.shouldCacheRes(ctx, r, reqResp.Response, maxCacheTime) { |
||||
continue |
||||
} |
||||
extent, err := toExtent(ctx, reqResp.Request, reqResp.Response) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
extents = append(extents, extent) |
||||
} |
||||
sort.Slice(extents, func(i, j int) bool { |
||||
if extents[i].Start == extents[j].Start { |
||||
// as an optimization, for two extents starts at the same time, we
|
||||
// put bigger extent at the front of the slice, which helps
|
||||
// to reduce the amount of merge we have to do later.
|
||||
return extents[i].End > extents[j].End |
||||
} |
||||
|
||||
return extents[i].Start < extents[j].Start |
||||
}) |
||||
|
||||
// Merge any extents - potentially overlapping
|
||||
accumulator, err := newAccumulator(extents[0]) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
mergedExtents := make([]Extent, 0, len(extents)) |
||||
|
||||
for i := 1; i < len(extents); i++ { |
||||
if accumulator.End+r.GetStep() < extents[i].Start { |
||||
mergedExtents, err = merge(mergedExtents, accumulator) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
accumulator, err = newAccumulator(extents[i]) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
continue |
||||
} |
||||
|
||||
if accumulator.End >= extents[i].End { |
||||
continue |
||||
} |
||||
|
||||
accumulator.TraceId = jaegerTraceID(ctx) |
||||
accumulator.End = extents[i].End |
||||
currentRes, err := extents[i].toResponse() |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
merged, err := s.merger.MergeResponse(accumulator.Response, currentRes) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
accumulator.Response = merged |
||||
} |
||||
|
||||
mergedExtents, err = merge(mergedExtents, accumulator) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
response, err := s.merger.MergeResponse(responses...) |
||||
return response, mergedExtents, err |
||||
} |
||||
|
||||
type accumulator struct { |
||||
Response |
||||
Extent |
||||
} |
||||
|
||||
func merge(extents []Extent, acc *accumulator) ([]Extent, error) { |
||||
anyResp, err := types.MarshalAny(acc.Response) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return append(extents, Extent{ |
||||
Start: acc.Extent.Start, |
||||
End: acc.Extent.End, |
||||
Response: anyResp, |
||||
TraceId: acc.Extent.TraceId, |
||||
}), nil |
||||
} |
||||
|
||||
func newAccumulator(base Extent) (*accumulator, error) { |
||||
res, err := base.toResponse() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return &accumulator{ |
||||
Response: res, |
||||
Extent: base, |
||||
}, nil |
||||
} |
||||
|
||||
func toExtent(ctx context.Context, req Request, res Response) (Extent, error) { |
||||
anyResp, err := types.MarshalAny(res) |
||||
if err != nil { |
||||
return Extent{}, err |
||||
} |
||||
return Extent{ |
||||
Start: req.GetStart().UnixMilli(), |
||||
End: req.GetEnd().UnixMilli(), |
||||
Response: anyResp, |
||||
TraceId: jaegerTraceID(ctx), |
||||
}, nil |
||||
} |
||||
|
||||
// partition calculates the required requests to satisfy req given the cached data.
|
||||
// extents must be in order by start time.
|
||||
func (s ResultsCache) partition(req Request, extents []Extent) ([]Request, []Response, error) { |
||||
var requests []Request |
||||
var cachedResponses []Response |
||||
start := req.GetStart().UnixMilli() |
||||
end := req.GetEnd().UnixMilli() |
||||
|
||||
for _, extent := range extents { |
||||
// If there is no overlap, ignore this extent.
|
||||
if extent.GetEnd() < start || extent.GetStart() > end { |
||||
continue |
||||
} |
||||
|
||||
// If this extent is tiny and request is not tiny, discard it: more efficient to do a few larger queries.
|
||||
// Hopefully tiny request can make tiny extent into not-so-tiny extent.
|
||||
|
||||
// However if the step is large enough, the split_query_by_interval middleware would generate a query with same start and end.
|
||||
// For example, if the step size is more than 12h and the interval is 24h.
|
||||
// This means the extent's start and end time would be same, even if the timerange covers several hours.
|
||||
if (req.GetStart() != req.GetEnd()) && ((end - start) > s.minCacheExtent) && (extent.End-extent.Start < s.minCacheExtent) { |
||||
continue |
||||
} |
||||
|
||||
// If there is a bit missing at the front, make a request for that.
|
||||
if start < extent.Start { |
||||
r := req.WithStartEndForCache(time.UnixMilli(start), time.UnixMilli(extent.Start)) |
||||
requests = append(requests, r) |
||||
} |
||||
res, err := extent.toResponse() |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
// extract the overlap from the cached extent.
|
||||
cachedResponses = append(cachedResponses, s.extractor.Extract(start, end, res, extent.GetStart(), extent.GetEnd())) |
||||
start = extent.End |
||||
} |
||||
|
||||
// Lastly, make a request for any data missing at the end.
|
||||
if start < req.GetEnd().UnixMilli() { |
||||
r := req.WithStartEndForCache(time.UnixMilli(start), time.UnixMilli(end)) |
||||
requests = append(requests, r) |
||||
} |
||||
|
||||
// If start and end are the same (valid in promql), start == req.GetEnd() and we won't do the query.
|
||||
// But we should only do the request if we don't have a valid cached response for it.
|
||||
if req.GetStart() == req.GetEnd() && len(cachedResponses) == 0 { |
||||
requests = append(requests, req) |
||||
} |
||||
|
||||
return requests, cachedResponses, nil |
||||
} |
||||
|
||||
func (s ResultsCache) filterRecentExtents(req Request, maxCacheFreshness time.Duration, extents []Extent) ([]Extent, error) { |
||||
step := math.Max64(1, req.GetStep()) |
||||
maxCacheTime := (int64(model.Now().Add(-maxCacheFreshness)) / step) * step |
||||
for i := range extents { |
||||
// Never cache data for the latest freshness period.
|
||||
if extents[i].End > maxCacheTime { |
||||
extents[i].End = maxCacheTime |
||||
res, err := extents[i].toResponse() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
extracted := s.extractor.Extract(extents[i].GetStart(), maxCacheTime, res, extents[i].GetStart(), extents[i].GetEnd()) |
||||
anyResp, err := types.MarshalAny(extracted) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
extents[i].Response = anyResp |
||||
} |
||||
} |
||||
return extents, nil |
||||
} |
||||
|
||||
func (s ResultsCache) get(ctx context.Context, key string) ([]Extent, bool) { |
||||
found, bufs, _, _ := s.cache.Fetch(ctx, []string{cache.HashKey(key)}) |
||||
if len(found) != 1 { |
||||
return nil, false |
||||
} |
||||
|
||||
var resp CachedResponse |
||||
sp, ctx := opentracing.StartSpanFromContext(ctx, "unmarshal-extent") //nolint:ineffassign,staticcheck
|
||||
defer sp.Finish() |
||||
log := spanlogger.FromContext(ctx) |
||||
defer log.Finish() |
||||
|
||||
log.LogFields(otlog.Int("bytes", len(bufs[0]))) |
||||
|
||||
if err := proto.Unmarshal(bufs[0], &resp); err != nil { |
||||
level.Error(log).Log("msg", "error unmarshalling cached value", "err", err) |
||||
log.Error(err) |
||||
return nil, false |
||||
} |
||||
|
||||
if resp.Key != key { |
||||
return nil, false |
||||
} |
||||
|
||||
// Refreshes the cache if it contains an old proto schema.
|
||||
for _, e := range resp.Extents { |
||||
if e.Response == nil { |
||||
return nil, false |
||||
} |
||||
} |
||||
|
||||
return resp.Extents, true |
||||
} |
||||
|
||||
func (s ResultsCache) put(ctx context.Context, key string, extents []Extent) { |
||||
buf, err := proto.Marshal(&CachedResponse{ |
||||
Key: key, |
||||
Extents: extents, |
||||
}) |
||||
if err != nil { |
||||
level.Error(s.logger).Log("msg", "error marshalling cached value", "err", err) |
||||
return |
||||
} |
||||
|
||||
_ = s.cache.Store(ctx, []string{cache.HashKey(key)}, [][]byte{buf}) |
||||
} |
||||
|
||||
func jaegerTraceID(ctx context.Context) string { |
||||
span := opentracing.SpanFromContext(ctx) |
||||
if span == nil { |
||||
return "" |
||||
} |
||||
|
||||
spanContext, ok := span.Context().(jaeger.SpanContext) |
||||
if !ok { |
||||
return "" |
||||
} |
||||
|
||||
return spanContext.TraceID().String() |
||||
} |
||||
|
||||
func (e *Extent) toResponse() (Response, error) { |
||||
msg, err := types.EmptyAny(e.Response) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if err := types.UnmarshalAny(e.Response, msg); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
resp, ok := msg.(Response) |
||||
if !ok { |
||||
return nil, fmt.Errorf("bad cached type") |
||||
} |
||||
return resp, nil |
||||
} |
||||
@ -0,0 +1,605 @@ |
||||
package resultscache |
||||
|
||||
import ( |
||||
"context" |
||||
"strconv" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/go-kit/log" |
||||
"github.com/gogo/protobuf/types" |
||||
"github.com/grafana/dskit/flagext" |
||||
"github.com/grafana/dskit/user" |
||||
"github.com/prometheus/common/model" |
||||
"github.com/stretchr/testify/require" |
||||
"golang.org/x/exp/slices" |
||||
|
||||
"github.com/grafana/loki/pkg/logqlmodel/stats" |
||||
"github.com/grafana/loki/pkg/storage/chunk/cache" |
||||
"github.com/grafana/loki/pkg/util/constants" |
||||
) |
||||
|
||||
const day = 24 * time.Hour |
||||
|
||||
var ( |
||||
parsedRequest = &MockRequest{ |
||||
Start: time.UnixMilli(1536673680 * 1e3), |
||||
End: time.UnixMilli(1536716898 * 1e3), |
||||
Step: 120 * 1e3, |
||||
Query: "sum(container_memory_rss) by (namespace)", |
||||
} |
||||
|
||||
parsedResponse = &MockResponse{ |
||||
Labels: []*MockLabelsPair{ |
||||
{Name: "foo", Value: "bar"}, |
||||
}, |
||||
Samples: []*MockSample{ |
||||
{Value: 137, TimestampMs: 1536673680000}, |
||||
{Value: 137, TimestampMs: 1536673780000}, |
||||
}, |
||||
} |
||||
) |
||||
|
||||
func TestPartition(t *testing.T) { |
||||
for _, tc := range []struct { |
||||
name string |
||||
input Request |
||||
prevCachedResponse []Extent |
||||
expectedRequests []Request |
||||
expectedCachedResponse []Response |
||||
}{ |
||||
{ |
||||
name: "Test a complete hit.", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(0), |
||||
End: time.UnixMilli(100), |
||||
}, |
||||
prevCachedResponse: []Extent{ |
||||
mkExtent(0, 100), |
||||
}, |
||||
expectedCachedResponse: []Response{ |
||||
mkAPIResponse(0, 100, 10), |
||||
}, |
||||
}, |
||||
|
||||
{ |
||||
name: "Test with a complete miss.", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(0), |
||||
End: time.UnixMilli(100), |
||||
}, |
||||
prevCachedResponse: []Extent{ |
||||
mkExtent(110, 210), |
||||
}, |
||||
expectedRequests: []Request{ |
||||
&MockRequest{ |
||||
Start: time.UnixMilli(0), |
||||
End: time.UnixMilli(100), |
||||
}, |
||||
}, |
||||
}, |
||||
{ |
||||
name: "Test a partial hit.", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(0), |
||||
End: time.UnixMilli(100), |
||||
}, |
||||
prevCachedResponse: []Extent{ |
||||
mkExtent(50, 100), |
||||
}, |
||||
expectedRequests: []Request{ |
||||
&MockRequest{ |
||||
Start: time.UnixMilli(0), |
||||
End: time.UnixMilli(50), |
||||
}, |
||||
}, |
||||
expectedCachedResponse: []Response{ |
||||
mkAPIResponse(50, 100, 10), |
||||
}, |
||||
}, |
||||
{ |
||||
name: "Test multiple partial hits.", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(100), |
||||
End: time.UnixMilli(200), |
||||
}, |
||||
prevCachedResponse: []Extent{ |
||||
mkExtent(50, 120), |
||||
mkExtent(160, 250), |
||||
}, |
||||
expectedRequests: []Request{ |
||||
&MockRequest{ |
||||
Start: time.UnixMilli(120), |
||||
End: time.UnixMilli(160), |
||||
}, |
||||
}, |
||||
expectedCachedResponse: []Response{ |
||||
mkAPIResponse(100, 120, 10), |
||||
mkAPIResponse(160, 200, 10), |
||||
}, |
||||
}, |
||||
{ |
||||
name: "Partial hits with tiny gap.", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(100), |
||||
End: time.UnixMilli(160), |
||||
}, |
||||
prevCachedResponse: []Extent{ |
||||
mkExtent(50, 120), |
||||
mkExtent(122, 130), |
||||
}, |
||||
expectedRequests: []Request{ |
||||
&MockRequest{ |
||||
Start: time.UnixMilli(120), |
||||
End: time.UnixMilli(160), |
||||
}, |
||||
}, |
||||
expectedCachedResponse: []Response{ |
||||
mkAPIResponse(100, 120, 10), |
||||
}, |
||||
}, |
||||
{ |
||||
name: "Extent is outside the range and the request has a single step (same start and end).", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(100), |
||||
End: time.UnixMilli(100), |
||||
}, |
||||
prevCachedResponse: []Extent{ |
||||
mkExtent(50, 90), |
||||
}, |
||||
expectedRequests: []Request{ |
||||
&MockRequest{ |
||||
Start: time.UnixMilli(100), |
||||
End: time.UnixMilli(100), |
||||
}, |
||||
}, |
||||
}, |
||||
{ |
||||
name: "Test when hit has a large step and only a single sample extent.", |
||||
// If there is a only a single sample in the split interval, start and end will be the same.
|
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(100), |
||||
End: time.UnixMilli(100), |
||||
}, |
||||
prevCachedResponse: []Extent{ |
||||
mkExtent(100, 100), |
||||
}, |
||||
expectedCachedResponse: []Response{ |
||||
mkAPIResponse(100, 105, 10), |
||||
}, |
||||
}, |
||||
} { |
||||
t.Run(tc.name, func(t *testing.T) { |
||||
s := ResultsCache{ |
||||
extractor: MockExtractor{}, |
||||
minCacheExtent: 10, |
||||
} |
||||
reqs, resps, err := s.partition(tc.input, tc.prevCachedResponse) |
||||
require.Nil(t, err) |
||||
require.Equal(t, tc.expectedRequests, reqs) |
||||
require.Equal(t, tc.expectedCachedResponse, resps) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestHandleHit(t *testing.T) { |
||||
for _, tc := range []struct { |
||||
name string |
||||
input Request |
||||
cachedEntry []Extent |
||||
expectedUpdatedCachedEntry []Extent |
||||
}{ |
||||
{ |
||||
name: "Should drop tiny extent that overlaps with non-tiny request only", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(100), |
||||
End: time.UnixMilli(120), |
||||
Step: 5, |
||||
}, |
||||
cachedEntry: []Extent{ |
||||
mkExtentWithStep(0, 50, 5), |
||||
mkExtentWithStep(60, 65, 5), |
||||
mkExtentWithStep(100, 105, 5), |
||||
mkExtentWithStep(110, 150, 5), |
||||
mkExtentWithStep(160, 165, 5), |
||||
}, |
||||
expectedUpdatedCachedEntry: []Extent{ |
||||
mkExtentWithStep(0, 50, 5), |
||||
mkExtentWithStep(60, 65, 5), |
||||
mkExtentWithStep(100, 150, 5), |
||||
mkExtentWithStep(160, 165, 5), |
||||
}, |
||||
}, |
||||
{ |
||||
name: "Should replace tiny extents that are cover by bigger request", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(100), |
||||
End: time.UnixMilli(200), |
||||
Step: 5, |
||||
}, |
||||
cachedEntry: []Extent{ |
||||
mkExtentWithStep(0, 50, 5), |
||||
mkExtentWithStep(60, 65, 5), |
||||
mkExtentWithStep(100, 105, 5), |
||||
mkExtentWithStep(110, 115, 5), |
||||
mkExtentWithStep(120, 125, 5), |
||||
mkExtentWithStep(220, 225, 5), |
||||
mkExtentWithStep(240, 250, 5), |
||||
}, |
||||
expectedUpdatedCachedEntry: []Extent{ |
||||
mkExtentWithStep(0, 50, 5), |
||||
mkExtentWithStep(60, 65, 5), |
||||
mkExtentWithStep(100, 200, 5), |
||||
mkExtentWithStep(220, 225, 5), |
||||
mkExtentWithStep(240, 250, 5), |
||||
}, |
||||
}, |
||||
{ |
||||
name: "Should not drop tiny extent that completely overlaps with tiny request", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(100), |
||||
End: time.UnixMilli(105), |
||||
Step: 5, |
||||
}, |
||||
cachedEntry: []Extent{ |
||||
mkExtentWithStep(0, 50, 5), |
||||
mkExtentWithStep(60, 65, 5), |
||||
mkExtentWithStep(100, 105, 5), |
||||
mkExtentWithStep(160, 165, 5), |
||||
}, |
||||
expectedUpdatedCachedEntry: nil, // no cache update need, request fulfilled using cache
|
||||
}, |
||||
{ |
||||
name: "Should not drop tiny extent that partially center-overlaps with tiny request", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(106), |
||||
End: time.UnixMilli(108), |
||||
Step: 2, |
||||
}, |
||||
cachedEntry: []Extent{ |
||||
mkExtentWithStep(60, 64, 2), |
||||
mkExtentWithStep(104, 110, 2), |
||||
mkExtentWithStep(160, 166, 2), |
||||
}, |
||||
expectedUpdatedCachedEntry: nil, // no cache update need, request fulfilled using cache
|
||||
}, |
||||
{ |
||||
name: "Should not drop tiny extent that partially left-overlaps with tiny request", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(100), |
||||
End: time.UnixMilli(106), |
||||
Step: 2, |
||||
}, |
||||
cachedEntry: []Extent{ |
||||
mkExtentWithStep(60, 64, 2), |
||||
mkExtentWithStep(104, 110, 2), |
||||
mkExtentWithStep(160, 166, 2), |
||||
}, |
||||
expectedUpdatedCachedEntry: []Extent{ |
||||
mkExtentWithStep(60, 64, 2), |
||||
mkExtentWithStep(100, 110, 2), |
||||
mkExtentWithStep(160, 166, 2), |
||||
}, |
||||
}, |
||||
{ |
||||
name: "Should not drop tiny extent that partially right-overlaps with tiny request", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(100), |
||||
End: time.UnixMilli(106), |
||||
Step: 2, |
||||
}, |
||||
cachedEntry: []Extent{ |
||||
mkExtentWithStep(60, 64, 2), |
||||
mkExtentWithStep(98, 102, 2), |
||||
mkExtentWithStep(160, 166, 2), |
||||
}, |
||||
expectedUpdatedCachedEntry: []Extent{ |
||||
mkExtentWithStep(60, 64, 2), |
||||
mkExtentWithStep(98, 106, 2), |
||||
mkExtentWithStep(160, 166, 2), |
||||
}, |
||||
}, |
||||
{ |
||||
name: "Should merge fragmented extents if request fills the hole", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(40), |
||||
End: time.UnixMilli(80), |
||||
Step: 20, |
||||
}, |
||||
cachedEntry: []Extent{ |
||||
mkExtentWithStep(0, 20, 20), |
||||
mkExtentWithStep(80, 100, 20), |
||||
}, |
||||
expectedUpdatedCachedEntry: []Extent{ |
||||
mkExtentWithStep(0, 100, 20), |
||||
}, |
||||
}, |
||||
{ |
||||
name: "Should left-extend extent if request starts earlier than extent in cache", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(40), |
||||
End: time.UnixMilli(80), |
||||
Step: 20, |
||||
}, |
||||
cachedEntry: []Extent{ |
||||
mkExtentWithStep(60, 160, 20), |
||||
}, |
||||
expectedUpdatedCachedEntry: []Extent{ |
||||
mkExtentWithStep(40, 160, 20), |
||||
}, |
||||
}, |
||||
{ |
||||
name: "Should right-extend extent if request ends later than extent in cache", |
||||
input: &MockRequest{ |
||||
Start: time.UnixMilli(100), |
||||
End: time.UnixMilli(180), |
||||
Step: 20, |
||||
}, |
||||
cachedEntry: []Extent{ |
||||
mkExtentWithStep(60, 160, 20), |
||||
}, |
||||
expectedUpdatedCachedEntry: []Extent{ |
||||
mkExtentWithStep(60, 180, 20), |
||||
}, |
||||
}, |
||||
{ |
||||
name: "Should not throw error if complete-overlapped smaller Extent is erroneous", |
||||
input: &MockRequest{ |
||||
// This request is carefully crated such that cachedEntry is not used to fulfill
|
||||
// the request.
|
||||
Start: time.UnixMilli(160), |
||||
End: time.UnixMilli(180), |
||||
Step: 20, |
||||
}, |
||||
cachedEntry: []Extent{ |
||||
{ |
||||
Start: 60, |
||||
End: 80, |
||||
|
||||
// if the optimization of "sorting by End when Start of 2 Extents are equal" is not there, this nil
|
||||
// response would cause error during Extents merge phase. With the optimization
|
||||
// this bad Extent should be dropped. The good Extent below can be used instead.
|
||||
Response: nil, |
||||
}, |
||||
mkExtentWithStep(60, 160, 20), |
||||
}, |
||||
expectedUpdatedCachedEntry: []Extent{ |
||||
mkExtentWithStep(60, 180, 20), |
||||
}, |
||||
}, |
||||
} { |
||||
t.Run(tc.name, func(t *testing.T) { |
||||
sut := ResultsCache{ |
||||
extractor: MockExtractor{}, |
||||
minCacheExtent: 10, |
||||
limits: mockLimits{}, |
||||
merger: MockMerger{}, |
||||
parallelismForReq: func(_ context.Context, tenantIDs []string, r Request) int { return 1 }, |
||||
next: HandlerFunc(func(_ context.Context, req Request) (Response, error) { |
||||
return mkAPIResponse(req.GetStart().UnixMilli(), req.GetEnd().UnixMilli(), req.GetStep()), nil |
||||
}), |
||||
} |
||||
|
||||
ctx := user.InjectOrgID(context.Background(), "1") |
||||
response, updatedExtents, err := sut.handleHit(ctx, tc.input, tc.cachedEntry, 0) |
||||
require.NoError(t, err) |
||||
|
||||
expectedResponse := mkAPIResponse(tc.input.GetStart().UnixMilli(), tc.input.GetEnd().UnixMilli(), tc.input.GetStep()) |
||||
require.Equal(t, expectedResponse, response, "response does not match the expectation") |
||||
require.Equal(t, tc.expectedUpdatedCachedEntry, updatedExtents, "updated cache entry does not match the expectation") |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestResultsCacheMaxFreshness(t *testing.T) { |
||||
modelNow := model.Now() |
||||
for i, tc := range []struct { |
||||
fakeLimits Limits |
||||
Handler HandlerFunc |
||||
expectedResponse *MockResponse |
||||
}{ |
||||
{ |
||||
fakeLimits: mockLimits{maxCacheFreshness: 5 * time.Second}, |
||||
Handler: nil, |
||||
expectedResponse: mkAPIResponse(int64(modelNow)-(50*1e3), int64(modelNow)-(10*1e3), 10), |
||||
}, |
||||
{ |
||||
// should not lookup cache because per-tenant override will be applied
|
||||
fakeLimits: mockLimits{maxCacheFreshness: 10 * time.Minute}, |
||||
Handler: HandlerFunc(func(_ context.Context, _ Request) (Response, error) { |
||||
return parsedResponse, nil |
||||
}), |
||||
expectedResponse: parsedResponse, |
||||
}, |
||||
} { |
||||
t.Run(strconv.Itoa(i), func(t *testing.T) { |
||||
var cfg Config |
||||
flagext.DefaultValues(&cfg) |
||||
cfg.CacheConfig.Cache = cache.NewMockCache() |
||||
c, err := cache.New(cfg.CacheConfig, nil, log.NewNopLogger(), stats.ResultCache, constants.Loki) |
||||
require.NoError(t, err) |
||||
fakeLimits := tc.fakeLimits |
||||
rc := NewResultsCache( |
||||
log.NewNopLogger(), |
||||
c, |
||||
tc.Handler, |
||||
ConstSplitter(day), |
||||
fakeLimits, |
||||
MockMerger{}, |
||||
MockExtractor{}, |
||||
nil, |
||||
nil, |
||||
func(_ context.Context, tenantIDs []string, r Request) int { |
||||
return 10 |
||||
}, |
||||
nil, |
||||
false, |
||||
) |
||||
require.NoError(t, err) |
||||
|
||||
// create cache with handler
|
||||
ctx := user.InjectOrgID(context.Background(), "1") |
||||
|
||||
// create request with start end within the key extents
|
||||
req := parsedRequest.WithStartEndForCache(time.UnixMilli(int64(modelNow)-(50*1e3)), time.UnixMilli(int64(modelNow)-(10*1e3))) |
||||
|
||||
// fill cache
|
||||
key := ConstSplitter(day).GenerateCacheKey(context.Background(), "1", req) |
||||
rc.put(ctx, key, []Extent{mkExtent(int64(modelNow)-(600*1e3), int64(modelNow))}) |
||||
|
||||
resp, err := rc.Do(ctx, req) |
||||
require.NoError(t, err) |
||||
require.Equal(t, tc.expectedResponse, resp) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func Test_resultsCache_MissingData(t *testing.T) { |
||||
cfg := Config{ |
||||
CacheConfig: cache.Config{ |
||||
Cache: cache.NewMockCache(), |
||||
}, |
||||
} |
||||
c, err := cache.New(cfg.CacheConfig, nil, log.NewNopLogger(), stats.ResultCache, constants.Loki) |
||||
require.NoError(t, err) |
||||
rc := NewResultsCache( |
||||
log.NewNopLogger(), |
||||
c, |
||||
nil, |
||||
ConstSplitter(day), |
||||
mockLimits{}, |
||||
MockMerger{}, |
||||
MockExtractor{}, |
||||
nil, |
||||
nil, |
||||
func(_ context.Context, tenantIDs []string, r Request) int { |
||||
return 10 |
||||
}, |
||||
nil, |
||||
false, |
||||
) |
||||
require.NoError(t, err) |
||||
ctx := context.Background() |
||||
|
||||
// fill up the cache
|
||||
rc.put(ctx, "empty", []Extent{{ |
||||
Start: 100, |
||||
End: 200, |
||||
Response: nil, |
||||
}}) |
||||
rc.put(ctx, "notempty", []Extent{mkExtent(100, 120)}) |
||||
rc.put(ctx, "mixed", []Extent{mkExtent(100, 120), { |
||||
Start: 120, |
||||
End: 200, |
||||
Response: nil, |
||||
}}) |
||||
|
||||
extents, hit := rc.get(ctx, "empty") |
||||
require.Empty(t, extents) |
||||
require.False(t, hit) |
||||
|
||||
extents, hit = rc.get(ctx, "notempty") |
||||
require.Equal(t, len(extents), 1) |
||||
require.True(t, hit) |
||||
|
||||
extents, hit = rc.get(ctx, "mixed") |
||||
require.Equal(t, len(extents), 0) |
||||
require.False(t, hit) |
||||
} |
||||
|
||||
func mkAPIResponse(start, end, step int64) *MockResponse { |
||||
var samples []*MockSample |
||||
for i := start; i <= end; i += step { |
||||
samples = append(samples, &MockSample{ |
||||
TimestampMs: i, |
||||
Value: float64(i), |
||||
}) |
||||
} |
||||
|
||||
return &MockResponse{ |
||||
Labels: []*MockLabelsPair{ |
||||
{Name: "foo", Value: "bar"}, |
||||
}, |
||||
Samples: samples, |
||||
} |
||||
} |
||||
|
||||
func mkExtent(start, end int64) Extent { |
||||
return mkExtentWithStep(start, end, 10) |
||||
} |
||||
|
||||
func mkExtentWithStep(start, end, step int64) Extent { |
||||
res := mkAPIResponse(start, end, step) |
||||
anyRes, err := types.MarshalAny(res) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return Extent{ |
||||
Start: start, |
||||
End: end, |
||||
Response: anyRes, |
||||
} |
||||
} |
||||
|
||||
func (r *MockRequest) WithStartEndForCache(start time.Time, end time.Time) Request { |
||||
m := *r |
||||
m.Start = start |
||||
m.End = end |
||||
return &m |
||||
} |
||||
|
||||
type MockMerger struct{} |
||||
|
||||
func (m MockMerger) MergeResponse(responses ...Response) (Response, error) { |
||||
samples := make([]*MockSample, 0, len(responses)*2) |
||||
for _, response := range responses { |
||||
samples = append(samples, response.(*MockResponse).Samples...) |
||||
} |
||||
|
||||
// Merge samples by:
|
||||
// 1. Sorting them by time.
|
||||
// 2. Removing duplicates.
|
||||
slices.SortFunc(samples, func(a, b *MockSample) int { |
||||
if a.TimestampMs == b.TimestampMs { |
||||
return 0 |
||||
} |
||||
if a.TimestampMs < b.TimestampMs { |
||||
return -1 |
||||
} |
||||
return 1 |
||||
}) |
||||
samples = slices.CompactFunc(samples, func(a, b *MockSample) bool { |
||||
return a.TimestampMs == b.TimestampMs |
||||
}) |
||||
|
||||
return &MockResponse{ |
||||
Labels: responses[0].(*MockResponse).Labels, |
||||
Samples: samples, |
||||
}, nil |
||||
} |
||||
|
||||
type MockExtractor struct{} |
||||
|
||||
func (m MockExtractor) Extract(start, end int64, res Response, _, _ int64) Response { |
||||
mockRes := res.(*MockResponse) |
||||
|
||||
result := MockResponse{ |
||||
Labels: mockRes.Labels, |
||||
Samples: make([]*MockSample, 0, len(mockRes.Samples)), |
||||
} |
||||
|
||||
for _, sample := range mockRes.Samples { |
||||
if start <= sample.TimestampMs && sample.TimestampMs <= end { |
||||
result.Samples = append(result.Samples, sample) |
||||
} |
||||
} |
||||
return &result |
||||
} |
||||
|
||||
type mockLimits struct { |
||||
maxCacheFreshness time.Duration |
||||
} |
||||
|
||||
func (m mockLimits) MaxCacheFreshness(context.Context, string) time.Duration { |
||||
return m.maxCacheFreshness |
||||
} |
||||
@ -0,0 +1,41 @@ |
||||
package resultscache |
||||
|
||||
import ( |
||||
"context" |
||||
"flag" |
||||
"time" |
||||
|
||||
"github.com/pkg/errors" |
||||
|
||||
"github.com/grafana/loki/pkg/storage/chunk/cache" |
||||
) |
||||
|
||||
// Config is the config for the results cache.
|
||||
type Config struct { |
||||
CacheConfig cache.Config `yaml:"cache"` |
||||
Compression string `yaml:"compression"` |
||||
} |
||||
|
||||
func (cfg *Config) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix string) { |
||||
cfg.CacheConfig.RegisterFlagsWithPrefix(prefix, "", f) |
||||
f.StringVar(&cfg.Compression, prefix+"compression", "", "Use compression in cache. The default is an empty value '', which disables compression. Supported values are: 'snappy' and ''.") |
||||
} |
||||
|
||||
func (cfg *Config) RegisterFlags(f *flag.FlagSet) { |
||||
cfg.RegisterFlagsWithPrefix(f, "") |
||||
} |
||||
|
||||
func (cfg *Config) Validate() error { |
||||
switch cfg.Compression { |
||||
case "snappy", "": |
||||
// valid
|
||||
default: |
||||
return errors.Errorf("unsupported compression type: %s", cfg.Compression) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
type Limits interface { |
||||
MaxCacheFreshness(ctx context.Context, tenantID string) time.Duration |
||||
} |
||||
@ -0,0 +1,56 @@ |
||||
package resultscache |
||||
|
||||
import ( |
||||
"context" |
||||
"time" |
||||
|
||||
"github.com/gogo/protobuf/proto" |
||||
) |
||||
|
||||
type Request interface { |
||||
proto.Message |
||||
// GetStart returns the start timestamp of the request in milliseconds.
|
||||
GetStart() time.Time |
||||
// GetEnd returns the end timestamp of the request in milliseconds.
|
||||
GetEnd() time.Time |
||||
// GetStep returns the step of the request in milliseconds.
|
||||
GetStep() int64 |
||||
// GetQuery returns the query of the request.
|
||||
GetQuery() string |
||||
// GetCachingOptions returns the caching options.
|
||||
GetCachingOptions() CachingOptions |
||||
// WithStartEndForCache clone the current request with different start and end timestamp.
|
||||
WithStartEndForCache(start time.Time, end time.Time) Request |
||||
} |
||||
|
||||
type Response interface { |
||||
proto.Message |
||||
} |
||||
|
||||
// ResponseMerger is used by middlewares making multiple requests to merge back all responses into a single one.
|
||||
type ResponseMerger interface { |
||||
// MergeResponse merges responses from multiple requests into a single Response
|
||||
MergeResponse(...Response) (Response, error) |
||||
} |
||||
|
||||
type Handler interface { |
||||
Do(ctx context.Context, req Request) (Response, error) |
||||
} |
||||
|
||||
// Extractor is used by the cache to extract a subset of a response from a cache entry.
|
||||
type Extractor interface { |
||||
// Extract extracts a subset of a response from the `start` and `end` timestamps in milliseconds
|
||||
// in the `res` response which spans from `resStart` to `resEnd`.
|
||||
Extract(start, end int64, res Response, resStart, resEnd int64) Response |
||||
} |
||||
|
||||
// KeyGenerator generates cache keys. This is a useful interface for downstream
|
||||
// consumers who wish to implement their own strategies.
|
||||
type KeyGenerator interface { |
||||
GenerateCacheKey(ctx context.Context, userID string, r Request) string |
||||
} |
||||
|
||||
type CacheGenNumberLoader interface { |
||||
GetResultsCacheGenNumber(tenantIDs []string) string |
||||
Stop() |
||||
} |
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,41 @@ |
||||
syntax = "proto3"; |
||||
|
||||
package resultscache; |
||||
|
||||
import "gogoproto/gogo.proto"; |
||||
import "google/protobuf/timestamp.proto"; |
||||
import "types.proto"; |
||||
|
||||
option go_package = "github.com/grafana/loki/pkg/storage/chunk/cache/resultscache"; |
||||
option (gogoproto.marshaler_all) = true; |
||||
option (gogoproto.unmarshaler_all) = true; |
||||
|
||||
message MockRequest { |
||||
string path = 1; |
||||
google.protobuf.Timestamp start = 2 [ |
||||
(gogoproto.stdtime) = true, |
||||
(gogoproto.nullable) = false |
||||
]; |
||||
google.protobuf.Timestamp end = 3 [ |
||||
(gogoproto.stdtime) = true, |
||||
(gogoproto.nullable) = false |
||||
]; |
||||
int64 step = 4; |
||||
string query = 6; |
||||
CachingOptions cachingOptions = 7 [(gogoproto.nullable) = false]; |
||||
} |
||||
|
||||
message MockResponse { |
||||
repeated MockLabelsPair labels = 1; |
||||
repeated MockSample samples = 2; |
||||
} |
||||
|
||||
message MockLabelsPair { |
||||
string name = 1; |
||||
string value = 2; |
||||
} |
||||
|
||||
message MockSample { |
||||
double value = 1; |
||||
int64 timestamp_ms = 2; |
||||
} |
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,34 @@ |
||||
syntax = "proto3"; |
||||
|
||||
package resultscache; |
||||
|
||||
import "gogoproto/gogo.proto"; |
||||
import "google/protobuf/any.proto"; |
||||
|
||||
option go_package = "github.com/grafana/loki/pkg/storage/chunk/cache/resultscache"; |
||||
option (gogoproto.marshaler_all) = true; |
||||
option (gogoproto.unmarshaler_all) = true; |
||||
|
||||
// Defined here to prevent circular imports between logproto & queryrangebase |
||||
message CachingOptions { |
||||
bool disabled = 1; |
||||
} |
||||
|
||||
message CachedResponse { |
||||
string key = 1 [(gogoproto.jsontag) = "key"]; |
||||
|
||||
// List of cached responses; non-overlapping and in order. |
||||
repeated Extent extents = 2 [ |
||||
(gogoproto.nullable) = false, |
||||
(gogoproto.jsontag) = "extents" |
||||
]; |
||||
} |
||||
|
||||
message Extent { |
||||
int64 start = 1 [(gogoproto.jsontag) = "start"]; |
||||
int64 end = 2 [(gogoproto.jsontag) = "end"]; |
||||
// reserved the previous key to ensure cache transition |
||||
reserved 3; |
||||
string trace_id = 4 [(gogoproto.jsontag) = "-"]; |
||||
google.protobuf.Any response = 5 [(gogoproto.jsontag) = "response"]; |
||||
} |
||||
@ -0,0 +1,67 @@ |
||||
package resultscache |
||||
|
||||
import ( |
||||
"context" |
||||
) |
||||
|
||||
type HandlerFunc func(context.Context, Request) (Response, error) |
||||
|
||||
// Do implements Handler.
|
||||
func (q HandlerFunc) Do(ctx context.Context, req Request) (Response, error) { |
||||
return q(ctx, req) |
||||
} |
||||
|
||||
// RequestResponse contains a request response and the respective request that was used.
|
||||
type RequestResponse struct { |
||||
Request Request |
||||
Response Response |
||||
} |
||||
|
||||
// DoRequests executes a list of requests in parallel.
|
||||
func DoRequests(ctx context.Context, downstream Handler, reqs []Request, parallelism int) ([]RequestResponse, error) { |
||||
// If one of the requests fail, we want to be able to cancel the rest of them.
|
||||
ctx, cancel := context.WithCancel(ctx) |
||||
defer cancel() |
||||
|
||||
// Feed all requests to a bounded intermediate channel to limit parallelism.
|
||||
intermediate := make(chan Request) |
||||
go func() { |
||||
for _, req := range reqs { |
||||
intermediate <- req |
||||
} |
||||
close(intermediate) |
||||
}() |
||||
|
||||
respChan, errChan := make(chan RequestResponse), make(chan error) |
||||
if parallelism > len(reqs) { |
||||
parallelism = len(reqs) |
||||
} |
||||
for i := 0; i < parallelism; i++ { |
||||
go func() { |
||||
for req := range intermediate { |
||||
resp, err := downstream.Do(ctx, req) |
||||
if err != nil { |
||||
errChan <- err |
||||
} else { |
||||
respChan <- RequestResponse{req, resp} |
||||
} |
||||
} |
||||
}() |
||||
} |
||||
|
||||
resps := make([]RequestResponse, 0, len(reqs)) |
||||
var firstErr error |
||||
for range reqs { |
||||
select { |
||||
case resp := <-respChan: |
||||
resps = append(resps, resp) |
||||
case err := <-errChan: |
||||
if firstErr == nil { |
||||
cancel() |
||||
firstErr = err |
||||
} |
||||
} |
||||
} |
||||
|
||||
return resps, firstErr |
||||
} |
||||
Loading…
Reference in new issue