|
|
|
|
@ -1,199 +1,72 @@ |
|
|
|
|
package sql |
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"errors" |
|
|
|
|
"encoding/json" |
|
|
|
|
"fmt" |
|
|
|
|
"sort" |
|
|
|
|
"strings" |
|
|
|
|
|
|
|
|
|
parser "github.com/krasun/gosqlparser" |
|
|
|
|
"github.com/xwb1989/sqlparser" |
|
|
|
|
"github.com/jeremywohl/flatten" |
|
|
|
|
"github.com/scottlepp/go-duck/duck" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
const ( |
|
|
|
|
TABLE_NAME = "table_name" |
|
|
|
|
ERROR = ".error" |
|
|
|
|
ERROR_MESSAGE = ".error_message" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
// TablesList returns a list of tables for the sql statement
|
|
|
|
|
// TODO: should we just return all query refs instead of trying to parse them from the sql?
|
|
|
|
|
func TablesList(rawSQL string) ([]string, error) { |
|
|
|
|
stmt, err := sqlparser.Parse(rawSQL) |
|
|
|
|
duckDB := duck.NewInMemoryDB() |
|
|
|
|
cmd := fmt.Sprintf("SELECT json_serialize_sql('%s')", rawSQL) |
|
|
|
|
ret, err := duckDB.RunCommands([]string{cmd}) |
|
|
|
|
if err != nil { |
|
|
|
|
tables, err := parse(rawSQL) |
|
|
|
|
if err != nil { |
|
|
|
|
return parseTables(rawSQL) |
|
|
|
|
} |
|
|
|
|
return tables, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
tables := []string{} |
|
|
|
|
switch kind := stmt.(type) { |
|
|
|
|
case *sqlparser.Select: |
|
|
|
|
for _, from := range kind.From { |
|
|
|
|
tables = append(tables, getTables(from)...) |
|
|
|
|
} |
|
|
|
|
default: |
|
|
|
|
return parseTables(rawSQL) |
|
|
|
|
} |
|
|
|
|
if len(tables) == 0 { |
|
|
|
|
return parseTables(rawSQL) |
|
|
|
|
return nil, fmt.Errorf("error serializing sql: %s", err.Error()) |
|
|
|
|
} |
|
|
|
|
return validateTables(tables), nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func validateTables(tables []string) []string { |
|
|
|
|
validTables := []string{} |
|
|
|
|
for _, table := range tables { |
|
|
|
|
if strings.ToUpper(table) != "DUAL" { |
|
|
|
|
validTables = append(validTables, table) |
|
|
|
|
} |
|
|
|
|
ast := []map[string]any{} |
|
|
|
|
err = json.Unmarshal([]byte(ret), &ast) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, fmt.Errorf("error converting json to ast: %s", err.Error()) |
|
|
|
|
} |
|
|
|
|
return validTables |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func joinTables(join *sqlparser.JoinTableExpr) []string { |
|
|
|
|
t := getTables(join.LeftExpr) |
|
|
|
|
t = append(t, getTables(join.RightExpr)...) |
|
|
|
|
return t |
|
|
|
|
return tablesFromAST(ast) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func getTables(te sqlparser.TableExpr) []string { |
|
|
|
|
tables := []string{} |
|
|
|
|
switch v := te.(type) { |
|
|
|
|
case *sqlparser.AliasedTableExpr: |
|
|
|
|
tables = append(tables, nodeValue(v.Expr)) |
|
|
|
|
return tables |
|
|
|
|
case *sqlparser.JoinTableExpr: |
|
|
|
|
tables = append(tables, joinTables(v)...) |
|
|
|
|
return tables |
|
|
|
|
case *sqlparser.ParenTableExpr: |
|
|
|
|
for _, e := range v.Exprs { |
|
|
|
|
tables = getTables(e) |
|
|
|
|
} |
|
|
|
|
default: |
|
|
|
|
tables = append(tables, unknownExpr(te)...) |
|
|
|
|
} |
|
|
|
|
return tables |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func unknownExpr(te sqlparser.TableExpr) []string { |
|
|
|
|
tables := []string{} |
|
|
|
|
fromClause := nodeValue(te) |
|
|
|
|
upperFromClause := strings.ToUpper(fromClause) |
|
|
|
|
if strings.Contains(upperFromClause, "JOIN") { |
|
|
|
|
return extractTablesFrom(fromClause) |
|
|
|
|
} |
|
|
|
|
if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") { |
|
|
|
|
if strings.Contains(upperFromClause, " AS") { |
|
|
|
|
name := stripAlias(fromClause) |
|
|
|
|
tables = append(tables, name) |
|
|
|
|
return tables |
|
|
|
|
} |
|
|
|
|
tables = append(tables, fromClause) |
|
|
|
|
func tablesFromAST(ast []map[string]any) ([]string, error) { |
|
|
|
|
flat, err := flatten.Flatten(ast[0], "", flatten.DotStyle) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, fmt.Errorf("error flattening ast: %s", err.Error()) |
|
|
|
|
} |
|
|
|
|
return tables |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func nodeValue(node sqlparser.SQLNode) string { |
|
|
|
|
buf := sqlparser.NewTrackedBuffer(nil) |
|
|
|
|
node.Format(buf) |
|
|
|
|
return buf.String() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func extractTablesFrom(stmt string) []string { |
|
|
|
|
// example: A join B on A.name = B.name
|
|
|
|
|
tables := []string{} |
|
|
|
|
parts := strings.Split(stmt, " ") |
|
|
|
|
for _, part := range parts { |
|
|
|
|
part = strings.ToUpper(part) |
|
|
|
|
if isJoin(part) { |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
if strings.Contains(part, "ON") { |
|
|
|
|
break |
|
|
|
|
for k, v := range flat { |
|
|
|
|
if strings.HasSuffix(k, ERROR) { |
|
|
|
|
v, ok := v.(bool) |
|
|
|
|
if ok && v { |
|
|
|
|
return nil, astError(k, flat) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
if part != "" { |
|
|
|
|
if !existsInList(part, tables) { |
|
|
|
|
tables = append(tables, part) |
|
|
|
|
if strings.Contains(k, TABLE_NAME) { |
|
|
|
|
table, ok := v.(string) |
|
|
|
|
if ok && !existsInList(table, tables) { |
|
|
|
|
tables = append(tables, v.(string)) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return tables |
|
|
|
|
} |
|
|
|
|
sort.Strings(tables) |
|
|
|
|
|
|
|
|
|
func stripAlias(table string) string { |
|
|
|
|
tableParts := []string{} |
|
|
|
|
for _, part := range strings.Split(table, " ") { |
|
|
|
|
if strings.ToUpper(part) == "AS" { |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
tableParts = append(tableParts, part) |
|
|
|
|
} |
|
|
|
|
return strings.Join(tableParts, " ") |
|
|
|
|
return tables, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// uses a simple tokenizer
|
|
|
|
|
func parse(rawSQL string) ([]string, error) { |
|
|
|
|
query, err := parser.Parse(rawSQL) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
func astError(k string, flat map[string]any) error { |
|
|
|
|
key := strings.Replace(k, ERROR, "", 1) |
|
|
|
|
message, ok := flat[key+ERROR_MESSAGE] |
|
|
|
|
if !ok { |
|
|
|
|
message = "unknown error in sql" |
|
|
|
|
} |
|
|
|
|
if query.GetType() == parser.StatementSelect { |
|
|
|
|
sel, ok := query.(*parser.Select) |
|
|
|
|
if ok { |
|
|
|
|
return []string{sel.Table}, nil |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func parseTables(rawSQL string) ([]string, error) { |
|
|
|
|
checkSql := strings.ToUpper(rawSQL) |
|
|
|
|
rawSQL = strings.ReplaceAll(rawSQL, "\n", " ") |
|
|
|
|
rawSQL = strings.ReplaceAll(rawSQL, "\r", " ") |
|
|
|
|
if strings.HasPrefix(checkSql, "SELECT") || strings.HasPrefix(rawSQL, "WITH") { |
|
|
|
|
tables := []string{} |
|
|
|
|
tokens := strings.Split(rawSQL, " ") |
|
|
|
|
checkNext := false |
|
|
|
|
takeNext := false |
|
|
|
|
for _, token := range tokens { |
|
|
|
|
t := strings.ToUpper(token) |
|
|
|
|
t = strings.TrimSpace(t) |
|
|
|
|
|
|
|
|
|
if takeNext { |
|
|
|
|
if !existsInList(token, tables) { |
|
|
|
|
tables = append(tables, token) |
|
|
|
|
} |
|
|
|
|
checkNext = false |
|
|
|
|
takeNext = false |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
if checkNext { |
|
|
|
|
if strings.Contains(t, "(") { |
|
|
|
|
checkNext = false |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
if strings.Contains(t, ",") { |
|
|
|
|
values := strings.Split(token, ",") |
|
|
|
|
for _, v := range values { |
|
|
|
|
v := strings.TrimSpace(v) |
|
|
|
|
if v != "" { |
|
|
|
|
if !existsInList(token, tables) { |
|
|
|
|
tables = append(tables, v) |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
takeNext = true |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
if !existsInList(token, tables) { |
|
|
|
|
tables = append(tables, token) |
|
|
|
|
} |
|
|
|
|
checkNext = false |
|
|
|
|
} |
|
|
|
|
if t == "FROM" { |
|
|
|
|
checkNext = true |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return tables, nil |
|
|
|
|
} |
|
|
|
|
return nil, errors.New("not a select statement") |
|
|
|
|
return fmt.Errorf("error in sql: %s", message) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func existsInList(table string, list []string) bool { |
|
|
|
|
@ -204,15 +77,3 @@ func existsInList(table string, list []string) bool { |
|
|
|
|
} |
|
|
|
|
return false |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var joins = []string{"JOIN", "INNER", "LEFT", "RIGHT", "FULL", "OUTER"} |
|
|
|
|
|
|
|
|
|
func isJoin(token string) bool { |
|
|
|
|
token = strings.ToUpper(token) |
|
|
|
|
for _, join := range joins { |
|
|
|
|
if token == join { |
|
|
|
|
return true |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return false |
|
|
|
|
} |
|
|
|
|
|