From 5e056c2a3f14b3127377d4fa07ad255baf93f9f7 Mon Sep 17 00:00:00 2001 From: Kyle Brandt Date: Tue, 13 May 2025 15:22:20 -0400 Subject: [PATCH] SQL Expressions: Add sql expression specific timeout and output limit (#104834) Adds settings for SQL expressions: sql_expression_cell_output_limit Set the maximum number of cells that can be returned from a SQL expression. Default is 100000. sql_expression_timeout The duration a SQL expression will run before being cancelled. The default is 10s. --- .../setup-grafana/configure-grafana/_index.md | 8 +++ pkg/expr/graph.go | 2 +- pkg/expr/nodes.go | 5 +- pkg/expr/reader.go | 2 +- pkg/expr/service_test.go | 4 +- pkg/expr/sql/db.go | 55 ++++++++++++++++++- pkg/expr/sql/db_test.go | 45 +++++++++++++++ pkg/expr/sql/dummy_arm.go | 19 ++++++- pkg/expr/sql/frame_db_conv.go | 26 ++++++++- pkg/expr/sql_command.go | 25 ++++++--- pkg/expr/sql_command_test.go | 6 +- pkg/setting/setting.go | 8 +++ 12 files changed, 183 insertions(+), 22 deletions(-) diff --git a/docs/sources/setup-grafana/configure-grafana/_index.md b/docs/sources/setup-grafana/configure-grafana/_index.md index ff082322cb8..ddc6a0d193e 100644 --- a/docs/sources/setup-grafana/configure-grafana/_index.md +++ b/docs/sources/setup-grafana/configure-grafana/_index.md @@ -2782,6 +2782,14 @@ Set this to `false` to disable expressions and hide them in the Grafana UI. Defa Set the maximum number of cells that can be passed to a SQL expression. Default is `100000`. +#### `sql_expression_cell_output_limit` + +Set the maximum number of cells that can be returned from a SQL expression. Default is `100000`. + +#### `sql_expression_timeout` + +The duration a SQL expression will run before being cancelled. The default is `10s`. + ### `[geomap]` This section controls the defaults settings for **Geomap Plugin**. diff --git a/pkg/expr/graph.go b/pkg/expr/graph.go index ae0ec6f9660..28c5f3d951e 100644 --- a/pkg/expr/graph.go +++ b/pkg/expr/graph.go @@ -277,7 +277,7 @@ func (s *Service) buildGraph(req *Request) (*simple.DirectedGraph, error) { case TypeDatasourceNode: node, err = s.buildDSNode(dp, rn, req) case TypeCMDNode: - node, err = buildCMDNode(rn, s.features, s.cfg.SQLExpressionCellLimit) + node, err = buildCMDNode(rn, s.features, s.cfg) case TypeMLNode: if s.features.IsEnabledGlobally(featuremgmt.FlagMlExpressions) { node, err = s.buildMLNode(dp, rn, req) diff --git a/pkg/expr/nodes.go b/pkg/expr/nodes.go index 660a0a6b8fc..265bfa347b5 100644 --- a/pkg/expr/nodes.go +++ b/pkg/expr/nodes.go @@ -20,6 +20,7 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/setting" ) // label that is used when all mathexp.Series have 0 labels to make them identifiable by labels. The value of this label is extracted from value field names @@ -106,7 +107,7 @@ func (gn *CMDNode) Execute(ctx context.Context, now time.Time, vars mathexp.Vars return gn.Command.Execute(ctx, now, vars, s.tracer, s.metrics) } -func buildCMDNode(rn *rawNode, toggles featuremgmt.FeatureToggles, sqlExpressionCellLimit int64) (*CMDNode, error) { +func buildCMDNode(rn *rawNode, toggles featuremgmt.FeatureToggles, cfg *setting.Cfg) (*CMDNode, error) { commandType, err := GetExpressionCommandType(rn.Query) if err != nil { return nil, fmt.Errorf("invalid command type in expression '%v': %w", rn.RefID, err) @@ -163,7 +164,7 @@ func buildCMDNode(rn *rawNode, toggles featuremgmt.FeatureToggles, sqlExpression case TypeThreshold: node.Command, err = UnmarshalThresholdCommand(rn) case TypeSQL: - node.Command, err = UnmarshalSQLCommand(rn, sqlExpressionCellLimit) + node.Command, err = UnmarshalSQLCommand(rn, cfg) default: return nil, fmt.Errorf("expression command type '%v' in expression '%v' not implemented", commandType, rn.RefID) } diff --git a/pkg/expr/reader.go b/pkg/expr/reader.go index c108b913ccd..62cd6af4539 100644 --- a/pkg/expr/reader.go +++ b/pkg/expr/reader.go @@ -135,7 +135,7 @@ func (h *ExpressionQueryReader) ReadQuery( eq.Properties = q // TODO: Cascade limit from Grafana config in this (new Expression Parser) branch of the code cellLimit := 0 // zero means no limit - eq.Command, err = NewSQLCommand(common.RefID, q.Format, q.Expression, int64(cellLimit)) + eq.Command, err = NewSQLCommand(common.RefID, q.Format, q.Expression, int64(cellLimit), 0, 0) } case QueryTypeThreshold: diff --git a/pkg/expr/service_test.go b/pkg/expr/service_test.go index 306113079db..492f1ce4389 100644 --- a/pkg/expr/service_test.go +++ b/pkg/expr/service_test.go @@ -206,8 +206,8 @@ func TestSQLExpressionCellLimitFromConfig(t *testing.T) { cmdNode := node.(*CMDNode) sqlCmd := cmdNode.Command.(*SQLCommand) - // Verify the SQL command has the correct limit - require.Equal(t, tt.expectedLimit, sqlCmd.limit, "SQL command has incorrect cell limit") + // Verify the SQL command has the correct inputLimit + require.Equal(t, tt.expectedLimit, sqlCmd.inputLimit, "SQL command has incorrect cell limit") }) } } diff --git a/pkg/expr/sql/db.go b/pkg/expr/sql/db.go index e92ae6fc331..d8a8934f81e 100644 --- a/pkg/expr/sql/db.go +++ b/pkg/expr/sql/db.go @@ -4,7 +4,9 @@ package sql import ( "context" + "errors" "fmt" + "time" sqle "github.com/dolthub/go-mysql-server" mysql "github.com/dolthub/go-mysql-server/sql" @@ -53,11 +55,30 @@ func isFunctionNotFoundError(err error) bool { return mysql.ErrFunctionNotFound.Is(err) } +type QueryOption func(*QueryOptions) + +type QueryOptions struct { + Timeout time.Duration + MaxOutputCells int64 +} + +func WithTimeout(d time.Duration) QueryOption { + return func(o *QueryOptions) { + o.Timeout = d + } +} + +func WithMaxOutputCells(n int64) QueryOption { + return func(o *QueryOptions) { + o.MaxOutputCells = n + } +} + // QueryFrames runs the sql query query against a database created from frames, and returns the frame. // The RefID of each frame becomes a table in the database. // It is expected that there is only one frame per RefID. // The name becomes the name and RefID of the returned frame. -func (db *DB) QueryFrames(ctx context.Context, tracer tracing.Tracer, name string, query string, frames []*data.Frame) (*data.Frame, error) { +func (db *DB) QueryFrames(ctx context.Context, tracer tracing.Tracer, name string, query string, frames []*data.Frame, opts ...QueryOption) (*data.Frame, error) { // We are parsing twice due to TablesList, but don't care fow now. We can save the parsed query and reuse it later if we want. if allow, err := AllowQuery(query); err != nil || !allow { if err != nil { @@ -66,6 +87,16 @@ func (db *DB) QueryFrames(ctx context.Context, tracer tracing.Tracer, name strin return nil, err } + QueryOptions := &QueryOptions{} + for _, opt := range opts { + opt(QueryOptions) + } + + if QueryOptions.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, QueryOptions.Timeout) + defer cancel() + } _, span := tracer.Start(ctx, "SSE.ExecuteGMSQuery") defer span.End() @@ -88,15 +119,35 @@ func (db *DB) QueryFrames(ctx context.Context, tracer tracing.Tracer, name strin IsReadOnly: true, }) + contextErr := func(err error) error { + switch { + case errors.Is(err, context.DeadlineExceeded): + return fmt.Errorf("SQL expression for refId %v did not complete within the timeout of %v: %w", name, QueryOptions.Timeout, err) + case errors.Is(err, context.Canceled): + return fmt.Errorf("SQL expression for refId %v was cancelled before it completed: %w", name, err) + default: + return fmt.Errorf("SQL expression for refId %v ended unexpectedly: %w", name, err) + } + } + + // Execute the query (planning + iterator construction) schema, iter, _, err := engine.Query(mCtx, query) if err != nil { + if ctx.Err() != nil { + return nil, contextErr(ctx.Err()) + } return nil, WrapGoMySQLServerError(err) } - f, err := convertToDataFrame(mCtx, iter, schema) + // Convert the iterator into a Grafana data.Frame + f, err := convertToDataFrame(mCtx, iter, schema, QueryOptions.MaxOutputCells) if err != nil { + if ctx.Err() != nil { + return nil, contextErr(ctx.Err()) + } return nil, err } + f.Name = name f.RefID = name diff --git a/pkg/expr/sql/db_test.go b/pkg/expr/sql/db_test.go index b7774cbba48..c8f7d03d768 100644 --- a/pkg/expr/sql/db_test.go +++ b/pkg/expr/sql/db_test.go @@ -286,6 +286,51 @@ func TestQueryFrames_JSONFilter(t *testing.T) { } } +func TestQueryFrames_Limits(t *testing.T) { + tests := []struct { + name string + query string + opts []QueryOption + expectRows int + expectError string + }{ + { + name: "respects max output cells", + query: `SELECT 1 as x UNION ALL SELECT 2 UNION ALL SELECT 3`, + opts: []QueryOption{WithMaxOutputCells(2)}, + expectRows: 2, + }, + { + name: "timeout with large cross join", + query: ` + SELECT a.val + b.val AS sum + FROM (SELECT 1 AS val UNION ALL SELECT 2 UNION ALL SELECT 3 UNION ALL SELECT 4 UNION ALL SELECT 5) a + CROSS JOIN (SELECT 1 AS val UNION ALL SELECT 2 UNION ALL SELECT 3 UNION ALL SELECT 4 UNION ALL SELECT 5) b + `, + opts: []QueryOption{WithTimeout(5 * time.Microsecond)}, + expectError: "did not complete within the timeout", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := DB{} + ctx := context.Background() + frame, err := db.QueryFrames(ctx, &testTracer{}, "test", tt.query, nil, tt.opts...) + + if tt.expectError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectError) + return + } + + require.NoError(t, err) + require.NotNil(t, frame) + require.Equal(t, tt.expectRows, frame.Rows()) + }) + } +} + // p is a utility for pointers from constants func p[T any](v T) *T { return &v diff --git a/pkg/expr/sql/dummy_arm.go b/pkg/expr/sql/dummy_arm.go index 879ba965f7d..89c0d701545 100644 --- a/pkg/expr/sql/dummy_arm.go +++ b/pkg/expr/sql/dummy_arm.go @@ -5,6 +5,7 @@ package sql import ( "context" "fmt" + "time" "github.com/grafana/grafana-plugin-sdk-go/data" "github.com/grafana/grafana/pkg/infra/tracing" @@ -14,6 +15,22 @@ type DB struct{} // Stub out the QueryFrames method for ARM builds // See github.com/dolthub/go-mysql-server/issues/2837 -func (db *DB) QueryFrames(_ context.Context, _ tracing.Tracer, _, _ string, _ []*data.Frame) (*data.Frame, error) { +func (db *DB) QueryFrames(_ context.Context, _ tracing.Tracer, _, _ string, _ []*data.Frame, _...QueryOption) (*data.Frame, error) { return nil, fmt.Errorf("sql expressions not supported in arm") } + +func WithTimeout(_ time.Duration) QueryOption { + return func(_ *QueryOptions) { + // no-op + } +} + +func WithMaxOutputCells(_ int64) QueryOption { + return func(_ *QueryOptions) { + // no-op + } +} + +type QueryOptions struct{} + +type QueryOption func(*QueryOptions) diff --git a/pkg/expr/sql/frame_db_conv.go b/pkg/expr/sql/frame_db_conv.go index c34a36ee483..323a461f39b 100644 --- a/pkg/expr/sql/frame_db_conv.go +++ b/pkg/expr/sql/frame_db_conv.go @@ -16,8 +16,9 @@ import ( ) // TODO: Should this accept a row limit and converters, like sqlutil.FrameFromRows? -func convertToDataFrame(ctx *mysql.Context, iter mysql.RowIter, schema mysql.Schema) (*data.Frame, error) { +func convertToDataFrame(ctx *mysql.Context, iter mysql.RowIter, schema mysql.Schema, maxOutputCells int64) (*data.Frame, error) { f := &data.Frame{} + // Create fields based on the schema for _, col := range schema { fT, err := MySQLColToFieldType(col) @@ -29,8 +30,17 @@ func convertToDataFrame(ctx *mysql.Context, iter mysql.RowIter, schema mysql.Sch f.Fields = append(f.Fields, field) } + cellCount := int64(0) + // Iterate through the rows and append data to fields for { + // Check for context cancellation or timeout + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + row, err := iter.Next(ctx) if errors.Is(err, io.EOF) { break @@ -39,6 +49,20 @@ func convertToDataFrame(ctx *mysql.Context, iter mysql.RowIter, schema mysql.Sch return nil, fmt.Errorf("error reading row: %v", err) } + // We check the cell count here to avoid appending an incomplete row, so the + // the number returned may be less than the maxOutputCells. + // If the maxOutputCells is 0, we don't check the cell count. + if maxOutputCells > 0 { + cellCount += int64(len(row)) + if cellCount > maxOutputCells { + f.AppendNotices(data.Notice{ + Severity: data.NoticeSeverityWarning, + Text: fmt.Sprintf("Query exceeded max output cells (%d). Only %d cells returned.", maxOutputCells, cellCount-int64(len(row))), + }) + return f, nil + } + } + for i, val := range row { // Run val through mysql.Type.Convert to normalize underlying value // of the interface diff --git a/pkg/expr/sql_command.go b/pkg/expr/sql_command.go index 2bbac8fcb04..3fc0e8d0466 100644 --- a/pkg/expr/sql_command.go +++ b/pkg/expr/sql_command.go @@ -13,6 +13,7 @@ import ( "github.com/grafana/grafana/pkg/expr/metrics" "github.com/grafana/grafana/pkg/expr/sql" "github.com/grafana/grafana/pkg/infra/tracing" + "github.com/grafana/grafana/pkg/setting" ) var ( @@ -30,12 +31,16 @@ type SQLCommand struct { query string varsToQuery []string refID string - limit int64 - format string + + format string + + inputLimit int64 + outputLimit int64 + timeout time.Duration } // NewSQLCommand creates a new SQLCommand. -func NewSQLCommand(refID, format, rawSQL string, limit int64) (*SQLCommand, error) { +func NewSQLCommand(refID, format, rawSQL string, intputLimit, outputLimit int64, timeout time.Duration) (*SQLCommand, error) { if rawSQL == "" { return nil, ErrMissingSQLQuery } @@ -63,13 +68,15 @@ func NewSQLCommand(refID, format, rawSQL string, limit int64) (*SQLCommand, erro query: rawSQL, varsToQuery: tables, refID: refID, - limit: limit, + inputLimit: intputLimit, + outputLimit: outputLimit, + timeout: timeout, format: format, }, nil } // UnmarshalSQLCommand creates a SQLCommand from Grafana's frontend query. -func UnmarshalSQLCommand(rn *rawNode, limit int64) (*SQLCommand, error) { +func UnmarshalSQLCommand(rn *rawNode, cfg *setting.Cfg) (*SQLCommand, error) { if rn.TimeRange == nil { logger.Error("time range must be specified for refID", "refID", rn.RefID) return nil, fmt.Errorf("time range must be specified for refID %s", rn.RefID) @@ -89,7 +96,7 @@ func UnmarshalSQLCommand(rn *rawNode, limit int64) (*SQLCommand, error) { formatRaw := rn.Query["format"] format, _ := formatRaw.(string) - return NewSQLCommand(rn.RefID, format, expression, limit) + return NewSQLCommand(rn.RefID, format, expression, cfg.SQLExpressionCellLimit, cfg.SQLExpressionOutputCellLimit, cfg.SQLExpressionTimeout) } // NeedsVars returns the variable names (refIds) that are dependencies @@ -131,11 +138,11 @@ func (gr *SQLCommand) Execute(ctx context.Context, now time.Time, vars mathexp.V tc = totalCells(allFrames) // limit of 0 or less means no limit (following convention) - if gr.limit > 0 && tc > gr.limit { + if gr.inputLimit > 0 && tc > gr.inputLimit { return mathexp.Results{}, fmt.Errorf( "SQL expression: total cell count across all input tables exceeds limit of %d. Total cells: %d", - gr.limit, + gr.inputLimit, tc, ) } @@ -143,7 +150,7 @@ func (gr *SQLCommand) Execute(ctx context.Context, now time.Time, vars mathexp.V logger.Debug("Executing query", "query", gr.query, "frames", len(allFrames)) db := sql.DB{} - frame, err := db.QueryFrames(ctx, tracer, gr.refID, gr.query, allFrames) + frame, err := db.QueryFrames(ctx, tracer, gr.refID, gr.query, allFrames, sql.WithMaxOutputCells(gr.outputLimit), sql.WithTimeout(gr.timeout)) rsp := mathexp.Results{} if err != nil { diff --git a/pkg/expr/sql_command_test.go b/pkg/expr/sql_command_test.go index 32ed9ee9934..42c4173bf91 100644 --- a/pkg/expr/sql_command_test.go +++ b/pkg/expr/sql_command_test.go @@ -17,7 +17,7 @@ import ( ) func TestNewCommand(t *testing.T) { - cmd, err := NewSQLCommand("a", "", "select a from foo, bar", 0) + cmd, err := NewSQLCommand("a", "", "select a from foo, bar", 0, 0, 0) if err != nil && strings.Contains(err.Error(), "feature is not enabled") { return } @@ -125,7 +125,7 @@ func TestSQLCommandCellLimits(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cmd, err := NewSQLCommand("a", "", "select a from foo, bar", tt.limit) + cmd, err := NewSQLCommand("a", "", "select a from foo, bar", tt.limit, 0, 0) require.NoError(t, err, "Failed to create SQL command") vars := mathexp.Vars{} @@ -153,7 +153,7 @@ func TestSQLCommandMetrics(t *testing.T) { m := metrics.NewTestMetrics() // Create a command - cmd, err := NewSQLCommand("A", "someformat", "select * from foo", 0) + cmd, err := NewSQLCommand("A", "someformat", "select * from foo", 0, 0, 0) require.NoError(t, err) // Execute successful command diff --git a/pkg/setting/setting.go b/pkg/setting/setting.go index 0ad3d757354..511e7e6a657 100644 --- a/pkg/setting/setting.go +++ b/pkg/setting/setting.go @@ -428,6 +428,12 @@ type Cfg struct { // SQLExpressionCellLimit is the maximum number of cells (rows × columns, across all frames) that can be accepted by a SQL expression. SQLExpressionCellLimit int64 + // SQLExpressionOutputCellLimit is the maximum number of cells (rows × columns) that can be outputted by a SQL expression. + SQLExpressionOutputCellLimit int64 + + // SQLExpressionTimeoutSeconds is the duration a SQL expression will run before timing out + SQLExpressionTimeout time.Duration + ImageUploadProvider string // LiveMaxConnections is a maximum number of WebSocket connections to @@ -800,6 +806,8 @@ func (cfg *Cfg) readExpressionsSettings() { expressions := cfg.Raw.Section("expressions") cfg.ExpressionsEnabled = expressions.Key("enabled").MustBool(true) cfg.SQLExpressionCellLimit = expressions.Key("sql_expression_cell_limit").MustInt64(100000) + cfg.SQLExpressionOutputCellLimit = expressions.Key("sql_expression_output_cell_limit").MustInt64(100000) + cfg.SQLExpressionTimeout = expressions.Key("sql_expression_timeout").MustDuration(time.Second * 10) } type AnnotationCleanupSettings struct {