diff --git a/pkg/services/sqlstore/migrator/dialect.go b/pkg/services/sqlstore/migrator/dialect.go index 740704c6f40..fa3c8587ec5 100644 --- a/pkg/services/sqlstore/migrator/dialect.go +++ b/pkg/services/sqlstore/migrator/dialect.go @@ -49,6 +49,7 @@ type Dialect interface { ColumnCheckSQL(tableName, columnName string) (string, []interface{}) // UpsertSQL returns the upsert sql statement for a dialect UpsertSQL(tableName string, keyCols, updateCols []string) string + UpsertMultipleSQL(tableName string, keyCols, updateCols []string, count int) (string, error) ColString(*Column) string ColStringNoPk(*Column) string diff --git a/pkg/services/sqlstore/migrator/mysql_dialect.go b/pkg/services/sqlstore/migrator/mysql_dialect.go index aefdc81c67c..425928d05da 100644 --- a/pkg/services/sqlstore/migrator/mysql_dialect.go +++ b/pkg/services/sqlstore/migrator/mysql_dialect.go @@ -210,8 +210,16 @@ func (db *MySQLDialect) IsDeadlock(err error) bool { return db.isThisError(err, mysqlerr.ER_LOCK_DEADLOCK) } -// UpsertSQL returns the upsert sql statement for PostgreSQL dialect +// UpsertSQL returns the upsert sql statement for MySQL dialect func (db *MySQLDialect) UpsertSQL(tableName string, keyCols, updateCols []string) string { + q, _ := db.UpsertMultipleSQL(tableName, keyCols, updateCols, 1) + return q +} + +func (db *MySQLDialect) UpsertMultipleSQL(tableName string, keyCols, updateCols []string, count int) (string, error) { + if count < 1 { + return "", fmt.Errorf("upsert statement must have count >= 1. Got %v", count) + } columnsStr := strings.Builder{} colPlaceHoldersStr := strings.Builder{} setStr := strings.Builder{} @@ -226,13 +234,23 @@ func (db *MySQLDialect) UpsertSQL(tableName string, keyCols, updateCols []string setStr.WriteString(fmt.Sprintf("%s=VALUES(%s)%s", db.Quote(c), db.Quote(c), separator)) } - s := fmt.Sprintf(`INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE %s`, + valuesStr := strings.Builder{} + separator = ", " + colPlaceHolders := colPlaceHoldersStr.String() + for i := 0; i < count; i++ { + if i == count-1 { + separator = "" + } + valuesStr.WriteString(fmt.Sprintf("(%s)%s", colPlaceHolders, separator)) + } + + s := fmt.Sprintf(`INSERT INTO %s (%s) VALUES %s ON DUPLICATE KEY UPDATE %s`, tableName, columnsStr.String(), - colPlaceHoldersStr.String(), + valuesStr.String(), setStr.String(), ) - return s + return s, nil } func (db *MySQLDialect) Lock(cfg LockCfg) error { diff --git a/pkg/services/sqlstore/migrator/postgres_dialect.go b/pkg/services/sqlstore/migrator/postgres_dialect.go index dbe661d0be8..b2b0e53884c 100644 --- a/pkg/services/sqlstore/migrator/postgres_dialect.go +++ b/pkg/services/sqlstore/migrator/postgres_dialect.go @@ -224,6 +224,15 @@ func (db *PostgresDialect) PostInsertId(table string, sess *xorm.Session) error // UpsertSQL returns the upsert sql statement for PostgreSQL dialect func (db *PostgresDialect) UpsertSQL(tableName string, keyCols, updateCols []string) string { + str, _ := db.UpsertMultipleSQL(tableName, keyCols, updateCols, 1) + return str +} + +// UpsertMultipleSQL returns the upsert sql statement for PostgreSQL dialect +func (db *PostgresDialect) UpsertMultipleSQL(tableName string, keyCols, updateCols []string, count int) (string, error) { + if count < 1 { + return "", fmt.Errorf("upsert statement must have count >= 1. Got %v", count) + } columnsStr := strings.Builder{} onConflictStr := strings.Builder{} colPlaceHoldersStr := strings.Builder{} @@ -249,14 +258,24 @@ func (db *PostgresDialect) UpsertSQL(tableName string, keyCols, updateCols []str onConflictStr.WriteString(fmt.Sprintf("%s%s", db.Quote(c), separatorVar)) } - s := fmt.Sprintf(`INSERT INTO %s (%s) VALUES (%s) ON CONFLICT(%s) DO UPDATE SET %s`, + valuesStr := strings.Builder{} + separatorVar = separator + colPlaceHolders := colPlaceHoldersStr.String() + for i := 0; i < count; i++ { + if i == count-1 { + separatorVar = "" + } + valuesStr.WriteString(fmt.Sprintf("(%s)%s", colPlaceHolders, separatorVar)) + } + + s := fmt.Sprintf(`INSERT INTO %s (%s) VALUES %s ON CONFLICT(%s) DO UPDATE SET %s`, tableName, columnsStr.String(), - colPlaceHoldersStr.String(), + valuesStr.String(), onConflictStr.String(), setStr.String(), ) - return s + return s, nil } func (db *PostgresDialect) Lock(cfg LockCfg) error { diff --git a/pkg/services/sqlstore/migrator/sqlite_dialect.go b/pkg/services/sqlstore/migrator/sqlite_dialect.go index 2ea68656af4..25d3976a56c 100644 --- a/pkg/services/sqlstore/migrator/sqlite_dialect.go +++ b/pkg/services/sqlstore/migrator/sqlite_dialect.go @@ -151,6 +151,15 @@ func (db *SQLite3) IsDeadlock(err error) bool { // UpsertSQL returns the upsert sql statement for SQLite dialect func (db *SQLite3) UpsertSQL(tableName string, keyCols, updateCols []string) string { + str, _ := db.UpsertMultipleSQL(tableName, keyCols, updateCols, 1) + return str +} + +// UpsertMultipleSQL returns the upsert sql statement for PostgreSQL dialect +func (db *SQLite3) UpsertMultipleSQL(tableName string, keyCols, updateCols []string, count int) (string, error) { + if count < 1 { + return "", fmt.Errorf("upsert statement must have count >= 1. Got %v", count) + } columnsStr := strings.Builder{} onConflictStr := strings.Builder{} colPlaceHoldersStr := strings.Builder{} @@ -176,12 +185,22 @@ func (db *SQLite3) UpsertSQL(tableName string, keyCols, updateCols []string) str onConflictStr.WriteString(fmt.Sprintf("%s%s", db.Quote(c), separatorVar)) } - s := fmt.Sprintf(`INSERT INTO %s (%s) VALUES (%s) ON CONFLICT(%s) DO UPDATE SET %s`, + valuesStr := strings.Builder{} + separatorVar = separator + colPlaceHolders := colPlaceHoldersStr.String() + for i := 0; i < count; i++ { + if i == count-1 { + separatorVar = "" + } + valuesStr.WriteString(fmt.Sprintf("(%s)%s", colPlaceHolders, separatorVar)) + } + + s := fmt.Sprintf(`INSERT INTO %s (%s) VALUES %s ON CONFLICT(%s) DO UPDATE SET %s`, tableName, columnsStr.String(), - colPlaceHoldersStr.String(), + valuesStr.String(), onConflictStr.String(), setStr.String(), ) - return s + return s, nil } diff --git a/pkg/services/sqlstore/migrator/upsert_test.go b/pkg/services/sqlstore/migrator/upsert_test.go new file mode 100644 index 00000000000..9c39673d8d6 --- /dev/null +++ b/pkg/services/sqlstore/migrator/upsert_test.go @@ -0,0 +1,74 @@ +package migrator + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUpsertMultiple(t *testing.T) { + tests := []struct { + name string + keyCols []string + updateCols []string + count int + expectedErr bool + expectedPostgresQuery string + expectedMySQLQuery string + expectedSQLiteQuery string + }{ + { + "upsert one", + []string{"key1", "key2"}, + []string{"key1", "key2", "val1", "val2"}, + 1, + false, + "INSERT INTO test_table (\"key1\", \"key2\", \"val1\", \"val2\") VALUES (?, ?, ?, ?) ON CONFLICT(\"key1\", \"key2\") DO UPDATE SET \"key1\"=excluded.\"key1\", \"key2\"=excluded.\"key2\", \"val1\"=excluded.\"val1\", \"val2\"=excluded.\"val2\"", + "INSERT INTO test_table (`key1`, `key2`, `val1`, `val2`) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE `key1`=VALUES(`key1`), `key2`=VALUES(`key2`), `val1`=VALUES(`val1`), `val2`=VALUES(`val2`)", + "INSERT INTO test_table (`key1`, `key2`, `val1`, `val2`) VALUES (?, ?, ?, ?) ON CONFLICT(`key1`, `key2`) DO UPDATE SET `key1`=excluded.`key1`, `key2`=excluded.`key2`, `val1`=excluded.`val1`, `val2`=excluded.`val2`", + }, + { + "upsert two", + []string{"key1", "key2"}, + []string{"key1", "key2", "val1", "val2"}, + 2, + false, + "INSERT INTO test_table (\"key1\", \"key2\", \"val1\", \"val2\") VALUES (?, ?, ?, ?), (?, ?, ?, ?) ON CONFLICT(\"key1\", \"key2\") DO UPDATE SET \"key1\"=excluded.\"key1\", \"key2\"=excluded.\"key2\", \"val1\"=excluded.\"val1\", \"val2\"=excluded.\"val2\"", + "INSERT INTO test_table (`key1`, `key2`, `val1`, `val2`) VALUES (?, ?, ?, ?), (?, ?, ?, ?) ON DUPLICATE KEY UPDATE `key1`=VALUES(`key1`), `key2`=VALUES(`key2`), `val1`=VALUES(`val1`), `val2`=VALUES(`val2`)", + "INSERT INTO test_table (`key1`, `key2`, `val1`, `val2`) VALUES (?, ?, ?, ?), (?, ?, ?, ?) ON CONFLICT(`key1`, `key2`) DO UPDATE SET `key1`=excluded.`key1`, `key2`=excluded.`key2`, `val1`=excluded.`val1`, `val2`=excluded.`val2`", + }, + { + "count error", + []string{"key1", "key2"}, + []string{"key1", "key2", "val1", "val2"}, + 0, + true, + "", + "", + "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var db Dialect + db = &PostgresDialect{} + q, err := db.UpsertMultipleSQL("test_table", tc.keyCols, tc.updateCols, tc.count) + + require.True(t, (err != nil) == tc.expectedErr) + require.Equal(t, tc.expectedPostgresQuery, q, "Postgres query incorrect") + + db = &MySQLDialect{} + q, err = db.UpsertMultipleSQL("test_table", tc.keyCols, tc.updateCols, tc.count) + + require.True(t, (err != nil) == tc.expectedErr) + require.Equal(t, tc.expectedMySQLQuery, q, "MySQL query incorrect") + + db = &SQLite3{} + q, err = db.UpsertMultipleSQL("test_table", tc.keyCols, tc.updateCols, tc.count) + + require.True(t, (err != nil) == tc.expectedErr) + require.Equal(t, tc.expectedSQLiteQuery, q, "SQLite query incorrect") + }) + } +}