|
|
|
@ -24,21 +24,7 @@ func TablesList(rawSQL string) ([]string, error) { |
|
|
|
|
switch kind := stmt.(type) { |
|
|
|
|
case *sqlparser.Select: |
|
|
|
|
for _, from := range kind.From { |
|
|
|
|
buf := sqlparser.NewTrackedBuffer(nil) |
|
|
|
|
from.Format(buf) |
|
|
|
|
fromClause := buf.String() |
|
|
|
|
upperFromClause := strings.ToUpper(fromClause) |
|
|
|
|
if strings.Contains(upperFromClause, "JOIN") { |
|
|
|
|
return extractTablesFrom(fromClause), nil |
|
|
|
|
} |
|
|
|
|
if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") { |
|
|
|
|
if strings.Contains(upperFromClause, " AS") { |
|
|
|
|
name := stripAlias(fromClause) |
|
|
|
|
tables = append(tables, name) |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
tables = append(tables, fromClause) |
|
|
|
|
} |
|
|
|
|
tables = append(tables, getTables(from)...) |
|
|
|
|
} |
|
|
|
|
default: |
|
|
|
|
return parseTables(rawSQL) |
|
|
|
@ -46,7 +32,66 @@ func TablesList(rawSQL string) ([]string, error) { |
|
|
|
|
if len(tables) == 0 { |
|
|
|
|
return parseTables(rawSQL) |
|
|
|
|
} |
|
|
|
|
return tables, nil |
|
|
|
|
return validateTables(tables), nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func validateTables(tables []string) []string { |
|
|
|
|
validTables := []string{} |
|
|
|
|
for _, table := range tables { |
|
|
|
|
if strings.ToUpper(table) != "DUAL" { |
|
|
|
|
validTables = append(validTables, table) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return validTables |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func joinTables(join *sqlparser.JoinTableExpr) []string { |
|
|
|
|
t := getTables(join.LeftExpr) |
|
|
|
|
t = append(t, getTables(join.RightExpr)...) |
|
|
|
|
return t |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func getTables(te sqlparser.TableExpr) []string { |
|
|
|
|
tables := []string{} |
|
|
|
|
switch v := te.(type) { |
|
|
|
|
case *sqlparser.AliasedTableExpr: |
|
|
|
|
tables = append(tables, nodeValue(v.Expr)) |
|
|
|
|
return tables |
|
|
|
|
case *sqlparser.JoinTableExpr: |
|
|
|
|
tables = append(tables, joinTables(v)...) |
|
|
|
|
return tables |
|
|
|
|
case *sqlparser.ParenTableExpr: |
|
|
|
|
for _, e := range v.Exprs { |
|
|
|
|
tables = getTables(e) |
|
|
|
|
} |
|
|
|
|
default: |
|
|
|
|
tables = append(tables, unknownExpr(te)...) |
|
|
|
|
} |
|
|
|
|
return tables |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func unknownExpr(te sqlparser.TableExpr) []string { |
|
|
|
|
tables := []string{} |
|
|
|
|
fromClause := nodeValue(te) |
|
|
|
|
upperFromClause := strings.ToUpper(fromClause) |
|
|
|
|
if strings.Contains(upperFromClause, "JOIN") { |
|
|
|
|
return extractTablesFrom(fromClause) |
|
|
|
|
} |
|
|
|
|
if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") { |
|
|
|
|
if strings.Contains(upperFromClause, " AS") { |
|
|
|
|
name := stripAlias(fromClause) |
|
|
|
|
tables = append(tables, name) |
|
|
|
|
return tables |
|
|
|
|
} |
|
|
|
|
tables = append(tables, fromClause) |
|
|
|
|
} |
|
|
|
|
return tables |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func nodeValue(node sqlparser.SQLNode) string { |
|
|
|
|
buf := sqlparser.NewTrackedBuffer(nil) |
|
|
|
|
node.Format(buf) |
|
|
|
|
return buf.String() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func extractTablesFrom(stmt string) []string { |
|
|
|
|