diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/args.go b/pkg/services/store/entity/sqlstash/sqltemplate/args.go index a44eed66758..9f95a6ff5f2 100644 --- a/pkg/services/store/entity/sqlstash/sqltemplate/args.go +++ b/pkg/services/store/entity/sqlstash/sqltemplate/args.go @@ -12,3 +12,7 @@ func (a *Args) Arg(x any) string { *a = append(*a, x) return "?" } + +func (a *Args) GetArgs() Args { + return *a +} diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/args_test.go b/pkg/services/store/entity/sqlstash/sqltemplate/args_test.go index 6394868d218..ac0e80f66be 100644 --- a/pkg/services/store/entity/sqlstash/sqltemplate/args_test.go +++ b/pkg/services/store/entity/sqlstash/sqltemplate/args_test.go @@ -1,8 +1,6 @@ package sqltemplate -import ( - "testing" -) +import "testing" func TestArgs_Arg(t *testing.T) { t.Parallel() @@ -22,7 +20,7 @@ func TestArgs_Arg(t *testing.T) { shouldBeQuestionMark(t, a.Arg(3)) shouldBeQuestionMark(t, a.Arg(4)) - for i, arg := range *a { + for i, arg := range a.GetArgs() { v, ok := arg.(int) if !ok { t.Fatalf("unexpected value: %T(%v)", arg, arg) diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/dialect.go b/pkg/services/store/entity/sqlstash/sqltemplate/dialect.go index 18dc0e9c7f7..5bb0967f0ff 100644 --- a/pkg/services/store/entity/sqlstash/sqltemplate/dialect.go +++ b/pkg/services/store/entity/sqlstash/sqltemplate/dialect.go @@ -29,7 +29,7 @@ type Dialect interface { // SELECT * // FROM mytab // WHERE id = ? - // {{ .SelectFor Update NoWait }}; -- will be uppercased + // {{ .SelectFor "Update NoWait" }}; -- will be uppercased SelectFor(...string) (string, error) } @@ -85,7 +85,7 @@ func (rlc rowLockingClauseAll) SelectFor(s ...string) (string, error) { return "", nil } - return string(o), nil + return "FOR " + string(o), nil } // standardIdent provides standard SQL escaping of identifiers. diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_mysql.go b/pkg/services/store/entity/sqlstash/sqltemplate/dialect_mysql.go index 540b7b21870..3db4dbc5802 100644 --- a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_mysql.go +++ b/pkg/services/store/entity/sqlstash/sqltemplate/dialect_mysql.go @@ -5,14 +5,13 @@ package sqltemplate // Modes see: // // https://dev.mysql.com/doc/refman/8.4/en/sql-mode.html#sqlmode_ansi_quotes -var MySQL mysql +var MySQL = mysql{ + rowLockingClauseAll: true, +} var _ Dialect = MySQL type mysql struct { standardIdent -} - -func (mysql) SelectFor(s ...string) (string, error) { - return rowLockingClauseAll(true).SelectFor(s...) + rowLockingClauseAll } diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_mysql_test.go b/pkg/services/store/entity/sqlstash/sqltemplate/dialect_mysql_test.go deleted file mode 100644 index 58a78c73d9a..00000000000 --- a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_mysql_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package sqltemplate - -import "testing" - -func TestMySQL_SelectFor(t *testing.T) { - MySQL.SelectFor() //nolint: errcheck,gosec -} diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_postgresql.go b/pkg/services/store/entity/sqlstash/sqltemplate/dialect_postgresql.go index 2d5c21fef1b..054746c4d33 100644 --- a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_postgresql.go +++ b/pkg/services/store/entity/sqlstash/sqltemplate/dialect_postgresql.go @@ -6,7 +6,9 @@ import ( ) // PostgreSQL is an implementation of Dialect for the PostgreSQL DMBS. -var PostgreSQL postgresql +var PostgreSQL = postgresql{ + rowLockingClauseAll: true, +} var _ Dialect = PostgreSQL @@ -17,6 +19,7 @@ var ( type postgresql struct { standardIdent + rowLockingClauseAll } func (p postgresql) Ident(s string) (string, error) { @@ -28,7 +31,3 @@ func (p postgresql) Ident(s string) (string, error) { return p.standardIdent.Ident(s) } - -func (postgresql) SelectFor(s ...string) (string, error) { - return rowLockingClauseAll(true).SelectFor(s...) -} diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_postgresql_test.go b/pkg/services/store/entity/sqlstash/sqltemplate/dialect_postgresql_test.go index 3ffae1a3bbc..11a7dcc2f41 100644 --- a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_postgresql_test.go +++ b/pkg/services/store/entity/sqlstash/sqltemplate/dialect_postgresql_test.go @@ -5,10 +5,6 @@ import ( "testing" ) -func TestPostgreSQL_SelectFor(t *testing.T) { - PostgreSQL.SelectFor() //nolint: errcheck,gosec -} - func TestPostgreSQL_Ident(t *testing.T) { t.Parallel() diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_sqlite.go b/pkg/services/store/entity/sqlstash/sqltemplate/dialect_sqlite.go index 4540dd3b797..b55cc42a868 100644 --- a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_sqlite.go +++ b/pkg/services/store/entity/sqlstash/sqltemplate/dialect_sqlite.go @@ -1,7 +1,9 @@ package sqltemplate // SQLite is an implementation of Dialect for the SQLite DMBS. -var SQLite sqlite +var SQLite = sqlite{ + rowLockingClauseAll: false, +} var _ Dialect = SQLite @@ -9,8 +11,5 @@ type sqlite struct { // See: // https://www.sqlite.org/lang_keywords.html standardIdent -} - -func (sqlite) SelectFor(s ...string) (string, error) { - return rowLockingClauseAll(false).SelectFor(s...) + rowLockingClauseAll } diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_sqlite_test.go b/pkg/services/store/entity/sqlstash/sqltemplate/dialect_sqlite_test.go deleted file mode 100644 index 8869cff7a5a..00000000000 --- a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_sqlite_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package sqltemplate - -import "testing" - -func TestSQLite_SelectFor(t *testing.T) { - SQLite.SelectFor() //nolint: errcheck,gosec -} diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_test.go b/pkg/services/store/entity/sqlstash/sqltemplate/dialect_test.go index a481d436601..4f775ab77aa 100644 --- a/pkg/services/store/entity/sqlstash/sqltemplate/dialect_test.go +++ b/pkg/services/store/entity/sqlstash/sqltemplate/dialect_test.go @@ -91,7 +91,7 @@ func TestRowLockingClauseAll_SelectFor(t *testing.T) { { input: splitSpace(string(SelectForShare)), - output: SelectForShare, + output: "FOR " + SelectForShare, }, } diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/example_test.go b/pkg/services/store/entity/sqlstash/sqltemplate/example_test.go index 25c6b8f486d..f8edb6017d6 100644 --- a/pkg/services/store/entity/sqlstash/sqltemplate/example_test.go +++ b/pkg/services/store/entity/sqlstash/sqltemplate/example_test.go @@ -2,6 +2,7 @@ package sqltemplate import ( "fmt" + "regexp" "strings" "text/template" ) @@ -17,105 +18,139 @@ import ( // To learn more about Go's runnable tests, which are a core builtin feature of // Go's standard testing library, see: // https://pkg.go.dev/testing#hdr-Examples +// +// If you're unfamiliar with Go text templating language, please, consider +// reading that library's documentation first. // In this example we will use both Args and Dialect to dynamically and securely // build SQL queries, while also keeping track of the arguments that need to be // passed to the database methods to replace the placeholder "?" with the -// correct values. If you're not familiar with Go text templating language, -// please, consider reading that library's documentation first. - -// We will start with creating a simple text template to insert a new row into a -// users table: -var createUserTmpl = template.Must(template.New("query").Parse(` - INSERT INTO users (id, {{ .Ident "type" }}, name) - VALUES ({{ .Arg .ID }}, {{ .Arg .Type }}, {{ .Arg .Name}}); -`)) +// correct values. + +// We will start by assuming we receive a request to retrieve a user's +// information and that we need to provide a certain response. + +type GetUserRequest struct { + ID int +} -// The two interesting methods here are Arg and Ident. Note that now we have a -// reusable text template, that will dynamically create the SQL code when -// executed, which is interesting because we have a SQL-implementation dependant -// code handled for us within the template (escaping the reserved word "type"), -// but also because the arguments to the database Exec method will be handled -// for us. The struct with the data needed to create a new user could be -// something like the following: -type CreateUserRequest struct { +type GetUserResponse struct { ID int - Name string Type string + Name string } -// Note that this struct could actually come from a different definition, for -// example, from a DTO. We can reuse this DTO and create a smaller struct for -// the purpose of writing to the database without the need of mapping: -type DBCreateUserRequest struct { - Dialect // provides access to all Dialect methods, like Ident - *Args // provides access to Arg method, to keep track of db arguments - *CreateUserRequest +// Our template will take care for us of taking the request to build the query, +// and then sort the arguments for execution as well as preparing the values +// that need to be read for the response. We wil create a struct to pass the +// request and an empty response, as well as a *SQLTemplate that will provide +// the methods to achieve our purpose:: + +type GetUserQuery struct { + *SQLTemplate + Request *GetUserRequest + Response *GetUserResponse } +// And finally we will define our template, that is free to use all the power of +// the Go templating language, plus the methods we added with *SQLTemplate: +var getUserTmpl = template.Must(template.New("example").Parse(` + SELECT + {{ .Ident "id" | .Into .Response.ID }}, + {{ .Ident "type" | .Into .Response.Type }}, + {{ .Ident "name" | .Into .Response.Name }} + + FROM {{ .Ident "users" }} + WHERE + {{ .Ident "id" }} = {{ .Arg .Request.ID }}; +`)) + +// There are three interesting methods used in the above template: +// 1. Ident: safely escape a SQL identifier. Even though here the only +// identifier that may be problematic is "type" (because it is a reserved +// word in many dialects), it is a good practice to escape all identifiers +// just to make sure we're accounting for all variability in dialects, and +// also for consistency. +// 2. Into: this causes the selected field to be saved to the corresponding +// field of GetUserQuery. +// 3. Arg: this allows us to state that at this point will be a "?" that has to +// be populated with the value of the given field of GetUserQuery. + func Example() { - // Finally, we can take a request received from a user like the following: - dto := &CreateUserRequest{ - ID: 1, - Name: "root", - Type: "admin", - } + // Let's pretend this example function is the handler of the GetUser method + // of our service to see how it all works together. + + queryData := &GetUserQuery{ + // The dialect (in this case we chose MySQL) should be set in your + // service at startup when you connect to your database + SQLTemplate: New(MySQL), + + // This is a synthetic request for our test + Request: &GetUserRequest{ + ID: 1, + }, - // Put it into a database request: - req := DBCreateUserRequest{ - Dialect: SQLite, // set at runtime, the template is agnostic - Args: new(Args), - CreateUserRequest: dto, + // Create an empty response to be populated + Response: new(GetUserResponse), } - // Then we finally execute the template to both generate the SQL code and to - // populate req.Args with the arguments: - var b strings.Builder - err := createUserTmpl.Execute(&b, req) + // The next step is to execute the query template for our queryData, and + // generate the arguments for the db.QueryRow and row.Scan methods later + query, err := Execute(getUserTmpl, queryData) if err != nil { panic(err) // terminate the runnable example on error } - // And we should finally be able to see the SQL generated, as well as - // getting the arguments populated for execution in a database. To execute - // it in the databse, we could run: - // db.ExecContext(ctx, b.String(), req.Args...) + // Assuming that we have a *sql.DB object named "db", we could now make our + // query with: + // row := db.QueryRowContext(ctx, query, queryData.Args...) + // // and check row.Err() here - // To provide the runnable example with some code to test, we will now print - // the values to standard output: - fmt.Println(b.String()) - fmt.Printf("%#v", req.Args) - - // Output: - // INSERT INTO users (id, "type", name) - // VALUES (?, ?, ?); - // - // &sqltemplate.Args{1, "admin", "root"} -} - -// A more complex template example follows, which should be self-explanatory -// given the previous example. It is left as an exercise to the reader how the -// code should be implemented, based on the ExampleCreateUser function. + // As we're not actually running a database in this example, let's verify + // that we find our arguments populated as expected instead: + if len(queryData.Args) != 1 { + panic(fmt.Sprintf("unexpected number of args: %#v", queryData.Args)) + } + id, ok := queryData.Args[0].(int) + if !ok || id != queryData.Request.ID { + panic(fmt.Sprintf("unexpected args: %#v", queryData.Args)) + } -// List users example. -var _ = template.Must(template.New("query").Parse(` - SELECT id, {{ .Ident "type" }}, name - FROM users - WHERE - {{ if eq .By "type" }} - {{ .Ident "type" }} = {{ .Arg .Value }} - {{ else if eq .By "name" }} - name LIKE {{ .Arg .Value }} - {{ end }}; -`)) + // In your code you would now have "row" populated with the row data, + // assuming that the operation succeeded, so you would now scan the row data + // abd populate the values of our response: + // err := row.Scan(queryData.ScanDest...) + // // and check err here + + // Again, as we're not actually running a database in this example, we will + // instead run the code to assert that queryData.ScanDest was populated with + // the expected data, which should be pointers to each of the fields of + // Response so that the Scan method can write to them: + if len(queryData.ScanDest) != 3 { + panic(fmt.Sprintf("unexpected number of scan dest: %#v", queryData.ScanDest)) + } + idPtr, ok := queryData.ScanDest[0].(*int) + if !ok || idPtr != &queryData.Response.ID { + panic(fmt.Sprintf("unexpected response 'id' pointer: %#v", queryData.ScanDest)) + } + typePtr, ok := queryData.ScanDest[1].(*string) + if !ok || typePtr != &queryData.Response.Type { + panic(fmt.Sprintf("unexpected response 'type' pointer: %#v", queryData.ScanDest)) + } + namePtr, ok := queryData.ScanDest[2].(*string) + if !ok || namePtr != &queryData.Response.Name { + panic(fmt.Sprintf("unexpected response 'name' pointer: %#v", queryData.ScanDest)) + } -type ListUsersRequest struct { - By string - Value string -} + // Remember the variable "query"? Well, we didn't check it. We will now make + // use of Go's runnable examples and print its contents to standard output + // so Go's tooling verify this example's output each time we run tests. + // By the way, to make the result more stable, we will remove some + // unnecessary white space from the query. + whiteSpaceRE := regexp.MustCompile(`\s+`) + query = strings.TrimSpace(whiteSpaceRE.ReplaceAllString(query, " ")) + fmt.Println(query) -type DBListUsersRequest struct { - Dialect - *Args - ListUsersRequest + // Output: + // SELECT "id", "type", "name" FROM "users" WHERE "id" = ?; } diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/into.go b/pkg/services/store/entity/sqlstash/sqltemplate/into.go new file mode 100644 index 00000000000..22b98423c71 --- /dev/null +++ b/pkg/services/store/entity/sqlstash/sqltemplate/into.go @@ -0,0 +1,22 @@ +package sqltemplate + +import ( + "fmt" + "reflect" +) + +type ScanDest []any + +func (i *ScanDest) Into(v reflect.Value, colName string) (string, error) { + if !v.IsValid() || !v.CanAddr() || !v.Addr().CanInterface() { + return "", fmt.Errorf("invalid or unaddressable value: %v", colName) + } + + *i = append(*i, v.Addr().Interface()) + + return colName, nil +} + +func (i *ScanDest) GetScanDest() ScanDest { + return *i +} diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/into_test.go b/pkg/services/store/entity/sqlstash/sqltemplate/into_test.go new file mode 100644 index 00000000000..dcb11383fd6 --- /dev/null +++ b/pkg/services/store/entity/sqlstash/sqltemplate/into_test.go @@ -0,0 +1,36 @@ +package sqltemplate + +import ( + "reflect" + "testing" +) + +func TestScanDest_Into(t *testing.T) { + t.Parallel() + + var d ScanDest + + colName, err := d.Into(reflect.Value{}, "some field") + if colName != "" || err == nil || len(d.GetScanDest()) != 0 { + t.Fatalf("unexpected outcome, got colname %q, err: %v, scan dest: %#v", + colName, err, d) + } + + data := struct { + X int + Y byte + }{} + dataVal := reflect.ValueOf(&data).Elem() + + colName, err = d.Into(dataVal.FieldByName("X"), "some int") + if err != nil || colName != "some int" || len(d) != 1 || d[0] != &data.X { + t.Fatalf("unexpected outcome, got colname %q, err: %v, scan dest: %#v", + colName, err, d) + } + + colName, err = d.Into(dataVal.FieldByName("Y"), "some byte") + if err != nil || colName != "some byte" || len(d) != 2 || d[1] != &data.Y { + t.Fatalf("unexpected outcome, got colname %q, err: %v, scan dest: %#v", + colName, err, d) + } +} diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/sqltemplate.go b/pkg/services/store/entity/sqlstash/sqltemplate/sqltemplate.go new file mode 100644 index 00000000000..0b2ee997015 --- /dev/null +++ b/pkg/services/store/entity/sqlstash/sqltemplate/sqltemplate.go @@ -0,0 +1,35 @@ +package sqltemplate + +import ( + "strings" + "text/template" +) + +type SQLTemplate struct { + Dialect + Args + ScanDest +} + +func New(d Dialect) *SQLTemplate { + return &SQLTemplate{ + Dialect: d, + } +} + +type SQLTemplateIface interface { + Dialect + GetArgs() Args + GetScanDest() ScanDest +} + +// Execute is a trivial utility to execute and return the results of any +// text/template as a string and an error. +func Execute(t *template.Template, data any) (string, error) { + var b strings.Builder + if err := t.Execute(&b, data); err != nil { + return "", err + } + + return b.String(), nil +} diff --git a/pkg/services/store/entity/sqlstash/sqltemplate/sqltemplate_test.go b/pkg/services/store/entity/sqlstash/sqltemplate/sqltemplate_test.go new file mode 100644 index 00000000000..f5f0a9e5d70 --- /dev/null +++ b/pkg/services/store/entity/sqlstash/sqltemplate/sqltemplate_test.go @@ -0,0 +1,28 @@ +package sqltemplate + +import ( + "testing" + "text/template" +) + +func TestExecute(t *testing.T) { + t.Parallel() + + tmpl := template.Must(template.New("test").Parse(`{{ .ID }}`)) + + data := struct { + ID int + }{ + ID: 1, + } + + txt, err := Execute(tmpl, data) + if txt != "1" || err != nil { + t.Fatalf("unexpected error, txt: %q, err: %v", txt, err) + } + + txt, err = Execute(tmpl, 1) + if txt != "" || err == nil { + t.Fatalf("unexpected result, txt: %q, err: %v", txt, err) + } +}