diff --git a/pkg/limits/frontend/ring.go b/pkg/limits/frontend/ring.go index 47c7549d31..e31f7e2446 100644 --- a/pkg/limits/frontend/ring.go +++ b/pkg/limits/frontend/ring.go @@ -2,6 +2,7 @@ package frontend import ( "context" + "iter" "slices" "sort" "strings" @@ -76,24 +77,10 @@ func (r *ringLimitsClient) ExceedsLimits(ctx context.Context, req *proto.Exceeds if len(req.Streams) == 0 { return nil, nil } - rs, err := r.ring.GetAllHealthy(LimitsRead) - if err != nil { - return nil, err - } - // Get the partition consumers for each zone. - zonesPartitions, err := r.getZoneAwarePartitionConsumers(ctx, rs.Instances) + zonesIter, err := r.allZones(ctx) if err != nil { return nil, err } - // In practice we want zones to be queried in random order to spread - // reads. However, in tests we want a deterministic order so test cases - // are stable and reproducible. Having a custom sort func supports both - // use cases as zoneCmp can be switched out in tests. - zonesToQuery := make([]string, 0, len(zonesPartitions)) - for zone := range zonesPartitions { - zonesToQuery = append(zonesToQuery, zone) - } - slices.SortFunc(zonesToQuery, r.zoneCmp) // Make a copy of the streams from the request. We will prune this slice // each time we receive the responses from a zone. streams := make([]*proto.StreamMetadata, 0, len(req.Streams)) @@ -107,12 +94,12 @@ func (r *ringLimitsClient) ExceedsLimits(ctx context.Context, req *proto.Exceeds // process until all streams have been queried or we have exhausted all // zones. responses := make([]*proto.ExceedsLimitsResponse, 0) - for _, zone := range zonesToQuery { + for zone, consumers := range zonesIter { // All streams been checked against per-tenant limits. if len(streams) == 0 { break } - resps, answered, err := r.doExceedsLimitsRPCs(ctx, req.Tenant, streams, zonesPartitions[zone], zone) + resps, answered, err := r.doExceedsLimitsRPCs(ctx, req.Tenant, streams, zone, consumers) if err != nil { continue } @@ -143,12 +130,12 @@ func (r *ringLimitsClient) ExceedsLimits(ctx context.Context, req *proto.Exceeds return responses, nil } -func (r *ringLimitsClient) doExceedsLimitsRPCs(ctx context.Context, tenant string, streams []*proto.StreamMetadata, partitions map[int32]string, zone string) ([]*proto.ExceedsLimitsResponse, []uint64, error) { +func (r *ringLimitsClient) doExceedsLimitsRPCs(ctx context.Context, tenant string, streams []*proto.StreamMetadata, zone string, consumers map[int32]string) ([]*proto.ExceedsLimitsResponse, []uint64, error) { // For each stream, figure out which instance consume its partition. instancesForStreams := make(map[string][]*proto.StreamMetadata) for _, stream := range streams { partition := int32(stream.StreamHash % uint64(r.numPartitions)) - addr, ok := partitions[partition] + addr, ok := consumers[partition] if !ok { r.partitionsMissing.WithLabelValues(zone).Inc() continue @@ -194,9 +181,37 @@ func (r *ringLimitsClient) doExceedsLimitsRPCs(ctx context.Context, tenant strin return responses, answered, nil } -type zonePartitionConsumersResult struct { - zone string - partitions map[int32]string +// allZones returns an iterator over all zones and the consumers for each +// partition in each zone. If a zone has no active partition consumers, the +// zone will still be returned but its partition consumers will be nil. +// If ZoneAwarenessEnabled is false, it returns all partition consumers under +// a pseudo-zone (""). +func (r *ringLimitsClient) allZones(ctx context.Context) (iter.Seq2[string, map[int32]string], error) { + rs, err := r.ring.GetAllHealthy(LimitsRead) + if err != nil { + return nil, err + } + // Get the partition consumers for each zone. + zonesPartitions, err := r.getZoneAwarePartitionConsumers(ctx, rs.Instances) + if err != nil { + return nil, err + } + // In practice we want zones to be queried in random order to spread + // reads. However, in tests we want a deterministic order so test cases + // are stable and reproducible. Having a custom sort func supports both + // use cases as zoneCmp can be switched out in tests. + zonesToQuery := make([]string, 0, len(zonesPartitions)) + for zone := range zonesPartitions { + zonesToQuery = append(zonesToQuery, zone) + } + slices.SortFunc(zonesToQuery, r.zoneCmp) + return func(yield func(string, map[int32]string) bool) { + for _, zone := range zonesToQuery { + if !yield(zone, zonesPartitions[zone]) { + return + } + } + }, nil } // getZoneAwarePartitionConsumers returns partition consumers for each zone @@ -210,6 +225,10 @@ func (r *ringLimitsClient) getZoneAwarePartitionConsumers(ctx context.Context, i zoneDescs[instance.Zone] = append(zoneDescs[instance.Zone], instance) } // Get the partition consumers for each zone. + type zonePartitionConsumersResult struct { + zone string + partitions map[int32]string + } resultsCh := make(chan zonePartitionConsumersResult, len(zoneDescs)) errg, ctx := errgroup.WithContext(ctx) for zone, instances := range zoneDescs { @@ -236,11 +255,6 @@ func (r *ringLimitsClient) getZoneAwarePartitionConsumers(ctx context.Context, i return results, nil } -type getAssignedPartitionsResponse struct { - addr string - response *proto.GetAssignedPartitionsResponse -} - // getPartitionConsumers returns the consumer for each partition. // In some cases, it might not be possible to know the consumer for a @@ -259,6 +273,10 @@ type getAssignedPartitionsResponse struct { // to find the most up to date consumer for each partition across all zones. func (r *ringLimitsClient) getPartitionConsumers(ctx context.Context, instances []ring.InstanceDesc) (map[int32]string, error) { errg, ctx := errgroup.WithContext(ctx) + type getAssignedPartitionsResponse struct { + addr string + response *proto.GetAssignedPartitionsResponse + } responseCh := make(chan getAssignedPartitionsResponse, len(instances)) for _, instance := range instances { errg.Go(func() error {