|
|
@ -26,16 +26,35 @@ func TablesList(rawSQL string) ([]string, error) { |
|
|
|
buf := sqlparser.NewTrackedBuffer(nil) |
|
|
|
buf := sqlparser.NewTrackedBuffer(nil) |
|
|
|
t.Format(buf) |
|
|
|
t.Format(buf) |
|
|
|
table := buf.String() |
|
|
|
table := buf.String() |
|
|
|
if table != "dual" { |
|
|
|
if table != "dual" && !strings.HasPrefix(table, "(") { |
|
|
|
|
|
|
|
if strings.Contains(table, " as") { |
|
|
|
|
|
|
|
name := stripAlias(table) |
|
|
|
|
|
|
|
tables = append(tables, name) |
|
|
|
|
|
|
|
continue |
|
|
|
|
|
|
|
} |
|
|
|
tables = append(tables, buf.String()) |
|
|
|
tables = append(tables, buf.String()) |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
default: |
|
|
|
default: |
|
|
|
return nil, errors.New("not a select statement") |
|
|
|
return parseTables(rawSQL) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
if len(tables) == 0 { |
|
|
|
|
|
|
|
return parseTables(rawSQL) |
|
|
|
} |
|
|
|
} |
|
|
|
return tables, nil |
|
|
|
return tables, nil |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func stripAlias(table string) string { |
|
|
|
|
|
|
|
tableParts := []string{} |
|
|
|
|
|
|
|
for _, part := range strings.Split(table, " ") { |
|
|
|
|
|
|
|
if part == "as" { |
|
|
|
|
|
|
|
break |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
tableParts = append(tableParts, part) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
return strings.Join(tableParts, " ") |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// uses a simple tokenizer
|
|
|
|
// uses a simple tokenizer
|
|
|
|
func parse(rawSQL string) ([]string, error) { |
|
|
|
func parse(rawSQL string) ([]string, error) { |
|
|
|
query, err := parser.Parse(rawSQL) |
|
|
|
query, err := parser.Parse(rawSQL) |
|
|
@ -53,17 +72,20 @@ func parse(rawSQL string) ([]string, error) { |
|
|
|
|
|
|
|
|
|
|
|
func parseTables(rawSQL string) ([]string, error) { |
|
|
|
func parseTables(rawSQL string) ([]string, error) { |
|
|
|
checkSql := strings.ToUpper(rawSQL) |
|
|
|
checkSql := strings.ToUpper(rawSQL) |
|
|
|
|
|
|
|
rawSQL = strings.ReplaceAll(rawSQL, "\n", " ") |
|
|
|
if strings.HasPrefix(checkSql, "SELECT") || strings.HasPrefix(rawSQL, "WITH") { |
|
|
|
if strings.HasPrefix(checkSql, "SELECT") || strings.HasPrefix(rawSQL, "WITH") { |
|
|
|
tables := []string{} |
|
|
|
tables := []string{} |
|
|
|
tokens := strings.Split(rawSQL, " ") |
|
|
|
tokens := strings.Split(rawSQL, " ") |
|
|
|
checkNext := false |
|
|
|
checkNext := false |
|
|
|
takeNext := false |
|
|
|
takeNext := false |
|
|
|
for _, t := range tokens { |
|
|
|
for _, token := range tokens { |
|
|
|
t = strings.ToUpper(t) |
|
|
|
t := strings.ToUpper(token) |
|
|
|
t = strings.TrimSpace(t) |
|
|
|
t = strings.TrimSpace(t) |
|
|
|
|
|
|
|
|
|
|
|
if takeNext { |
|
|
|
if takeNext { |
|
|
|
tables = append(tables, t) |
|
|
|
if !existsInList(token, tables) { |
|
|
|
|
|
|
|
tables = append(tables, token) |
|
|
|
|
|
|
|
} |
|
|
|
checkNext = false |
|
|
|
checkNext = false |
|
|
|
takeNext = false |
|
|
|
takeNext = false |
|
|
|
continue |
|
|
|
continue |
|
|
@ -74,11 +96,13 @@ func parseTables(rawSQL string) ([]string, error) { |
|
|
|
continue |
|
|
|
continue |
|
|
|
} |
|
|
|
} |
|
|
|
if strings.Contains(t, ",") { |
|
|
|
if strings.Contains(t, ",") { |
|
|
|
values := strings.Split(t, ",") |
|
|
|
values := strings.Split(token, ",") |
|
|
|
for _, v := range values { |
|
|
|
for _, v := range values { |
|
|
|
v := strings.TrimSpace(v) |
|
|
|
v := strings.TrimSpace(v) |
|
|
|
if v != "" { |
|
|
|
if v != "" { |
|
|
|
|
|
|
|
if !existsInList(token, tables) { |
|
|
|
tables = append(tables, v) |
|
|
|
tables = append(tables, v) |
|
|
|
|
|
|
|
} |
|
|
|
} else { |
|
|
|
} else { |
|
|
|
takeNext = true |
|
|
|
takeNext = true |
|
|
|
break |
|
|
|
break |
|
|
@ -86,7 +110,9 @@ func parseTables(rawSQL string) ([]string, error) { |
|
|
|
} |
|
|
|
} |
|
|
|
continue |
|
|
|
continue |
|
|
|
} |
|
|
|
} |
|
|
|
tables = append(tables, t) |
|
|
|
if !existsInList(token, tables) { |
|
|
|
|
|
|
|
tables = append(tables, token) |
|
|
|
|
|
|
|
} |
|
|
|
checkNext = false |
|
|
|
checkNext = false |
|
|
|
} |
|
|
|
} |
|
|
|
if t == "FROM" { |
|
|
|
if t == "FROM" { |
|
|
@ -97,3 +123,12 @@ func parseTables(rawSQL string) ([]string, error) { |
|
|
|
} |
|
|
|
} |
|
|
|
return nil, errors.New("not a select statement") |
|
|
|
return nil, errors.New("not a select statement") |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func existsInList(table string, list []string) bool { |
|
|
|
|
|
|
|
for _, t := range list { |
|
|
|
|
|
|
|
if t == table { |
|
|
|
|
|
|
|
return true |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
return false |
|
|
|
|
|
|
|
} |
|
|
|