From 59eea3920d860f63f6506623688d26016cf5ff11 Mon Sep 17 00:00:00 2001 From: Scott Lepper Date: Fri, 3 May 2024 14:52:13 +0100 Subject: [PATCH] sql expressions - Backport 87277 to v11.0.x (#87315) sql expressions: improve parser (#87277) (cherry picked from commit 1a2bbd61fd7c8c04befd8ef8a3f5c63aa494ca98) --- pkg/expr/sql/parser.go | 77 +++++++++++++++++++++++++++++-------- pkg/expr/sql/parser_test.go | 39 +++++++++++++++++++ 2 files changed, 100 insertions(+), 16 deletions(-) diff --git a/pkg/expr/sql/parser.go b/pkg/expr/sql/parser.go index 1ba6f9eee72..5aa3b07c577 100644 --- a/pkg/expr/sql/parser.go +++ b/pkg/expr/sql/parser.go @@ -24,21 +24,7 @@ func TablesList(rawSQL string) ([]string, error) { switch kind := stmt.(type) { case *sqlparser.Select: for _, from := range kind.From { - buf := sqlparser.NewTrackedBuffer(nil) - from.Format(buf) - fromClause := buf.String() - upperFromClause := strings.ToUpper(fromClause) - if strings.Contains(upperFromClause, "JOIN") { - return extractTablesFrom(fromClause), nil - } - if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") { - if strings.Contains(upperFromClause, " AS") { - name := stripAlias(fromClause) - tables = append(tables, name) - continue - } - tables = append(tables, fromClause) - } + tables = append(tables, getTables(from)...) } default: return parseTables(rawSQL) @@ -46,7 +32,66 @@ func TablesList(rawSQL string) ([]string, error) { if len(tables) == 0 { return parseTables(rawSQL) } - return tables, nil + return validateTables(tables), nil +} + +func validateTables(tables []string) []string { + validTables := []string{} + for _, table := range tables { + if strings.ToUpper(table) != "DUAL" { + validTables = append(validTables, table) + } + } + return validTables +} + +func joinTables(join *sqlparser.JoinTableExpr) []string { + t := getTables(join.LeftExpr) + t = append(t, getTables(join.RightExpr)...) + return t +} + +func getTables(te sqlparser.TableExpr) []string { + tables := []string{} + switch v := te.(type) { + case *sqlparser.AliasedTableExpr: + tables = append(tables, nodeValue(v.Expr)) + return tables + case *sqlparser.JoinTableExpr: + tables = append(tables, joinTables(v)...) + return tables + case *sqlparser.ParenTableExpr: + for _, e := range v.Exprs { + tables = getTables(e) + } + default: + tables = append(tables, unknownExpr(te)...) + } + return tables +} + +func unknownExpr(te sqlparser.TableExpr) []string { + tables := []string{} + fromClause := nodeValue(te) + upperFromClause := strings.ToUpper(fromClause) + if strings.Contains(upperFromClause, "JOIN") { + return extractTablesFrom(fromClause) + } + if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") { + if strings.Contains(upperFromClause, " AS") { + name := stripAlias(fromClause) + tables = append(tables, name) + return tables + } + tables = append(tables, fromClause) + } + return tables +} + +func nodeValue(node sqlparser.SQLNode) string { + buf := sqlparser.NewTrackedBuffer(nil) + node.Format(buf) + return buf.String() } func extractTablesFrom(stmt string) []string { diff --git a/pkg/expr/sql/parser_test.go b/pkg/expr/sql/parser_test.go index 200dd07195e..ca3b685e34b 100644 --- a/pkg/expr/sql/parser_test.go +++ b/pkg/expr/sql/parser_test.go @@ -89,3 +89,42 @@ func TestRightJoin(t *testing.T) { assert.Equal(t, "A", tables[0]) assert.Equal(t, "B", tables[1]) } + +func TestAliasWithJoin(t *testing.T) { + sql := `select * from A as X + RIGHT JOIN B ON A.name = X.name + LIMIT 10` + tables, err := TablesList((sql)) + assert.Nil(t, err) + + assert.Equal(t, 2, len(tables)) + assert.Equal(t, "A", tables[0]) + assert.Equal(t, "B", tables[1]) +} + +func TestAlias(t *testing.T) { + sql := `select * from A as X LIMIT 10` + tables, err := TablesList((sql)) + assert.Nil(t, err) + + assert.Equal(t, 1, len(tables)) + assert.Equal(t, "A", tables[0]) +} + +func TestParens(t *testing.T) { + sql := `SELECT t1.Col1, + t2.Col1, + t3.Col1 + FROM table1 AS t1 + LEFT JOIN ( + table2 AS t2 + INNER JOIN table3 AS t3 ON t3.Col1 = t2.Col1 + ) ON t2.Col1 = t1.Col1;` + tables, err := TablesList((sql)) + assert.Nil(t, err) + + assert.Equal(t, 3, len(tables)) + assert.Equal(t, "table1", tables[0]) + assert.Equal(t, "table2", tables[1]) + assert.Equal(t, "table3", tables[2]) +}