feat: Resolve ingestion policy via a header (#19548)

pull/19403/head
Salva Corts 3 months ago committed by GitHub
parent dd52234fc5
commit 987840b5d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 14
      pkg/distributor/distributor.go
  2. 16
      pkg/distributor/distributor_test.go
  3. 2
      pkg/distributor/http.go
  4. 12
      pkg/ingester/instance.go
  5. 2
      pkg/ingester/instance_test.go
  6. 2
      pkg/ingester/recalculate_owned_streams_test.go
  7. 4
      pkg/loghttp/push/otlp.go
  8. 8
      pkg/loghttp/push/otlp_test.go
  9. 9
      pkg/loghttp/push/push.go
  10. 6
      pkg/loghttp/push/push_test.go
  11. 1
      pkg/loki/modules.go
  12. 67
      pkg/validation/ingestion_policies.go
  13. 202
      pkg/validation/ingestion_policies_test.go

@ -601,7 +601,7 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe
var lbs labels.Labels
var retentionHours, policy string
lbs, stream.Labels, stream.Hash, retentionHours, policy, err = d.parseStreamLabels(validationContext, stream.Labels, stream, streamResolver, format)
lbs, stream.Labels, stream.Hash, retentionHours, policy, err = d.parseStreamLabels(ctx, validationContext, stream.Labels, stream, streamResolver, format)
if err != nil {
d.writeFailuresManager.Log(tenantID, err)
validationErrors.Add(err)
@ -915,7 +915,7 @@ func (d *Distributor) trackDiscardedData(
if d.usageTracker != nil {
for _, stream := range req.Streams {
lbs, _, _, _, _, err := d.parseStreamLabels(validationContext, stream.Labels, stream, streamResolver, format)
lbs, _, _, _, _, err := d.parseStreamLabels(ctx, validationContext, stream.Labels, stream, streamResolver, format)
if err != nil {
continue
}
@ -1309,10 +1309,10 @@ type labelData struct {
}
// parseStreamLabels parses stream labels using a request-scoped policy resolver
func (d *Distributor) parseStreamLabels(vContext validationContext, key string, stream logproto.Stream, streamResolver push.StreamResolver, format string) (labels.Labels, string, uint64, string, string, error) {
func (d *Distributor) parseStreamLabels(ctx context.Context, vContext validationContext, key string, stream logproto.Stream, streamResolver push.StreamResolver, format string) (labels.Labels, string, uint64, string, string, error) {
if val, ok := d.labelCache.Get(key); ok {
retentionHours := streamResolver.RetentionHoursFor(val.ls)
policy := streamResolver.PolicyFor(val.ls)
policy := streamResolver.PolicyFor(ctx, val.ls)
return val.ls, val.ls.String(), val.hash, retentionHours, policy, nil
}
@ -1323,7 +1323,7 @@ func (d *Distributor) parseStreamLabels(vContext validationContext, key string,
return labels.EmptyLabels(), "", 0, retentionHours, "", fmt.Errorf(validation.InvalidLabelsErrorMsg, key, err)
}
policy := streamResolver.PolicyFor(ls)
policy := streamResolver.PolicyFor(ctx, ls)
retentionHours := d.tenantsRetention.RetentionHoursFor(vContext.userID, ls)
if err := d.validator.ValidateLabels(vContext, ls, stream, retentionHours, policy, format); err != nil {
@ -1454,8 +1454,8 @@ func (r requestScopedStreamResolver) RetentionHoursFor(lbs labels.Labels) string
return r.retention.RetentionHoursFor(lbs)
}
func (r requestScopedStreamResolver) PolicyFor(lbs labels.Labels) string {
policies := r.policyStreamMappings.PolicyFor(lbs)
func (r requestScopedStreamResolver) PolicyFor(ctx context.Context, lbs labels.Labels) string {
policies := r.policyStreamMappings.PolicyFor(ctx, lbs)
var policy string
if len(policies) > 0 {

@ -1291,7 +1291,7 @@ func Benchmark_SortLabelsOnPush(b *testing.B) {
for n := 0; n < b.N; n++ {
stream := request.Streams[0]
stream.Labels = `{buzz="f", a="b"}`
_, _, _, _, _, err := d.parseStreamLabels(vCtx, stream.Labels, stream, streamResolver, constants.Loki)
_, _, _, _, _, err := d.parseStreamLabels(context.Background(), vCtx, stream.Labels, stream, streamResolver, constants.Loki)
if err != nil {
panic("parseStreamLabels fail,err:" + err.Error())
}
@ -1331,7 +1331,7 @@ func TestParseStreamLabels(t *testing.T) {
vCtx := d.validator.getValidationContextForTime(testTime, "123")
streamResolver := newRequestScopedStreamResolver("123", d.validator.Limits, nil)
t.Run(tc.name, func(t *testing.T) {
lbs, lbsString, hash, _, _, err := d.parseStreamLabels(vCtx, tc.origLabels, logproto.Stream{
lbs, lbsString, hash, _, _, err := d.parseStreamLabels(context.Background(), vCtx, tc.origLabels, logproto.Stream{
Labels: tc.origLabels,
}, streamResolver, constants.Loki)
if tc.expectedErr != nil {
@ -2336,10 +2336,10 @@ func TestRequestScopedStreamResolver(t *testing.T) {
retentionPeriod = resolver.RetentionPeriodFor(labels.FromStrings("env", "dev"))
require.Equal(t, 24*time.Hour, retentionPeriod)
policy := resolver.PolicyFor(labels.FromStrings("env", "prod"))
policy := resolver.PolicyFor(t.Context(), labels.FromStrings("env", "prod"))
require.Equal(t, "policy0", policy)
policy = resolver.PolicyFor(labels.FromStrings("env", "dev"))
policy = resolver.PolicyFor(t.Context(), labels.FromStrings("env", "dev"))
require.Empty(t, policy)
// We now modify the underlying limits to test that the resolver is not affected by changes to the limits
@ -2378,10 +2378,10 @@ func TestRequestScopedStreamResolver(t *testing.T) {
retentionPeriod = resolver.RetentionPeriodFor(labels.FromStrings("env", "dev"))
require.Equal(t, 24*time.Hour, retentionPeriod)
policy = resolver.PolicyFor(labels.FromStrings("env", "prod"))
policy = resolver.PolicyFor(t.Context(), labels.FromStrings("env", "prod"))
require.Equal(t, "policy0", policy)
policy = resolver.PolicyFor(labels.FromStrings("env", "dev"))
policy = resolver.PolicyFor(t.Context(), labels.FromStrings("env", "dev"))
require.Empty(t, policy)
// But a new resolver should return the new values
@ -2397,10 +2397,10 @@ func TestRequestScopedStreamResolver(t *testing.T) {
retentionPeriod = newResolver.RetentionPeriodFor(labels.FromStrings("env", "dev"))
require.Equal(t, 72*time.Hour, retentionPeriod)
policy = newResolver.PolicyFor(labels.FromStrings("env", "prod"))
policy = newResolver.PolicyFor(t.Context(), labels.FromStrings("env", "prod"))
require.Empty(t, policy)
policy = newResolver.PolicyFor(labels.FromStrings("env", "dev"))
policy = newResolver.PolicyFor(t.Context(), labels.FromStrings("env", "dev"))
require.Equal(t, "policy1", policy)
}

@ -132,7 +132,7 @@ func (d *Distributor) pushHandler(w http.ResponseWriter, r *http.Request, pushRe
"stream", s.Labels,
"streamLabelsHash", util.HashedQuery(s.Labels), // this is to make it easier to do searching and grouping
"streamSizeBytes", humanize.Bytes(uint64(pushStats.StreamSizeBytes[s.Labels])),
"policy", streamResolver.PolicyFor(lbs),
"policy", streamResolver.PolicyFor(r.Context(), lbs),
}
if timestamp, ok := pushStats.MostRecentEntryTimestampPerStream[s.Labels]; ok {
logValues = append(logValues, "mostRecentLagMs", time.Since(timestamp).Milliseconds())

@ -201,7 +201,7 @@ func (i *instance) consumeChunk(ctx context.Context, ls labels.Labels, chunk *lo
s, _, _ := i.streams.LoadOrStoreNewByFP(fp,
func() (*stream, error) {
s, err := i.createStreamByFP(ls, fp)
s, err := i.createStreamByFP(ctx, ls, fp)
s.chunkMtx.Lock() // Lock before return, because we have defer that unlocks it.
if err != nil {
return nil, err
@ -299,7 +299,7 @@ func (i *instance) createStream(ctx context.Context, pushReqStream logproto.Stre
}
retentionHours := util.RetentionHours(i.tenantsRetention.RetentionPeriodFor(i.instanceID, labels))
policy := i.resolvePolicyForStream(labels)
policy := i.resolvePolicyForStream(ctx, labels)
if record != nil {
err = i.streamCountLimiter.AssertNewStreamAllowed(i.instanceID, policy)
@ -336,9 +336,9 @@ func (i *instance) createStream(ctx context.Context, pushReqStream logproto.Stre
return s, nil
}
func (i *instance) resolvePolicyForStream(labels labels.Labels) string {
func (i *instance) resolvePolicyForStream(ctx context.Context, labels labels.Labels) string {
mapping := i.limiter.limits.PoliciesStreamMapping(i.instanceID)
policies := mapping.PolicyFor(labels)
policies := mapping.PolicyFor(ctx, labels)
// NOTE: We previously resolved the policy on distributors and logged when multiple policies were matched.
// As on distributors, we use the first policy by alphabetical order.
var policy string
@ -400,7 +400,7 @@ func (i *instance) onStreamCreated(s *stream) {
}
}
func (i *instance) createStreamByFP(ls labels.Labels, fp model.Fingerprint) (*stream, error) {
func (i *instance) createStreamByFP(ctx context.Context, ls labels.Labels, fp model.Fingerprint) (*stream, error) {
sortedLabels := i.index.Add(logproto.FromLabelsToLabelAdapters(ls), fp)
chunkfmt, headfmt, err := i.chunkFormatAt(model.Now())
@ -409,7 +409,7 @@ func (i *instance) createStreamByFP(ls labels.Labels, fp model.Fingerprint) (*st
}
retentionHours := util.RetentionHours(i.tenantsRetention.RetentionPeriodFor(i.instanceID, ls))
policy := i.resolvePolicyForStream(ls)
policy := i.resolvePolicyForStream(ctx, ls)
s := newStream(chunkfmt, headfmt, i.cfg, i.limiter.rateLimitStrategy, i.instanceID, fp, sortedLabels, i.limiter.UnorderedWrites(i.instanceID), i.streamRateCalculator, i.metrics, i.writeFailures, i.configs, retentionHours, policy)

@ -581,7 +581,7 @@ func Benchmark_instance_addNewTailer(b *testing.B) {
chunkfmt, headfmt, err := inst.chunkFormatAt(model.Now())
require.NoError(b, err)
retentionHours := util.RetentionHours(tenantsRetention.RetentionPeriodFor("test", lbs))
policy := inst.resolvePolicyForStream(lbs)
policy := inst.resolvePolicyForStream(context.Background(), lbs)
b.Run("addTailersToNewStream", func(b *testing.B) {
for n := 0; n < b.N; n++ {

@ -243,7 +243,7 @@ func createStream(t *testing.T, inst *instance, fingerprint int) *stream {
lbls := labels.FromStrings("mock", strconv.Itoa(fingerprint))
stream, _, err := inst.streams.LoadOrStoreNew(lbls.String(), func() (*stream, error) {
return inst.createStreamByFP(lbls, model.Fingerprint(fingerprint))
return inst.createStreamByFP(context.Background(), lbls, model.Fingerprint(fingerprint))
}, nil)
require.NoError(t, err)
return stream

@ -242,7 +242,7 @@ func otlpToLokiPushRequest(ctx context.Context, ld plog.Logs, userID string, otl
// Calculate resource attributes metadata size for stats
resourceAttributesAsStructuredMetadataSize := loki_util.StructuredMetadataSize(resourceAttributesAsStructuredMetadata)
retentionPeriodForUser := streamResolver.RetentionPeriodFor(lbs)
policy := streamResolver.PolicyFor(lbs)
policy := streamResolver.PolicyFor(ctx, lbs)
// Check if the stream has the exporter=OTLP label; set flag instead of incrementing per stream
if value, ok := streamLabels[model.LabelName("exporter")]; ok && value == "OTLP" {
@ -386,7 +386,7 @@ func otlpToLokiPushRequest(ctx context.Context, ld plog.Logs, userID string, otl
pushRequestsByStream[entryLabelsStr] = stream
entryRetentionPeriod := streamResolver.RetentionPeriodFor(entryLbs)
entryPolicy := streamResolver.PolicyFor(entryLbs)
entryPolicy := streamResolver.PolicyFor(ctx, entryLbs)
if _, ok := stats.StructuredMetadataBytes[entryPolicy]; !ok {
stats.StructuredMetadataBytes[entryPolicy] = make(map[time.Duration]int64)

@ -583,7 +583,7 @@ func TestOTLPToLokiPushRequest(t *testing.T) {
stats := NewPushStats()
tracker := NewMockTracker()
streamResolver := newMockStreamResolver("fake", &fakeLimits{})
streamResolver.policyForOverride = func(lbs labels.Labels) string {
streamResolver.policyForOverride = func(_ context.Context, lbs labels.Labels) string {
if lbs.Get("service_name") == "service-1" {
return "service-1-policy"
}
@ -926,7 +926,7 @@ func TestOTLPLogAttributesAsIndexLabels(t *testing.T) {
streamResolver := newMockStreamResolver("fake", &fakeLimits{})
// All logs will use the same policy for simplicity
streamResolver.policyForOverride = func(_ labels.Labels) string {
streamResolver.policyForOverride = func(_ context.Context, _ labels.Labels) string {
return "test-policy"
}
@ -1029,7 +1029,7 @@ func TestOTLPStructuredMetadataCalculation(t *testing.T) {
tracker := NewMockTracker()
streamResolver := newMockStreamResolver("fake", &fakeLimits{})
streamResolver.policyForOverride = func(_ labels.Labels) string {
streamResolver.policyForOverride = func(_ context.Context, _ labels.Labels) string {
return "test-policy"
}
@ -1215,7 +1215,7 @@ func TestOTLPSeverityTextAsLabel(t *testing.T) {
streamResolver := newMockStreamResolver("fake", &fakeLimits{})
// All logs will use the same policy for simplicity
streamResolver.policyForOverride = func(_ labels.Labels) string {
streamResolver.policyForOverride = func(_ context.Context, _ labels.Labels) string {
return "test-policy"
}

@ -3,6 +3,7 @@ package push
import (
"compress/flate"
"compress/gzip"
"context"
"fmt"
"io"
"mime"
@ -113,7 +114,7 @@ func (EmptyLimits) PolicyFor(_ string, _ labels.Labels) string {
type StreamResolver interface {
RetentionPeriodFor(lbs labels.Labels) time.Duration
RetentionHoursFor(lbs labels.Labels) string
PolicyFor(lbs labels.Labels) string
PolicyFor(ctx context.Context, lbs labels.Labels) string
}
type (
@ -443,7 +444,7 @@ func ParseLokiRequest(userID string, r *http.Request, limits Limits, tenantConfi
req.Streams[i] = s
}
err = CalculateStreamsStats(userID, req, streamResolver, tenantConfigs, pushStats)
err = CalculateStreamsStats(r.Context(), userID, req, streamResolver, tenantConfigs, pushStats)
if err != nil {
return nil, nil, err
}
@ -452,7 +453,7 @@ func ParseLokiRequest(userID string, r *http.Request, limits Limits, tenantConfi
}
// CalculateStreamsStats modifies pushStats with statistics about all the streams from req.
func CalculateStreamsStats(userID string, req *logproto.PushRequest, streamResolver StreamResolver, tenantConfigs *runtime.TenantConfigs, pushStats *Stats) error {
func CalculateStreamsStats(ctx context.Context, userID string, req *logproto.PushRequest, streamResolver StreamResolver, tenantConfigs *runtime.TenantConfigs, pushStats *Stats) error {
logPushRequestStreams := false
if tenantConfigs != nil {
logPushRequestStreams = tenantConfigs.LogPushRequestStreams(userID)
@ -471,7 +472,7 @@ func CalculateStreamsStats(userID string, req *logproto.PushRequest, streamResol
var policy string
if streamResolver != nil {
retentionPeriod = streamResolver.RetentionPeriodFor(lbs)
policy = streamResolver.PolicyFor(lbs)
policy = streamResolver.PolicyFor(ctx, lbs)
}
if _, ok := pushStats.LogLinesBytes[policy]; !ok {

@ -787,7 +787,7 @@ type mockStreamResolver struct {
tenant string
limits *fakeLimits
policyForOverride func(lbs labels.Labels) string
policyForOverride func(ctx context.Context, lbs labels.Labels) string
}
func newMockStreamResolver(tenant string, limits *fakeLimits) *mockStreamResolver {
@ -805,9 +805,9 @@ func (m mockStreamResolver) RetentionHoursFor(lbs labels.Labels) string {
return m.limits.RetentionHoursFor(m.tenant, lbs)
}
func (m mockStreamResolver) PolicyFor(lbs labels.Labels) string {
func (m mockStreamResolver) PolicyFor(ctx context.Context, lbs labels.Labels) string {
if m.policyForOverride != nil {
return m.policyForOverride(lbs)
return m.policyForOverride(ctx, lbs)
}
return m.limits.PolicyFor(m.tenant, lbs)

@ -391,6 +391,7 @@ func (t *Loki) initDistributor() (services.Service, error) {
httpPushHandlerMiddleware := middleware.Merge(
serverutil.RecoveryHTTPMiddleware,
t.HTTPAuthMiddleware,
validation.NewIngestionPolicyMiddleware(util_log.Logger),
)
lokiPushHandler := httpPushHandlerMiddleware.Wrap(http.HandlerFunc(t.distributor.PushHandler))

@ -1,9 +1,13 @@
package validation
import (
"context"
"fmt"
"net/http"
"slices"
"github.com/go-kit/log"
"github.com/grafana/dskit/middleware"
"github.com/prometheus/prometheus/model/labels"
"github.com/grafana/loki/v3/pkg/logql/syntax"
@ -11,6 +15,8 @@ import (
const (
GlobalPolicy = "*"
HTTPHeaderIngestionPolicyKey = "X-Loki-Ingestion-Policy"
)
type PriorityStream struct {
@ -54,7 +60,15 @@ func (p *PolicyStreamMapping) Validate() error {
// with the same priority.
// Returned policies are sorted alphabetically.
// If no policies match, it returns an empty slice.
func (p *PolicyStreamMapping) PolicyFor(lbs labels.Labels) []string {
// If a policy is set via the X-Loki-Ingestion-Policy header (passed through context), it overrides
// all stream-to-policy mappings and returns that policy.
func (p *PolicyStreamMapping) PolicyFor(ctx context.Context, lbs labels.Labels) []string {
// Check if a policy was set via the HTTP header (X-Loki-Ingestion-Policy)
// This overrides any stream-to-policy mappings
if headerPolicy := ExtractIngestionPolicyContext(ctx); headerPolicy != "" {
return []string{headerPolicy}
}
var (
found bool
highestPriority int
@ -143,3 +157,54 @@ func (p *PolicyStreamMapping) ApplyDefaultPolicyStreamMappings(defaults PolicySt
}
return nil
}
// policyContextKey is used as a key for context values to avoid collisions
type policyContextKey int
const (
ingestionPolicyContextKey policyContextKey = 1
)
// ExtractIngestionPolicyHTTP retrieves the ingestion policy from the HTTP header and returns it.
// If no policy is found, it returns an empty string.
func ExtractIngestionPolicyHTTP(r *http.Request) string {
return r.Header.Get(HTTPHeaderIngestionPolicyKey)
}
// InjectIngestionPolicyContext returns a derived context containing the provided ingestion policy.
func InjectIngestionPolicyContext(ctx context.Context, policy string) context.Context {
return context.WithValue(ctx, ingestionPolicyContextKey, policy)
}
// ExtractIngestionPolicyContext gets the embedded ingestion policy from the context.
// If no policy is found, it returns an empty string.
func ExtractIngestionPolicyContext(ctx context.Context) string {
policy, ok := ctx.Value(ingestionPolicyContextKey).(string)
if !ok {
return ""
}
return policy
}
type ingestionPolicyMiddleware struct {
logger log.Logger
}
// NewIngestionPolicyMiddleware creates a middleware that extracts the ingestion policy
// from the HTTP header and injects it into the context of the request.
func NewIngestionPolicyMiddleware(logger log.Logger) middleware.Interface {
return &ingestionPolicyMiddleware{
logger: logger,
}
}
// Wrap implements the middleware interface
func (m *ingestionPolicyMiddleware) Wrap(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if policy := ExtractIngestionPolicyHTTP(r); policy != "" {
r = r.Clone(InjectIngestionPolicyContext(r.Context(), policy))
}
next.ServeHTTP(w, r)
})
}

@ -1,6 +1,9 @@
package validation
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/prometheus/prometheus/model/labels"
@ -102,19 +105,20 @@ func Test_PolicyStreamMapping_PolicyFor(t *testing.T) {
require.NoError(t, mapping.Validate())
require.Equal(t, []string{"policy1"}, mapping.PolicyFor(labels.FromStrings("foo", "bar")))
ctx := t.Context()
require.Equal(t, []string{"policy1"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar")))
// matches both policy2 and policy1 but policy1 has higher priority.
require.Equal(t, []string{"policy1"}, mapping.PolicyFor(labels.FromStrings("foo", "bar", "daz", "baz")))
require.Equal(t, []string{"policy1"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar", "daz", "baz")))
// matches policy3 and policy4 but policy3 has higher priority..
require.Equal(t, []string{"policy3"}, mapping.PolicyFor(labels.FromStrings("qyx", "qzx", "qox", "qox")))
require.Equal(t, []string{"policy3"}, mapping.PolicyFor(ctx, labels.FromStrings("qyx", "qzx", "qox", "qox")))
// matches no policy.
require.Empty(t, mapping.PolicyFor(labels.FromStrings("foo", "fooz", "daz", "qux", "quux", "corge")))
require.Empty(t, mapping.PolicyFor(ctx, labels.FromStrings("foo", "fooz", "daz", "qux", "quux", "corge")))
// matches policy5 through regex.
require.Equal(t, []string{"policy5"}, mapping.PolicyFor(labels.FromStrings("qab", "qzxqox")))
require.Equal(t, []string{"policy5"}, mapping.PolicyFor(ctx, labels.FromStrings("qab", "qzxqox")))
require.Equal(t, []string{"policy6"}, mapping.PolicyFor(labels.FromStrings("env", "prod", "team", "finance")))
require.Equal(t, []string{"policy6"}, mapping.PolicyFor(ctx, labels.FromStrings("env", "prod", "team", "finance")))
// Matches policy7 and policy8 which have the same priority.
require.Equal(t, []string{"policy7", "policy8"}, mapping.PolicyFor(labels.FromStrings("env", "prod")))
require.Equal(t, []string{"policy7", "policy8"}, mapping.PolicyFor(ctx, labels.FromStrings("env", "prod")))
}
func TestPolicyStreamMapping_ApplyDefaultPolicyStreamMappings(t *testing.T) {
@ -284,3 +288,187 @@ func TestPolicyStreamMapping_ApplyDefaultPolicyStreamMappings_Validation(t *test
// Verify the result is valid
require.NoError(t, existing.Validate())
}
func Test_PolicyStreamMapping_PolicyFor_WithHeaderOverride(t *testing.T) {
mapping := PolicyStreamMapping{
"policy1": []*PriorityStream{
{
Selector: `{foo="bar"}`,
Priority: 2,
Matchers: []*labels.Matcher{
labels.MustNewMatcher(labels.MatchEqual, "foo", "bar"),
},
},
},
"policy2": []*PriorityStream{
{
Selector: `{env="prod"}`,
Priority: 1,
Matchers: []*labels.Matcher{
labels.MustNewMatcher(labels.MatchEqual, "env", "prod"),
},
},
},
}
require.NoError(t, mapping.Validate())
t.Run("without header context, uses normal mapping", func(t *testing.T) {
ctx := t.Context()
// Should match policy1 based on labels
require.Equal(t, []string{"policy1"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar")))
// Should match policy2 based on labels
require.Equal(t, []string{"policy2"}, mapping.PolicyFor(ctx, labels.FromStrings("env", "prod")))
// Should match no policy
require.Empty(t, mapping.PolicyFor(ctx, labels.FromStrings("unknown", "label")))
})
t.Run("with header context, overrides all mappings", func(t *testing.T) {
ctx := InjectIngestionPolicyContext(t.Context(), "override-policy")
// Even though labels match policy1, header policy overrides
require.Equal(t, []string{"override-policy"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar")))
// Even though labels match policy2, header policy overrides
require.Equal(t, []string{"override-policy"}, mapping.PolicyFor(ctx, labels.FromStrings("env", "prod")))
// Even though labels don't match anything, header policy is used
require.Equal(t, []string{"override-policy"}, mapping.PolicyFor(ctx, labels.FromStrings("unknown", "label")))
})
t.Run("empty header context is ignored", func(t *testing.T) {
// Inject empty string - should be treated as not set
ctx := InjectIngestionPolicyContext(t.Context(), "")
// Should fall back to normal mapping behavior
require.Equal(t, []string{"policy1"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar")))
})
}
func TestExtractInjectIngestionPolicyContext(t *testing.T) {
t.Run("inject and extract policy", func(t *testing.T) {
policy := "test-policy"
ctx := InjectIngestionPolicyContext(t.Context(), policy)
extracted := ExtractIngestionPolicyContext(ctx)
require.Equal(t, policy, extracted)
})
t.Run("extract from empty context", func(t *testing.T) {
extracted := ExtractIngestionPolicyContext(t.Context())
require.Empty(t, extracted)
})
t.Run("inject empty string", func(t *testing.T) {
ctx := InjectIngestionPolicyContext(t.Context(), "")
extracted := ExtractIngestionPolicyContext(ctx)
require.Empty(t, extracted)
})
}
func TestExtractIngestionPolicyHTTP(t *testing.T) {
t.Run("extract policy from header", func(t *testing.T) {
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
require.NoError(t, err)
req.Header.Set(HTTPHeaderIngestionPolicyKey, "my-policy")
policy := ExtractIngestionPolicyHTTP(req)
require.Equal(t, "my-policy", policy)
})
t.Run("no header present", func(t *testing.T) {
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
require.NoError(t, err)
policy := ExtractIngestionPolicyHTTP(req)
require.Empty(t, policy)
})
t.Run("empty header value", func(t *testing.T) {
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
require.NoError(t, err)
req.Header.Set(HTTPHeaderIngestionPolicyKey, "")
policy := ExtractIngestionPolicyHTTP(req)
require.Empty(t, policy)
})
t.Run("header with whitespace", func(t *testing.T) {
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
require.NoError(t, err)
req.Header.Set(HTTPHeaderIngestionPolicyKey, " policy-with-spaces ")
policy := ExtractIngestionPolicyHTTP(req)
require.Equal(t, " policy-with-spaces ", policy)
})
}
func TestIngestionPolicyMiddleware(t *testing.T) {
t.Run("middleware injects policy into context", func(t *testing.T) {
var capturedCtx context.Context
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCtx = r.Context()
w.WriteHeader(http.StatusOK)
})
middleware := NewIngestionPolicyMiddleware(nil)
wrappedHandler := middleware.Wrap(handler)
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
require.NoError(t, err)
req.Header.Set(HTTPHeaderIngestionPolicyKey, "test-policy")
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
policy := ExtractIngestionPolicyContext(capturedCtx)
require.Equal(t, "test-policy", policy)
})
t.Run("middleware does not modify context when no header", func(t *testing.T) {
var capturedCtx context.Context
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCtx = r.Context()
w.WriteHeader(http.StatusOK)
})
middleware := NewIngestionPolicyMiddleware(nil)
wrappedHandler := middleware.Wrap(handler)
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
require.NoError(t, err)
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
policy := ExtractIngestionPolicyContext(capturedCtx)
require.Empty(t, policy)
})
t.Run("middleware does not inject empty header value", func(t *testing.T) {
var capturedCtx context.Context
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCtx = r.Context()
w.WriteHeader(http.StatusOK)
})
middleware := NewIngestionPolicyMiddleware(nil)
wrappedHandler := middleware.Wrap(handler)
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
require.NoError(t, err)
req.Header.Set(HTTPHeaderIngestionPolicyKey, "")
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
policy := ExtractIngestionPolicyContext(capturedCtx)
require.Empty(t, policy)
})
}

Loading…
Cancel
Save