@ -26,16 +26,35 @@ func TablesList(rawSQL string) ([]string, error) {
buf := sqlparser . NewTrackedBuffer ( nil )
t . Format ( buf )
table := buf . String ( )
if table != "dual" {
if table != "dual" && ! strings . HasPrefix ( table , "(" ) {
if strings . Contains ( table , " as" ) {
name := stripAlias ( table )
tables = append ( tables , name )
continue
}
tables = append ( tables , buf . String ( ) )
}
}
default :
return nil , errors . New ( "not a select statement" )
return parseTables ( rawSQL )
}
if len ( tables ) == 0 {
return parseTables ( rawSQL )
}
return tables , nil
}
func stripAlias ( table string ) string {
tableParts := [ ] string { }
for _ , part := range strings . Split ( table , " " ) {
if part == "as" {
break
}
tableParts = append ( tableParts , part )
}
return strings . Join ( tableParts , " " )
}
// uses a simple tokenizer
func parse ( rawSQL string ) ( [ ] string , error ) {
query , err := parser . Parse ( rawSQL )
@ -51,19 +70,23 @@ func parse(rawSQL string) ([]string, error) {
return nil , err
}
// parseTables uses a simple tokenizer to parse tables from a SQL statement
func parseTables ( rawSQL string ) ( [ ] string , error ) {
checkSql := strings . ToUpper ( rawSQL )
rawSQL = strings . ReplaceAll ( rawSQL , "\n" , " " )
if strings . HasPrefix ( checkSql , "SELECT" ) || strings . HasPrefix ( rawSQL , "WITH" ) {
tables := [ ] string { }
tokens := strings . Split ( rawSQL , " " )
checkNext := false
takeNext := false
for _ , t := range tokens {
t = strings . ToUpper ( t )
for _ , token := range tokens {
t : = strings . ToUpper ( token )
t = strings . TrimSpace ( t )
if takeNext {
tables = append ( tables , t )
if ! existsInList ( token , tables ) {
tables = append ( tables , token )
}
checkNext = false
takeNext = false
continue
@ -74,11 +97,13 @@ func parseTables(rawSQL string) ([]string, error) {
continue
}
if strings . Contains ( t , "," ) {
values := strings . Split ( t , "," )
values := strings . Split ( token , "," )
for _ , v := range values {
v := strings . TrimSpace ( v )
if v != "" {
tables = append ( tables , v )
if ! existsInList ( token , tables ) {
tables = append ( tables , v )
}
} else {
takeNext = true
break
@ -86,7 +111,9 @@ func parseTables(rawSQL string) ([]string, error) {
}
continue
}
tables = append ( tables , t )
if ! existsInList ( token , tables ) {
tables = append ( tables , token )
}
checkNext = false
}
if t == "FROM" {
@ -97,3 +124,12 @@ func parseTables(rawSQL string) ([]string, error) {
}
return nil , errors . New ( "not a select statement" )
}
func existsInList ( table string , list [ ] string ) bool {
for _ , t := range list {
if t == table {
return true
}
}
return false
}