The open and composable observability and data visualization platform. Visualize metrics, logs, and traces from multiple sources like Prometheus, Loki, Elasticsearch, InfluxDB, Postgres and many more.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
grafana/pkg/expr/sql/parser.go

219 lines
4.7 KiB

package sql
import (
"errors"
"strings"
parser "github.com/krasun/gosqlparser"
"github.com/xwb1989/sqlparser"
)
// TablesList returns a list of tables for the sql statement
// TODO: should we just return all query refs instead of trying to parse them from the sql?
func TablesList(rawSQL string) ([]string, error) {
stmt, err := sqlparser.Parse(rawSQL)
if err != nil {
tables, err := parse(rawSQL)
if err != nil {
return parseTables(rawSQL)
}
return tables, nil
}
tables := []string{}
switch kind := stmt.(type) {
case *sqlparser.Select:
for _, from := range kind.From {
tables = append(tables, getTables(from)...)
}
default:
return parseTables(rawSQL)
}
if len(tables) == 0 {
return parseTables(rawSQL)
}
return validateTables(tables), nil
}
func validateTables(tables []string) []string {
validTables := []string{}
for _, table := range tables {
if strings.ToUpper(table) != "DUAL" {
validTables = append(validTables, table)
}
}
return validTables
}
func joinTables(join *sqlparser.JoinTableExpr) []string {
t := getTables(join.LeftExpr)
t = append(t, getTables(join.RightExpr)...)
return t
}
func getTables(te sqlparser.TableExpr) []string {
tables := []string{}
switch v := te.(type) {
case *sqlparser.AliasedTableExpr:
tables = append(tables, nodeValue(v.Expr))
return tables
case *sqlparser.JoinTableExpr:
tables = append(tables, joinTables(v)...)
return tables
case *sqlparser.ParenTableExpr:
for _, e := range v.Exprs {
tables = getTables(e)
}
default:
tables = append(tables, unknownExpr(te)...)
}
return tables
}
func unknownExpr(te sqlparser.TableExpr) []string {
tables := []string{}
fromClause := nodeValue(te)
upperFromClause := strings.ToUpper(fromClause)
if strings.Contains(upperFromClause, "JOIN") {
return extractTablesFrom(fromClause)
}
if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") {
if strings.Contains(upperFromClause, " AS") {
name := stripAlias(fromClause)
tables = append(tables, name)
return tables
}
tables = append(tables, fromClause)
}
return tables
}
func nodeValue(node sqlparser.SQLNode) string {
buf := sqlparser.NewTrackedBuffer(nil)
node.Format(buf)
return buf.String()
}
func extractTablesFrom(stmt string) []string {
// example: A join B on A.name = B.name
tables := []string{}
parts := strings.Split(stmt, " ")
for _, part := range parts {
part = strings.ToUpper(part)
if isJoin(part) {
continue
}
if strings.Contains(part, "ON") {
break
}
if part != "" {
if !existsInList(part, tables) {
tables = append(tables, part)
}
}
}
return tables
}
func stripAlias(table string) string {
tableParts := []string{}
for _, part := range strings.Split(table, " ") {
if strings.ToUpper(part) == "AS" {
break
}
tableParts = append(tableParts, part)
}
return strings.Join(tableParts, " ")
}
// uses a simple tokenizer
func parse(rawSQL string) ([]string, error) {
query, err := parser.Parse(rawSQL)
if err != nil {
return nil, err
}
if query.GetType() == parser.StatementSelect {
sel, ok := query.(*parser.Select)
if ok {
return []string{sel.Table}, nil
}
}
return nil, err
}
func parseTables(rawSQL string) ([]string, error) {
checkSql := strings.ToUpper(rawSQL)
rawSQL = strings.ReplaceAll(rawSQL, "\n", " ")
rawSQL = strings.ReplaceAll(rawSQL, "\r", " ")
if strings.HasPrefix(checkSql, "SELECT") || strings.HasPrefix(rawSQL, "WITH") {
tables := []string{}
tokens := strings.Split(rawSQL, " ")
checkNext := false
takeNext := false
for _, token := range tokens {
t := strings.ToUpper(token)
t = strings.TrimSpace(t)
if takeNext {
if !existsInList(token, tables) {
tables = append(tables, token)
}
checkNext = false
takeNext = false
continue
}
if checkNext {
if strings.Contains(t, "(") {
checkNext = false
continue
}
if strings.Contains(t, ",") {
values := strings.Split(token, ",")
for _, v := range values {
v := strings.TrimSpace(v)
if v != "" {
if !existsInList(token, tables) {
tables = append(tables, v)
}
} else {
takeNext = true
break
}
}
continue
}
if !existsInList(token, tables) {
tables = append(tables, token)
}
checkNext = false
}
if t == "FROM" {
checkNext = true
}
}
return tables, nil
}
return nil, errors.New("not a select statement")
}
func existsInList(table string, list []string) bool {
for _, t := range list {
if t == table {
return true
}
}
return false
}
var joins = []string{"JOIN", "INNER", "LEFT", "RIGHT", "FULL", "OUTER"}
func isJoin(token string) bool {
token = strings.ToUpper(token)
for _, join := range joins {
if token == join {
return true
}
}
return false
}