Loki: Shard streams based on desired rate (#7199)

**What this PR does / why we need it**:
Modify our sharding mechanism to be dictated by the desired rate and ingestion rate.
Once the number of shards is calculated, the stream entries will be distributed across the shards. If the number of entries is lower than number of shards, we report a metric and use number of entries as the number of shards.
pull/7257/head
Dylan Guedes 3 years ago committed by GitHub
parent eb949e2907
commit dcfba366de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 12
      docs/sources/configuration/_index.md
  2. 89
      pkg/distributor/distributor.go
  3. 305
      pkg/distributor/distributor_test.go
  4. 25
      pkg/distributor/ratestore.go
  5. 49
      pkg/distributor/streamsharder.go
  6. 57
      pkg/distributor/streamsharder_test.go

@ -311,15 +311,23 @@ ring:
shard_streams:
# Whether to enable stream sharding
#
# CLI flag: -distributor.stream-sharding.enabled
# CLI flag: -distributor.shard-streams.enabled
[enabled: <boolean> | default = false]
# Enable logging when sharding streams because logging on the read path may
# impact performance. When disabled, stream sharding will emit no logs
# regardless of log level
#
# CLI flag: -distributor.stream-sharding.logging-enabled
# CLI flag: -distributor.shard-streams.logging-enabled
[logging_enabled: <boolean> | default = false]
# Threshold that determines how much the stream should be sharded.
# The formula used is n = ceil(stream size + ingested rate / desired rate), where n is the number of shards.
# For instance, if a stream ingestion is at 10MB, desired rate is 3MB (default), and a stream of size 1MB is
# received, the given stream will be split into n = ceil((1 + 10)/3) = 4 shards.
#
# CLI flag: -distributor.shard-streams.desired-rate
[desired_rate: <string> | default = 3MB]
```
## querier

@ -3,6 +3,7 @@ package distributor
import (
"context"
"flag"
"math"
"net/http"
"strconv"
"strings"
@ -33,6 +34,7 @@ import (
"github.com/grafana/loki/pkg/storage/stores/indexshipper/compactor/retention"
"github.com/grafana/loki/pkg/usagestats"
"github.com/grafana/loki/pkg/util"
"github.com/grafana/loki/pkg/util/flagext"
util_log "github.com/grafana/loki/pkg/util/log"
"github.com/grafana/loki/pkg/validation"
)
@ -54,6 +56,16 @@ var (
type ShardStreamsConfig struct {
Enabled bool `yaml:"enabled"`
LoggingEnabled bool `yaml:"logging_enabled"`
// DesiredRate is the threshold used to shard the stream into smaller pieces.
// Expected to be in bytes.
DesiredRate flagext.ByteSize `yaml:"desired_rate"`
}
func (cfg *ShardStreamsConfig) RegisterFlagsWithPrefix(prefix string, fs *flag.FlagSet) {
fs.BoolVar(&cfg.Enabled, prefix+".enabled", false, "Automatically shard streams to keep them under the per-stream rate limit")
fs.BoolVar(&cfg.LoggingEnabled, prefix+".logging-enabled", false, "Enable logging when sharding streams")
fs.Var(&cfg.DesiredRate, prefix+".desired-rate", "threshold used to cut a new shard. Default (3MB) means if a rate is above 3MB, it will be sharded.")
}
// Config for a Distributor.
@ -71,14 +83,12 @@ type Config struct {
// RegisterFlags registers distributor-related flags.
func (cfg *Config) RegisterFlags(fs *flag.FlagSet) {
cfg.DistributorRing.RegisterFlags(fs)
fs.BoolVar(&cfg.ShardStreams.Enabled, "distributor.stream-sharding.enabled", false, "Automatically shard streams to keep them under the per-stream rate limit")
fs.BoolVar(&cfg.ShardStreams.LoggingEnabled, "distributor.stream-sharding.logging-enabled", false, "Enable logging when sharding streams")
cfg.ShardStreams.RegisterFlagsWithPrefix("distributor.shard-streams", fs)
}
// StreamSharder manages the state necessary to shard streams.
type StreamSharder interface {
ShardCountFor(stream logproto.Stream) (int, bool)
IncreaseShardsFor(stream logproto.Stream)
// RateStore manages the ingestion rate of streams, populated by data fetched from ingesters.
type RateStore interface {
RateFor(stream *logproto.Stream) (int, error)
}
// Distributor coordinates replicates and distribution of log streams.
@ -92,7 +102,8 @@ type Distributor struct {
ingestersRing ring.ReadRing
validator *Validator
pool *ring_client.Pool
streamSharder StreamSharder
rateStore RateStore
// The global rate limiter requires a distributors ring to count
// the number of healthy instances.
@ -109,6 +120,7 @@ type Distributor struct {
ingesterAppends *prometheus.CounterVec
ingesterAppendFailures *prometheus.CounterVec
replicationFactor prometheus.Gauge
streamShardingFailures *prometheus.CounterVec
}
// New a distributor creates.
@ -186,6 +198,13 @@ func New(
Name: "distributor_replication_factor",
Help: "The configured replication factor.",
}),
streamShardingFailures: promauto.With(registerer).NewCounterVec(prometheus.CounterOpts{
Namespace: "loki",
Name: "stream_sharding_failures",
Help: "Total number of failures when sharding a stream",
}, []string{
"reason",
}),
}
d.replicationFactor.Set(float64(ingestersRing.ReplicationFactor()))
rfStats.Set(int64(ingestersRing.ReplicationFactor()))
@ -199,7 +218,7 @@ func New(
d.subservicesWatcher.WatchManager(d.subservices)
d.Service = services.NewBasicService(d.starting, d.running, d.stopping)
d.streamSharder = NewStreamSharder()
d.rateStore = &noopRateStore{}
return &d, nil
}
@ -284,6 +303,7 @@ func (d *Distributor) Push(ctx context.Context, req *logproto.PushRequest) (*log
}
n := 0
streamSize := 0
for _, entry := range stream.Entries {
if err := d.validator.ValidateEntry(validationContext, stream.Labels, entry); err != nil {
validationErr = err
@ -307,11 +327,12 @@ func (d *Distributor) Push(ctx context.Context, req *logproto.PushRequest) (*log
n++
validatedLineSize += len(entry.Line)
validatedLineCount++
streamSize += len(entry.Line)
}
stream.Entries = stream.Entries[:n]
if d.cfg.ShardStreams.Enabled {
derivedKeys, derivedStreams := d.shardStream(stream, userID)
derivedKeys, derivedStreams := d.shardStream(stream, streamSize, userID)
keys = append(keys, derivedKeys...)
streams = append(streams, derivedStreams...)
} else {
@ -389,14 +410,15 @@ func min(x1, x2 int) int {
// shardStream shards (divides) the given stream into N smaller streams, where
// N is the sharding size for the given stream. shardSteam returns the smaller
// streams and their associated keys for hashing to ingesters.
func (d *Distributor) shardStream(stream logproto.Stream, userID string) ([]uint32, []streamTracker) {
shardCount, ok := d.streamSharder.ShardCountFor(stream)
if !ok || shardCount <= 1 {
func (d *Distributor) shardStream(stream logproto.Stream, streamSize int, userID string) ([]uint32, []streamTracker) {
shardCount := d.shardCountFor(&stream, streamSize, d.cfg.ShardStreams.DesiredRate.Val(), d.rateStore)
if shardCount <= 1 {
return []uint32{util.TokenFor(userID, stream.Labels)}, []streamTracker{{stream: stream}}
}
if d.cfg.ShardStreams.LoggingEnabled {
level.Info(util_log.Logger).Log("msg", "sharding request", "stream", stream.Labels)
level.Info(util_log.Logger).Log("msg", "sharding request", "stream", stream.Labels, "shard_count", shardCount)
}
streamLabels := labelTemplate(stream.Labels)
@ -407,6 +429,7 @@ func (d *Distributor) shardStream(stream logproto.Stream, userID string) ([]uint
for i := 0; i < shardCount; i++ {
shard, ok := d.createShard(stream, streamLabels, streamPattern, shardCount, i)
if !ok {
level.Error(util_log.Logger).Log("msg", "couldn't create shard", "stream", stream.Labels, "idx", i)
continue
}
@ -570,3 +593,43 @@ func (d *Distributor) parseStreamLabels(vContext validationContext, key string,
d.labelCache.Add(key, lsVal)
return lsVal, nil
}
// shardCountFor returns the right number of shards to be used by the given stream.
//
// It first checks if the number of shards is present in the shard store. If it isn't it will calculate it
// based on the rate stored in the rate store and will store the new evaluated number of shards.
//
// desiredRate is expected to be given in bytes.
func (d *Distributor) shardCountFor(stream *logproto.Stream, streamSize, desiredRate int, rateStore RateStore) int {
rate, err := rateStore.RateFor(stream)
if err != nil {
d.streamShardingFailures.WithLabelValues("rate_not_found").Inc()
if d.cfg.ShardStreams.LoggingEnabled {
level.Error(util_log.Logger).Log("msg", "couldn't shard stream because rate wasn't found", "stream", stream.Labels)
}
return 1
}
shards := calculateShards(rate, streamSize, desiredRate)
if shards > len(stream.Entries) {
d.streamShardingFailures.WithLabelValues("too_many_shards").Inc()
if d.cfg.ShardStreams.LoggingEnabled {
level.Error(util_log.Logger).Log("msg", "number of shards bigger than number of entries", "stream", stream.Labels, "shards", shards, "entries", len(stream.Entries))
}
return len(stream.Entries)
}
if shards == 0 {
// 1 shard is enough for the given stream.
return 1
}
return shards
}
func calculateShards(rate, streamSize, desiredRate int) int {
shards := float64((rate + streamSize)) / float64(desiredRate)
if shards <= 1 {
return 1
}
return int(math.Ceil(shards))
}

@ -20,6 +20,7 @@ import (
ring_client "github.com/grafana/dskit/ring/client"
"github.com/grafana/dskit/services"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/common/model"
"github.com/prometheus/prometheus/model/labels"
"github.com/stretchr/testify/assert"
@ -35,6 +36,7 @@ import (
"github.com/grafana/loki/pkg/logql/syntax"
"github.com/grafana/loki/pkg/runtime"
fe "github.com/grafana/loki/pkg/util/flagext"
loki_flagext "github.com/grafana/loki/pkg/util/flagext"
loki_net "github.com/grafana/loki/pkg/util/net"
"github.com/grafana/loki/pkg/util/test"
"github.com/grafana/loki/pkg/validation"
@ -480,22 +482,35 @@ func TestStreamShard(t *testing.T) {
totalEntries := generateEntries(100)
shardingFailureMetric := promauto.With(prometheus.DefaultRegisterer).NewCounterVec(
prometheus.CounterOpts{
Namespace: "loki",
Name: "stream_sharding_failures",
Help: "Total number of failures when sharding a stream",
}, []string{
"reason",
},
)
desiredRate := loki_flagext.ByteSize(300)
for _, tc := range []struct {
name string
entries []logproto.Entry
shards int // stub call to ShardCountFor.
name string
entries []logproto.Entry
streamSize int
wantDerivedStream []streamTracker
}{
{
name: "one shard with no entries",
name: "zero shard because no entries",
entries: nil,
shards: 1,
streamSize: 50,
wantDerivedStream: []streamTracker{{stream: baseStream}},
},
{
name: "one shard with one entry",
shards: 1,
entries: totalEntries[0:1],
name: "one shard with one entry",
streamSize: 1,
entries: totalEntries[0:1],
wantDerivedStream: []streamTracker{
{
stream: logproto.Stream{
@ -507,9 +522,9 @@ func TestStreamShard(t *testing.T) {
},
},
{
name: "two shards with 3 entries",
shards: 2,
entries: totalEntries[0:3],
name: "two shards with 3 entries",
streamSize: desiredRate.Val() + 1, // pass the desired rate for 1 byte to force two shards.
entries: totalEntries[0:3],
wantDerivedStream: []streamTracker{
{ // shard 1.
stream: logproto.Stream{
@ -528,9 +543,9 @@ func TestStreamShard(t *testing.T) {
},
},
{
name: "two shards with 5 entries",
shards: 2,
entries: totalEntries[0:5],
name: "two shards with 5 entries",
entries: totalEntries[0:5],
streamSize: desiredRate.Val() + 1, // pass the desired rate for 1 byte to force two shards.
wantDerivedStream: []streamTracker{
{ // shard 1.
stream: logproto.Stream{
@ -549,9 +564,9 @@ func TestStreamShard(t *testing.T) {
},
},
{
name: "one shard with 20 entries",
shards: 1,
entries: totalEntries[0:20],
name: "one shard with 20 entries",
entries: totalEntries[0:20],
streamSize: 1,
wantDerivedStream: []streamTracker{
{ // shard 1.
stream: logproto.Stream{
@ -563,9 +578,9 @@ func TestStreamShard(t *testing.T) {
},
},
{
name: "two shards with 20 entries",
shards: 2,
entries: totalEntries[0:20],
name: "two shards with 20 entries",
entries: totalEntries[0:20],
streamSize: desiredRate.Val() + 1, // pass desired rate by 1 to force two shards.
wantDerivedStream: []streamTracker{
{ // shard 1.
stream: logproto.Stream{
@ -584,9 +599,9 @@ func TestStreamShard(t *testing.T) {
},
},
{
name: "four shards with 20 entries",
shards: 4,
entries: totalEntries[0:20],
name: "four shards with 20 entries",
entries: totalEntries[0:20],
streamSize: 1 + (desiredRate.Val() * 3), // force 4 shards.
wantDerivedStream: []streamTracker{
{ // shard 1.
stream: logproto.Stream{
@ -619,84 +634,50 @@ func TestStreamShard(t *testing.T) {
},
},
{
name: "four shards with 2 entries",
shards: 4,
entries: totalEntries[0:2],
name: "size for four shards with 2 entries, ends up with 4 shards ",
streamSize: 1 + (desiredRate.Val() * 3), // force 4 shards.
entries: totalEntries[0:2],
wantDerivedStream: []streamTracker{
{
stream: logproto.Stream{
Entries: []logproto.Entry{},
Entries: totalEntries[0:1],
Labels: generateShardLabels(baseLabels, 0).String(),
Hash: generateShardLabels(baseLabels, 0).Hash(),
},
},
{
stream: logproto.Stream{
Entries: totalEntries[0:1],
Entries: totalEntries[1:2],
Labels: generateShardLabels(baseLabels, 1).String(),
Hash: generateShardLabels(baseLabels, 1).Hash(),
},
},
{
stream: logproto.Stream{
Entries: []logproto.Entry{},
Labels: generateShardLabels(baseLabels, 2).String(),
Hash: generateShardLabels(baseLabels, 2).Hash(),
},
},
{
stream: logproto.Stream{
Entries: totalEntries[1:2],
Labels: generateShardLabels(baseLabels, 3).String(),
Hash: generateShardLabels(baseLabels, 3).Hash(),
},
},
},
},
{
name: "four shards with 1 entry",
shards: 4,
name: "four shards with 1 entry, ends up with 1 shard only",
entries: totalEntries[0:1],
wantDerivedStream: []streamTracker{
{
stream: logproto.Stream{
Labels: generateShardLabels(baseLabels, 0).String(),
Hash: generateShardLabels(baseLabels, 0).Hash(),
Entries: []logproto.Entry{},
},
},
{
stream: logproto.Stream{
Labels: generateShardLabels(baseLabels, 1).String(),
Hash: generateShardLabels(baseLabels, 1).Hash(),
Entries: []logproto.Entry{},
},
},
{
stream: logproto.Stream{
Labels: generateShardLabels(baseLabels, 2).String(),
Hash: generateShardLabels(baseLabels, 2).Hash(),
Entries: []logproto.Entry{},
},
},
{
stream: logproto.Stream{
stream: logproto.Stream{ // when only one shard we don't even add the stream_shard label.
Labels: baseStream.Labels,
Hash: baseStream.Hash,
Entries: totalEntries[0:1],
Labels: generateShardLabels(baseLabels, 3).String(),
Hash: generateShardLabels(baseLabels, 3).Hash(),
},
},
},
},
} {
t.Run(tc.name, func(t *testing.T) {
d := Distributor{
streamSharder: NewStreamSharderMock(tc.shards),
}
baseStream.Entries = tc.entries
_, derivedStreams := d.shardStream(baseStream, "fake")
d := Distributor{
rateStore: &noopRateStore{},
streamShardingFailures: shardingFailureMetric,
}
d.cfg.ShardStreams.DesiredRate = desiredRate
_, derivedStreams := d.shardStream(baseStream, tc.streamSize, "fake")
require.Equal(t, tc.wantDerivedStream, derivedStreams)
})
}
@ -709,6 +690,15 @@ func BenchmarkShardStream(b *testing.B) {
require.NoError(b, err)
stream.Hash = lbs.Hash()
stream.Labels = lbs.String()
shardingFailureMetric := promauto.With(prometheus.DefaultRegisterer).NewCounterVec(
prometheus.CounterOpts{
Namespace: "loki",
Name: "stream_sharding_failures",
Help: "Total number of failures when sharding a stream",
}, []string{
"reason",
},
)
// helper funcs
generateEntries := func(n int) []logproto.Entry {
@ -723,51 +713,52 @@ func BenchmarkShardStream(b *testing.B) {
}
allEntries := generateEntries(25000)
desiredRate := 3000
distributorBuilder := func(shards int) *Distributor {
d := &Distributor{streamShardingFailures: shardingFailureMetric}
// streamSize is always zero, so number of shards will be dictated just by the rate returned from store.
d.rateStore = &noopRateStore{rate: desiredRate*shards - 1}
return d
}
b.Run("high number of entries, low number of shards", func(b *testing.B) {
d := Distributor{
streamSharder: NewStreamSharderMock(2),
}
d := distributorBuilder(2)
stream.Entries = allEntries
b.ResetTimer()
for n := 0; n < b.N; n++ {
d.shardStream(stream, "fake")
d.shardStream(stream, 0, "fake") //nolint:errcheck
}
})
b.Run("low number of entries, low number of shards", func(b *testing.B) {
d := Distributor{
streamSharder: NewStreamSharderMock(2),
}
d := distributorBuilder(2)
stream.Entries = nil
b.ResetTimer()
for n := 0; n < b.N; n++ {
d.shardStream(stream, "fake")
d.shardStream(stream, 0, "fake") //nolint:errcheck
}
})
b.Run("high number of entries, high number of shards", func(b *testing.B) {
d := Distributor{
streamSharder: NewStreamSharderMock(64),
}
d := distributorBuilder(64)
stream.Entries = allEntries
b.ResetTimer()
for n := 0; n < b.N; n++ {
d.shardStream(stream, "fake")
d.shardStream(stream, 0, "fake") //nolint:errcheck
}
})
b.Run("low number of entries, high number of shards", func(b *testing.B) {
d := Distributor{
streamSharder: NewStreamSharderMock(64),
}
d := distributorBuilder(64)
stream.Entries = nil
b.ResetTimer()
for n := 0; n < b.N; n++ {
d.shardStream(stream, "fake")
d.shardStream(stream, 0, "fake") //nolint:errcheck
}
})
}
@ -815,6 +806,144 @@ func Benchmark_Push(b *testing.B) {
}
}
func TestShardCalculation(t *testing.T) {
megabyte := 1000
desiredRate := 3 * megabyte
for _, tc := range []struct {
name string
streamSize int
rate int
wantShards int
}{
{
name: "not enough data to be sharded, stream size (1mb) + ingested rate (0mb) < 3mb",
streamSize: 1 * megabyte,
rate: 0,
wantShards: 1,
},
{
name: "enough data to have two shards, stream size (1mb) + ingested rate (4mb) > 3mb",
streamSize: 1 * megabyte,
rate: desiredRate + 1,
wantShards: 2,
},
{
name: "enough data to have two shards, stream size (4mb) + ingested rate (0mb) > 3mb",
streamSize: 4 * megabyte,
rate: 0,
wantShards: 2,
},
{
name: "a lot of shards, stream size (1mb) + ingested rate (300mb) > 3mb",
streamSize: 1 * megabyte,
rate: 300 * megabyte,
wantShards: 101,
},
} {
t.Run(tc.name, func(t *testing.T) {
got := calculateShards(tc.rate, tc.streamSize, desiredRate)
require.Equal(t, tc.wantShards, got)
})
}
}
func TestShardCountFor(t *testing.T) {
shardingFailureMetric := promauto.With(prometheus.DefaultRegisterer).NewCounterVec(
prometheus.CounterOpts{
Namespace: "loki",
Name: "test_shard_count_for",
Help: "Total number of failures when sharding a stream",
}, []string{
"reason",
},
)
for _, tc := range []struct {
name string
stream *logproto.Stream
rate int
desiredRate int
wantStreamSize int // used for sanity check.
wantShards int
wantErr bool
}{
{
name: "0 entries, return 0 shards always",
stream: &logproto.Stream{Hash: 1},
rate: 0,
desiredRate: 3, // in bytes
wantStreamSize: 2, // in bytes
wantShards: 0,
wantErr: true,
},
{
// although in this scenario we have enough size to be sharded, we can't divide the number of entries between the ingesters
// because the number of entries is lower than the number of shards.
name: "not enough entries to be sharded, stream size (2b) + ingested rate (0b) < 3b = 1 shard but 0 entries",
stream: &logproto.Stream{Hash: 1, Entries: []logproto.Entry{{Line: "abcde"}}},
rate: 0,
desiredRate: 3, // in bytes
wantStreamSize: 2, // in bytes
wantShards: 1,
wantErr: true,
},
{
name: "not enough data to be sharded, stream size (18b) + ingested rate (0b) < 20b",
stream: &logproto.Stream{Entries: []logproto.Entry{{Line: "a"}}},
rate: 0,
desiredRate: 20, // in bytes
wantStreamSize: 18, // in bytes
wantShards: 1,
wantErr: false,
},
{
name: "enough data to have two shards, stream size (36b) + ingested rate (24b) > 40b",
stream: &logproto.Stream{Entries: []logproto.Entry{{Line: "a"}, {Line: "b"}}},
rate: 24, // in bytes
desiredRate: 40, // in bytes
wantStreamSize: 36, // in bytes
wantShards: 2,
wantErr: false,
},
{
// although the ingested rate by an ingester is 0, the stream is big enough to be sharded.
name: "enough data to have two shards, stream size (36b) + ingested rate (0b) > 22b",
stream: &logproto.Stream{Entries: []logproto.Entry{{Line: "a"}, {Line: "b"}}},
rate: 0, // in bytes
desiredRate: 22, // in bytes
wantStreamSize: 36, // in bytes
wantShards: 2,
wantErr: false,
},
{
name: "a lot of shards, stream size (1mb) + ingested rate (300mb) > 3mb",
stream: &logproto.Stream{Entries: []logproto.Entry{
{Line: "a"}, {Line: "b"}, {Line: "c"}, {Line: "d"}, {Line: "e"},
}},
rate: 0, // in bytes
desiredRate: 22, // in bytes
wantStreamSize: 90, // in bytes
wantShards: 5,
wantErr: false,
},
} {
t.Run(tc.name, func(t *testing.T) {
limits := &validation.Limits{}
flagext.DefaultValues(limits)
limits.EnforceMetricName = false
d := &Distributor{
streamShardingFailures: shardingFailureMetric,
}
got := d.shardCountFor(tc.stream, tc.wantStreamSize, tc.desiredRate, &noopRateStore{tc.rate})
require.Equal(t, tc.wantShards, got)
})
}
}
func Benchmark_PushWithLineTruncation(b *testing.B) {
limits := &validation.Limits{}
flagext.DefaultValues(limits)

@ -0,0 +1,25 @@
package distributor
import (
"fmt"
"github.com/grafana/loki/pkg/logproto"
)
type unshardableStreamErr struct {
labels string
entriesNum int
shardNum int
}
func (u *unshardableStreamErr) Error() string {
return fmt.Sprintf("couldn't shard stream %s. number of shards (%d) is higher than number of entries (%d)", u.labels, u.shardNum, u.entriesNum)
}
type noopRateStore struct {
rate int
}
func (n *noopRateStore) RateFor(stream *logproto.Stream) (int, error) {
return n.rate, nil
}

@ -1,49 +0,0 @@
package distributor
import (
"sync"
"github.com/grafana/loki/pkg/logproto"
)
type streamSharder struct {
mu sync.RWMutex
streams map[string]int
}
func NewStreamSharder() StreamSharder {
return &streamSharder{
streams: make(map[string]int),
}
}
func (s *streamSharder) ShardCountFor(stream logproto.Stream) (int, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
shards := s.streams[stream.Labels]
if shards > 0 {
return shards, true
}
return 0, false
}
// IncreaseShardsFor shards the given stream by doubling its number of shards.
func (s *streamSharder) IncreaseShardsFor(stream logproto.Stream) {
s.mu.Lock()
defer s.mu.Unlock()
shards := s.streams[stream.Labels]
// Since the number of shards of a stream that is being sharded for the first time is 0,
// we assign to it shards = max(shards*2, 2) such that its number of shards will be no less than 2.
s.streams[stream.Labels] = max(shards*2, 2)
}
func max(a, b int) int {
if a > b {
return a
}
return b
}

@ -1,73 +1,24 @@
package distributor
import (
"testing"
"fmt"
"github.com/grafana/loki/pkg/logproto"
"github.com/stretchr/testify/require"
)
func TestStreamSharder(t *testing.T) {
stream := logproto.Stream{Entries: make([]logproto.Entry, 11), Labels: "test-stream"}
stream2 := logproto.Stream{Entries: make([]logproto.Entry, 11), Labels: "test-stream-2"}
t.Run("it returns not ok when a stream should not be sharded", func(t *testing.T) {
sharder := NewStreamSharder()
shards, ok := sharder.ShardCountFor(stream)
require.Equal(t, shards, 0)
require.False(t, ok)
})
t.Run("it keeps track of multiple streams", func(t *testing.T) {
sharder := NewStreamSharder()
sharder.IncreaseShardsFor(stream)
sharder.IncreaseShardsFor(stream)
sharder.IncreaseShardsFor(stream2)
shards, ok := sharder.ShardCountFor(stream)
require.True(t, ok)
require.Equal(t, 4, shards)
shards, ok = sharder.ShardCountFor(stream2)
require.True(t, ok)
require.Equal(t, 2, shards)
})
}
type StreamSharderMock struct {
calls map[string]int
wantShards int
}
func NewStreamSharderMock(shards int) *StreamSharderMock {
return &StreamSharderMock{
calls: make(map[string]int),
wantShards: shards,
}
}
func (s *StreamSharderMock) IncreaseShardsFor(stream logproto.Stream) {
s.increaseCallsFor("IncreaseShardsFor")
}
func (s *StreamSharderMock) ShardCountFor(stream logproto.Stream) (int, bool) {
s.increaseCallsFor("ShardCountFor")
func (s *StreamSharderMock) ShardCountFor(*logproto.Stream, int, RateStore) (int, error) {
if s.wantShards < 0 {
return 0, false
return 0, fmt.Errorf("unshardable stream")
}
return s.wantShards, true
}
func (s *StreamSharderMock) increaseCallsFor(funcName string) {
if _, ok := s.calls[funcName]; ok {
s.calls[funcName]++
return
}
s.calls[funcName] = 1
return s.wantShards, nil
}

Loading…
Cancel
Save