diff --git a/pkg/distributor/distributor.go b/pkg/distributor/distributor.go index 16b1ecd248..9069f40a40 100644 --- a/pkg/distributor/distributor.go +++ b/pkg/distributor/distributor.go @@ -601,7 +601,7 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe var lbs labels.Labels var retentionHours, policy string - lbs, stream.Labels, stream.Hash, retentionHours, policy, err = d.parseStreamLabels(validationContext, stream.Labels, stream, streamResolver, format) + lbs, stream.Labels, stream.Hash, retentionHours, policy, err = d.parseStreamLabels(ctx, validationContext, stream.Labels, stream, streamResolver, format) if err != nil { d.writeFailuresManager.Log(tenantID, err) validationErrors.Add(err) @@ -915,7 +915,7 @@ func (d *Distributor) trackDiscardedData( if d.usageTracker != nil { for _, stream := range req.Streams { - lbs, _, _, _, _, err := d.parseStreamLabels(validationContext, stream.Labels, stream, streamResolver, format) + lbs, _, _, _, _, err := d.parseStreamLabels(ctx, validationContext, stream.Labels, stream, streamResolver, format) if err != nil { continue } @@ -1309,10 +1309,10 @@ type labelData struct { } // parseStreamLabels parses stream labels using a request-scoped policy resolver -func (d *Distributor) parseStreamLabels(vContext validationContext, key string, stream logproto.Stream, streamResolver push.StreamResolver, format string) (labels.Labels, string, uint64, string, string, error) { +func (d *Distributor) parseStreamLabels(ctx context.Context, vContext validationContext, key string, stream logproto.Stream, streamResolver push.StreamResolver, format string) (labels.Labels, string, uint64, string, string, error) { if val, ok := d.labelCache.Get(key); ok { retentionHours := streamResolver.RetentionHoursFor(val.ls) - policy := streamResolver.PolicyFor(val.ls) + policy := streamResolver.PolicyFor(ctx, val.ls) return val.ls, val.ls.String(), val.hash, retentionHours, policy, nil } @@ -1323,7 +1323,7 @@ func (d *Distributor) parseStreamLabels(vContext validationContext, key string, return labels.EmptyLabels(), "", 0, retentionHours, "", fmt.Errorf(validation.InvalidLabelsErrorMsg, key, err) } - policy := streamResolver.PolicyFor(ls) + policy := streamResolver.PolicyFor(ctx, ls) retentionHours := d.tenantsRetention.RetentionHoursFor(vContext.userID, ls) if err := d.validator.ValidateLabels(vContext, ls, stream, retentionHours, policy, format); err != nil { @@ -1454,8 +1454,8 @@ func (r requestScopedStreamResolver) RetentionHoursFor(lbs labels.Labels) string return r.retention.RetentionHoursFor(lbs) } -func (r requestScopedStreamResolver) PolicyFor(lbs labels.Labels) string { - policies := r.policyStreamMappings.PolicyFor(lbs) +func (r requestScopedStreamResolver) PolicyFor(ctx context.Context, lbs labels.Labels) string { + policies := r.policyStreamMappings.PolicyFor(ctx, lbs) var policy string if len(policies) > 0 { diff --git a/pkg/distributor/distributor_test.go b/pkg/distributor/distributor_test.go index 303ee2a2d1..5e445f8cfe 100644 --- a/pkg/distributor/distributor_test.go +++ b/pkg/distributor/distributor_test.go @@ -1291,7 +1291,7 @@ func Benchmark_SortLabelsOnPush(b *testing.B) { for n := 0; n < b.N; n++ { stream := request.Streams[0] stream.Labels = `{buzz="f", a="b"}` - _, _, _, _, _, err := d.parseStreamLabels(vCtx, stream.Labels, stream, streamResolver, constants.Loki) + _, _, _, _, _, err := d.parseStreamLabels(context.Background(), vCtx, stream.Labels, stream, streamResolver, constants.Loki) if err != nil { panic("parseStreamLabels fail,err:" + err.Error()) } @@ -1331,7 +1331,7 @@ func TestParseStreamLabels(t *testing.T) { vCtx := d.validator.getValidationContextForTime(testTime, "123") streamResolver := newRequestScopedStreamResolver("123", d.validator.Limits, nil) t.Run(tc.name, func(t *testing.T) { - lbs, lbsString, hash, _, _, err := d.parseStreamLabels(vCtx, tc.origLabels, logproto.Stream{ + lbs, lbsString, hash, _, _, err := d.parseStreamLabels(context.Background(), vCtx, tc.origLabels, logproto.Stream{ Labels: tc.origLabels, }, streamResolver, constants.Loki) if tc.expectedErr != nil { @@ -2336,10 +2336,10 @@ func TestRequestScopedStreamResolver(t *testing.T) { retentionPeriod = resolver.RetentionPeriodFor(labels.FromStrings("env", "dev")) require.Equal(t, 24*time.Hour, retentionPeriod) - policy := resolver.PolicyFor(labels.FromStrings("env", "prod")) + policy := resolver.PolicyFor(t.Context(), labels.FromStrings("env", "prod")) require.Equal(t, "policy0", policy) - policy = resolver.PolicyFor(labels.FromStrings("env", "dev")) + policy = resolver.PolicyFor(t.Context(), labels.FromStrings("env", "dev")) require.Empty(t, policy) // We now modify the underlying limits to test that the resolver is not affected by changes to the limits @@ -2378,10 +2378,10 @@ func TestRequestScopedStreamResolver(t *testing.T) { retentionPeriod = resolver.RetentionPeriodFor(labels.FromStrings("env", "dev")) require.Equal(t, 24*time.Hour, retentionPeriod) - policy = resolver.PolicyFor(labels.FromStrings("env", "prod")) + policy = resolver.PolicyFor(t.Context(), labels.FromStrings("env", "prod")) require.Equal(t, "policy0", policy) - policy = resolver.PolicyFor(labels.FromStrings("env", "dev")) + policy = resolver.PolicyFor(t.Context(), labels.FromStrings("env", "dev")) require.Empty(t, policy) // But a new resolver should return the new values @@ -2397,10 +2397,10 @@ func TestRequestScopedStreamResolver(t *testing.T) { retentionPeriod = newResolver.RetentionPeriodFor(labels.FromStrings("env", "dev")) require.Equal(t, 72*time.Hour, retentionPeriod) - policy = newResolver.PolicyFor(labels.FromStrings("env", "prod")) + policy = newResolver.PolicyFor(t.Context(), labels.FromStrings("env", "prod")) require.Empty(t, policy) - policy = newResolver.PolicyFor(labels.FromStrings("env", "dev")) + policy = newResolver.PolicyFor(t.Context(), labels.FromStrings("env", "dev")) require.Equal(t, "policy1", policy) } diff --git a/pkg/distributor/http.go b/pkg/distributor/http.go index 60b1e2445d..551c91b8f8 100644 --- a/pkg/distributor/http.go +++ b/pkg/distributor/http.go @@ -132,7 +132,7 @@ func (d *Distributor) pushHandler(w http.ResponseWriter, r *http.Request, pushRe "stream", s.Labels, "streamLabelsHash", util.HashedQuery(s.Labels), // this is to make it easier to do searching and grouping "streamSizeBytes", humanize.Bytes(uint64(pushStats.StreamSizeBytes[s.Labels])), - "policy", streamResolver.PolicyFor(lbs), + "policy", streamResolver.PolicyFor(r.Context(), lbs), } if timestamp, ok := pushStats.MostRecentEntryTimestampPerStream[s.Labels]; ok { logValues = append(logValues, "mostRecentLagMs", time.Since(timestamp).Milliseconds()) diff --git a/pkg/ingester/instance.go b/pkg/ingester/instance.go index 2dd3d052f0..e2fcae1351 100644 --- a/pkg/ingester/instance.go +++ b/pkg/ingester/instance.go @@ -201,7 +201,7 @@ func (i *instance) consumeChunk(ctx context.Context, ls labels.Labels, chunk *lo s, _, _ := i.streams.LoadOrStoreNewByFP(fp, func() (*stream, error) { - s, err := i.createStreamByFP(ls, fp) + s, err := i.createStreamByFP(ctx, ls, fp) s.chunkMtx.Lock() // Lock before return, because we have defer that unlocks it. if err != nil { return nil, err @@ -299,7 +299,7 @@ func (i *instance) createStream(ctx context.Context, pushReqStream logproto.Stre } retentionHours := util.RetentionHours(i.tenantsRetention.RetentionPeriodFor(i.instanceID, labels)) - policy := i.resolvePolicyForStream(labels) + policy := i.resolvePolicyForStream(ctx, labels) if record != nil { err = i.streamCountLimiter.AssertNewStreamAllowed(i.instanceID, policy) @@ -336,9 +336,9 @@ func (i *instance) createStream(ctx context.Context, pushReqStream logproto.Stre return s, nil } -func (i *instance) resolvePolicyForStream(labels labels.Labels) string { +func (i *instance) resolvePolicyForStream(ctx context.Context, labels labels.Labels) string { mapping := i.limiter.limits.PoliciesStreamMapping(i.instanceID) - policies := mapping.PolicyFor(labels) + policies := mapping.PolicyFor(ctx, labels) // NOTE: We previously resolved the policy on distributors and logged when multiple policies were matched. // As on distributors, we use the first policy by alphabetical order. var policy string @@ -400,7 +400,7 @@ func (i *instance) onStreamCreated(s *stream) { } } -func (i *instance) createStreamByFP(ls labels.Labels, fp model.Fingerprint) (*stream, error) { +func (i *instance) createStreamByFP(ctx context.Context, ls labels.Labels, fp model.Fingerprint) (*stream, error) { sortedLabels := i.index.Add(logproto.FromLabelsToLabelAdapters(ls), fp) chunkfmt, headfmt, err := i.chunkFormatAt(model.Now()) @@ -409,7 +409,7 @@ func (i *instance) createStreamByFP(ls labels.Labels, fp model.Fingerprint) (*st } retentionHours := util.RetentionHours(i.tenantsRetention.RetentionPeriodFor(i.instanceID, ls)) - policy := i.resolvePolicyForStream(ls) + policy := i.resolvePolicyForStream(ctx, ls) s := newStream(chunkfmt, headfmt, i.cfg, i.limiter.rateLimitStrategy, i.instanceID, fp, sortedLabels, i.limiter.UnorderedWrites(i.instanceID), i.streamRateCalculator, i.metrics, i.writeFailures, i.configs, retentionHours, policy) diff --git a/pkg/ingester/instance_test.go b/pkg/ingester/instance_test.go index c97b18d0a1..a77e2d6a9f 100644 --- a/pkg/ingester/instance_test.go +++ b/pkg/ingester/instance_test.go @@ -581,7 +581,7 @@ func Benchmark_instance_addNewTailer(b *testing.B) { chunkfmt, headfmt, err := inst.chunkFormatAt(model.Now()) require.NoError(b, err) retentionHours := util.RetentionHours(tenantsRetention.RetentionPeriodFor("test", lbs)) - policy := inst.resolvePolicyForStream(lbs) + policy := inst.resolvePolicyForStream(context.Background(), lbs) b.Run("addTailersToNewStream", func(b *testing.B) { for n := 0; n < b.N; n++ { diff --git a/pkg/ingester/recalculate_owned_streams_test.go b/pkg/ingester/recalculate_owned_streams_test.go index 071b8929cb..c4031f0037 100644 --- a/pkg/ingester/recalculate_owned_streams_test.go +++ b/pkg/ingester/recalculate_owned_streams_test.go @@ -243,7 +243,7 @@ func createStream(t *testing.T, inst *instance, fingerprint int) *stream { lbls := labels.FromStrings("mock", strconv.Itoa(fingerprint)) stream, _, err := inst.streams.LoadOrStoreNew(lbls.String(), func() (*stream, error) { - return inst.createStreamByFP(lbls, model.Fingerprint(fingerprint)) + return inst.createStreamByFP(context.Background(), lbls, model.Fingerprint(fingerprint)) }, nil) require.NoError(t, err) return stream diff --git a/pkg/loghttp/push/otlp.go b/pkg/loghttp/push/otlp.go index 084de7ccae..88feceee8b 100644 --- a/pkg/loghttp/push/otlp.go +++ b/pkg/loghttp/push/otlp.go @@ -242,7 +242,7 @@ func otlpToLokiPushRequest(ctx context.Context, ld plog.Logs, userID string, otl // Calculate resource attributes metadata size for stats resourceAttributesAsStructuredMetadataSize := loki_util.StructuredMetadataSize(resourceAttributesAsStructuredMetadata) retentionPeriodForUser := streamResolver.RetentionPeriodFor(lbs) - policy := streamResolver.PolicyFor(lbs) + policy := streamResolver.PolicyFor(ctx, lbs) // Check if the stream has the exporter=OTLP label; set flag instead of incrementing per stream if value, ok := streamLabels[model.LabelName("exporter")]; ok && value == "OTLP" { @@ -386,7 +386,7 @@ func otlpToLokiPushRequest(ctx context.Context, ld plog.Logs, userID string, otl pushRequestsByStream[entryLabelsStr] = stream entryRetentionPeriod := streamResolver.RetentionPeriodFor(entryLbs) - entryPolicy := streamResolver.PolicyFor(entryLbs) + entryPolicy := streamResolver.PolicyFor(ctx, entryLbs) if _, ok := stats.StructuredMetadataBytes[entryPolicy]; !ok { stats.StructuredMetadataBytes[entryPolicy] = make(map[time.Duration]int64) diff --git a/pkg/loghttp/push/otlp_test.go b/pkg/loghttp/push/otlp_test.go index 5c395da71a..670fdb5e3f 100644 --- a/pkg/loghttp/push/otlp_test.go +++ b/pkg/loghttp/push/otlp_test.go @@ -583,7 +583,7 @@ func TestOTLPToLokiPushRequest(t *testing.T) { stats := NewPushStats() tracker := NewMockTracker() streamResolver := newMockStreamResolver("fake", &fakeLimits{}) - streamResolver.policyForOverride = func(lbs labels.Labels) string { + streamResolver.policyForOverride = func(_ context.Context, lbs labels.Labels) string { if lbs.Get("service_name") == "service-1" { return "service-1-policy" } @@ -926,7 +926,7 @@ func TestOTLPLogAttributesAsIndexLabels(t *testing.T) { streamResolver := newMockStreamResolver("fake", &fakeLimits{}) // All logs will use the same policy for simplicity - streamResolver.policyForOverride = func(_ labels.Labels) string { + streamResolver.policyForOverride = func(_ context.Context, _ labels.Labels) string { return "test-policy" } @@ -1029,7 +1029,7 @@ func TestOTLPStructuredMetadataCalculation(t *testing.T) { tracker := NewMockTracker() streamResolver := newMockStreamResolver("fake", &fakeLimits{}) - streamResolver.policyForOverride = func(_ labels.Labels) string { + streamResolver.policyForOverride = func(_ context.Context, _ labels.Labels) string { return "test-policy" } @@ -1215,7 +1215,7 @@ func TestOTLPSeverityTextAsLabel(t *testing.T) { streamResolver := newMockStreamResolver("fake", &fakeLimits{}) // All logs will use the same policy for simplicity - streamResolver.policyForOverride = func(_ labels.Labels) string { + streamResolver.policyForOverride = func(_ context.Context, _ labels.Labels) string { return "test-policy" } diff --git a/pkg/loghttp/push/push.go b/pkg/loghttp/push/push.go index 6c068b52e6..8d4d502254 100644 --- a/pkg/loghttp/push/push.go +++ b/pkg/loghttp/push/push.go @@ -3,6 +3,7 @@ package push import ( "compress/flate" "compress/gzip" + "context" "fmt" "io" "mime" @@ -113,7 +114,7 @@ func (EmptyLimits) PolicyFor(_ string, _ labels.Labels) string { type StreamResolver interface { RetentionPeriodFor(lbs labels.Labels) time.Duration RetentionHoursFor(lbs labels.Labels) string - PolicyFor(lbs labels.Labels) string + PolicyFor(ctx context.Context, lbs labels.Labels) string } type ( @@ -443,7 +444,7 @@ func ParseLokiRequest(userID string, r *http.Request, limits Limits, tenantConfi req.Streams[i] = s } - err = CalculateStreamsStats(userID, req, streamResolver, tenantConfigs, pushStats) + err = CalculateStreamsStats(r.Context(), userID, req, streamResolver, tenantConfigs, pushStats) if err != nil { return nil, nil, err } @@ -452,7 +453,7 @@ func ParseLokiRequest(userID string, r *http.Request, limits Limits, tenantConfi } // CalculateStreamsStats modifies pushStats with statistics about all the streams from req. -func CalculateStreamsStats(userID string, req *logproto.PushRequest, streamResolver StreamResolver, tenantConfigs *runtime.TenantConfigs, pushStats *Stats) error { +func CalculateStreamsStats(ctx context.Context, userID string, req *logproto.PushRequest, streamResolver StreamResolver, tenantConfigs *runtime.TenantConfigs, pushStats *Stats) error { logPushRequestStreams := false if tenantConfigs != nil { logPushRequestStreams = tenantConfigs.LogPushRequestStreams(userID) @@ -471,7 +472,7 @@ func CalculateStreamsStats(userID string, req *logproto.PushRequest, streamResol var policy string if streamResolver != nil { retentionPeriod = streamResolver.RetentionPeriodFor(lbs) - policy = streamResolver.PolicyFor(lbs) + policy = streamResolver.PolicyFor(ctx, lbs) } if _, ok := pushStats.LogLinesBytes[policy]; !ok { diff --git a/pkg/loghttp/push/push_test.go b/pkg/loghttp/push/push_test.go index efc0c56020..70fa3d4d21 100644 --- a/pkg/loghttp/push/push_test.go +++ b/pkg/loghttp/push/push_test.go @@ -787,7 +787,7 @@ type mockStreamResolver struct { tenant string limits *fakeLimits - policyForOverride func(lbs labels.Labels) string + policyForOverride func(ctx context.Context, lbs labels.Labels) string } func newMockStreamResolver(tenant string, limits *fakeLimits) *mockStreamResolver { @@ -805,9 +805,9 @@ func (m mockStreamResolver) RetentionHoursFor(lbs labels.Labels) string { return m.limits.RetentionHoursFor(m.tenant, lbs) } -func (m mockStreamResolver) PolicyFor(lbs labels.Labels) string { +func (m mockStreamResolver) PolicyFor(ctx context.Context, lbs labels.Labels) string { if m.policyForOverride != nil { - return m.policyForOverride(lbs) + return m.policyForOverride(ctx, lbs) } return m.limits.PolicyFor(m.tenant, lbs) diff --git a/pkg/loki/modules.go b/pkg/loki/modules.go index cb2aa7c333..de81e49de9 100644 --- a/pkg/loki/modules.go +++ b/pkg/loki/modules.go @@ -391,6 +391,7 @@ func (t *Loki) initDistributor() (services.Service, error) { httpPushHandlerMiddleware := middleware.Merge( serverutil.RecoveryHTTPMiddleware, t.HTTPAuthMiddleware, + validation.NewIngestionPolicyMiddleware(util_log.Logger), ) lokiPushHandler := httpPushHandlerMiddleware.Wrap(http.HandlerFunc(t.distributor.PushHandler)) diff --git a/pkg/validation/ingestion_policies.go b/pkg/validation/ingestion_policies.go index efcb5f9d30..4794160e08 100644 --- a/pkg/validation/ingestion_policies.go +++ b/pkg/validation/ingestion_policies.go @@ -1,9 +1,13 @@ package validation import ( + "context" "fmt" + "net/http" "slices" + "github.com/go-kit/log" + "github.com/grafana/dskit/middleware" "github.com/prometheus/prometheus/model/labels" "github.com/grafana/loki/v3/pkg/logql/syntax" @@ -11,6 +15,8 @@ import ( const ( GlobalPolicy = "*" + + HTTPHeaderIngestionPolicyKey = "X-Loki-Ingestion-Policy" ) type PriorityStream struct { @@ -54,7 +60,15 @@ func (p *PolicyStreamMapping) Validate() error { // with the same priority. // Returned policies are sorted alphabetically. // If no policies match, it returns an empty slice. -func (p *PolicyStreamMapping) PolicyFor(lbs labels.Labels) []string { +// If a policy is set via the X-Loki-Ingestion-Policy header (passed through context), it overrides +// all stream-to-policy mappings and returns that policy. +func (p *PolicyStreamMapping) PolicyFor(ctx context.Context, lbs labels.Labels) []string { + // Check if a policy was set via the HTTP header (X-Loki-Ingestion-Policy) + // This overrides any stream-to-policy mappings + if headerPolicy := ExtractIngestionPolicyContext(ctx); headerPolicy != "" { + return []string{headerPolicy} + } + var ( found bool highestPriority int @@ -143,3 +157,54 @@ func (p *PolicyStreamMapping) ApplyDefaultPolicyStreamMappings(defaults PolicySt } return nil } + +// policyContextKey is used as a key for context values to avoid collisions +type policyContextKey int + +const ( + ingestionPolicyContextKey policyContextKey = 1 +) + +// ExtractIngestionPolicyHTTP retrieves the ingestion policy from the HTTP header and returns it. +// If no policy is found, it returns an empty string. +func ExtractIngestionPolicyHTTP(r *http.Request) string { + return r.Header.Get(HTTPHeaderIngestionPolicyKey) +} + +// InjectIngestionPolicyContext returns a derived context containing the provided ingestion policy. +func InjectIngestionPolicyContext(ctx context.Context, policy string) context.Context { + return context.WithValue(ctx, ingestionPolicyContextKey, policy) +} + +// ExtractIngestionPolicyContext gets the embedded ingestion policy from the context. +// If no policy is found, it returns an empty string. +func ExtractIngestionPolicyContext(ctx context.Context) string { + policy, ok := ctx.Value(ingestionPolicyContextKey).(string) + if !ok { + return "" + } + return policy +} + +type ingestionPolicyMiddleware struct { + logger log.Logger +} + +// NewIngestionPolicyMiddleware creates a middleware that extracts the ingestion policy +// from the HTTP header and injects it into the context of the request. +func NewIngestionPolicyMiddleware(logger log.Logger) middleware.Interface { + return &ingestionPolicyMiddleware{ + logger: logger, + } +} + +// Wrap implements the middleware interface +func (m *ingestionPolicyMiddleware) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if policy := ExtractIngestionPolicyHTTP(r); policy != "" { + r = r.Clone(InjectIngestionPolicyContext(r.Context(), policy)) + } + + next.ServeHTTP(w, r) + }) +} diff --git a/pkg/validation/ingestion_policies_test.go b/pkg/validation/ingestion_policies_test.go index 13aac68613..bca8f41240 100644 --- a/pkg/validation/ingestion_policies_test.go +++ b/pkg/validation/ingestion_policies_test.go @@ -1,6 +1,9 @@ package validation import ( + "context" + "net/http" + "net/http/httptest" "testing" "github.com/prometheus/prometheus/model/labels" @@ -102,19 +105,20 @@ func Test_PolicyStreamMapping_PolicyFor(t *testing.T) { require.NoError(t, mapping.Validate()) - require.Equal(t, []string{"policy1"}, mapping.PolicyFor(labels.FromStrings("foo", "bar"))) + ctx := t.Context() + require.Equal(t, []string{"policy1"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar"))) // matches both policy2 and policy1 but policy1 has higher priority. - require.Equal(t, []string{"policy1"}, mapping.PolicyFor(labels.FromStrings("foo", "bar", "daz", "baz"))) + require.Equal(t, []string{"policy1"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar", "daz", "baz"))) // matches policy3 and policy4 but policy3 has higher priority.. - require.Equal(t, []string{"policy3"}, mapping.PolicyFor(labels.FromStrings("qyx", "qzx", "qox", "qox"))) + require.Equal(t, []string{"policy3"}, mapping.PolicyFor(ctx, labels.FromStrings("qyx", "qzx", "qox", "qox"))) // matches no policy. - require.Empty(t, mapping.PolicyFor(labels.FromStrings("foo", "fooz", "daz", "qux", "quux", "corge"))) + require.Empty(t, mapping.PolicyFor(ctx, labels.FromStrings("foo", "fooz", "daz", "qux", "quux", "corge"))) // matches policy5 through regex. - require.Equal(t, []string{"policy5"}, mapping.PolicyFor(labels.FromStrings("qab", "qzxqox"))) + require.Equal(t, []string{"policy5"}, mapping.PolicyFor(ctx, labels.FromStrings("qab", "qzxqox"))) - require.Equal(t, []string{"policy6"}, mapping.PolicyFor(labels.FromStrings("env", "prod", "team", "finance"))) + require.Equal(t, []string{"policy6"}, mapping.PolicyFor(ctx, labels.FromStrings("env", "prod", "team", "finance"))) // Matches policy7 and policy8 which have the same priority. - require.Equal(t, []string{"policy7", "policy8"}, mapping.PolicyFor(labels.FromStrings("env", "prod"))) + require.Equal(t, []string{"policy7", "policy8"}, mapping.PolicyFor(ctx, labels.FromStrings("env", "prod"))) } func TestPolicyStreamMapping_ApplyDefaultPolicyStreamMappings(t *testing.T) { @@ -284,3 +288,187 @@ func TestPolicyStreamMapping_ApplyDefaultPolicyStreamMappings_Validation(t *test // Verify the result is valid require.NoError(t, existing.Validate()) } + +func Test_PolicyStreamMapping_PolicyFor_WithHeaderOverride(t *testing.T) { + mapping := PolicyStreamMapping{ + "policy1": []*PriorityStream{ + { + Selector: `{foo="bar"}`, + Priority: 2, + Matchers: []*labels.Matcher{ + labels.MustNewMatcher(labels.MatchEqual, "foo", "bar"), + }, + }, + }, + "policy2": []*PriorityStream{ + { + Selector: `{env="prod"}`, + Priority: 1, + Matchers: []*labels.Matcher{ + labels.MustNewMatcher(labels.MatchEqual, "env", "prod"), + }, + }, + }, + } + + require.NoError(t, mapping.Validate()) + + t.Run("without header context, uses normal mapping", func(t *testing.T) { + ctx := t.Context() + // Should match policy1 based on labels + require.Equal(t, []string{"policy1"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar"))) + // Should match policy2 based on labels + require.Equal(t, []string{"policy2"}, mapping.PolicyFor(ctx, labels.FromStrings("env", "prod"))) + // Should match no policy + require.Empty(t, mapping.PolicyFor(ctx, labels.FromStrings("unknown", "label"))) + }) + + t.Run("with header context, overrides all mappings", func(t *testing.T) { + ctx := InjectIngestionPolicyContext(t.Context(), "override-policy") + + // Even though labels match policy1, header policy overrides + require.Equal(t, []string{"override-policy"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar"))) + + // Even though labels match policy2, header policy overrides + require.Equal(t, []string{"override-policy"}, mapping.PolicyFor(ctx, labels.FromStrings("env", "prod"))) + + // Even though labels don't match anything, header policy is used + require.Equal(t, []string{"override-policy"}, mapping.PolicyFor(ctx, labels.FromStrings("unknown", "label"))) + }) + + t.Run("empty header context is ignored", func(t *testing.T) { + // Inject empty string - should be treated as not set + ctx := InjectIngestionPolicyContext(t.Context(), "") + + // Should fall back to normal mapping behavior + require.Equal(t, []string{"policy1"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar"))) + }) +} + +func TestExtractInjectIngestionPolicyContext(t *testing.T) { + t.Run("inject and extract policy", func(t *testing.T) { + policy := "test-policy" + + ctx := InjectIngestionPolicyContext(t.Context(), policy) + extracted := ExtractIngestionPolicyContext(ctx) + require.Equal(t, policy, extracted) + }) + + t.Run("extract from empty context", func(t *testing.T) { + extracted := ExtractIngestionPolicyContext(t.Context()) + require.Empty(t, extracted) + }) + + t.Run("inject empty string", func(t *testing.T) { + ctx := InjectIngestionPolicyContext(t.Context(), "") + extracted := ExtractIngestionPolicyContext(ctx) + require.Empty(t, extracted) + }) +} + +func TestExtractIngestionPolicyHTTP(t *testing.T) { + t.Run("extract policy from header", func(t *testing.T) { + req, err := http.NewRequest("POST", "/loki/api/v1/push", nil) + require.NoError(t, err) + + req.Header.Set(HTTPHeaderIngestionPolicyKey, "my-policy") + + policy := ExtractIngestionPolicyHTTP(req) + require.Equal(t, "my-policy", policy) + }) + + t.Run("no header present", func(t *testing.T) { + req, err := http.NewRequest("POST", "/loki/api/v1/push", nil) + require.NoError(t, err) + + policy := ExtractIngestionPolicyHTTP(req) + require.Empty(t, policy) + }) + + t.Run("empty header value", func(t *testing.T) { + req, err := http.NewRequest("POST", "/loki/api/v1/push", nil) + require.NoError(t, err) + + req.Header.Set(HTTPHeaderIngestionPolicyKey, "") + + policy := ExtractIngestionPolicyHTTP(req) + require.Empty(t, policy) + }) + + t.Run("header with whitespace", func(t *testing.T) { + req, err := http.NewRequest("POST", "/loki/api/v1/push", nil) + require.NoError(t, err) + + req.Header.Set(HTTPHeaderIngestionPolicyKey, " policy-with-spaces ") + + policy := ExtractIngestionPolicyHTTP(req) + require.Equal(t, " policy-with-spaces ", policy) + }) +} + +func TestIngestionPolicyMiddleware(t *testing.T) { + t.Run("middleware injects policy into context", func(t *testing.T) { + var capturedCtx context.Context + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedCtx = r.Context() + w.WriteHeader(http.StatusOK) + }) + + middleware := NewIngestionPolicyMiddleware(nil) + wrappedHandler := middleware.Wrap(handler) + + req, err := http.NewRequest("POST", "/loki/api/v1/push", nil) + require.NoError(t, err) + req.Header.Set(HTTPHeaderIngestionPolicyKey, "test-policy") + + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + policy := ExtractIngestionPolicyContext(capturedCtx) + require.Equal(t, "test-policy", policy) + }) + + t.Run("middleware does not modify context when no header", func(t *testing.T) { + var capturedCtx context.Context + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedCtx = r.Context() + w.WriteHeader(http.StatusOK) + }) + + middleware := NewIngestionPolicyMiddleware(nil) + wrappedHandler := middleware.Wrap(handler) + + req, err := http.NewRequest("POST", "/loki/api/v1/push", nil) + require.NoError(t, err) + + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + policy := ExtractIngestionPolicyContext(capturedCtx) + require.Empty(t, policy) + }) + + t.Run("middleware does not inject empty header value", func(t *testing.T) { + var capturedCtx context.Context + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedCtx = r.Context() + w.WriteHeader(http.StatusOK) + }) + + middleware := NewIngestionPolicyMiddleware(nil) + wrappedHandler := middleware.Wrap(handler) + + req, err := http.NewRequest("POST", "/loki/api/v1/push", nil) + require.NoError(t, err) + req.Header.Set(HTTPHeaderIngestionPolicyKey, "") + + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + policy := ExtractIngestionPolicyContext(capturedCtx) + require.Empty(t, policy) + }) +}