fix: sql expressions - sql parser table name case (#87196)

* fix: sql parser table name case
pull/87242/head
Scott Lepper 1 year ago committed by GitHub
parent ac07a9794b
commit 4fd2cb6014
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 49
      pkg/expr/sql/parser.go
  2. 21
      pkg/expr/sql/parser_test.go
  3. 6
      pkg/expr/sql_command.go

@ -26,16 +26,35 @@ func TablesList(rawSQL string) ([]string, error) {
buf := sqlparser.NewTrackedBuffer(nil) buf := sqlparser.NewTrackedBuffer(nil)
t.Format(buf) t.Format(buf)
table := buf.String() table := buf.String()
if table != "dual" { if table != "dual" && !strings.HasPrefix(table, "(") {
if strings.Contains(table, " as") {
name := stripAlias(table)
tables = append(tables, name)
continue
}
tables = append(tables, buf.String()) tables = append(tables, buf.String())
} }
} }
default: default:
return nil, errors.New("not a select statement") return parseTables(rawSQL)
}
if len(tables) == 0 {
return parseTables(rawSQL)
} }
return tables, nil return tables, nil
} }
func stripAlias(table string) string {
tableParts := []string{}
for _, part := range strings.Split(table, " ") {
if part == "as" {
break
}
tableParts = append(tableParts, part)
}
return strings.Join(tableParts, " ")
}
// uses a simple tokenizer // uses a simple tokenizer
func parse(rawSQL string) ([]string, error) { func parse(rawSQL string) ([]string, error) {
query, err := parser.Parse(rawSQL) query, err := parser.Parse(rawSQL)
@ -53,17 +72,20 @@ func parse(rawSQL string) ([]string, error) {
func parseTables(rawSQL string) ([]string, error) { func parseTables(rawSQL string) ([]string, error) {
checkSql := strings.ToUpper(rawSQL) checkSql := strings.ToUpper(rawSQL)
rawSQL = strings.ReplaceAll(rawSQL, "\n", " ")
if strings.HasPrefix(checkSql, "SELECT") || strings.HasPrefix(rawSQL, "WITH") { if strings.HasPrefix(checkSql, "SELECT") || strings.HasPrefix(rawSQL, "WITH") {
tables := []string{} tables := []string{}
tokens := strings.Split(rawSQL, " ") tokens := strings.Split(rawSQL, " ")
checkNext := false checkNext := false
takeNext := false takeNext := false
for _, t := range tokens { for _, token := range tokens {
t = strings.ToUpper(t) t := strings.ToUpper(token)
t = strings.TrimSpace(t) t = strings.TrimSpace(t)
if takeNext { if takeNext {
tables = append(tables, t) if !existsInList(token, tables) {
tables = append(tables, token)
}
checkNext = false checkNext = false
takeNext = false takeNext = false
continue continue
@ -74,11 +96,13 @@ func parseTables(rawSQL string) ([]string, error) {
continue continue
} }
if strings.Contains(t, ",") { if strings.Contains(t, ",") {
values := strings.Split(t, ",") values := strings.Split(token, ",")
for _, v := range values { for _, v := range values {
v := strings.TrimSpace(v) v := strings.TrimSpace(v)
if v != "" { if v != "" {
if !existsInList(token, tables) {
tables = append(tables, v) tables = append(tables, v)
}
} else { } else {
takeNext = true takeNext = true
break break
@ -86,7 +110,9 @@ func parseTables(rawSQL string) ([]string, error) {
} }
continue continue
} }
tables = append(tables, t) if !existsInList(token, tables) {
tables = append(tables, token)
}
checkNext = false checkNext = false
} }
if t == "FROM" { if t == "FROM" {
@ -97,3 +123,12 @@ func parseTables(rawSQL string) ([]string, error) {
} }
return nil, errors.New("not a select statement") return nil, errors.New("not a select statement")
} }
func existsInList(table string, list []string) bool {
for _, t := range list {
if t == table {
return true
}
}
return false
}

@ -11,7 +11,7 @@ func TestParse(t *testing.T) {
tables, err := parseTables((sql)) tables, err := parseTables((sql))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "FOO", tables[0]) assert.Equal(t, "foo", tables[0])
} }
func TestParseWithComma(t *testing.T) { func TestParseWithComma(t *testing.T) {
@ -19,8 +19,8 @@ func TestParseWithComma(t *testing.T) {
tables, err := parseTables((sql)) tables, err := parseTables((sql))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "FOO", tables[0]) assert.Equal(t, "foo", tables[0])
assert.Equal(t, "BAR", tables[1]) assert.Equal(t, "bar", tables[1])
} }
func TestParseWithCommas(t *testing.T) { func TestParseWithCommas(t *testing.T) {
@ -28,9 +28,9 @@ func TestParseWithCommas(t *testing.T) {
tables, err := parseTables((sql)) tables, err := parseTables((sql))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "FOO", tables[0]) assert.Equal(t, "foo", tables[0])
assert.Equal(t, "BAR", tables[1]) assert.Equal(t, "bar", tables[1])
assert.Equal(t, "BAZ", tables[2]) assert.Equal(t, "baz", tables[2])
} }
func TestArray(t *testing.T) { func TestArray(t *testing.T) {
@ -56,3 +56,12 @@ func TestXxx(t *testing.T) {
assert.Equal(t, 0, len(tables)) assert.Equal(t, 0, len(tables))
} }
func TestParseSubquery(t *testing.T) {
sql := "select * from (select * from people limit 1)"
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 1, len(tables))
assert.Equal(t, "people", tables[0])
}

@ -74,7 +74,11 @@ func (gr *SQLCommand) Execute(ctx context.Context, now time.Time, vars mathexp.V
allFrames := []*data.Frame{} allFrames := []*data.Frame{}
for _, ref := range gr.varsToQuery { for _, ref := range gr.varsToQuery {
results := vars[ref] results, ok := vars[ref]
if !ok {
logger.Warn("no results found for", "ref", ref)
continue
}
frames := results.Values.AsDataFrames(ref) frames := results.Values.AsDataFrames(ref)
allFrames = append(allFrames, frames...) allFrames = append(allFrames, frames...)
} }

Loading…
Cancel
Save