diff --git a/pkg/querier/queryrange/limits.go b/pkg/querier/queryrange/limits.go index 82bef4bf95..b6f5c4d51f 100644 --- a/pkg/querier/queryrange/limits.go +++ b/pkg/querier/queryrange/limits.go @@ -16,6 +16,7 @@ import ( "github.com/grafana/dskit/tenant" "github.com/opentracing/opentracing-go" + otlog "github.com/opentracing/opentracing-go/log" "github.com/pkg/errors" "github.com/prometheus/common/model" "github.com/prometheus/prometheus/model/labels" @@ -452,6 +453,27 @@ func NewLimitedRoundTripper(next queryrangebase.Handler, limits Limits, configs return transport } +type SemaphoreWithTiming struct { + sem *semaphore.Weighted +} + +func NewSemaphoreWithTiming(max int64) *SemaphoreWithTiming { + return &SemaphoreWithTiming{ + sem: semaphore.NewWeighted(max), + } +} + +// acquires the semaphore and records the time it takes. +func (s *SemaphoreWithTiming) Acquire(ctx context.Context, n int64) (time.Duration, error) { + start := time.Now() + + if err := s.sem.Acquire(ctx, n); err != nil { + return 0, err + } + + return time.Since(start), nil +} + func (rt limitedRoundTripper) Do(c context.Context, request queryrangebase.Request) (queryrangebase.Response, error) { var ( ctx, cancel = context.WithCancel(c) @@ -460,9 +482,12 @@ func (rt limitedRoundTripper) Do(c context.Context, request queryrangebase.Reque cancel() }() - if span := opentracing.SpanFromContext(ctx); span != nil { + span := opentracing.SpanFromContext(ctx) + + if span != nil { request.LogToSpan(span) } + tenantIDs, err := tenant.TenantIDs(ctx) if err != nil { return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) @@ -481,7 +506,7 @@ func (rt limitedRoundTripper) Do(c context.Context, request queryrangebase.Reque return nil, httpgrpc.Errorf(http.StatusTooManyRequests, ErrMaxQueryParalellism.Error()) } - sem := semaphore.NewWeighted(int64(parallelism)) + semWithTiming := NewSemaphoreWithTiming(int64(parallelism)) return rt.middleware.Wrap( queryrangebase.HandlerFunc(func(ctx context.Context, r queryrangebase.Request) (queryrangebase.Response, error) { @@ -492,10 +517,20 @@ func (rt limitedRoundTripper) Do(c context.Context, request queryrangebase.Reque // the thousands. // Note: It is the responsibility of the caller to run // the handler in parallel. - if err := sem.Acquire(ctx, int64(1)); err != nil { + elapsed, err := semWithTiming.Acquire(ctx, int64(1)) + + if err != nil { return nil, fmt.Errorf("could not acquire work: %w", err) } - defer sem.Release(int64(1)) + + if span != nil { + span.LogFields( + otlog.String("wait_time", elapsed.String()), + otlog.Int64("max_parallelism", int64(parallelism)), + ) + } + + defer semWithTiming.sem.Release(int64(1)) return rt.next.Do(ctx, r) })).Do(ctx, request) diff --git a/pkg/querier/queryrange/limits_test.go b/pkg/querier/queryrange/limits_test.go index 4ab81ec4ac..24253892ca 100644 --- a/pkg/querier/queryrange/limits_test.go +++ b/pkg/querier/queryrange/limits_test.go @@ -623,3 +623,59 @@ func Test_MaxQuerySize_MaxLookBackPeriod(t *testing.T) { }) } } + +func TestAcquireWithTiming(t *testing.T) { + + ctx := context.Background() + sem := NewSemaphoreWithTiming(2) + + // Channel to collect waiting times + waitingTimes := make(chan struct { + GoroutineID int + WaitingTime int64 + }, 3) + + tryAcquire := func(n int64, goroutineID int) { + elapsed, err := sem.Acquire(ctx, n) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + waitingTimes <- struct { + GoroutineID int + WaitingTime int64 + }{goroutineID, elapsed.Milliseconds()} + + defer sem.sem.Release(n) + + time.Sleep(10 * time.Millisecond) + } + + go tryAcquire(1, 1) + go tryAcquire(1, 2) + + // Sleep briefly to allow the first two goroutines to start running + time.Sleep(5 * time.Millisecond) + + go tryAcquire(1, 3) + + // Collect and sort waiting times + var waitingDurations []struct { + GoroutineID int + WaitingTime int64 + } + for i := 0; i < 3; i++ { + waitingDurations = append(waitingDurations, <-waitingTimes) + } + // Find and check the waiting time for the third goroutine + var waiting3 int64 + for _, waiting := range waitingDurations { + if waiting.GoroutineID == 3 { + waiting3 = waiting.WaitingTime + break + } + } + + // Check that the waiting time for the third request is larger than 0 milliseconds and less than or equal to 10-5=5 milliseconds + require.Greater(t, waiting3, 0*time.Millisecond) + require.LessOrEqual(t, waiting3, 5*time.Millisecond) +}