diff --git a/pkg/engine/executor/executor.go b/pkg/engine/executor/executor.go index 4cc48a20d1..1f55350ad8 100644 --- a/pkg/engine/executor/executor.go +++ b/pkg/engine/executor/executor.go @@ -61,12 +61,16 @@ func (c *Context) executeDataObjScan(_ context.Context, _ *physical.DataObjScan) return errorPipeline(errNotImplemented) } -func (c *Context) executeSortMerge(_ context.Context, _ *physical.SortMerge, inputs []Pipeline) Pipeline { +func (c *Context) executeSortMerge(_ context.Context, sortmerge *physical.SortMerge, inputs []Pipeline) Pipeline { if len(inputs) == 0 { return emptyPipeline() } - return errorPipeline(errNotImplemented) + pipeline, err := NewSortMergePipeline(inputs, sortmerge.Order, sortmerge.Column, c.evaluator) + if err != nil { + return errorPipeline(err) + } + return pipeline } func (c *Context) executeLimit(_ context.Context, limit *physical.Limit, inputs []Pipeline) Pipeline { diff --git a/pkg/engine/executor/executor_test.go b/pkg/engine/executor/executor_test.go index ee2abf9ac9..b2198613c5 100644 --- a/pkg/engine/executor/executor_test.go +++ b/pkg/engine/executor/executor_test.go @@ -39,13 +39,6 @@ func TestExecutor_SortMerge(t *testing.T) { err := pipeline.Read() require.ErrorContains(t, err, EOF.Error()) }) - - t.Run("is not implemented", func(t *testing.T) { - c := &Context{} - pipeline := c.executeSortMerge(context.TODO(), &physical.SortMerge{}, []Pipeline{emptyPipeline()}) - err := pipeline.Read() - require.ErrorContains(t, err, errNotImplemented.Error()) - }) } func TestExecutor_Limit(t *testing.T) { diff --git a/pkg/engine/executor/expressions.go b/pkg/engine/executor/expressions.go index a32a720f24..8819cc237c 100644 --- a/pkg/engine/executor/expressions.go +++ b/pkg/engine/executor/expressions.go @@ -14,7 +14,7 @@ import ( type expressionEvaluator struct{} -func (e *expressionEvaluator) eval(expr physical.Expression, input arrow.Record) (ColumnVector, error) { +func (e expressionEvaluator) eval(expr physical.Expression, input arrow.Record) (ColumnVector, error) { switch expr := expr.(type) { case *physical.LiteralExpr: @@ -57,6 +57,15 @@ func (e *expressionEvaluator) eval(expr physical.Expression, input arrow.Record) return nil, fmt.Errorf("unknown expression: %v", expr) } +// newFunc returns a new function that can evaluate an input against a binded expression. +func (e expressionEvaluator) newFunc(expr physical.Expression) evalFunc { + return func(input arrow.Record) (ColumnVector, error) { + return e.eval(expr, input) + } +} + +type evalFunc func(input arrow.Record) (ColumnVector, error) + // ColumnVector represents columnar values from evaluated expressions. type ColumnVector interface { // ToArray returns the underlying Arrow array representation of the column vector. diff --git a/pkg/engine/executor/expressions_test.go b/pkg/engine/executor/expressions_test.go index ef1f00e950..d8e1c6da23 100644 --- a/pkg/engine/executor/expressions_test.go +++ b/pkg/engine/executor/expressions_test.go @@ -53,7 +53,7 @@ func TestEvaluateLiteralExpression(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { literal := physical.NewLiteral(tt.value) - e := &expressionEvaluator{} + e := expressionEvaluator{} n := len(words) rec := batch(n, time.Now()) @@ -70,7 +70,7 @@ func TestEvaluateLiteralExpression(t *testing.T) { } func TestEvaluateColumnExpression(t *testing.T) { - e := &expressionEvaluator{} + e := expressionEvaluator{} t.Run("invalid", func(t *testing.T) { colExpr := &physical.ColumnExpr{ diff --git a/pkg/engine/executor/sortmerge.go b/pkg/engine/executor/sortmerge.go new file mode 100644 index 0000000000..1ec1ec5230 --- /dev/null +++ b/pkg/engine/executor/sortmerge.go @@ -0,0 +1,204 @@ +package executor + +import ( + "errors" + "fmt" + "sort" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + + "github.com/grafana/loki/v3/pkg/engine/planner/physical" +) + +// NewSortMergePipeline returns a new pipeline that merges already sorted inputs into a single output. +func NewSortMergePipeline(inputs []Pipeline, order physical.SortOrder, column physical.ColumnExpression, evaluator expressionEvaluator) (*KWayMerge, error) { + var compare func(a, b uint64) bool + switch order { + case physical.ASC: + compare = func(a, b uint64) bool { return a <= b } + case physical.DESC: + compare = func(a, b uint64) bool { return a >= b } + default: + return nil, fmt.Errorf("invalid sort order %v", order) + } + + return &KWayMerge{ + inputs: inputs, + columnEval: evaluator.newFunc(column), + compare: compare, + }, nil +} + +// KWayMerge is a k-way merge of multiple sorted inputs. +// It requires the input batches to be sorted in the same order (ASC/DESC) as the SortMerge operator itself. +// The sort order is defined by the direction of the query, which is either FORWARD or BACKWARDS, +// which is applied to the SortMerge as well as to the DataObjScan during query planning. +type KWayMerge struct { + inputs []Pipeline + state state + initialized bool + batches []arrow.Record + exhausted []bool + offsets []int64 + columnEval evalFunc + compare func(a, b uint64) bool +} + +var _ Pipeline = (*KWayMerge)(nil) + +// Close implements Pipeline. +func (p *KWayMerge) Close() { + // Release last batch + if p.state.batch != nil { + p.state.batch.Release() + } + for _, input := range p.inputs { + input.Close() + } +} + +// Inputs implements Pipeline. +func (p *KWayMerge) Inputs() []Pipeline { + return p.inputs +} + +// Read implements Pipeline. +func (p *KWayMerge) Read() error { + p.init() + return p.read() +} + +// Transport implements Pipeline. +func (p *KWayMerge) Transport() Transport { + return Local +} + +// Value implements Pipeline. +func (p *KWayMerge) Value() (arrow.Record, error) { + return p.state.Value() +} + +func (p *KWayMerge) init() { + if p.initialized { + return + } + + p.initialized = true + + n := len(p.inputs) + p.batches = make([]arrow.Record, n) + p.exhausted = make([]bool, n) + p.offsets = make([]int64, n) + + if p.compare == nil { + p.compare = func(a, b uint64) bool { return a <= b } + } +} + +// Iterate through each record, looking at the value from their starting slice offset. +// Track the top two winners (e.g., the record whose next value is the smallest and the record whose next value is the next smallest). +// Find the largest offset in the starting record whose value is still less than the value of the runner-up record from the previous step. +// Return the slice of that record using the two offsets, and update the stored offset of the returned record for the next call to Read. +func (p *KWayMerge) read() error { + // Release previous batch + if p.state.batch != nil { + p.state.batch.Release() + } + + timestamps := make([]uint64, 0, len(p.inputs)) + batchIndexes := make([]int, 0, len(p.inputs)) + + for i := range len(p.inputs) { + // Skip exhausted inputs + if p.exhausted[i] { + continue + } + + // Load next batch if it hasn't been loaded yet, or if current one is already fully consumed + if p.batches[i] == nil || p.offsets[i] == p.batches[i].NumRows() { + err := p.inputs[i].Read() + if err != nil { + if err == EOF { + p.exhausted[i] = true + continue + } + return err + } + p.offsets[i] = 0 + // It is safe to use the value from the Value() call, because the error is already checked after the Read() call. + // In case the input is exhausted (reached EOF), the return value is `nil`, however, since the flag `p.exhausted[i]` is set, the value will never be read. + p.batches[i], _ = p.inputs[i].Value() + } + + // Fetch timestamp value at current offset + col, err := p.columnEval(p.batches[i]) + if err != nil { + return err + } + tsCol, ok := col.ToArray().(*array.Uint64) + if !ok { + return errors.New("column is not a timestamp column") + } + ts := tsCol.Value(int(p.offsets[i])) + + // Populate slices for sorting + batchIndexes = append(batchIndexes, i) + timestamps = append(timestamps, ts) + } + + // Pipeline is exhausted if no more input batches are available + if len(batchIndexes) == 0 { + p.state = Exhausted + return p.state.err + } + + // If there is only a single remaining batch, return the remaining record + if len(batchIndexes) == 1 { + j := batchIndexes[0] + start := p.offsets[j] + end := p.batches[j].NumRows() + p.state = successState(p.batches[j].NewSlice(start, end)) + p.offsets[j] = end + return nil + } + + // Sort inputs based on timestamps + sort.Slice(batchIndexes, func(i, j int) bool { + return p.compare(timestamps[i], timestamps[j]) + }) + + // Sort timestamps based on timestamps + sort.Slice(timestamps, func(i, j int) bool { + return p.compare(timestamps[i], timestamps[j]) + }) + + // Return the slice of the current record + j := batchIndexes[0] + + // Fetch timestamp value at current offset + col, err := p.columnEval(p.batches[j]) + if err != nil { + return err + } + // We assume the column is a Uint64 array + tsCol, ok := col.ToArray().(*array.Uint64) + if !ok { + return errors.New("column is not a timestamp column") + } + + // Calculate start/end of the sub-slice of the record + start := p.offsets[j] + end := start + for end < p.batches[j].NumRows() { + ts := tsCol.Value(int(end)) + end++ + if p.compare(ts, timestamps[1]) { + break + } + } + + p.state = successState(p.batches[j].NewSlice(start, end)) + p.offsets[j] = end + return nil +} diff --git a/pkg/engine/executor/sortmerge_test.go b/pkg/engine/executor/sortmerge_test.go new file mode 100644 index 0000000000..80f30eb3ff --- /dev/null +++ b/pkg/engine/executor/sortmerge_test.go @@ -0,0 +1,149 @@ +package executor + +import ( + "math" + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/stretchr/testify/require" + + "github.com/grafana/loki/v3/pkg/engine/internal/types" + "github.com/grafana/loki/v3/pkg/engine/planner/physical" +) + +func TestSortMerge(t *testing.T) { + now := time.Date(2024, 04, 15, 0, 0, 0, 0, time.UTC) + var batchSize = int64(10) + + c := &Context{ + batchSize: batchSize, + } + + t.Run("invalid column name", func(t *testing.T) { + merge := &physical.SortMerge{ + Column: &physical.ColumnExpr{ + Ref: types.ColumnRef{ + Column: "invalid", + Type: types.ColumnTypeBuiltin, + }, + }, + } + + inputs := []Pipeline{ + ascendingTimestampPipeline(now.Add(1*time.Nanosecond)).Pipeline(batchSize, 100), + ascendingTimestampPipeline(now.Add(2*time.Nanosecond)).Pipeline(batchSize, 100), + ascendingTimestampPipeline(now.Add(3*time.Nanosecond)).Pipeline(batchSize, 100), + } + + pipeline, err := NewSortMergePipeline(inputs, merge.Order, merge.Column, expressionEvaluator{}) + require.NoError(t, err) + + err = pipeline.Read() + require.ErrorContains(t, err, "key error") + }) + + t.Run("ascending timestamp", func(t *testing.T) { + merge := &physical.SortMerge{ + Column: &physical.ColumnExpr{ + Ref: types.ColumnRef{ + Column: "timestamp", + Type: types.ColumnTypeBuiltin, + }, + }, + Order: physical.ASC, + } + + inputs := []Pipeline{ + ascendingTimestampPipeline(now.Add(1*time.Nanosecond)).Pipeline(batchSize, 100), + ascendingTimestampPipeline(now.Add(2*time.Nanosecond)).Pipeline(batchSize, 100), + ascendingTimestampPipeline(now.Add(3*time.Nanosecond)).Pipeline(batchSize, 100), + } + + pipeline, err := NewSortMergePipeline(inputs, merge.Order, merge.Column, expressionEvaluator{}) + require.NoError(t, err) + + var lastTs uint64 + var batches, rows int64 + for { + err := pipeline.Read() + if err == EOF { + break + } + if err != nil { + t.Fatalf("did not expect error, got %s", err.Error()) + } + batch, _ := pipeline.Value() + + tsCol, err := c.evaluator.eval(merge.Column, batch) + require.NoError(t, err) + arr := tsCol.ToArray().(*array.Uint64) + + // Check if ts column is sorted + for i := 0; i < arr.Len()-1; i++ { + require.LessOrEqual(t, arr.Value(i), arr.Value(i+1)) + // also check ascending order across batches + require.GreaterOrEqual(t, arr.Value(i), lastTs) + lastTs = arr.Value(i + 1) + } + batches++ + rows += batch.NumRows() + } + + // The test scenario is worst case and produces single-row records. + // require.Equal(t, int64(30), batches) + require.Equal(t, int64(300), rows) + }) + + t.Run("descending timestamp", func(t *testing.T) { + merge := &physical.SortMerge{ + Column: &physical.ColumnExpr{ + Ref: types.ColumnRef{ + Column: "timestamp", + Type: types.ColumnTypeBuiltin, + }, + }, + Order: physical.DESC, + } + + inputs := []Pipeline{ + descendingTimestampPipeline(now.Add(1*time.Nanosecond)).Pipeline(batchSize, 100), + descendingTimestampPipeline(now.Add(2*time.Nanosecond)).Pipeline(batchSize, 100), + descendingTimestampPipeline(now.Add(3*time.Nanosecond)).Pipeline(batchSize, 100), + } + + pipeline, err := NewSortMergePipeline(inputs, merge.Order, merge.Column, expressionEvaluator{}) + require.NoError(t, err) + + var lastTs uint64 = math.MaxUint64 + var batches, rows int64 + for { + err := pipeline.Read() + if err == EOF { + break + } + if err != nil { + t.Fatalf("did not expect error, got %s", err.Error()) + } + batch, _ := pipeline.Value() + + tsCol, err := c.evaluator.eval(merge.Column, batch) + require.NoError(t, err) + arr := tsCol.ToArray().(*array.Uint64) + + // Check if ts column is sorted + for i := 0; i < arr.Len()-1; i++ { + require.GreaterOrEqual(t, arr.Value(i), arr.Value(i+1)) + // also check descending order across batches + require.LessOrEqual(t, arr.Value(i), lastTs) + lastTs = arr.Value(i + 1) + } + batches++ + rows += batch.NumRows() + } + + // The test scenario is worst case and produces single-row records. + // require.Equal(t, int64(30), batches) + require.Equal(t, int64(300), rows) + }) +} diff --git a/pkg/engine/executor/util_test.go b/pkg/engine/executor/util_test.go index c1c87dd3fe..c167383494 100644 --- a/pkg/engine/executor/util_test.go +++ b/pkg/engine/executor/util_test.go @@ -2,6 +2,7 @@ package executor import ( "testing" + "time" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" @@ -13,6 +14,7 @@ var ( arrow.NewSchema([]arrow.Field{ {Name: "id", Type: arrow.PrimitiveTypes.Int64}, }, nil), + func(offset, sz int64, schema *arrow.Schema) arrow.Record { builder := array.NewInt64Builder(memory.DefaultAllocator) defer builder.Release() @@ -30,6 +32,50 @@ var ( ) ) +func ascendingTimestampPipeline(start time.Time) *recordGenerator { + return timestampPipeline(start, ascending) +} + +func descendingTimestampPipeline(start time.Time) *recordGenerator { + return timestampPipeline(start, descending) +} + +const ( + ascending = time.Duration(1) + descending = time.Duration(-1) +) + +func timestampPipeline(start time.Time, order time.Duration) *recordGenerator { + return newRecordGenerator( + arrow.NewSchema([]arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int64}, + {Name: "timestamp", Type: arrow.PrimitiveTypes.Uint64}, + }, nil), + + func(offset, sz int64, schema *arrow.Schema) arrow.Record { + idColBuilder := array.NewInt64Builder(memory.DefaultAllocator) + defer idColBuilder.Release() + + tsColBuilder := array.NewUint64Builder(memory.DefaultAllocator) + defer tsColBuilder.Release() + + for i := int64(0); i < sz; i++ { + idColBuilder.Append(offset + i) + tsColBuilder.Append(uint64(start.Add(order * (time.Duration(offset)*time.Second + time.Duration(i)*time.Millisecond)).UnixNano())) + } + + idData := idColBuilder.NewArray() + defer idData.Release() + + tsData := tsColBuilder.NewArray() + defer tsData.Release() + + columns := []arrow.Array{idData, tsData} + return array.NewRecord(schema, columns, sz) + }, + ) +} + type recordGenerator struct { schema *arrow.Schema batch func(offset, sz int64, schema *arrow.Schema) arrow.Record