sqltemplate, dbimpl: Remove single-method function types (#107525)

* Remove dbProviderFunc function.

This removes one extra indirection that made the code bit more difficult to navigate.

* Remove indirection function types implementing single-method interfaces.

This streamlines the code and makes it bit easier to navigate.

* Update pkg/storage/unified/sql/sqltemplate/dialect_mysql.go

Co-authored-by: Mustafa Sencer Özcan <32759850+mustafasencer@users.noreply.github.com>

---------

Co-authored-by: Mustafa Sencer Özcan <32759850+mustafasencer@users.noreply.github.com>
pull/107389/head^2
Peter Štibraný 3 weeks ago committed by GitHub
parent a68f8107df
commit e076c74869
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 29
      pkg/storage/unified/sql/db/dbimpl/dbimpl.go
  2. 2
      pkg/storage/unified/sql/sqltemplate/args_test.go
  3. 29
      pkg/storage/unified/sql/sqltemplate/dialect.go
  4. 31
      pkg/storage/unified/sql/sqltemplate/dialect_mysql.go
  5. 29
      pkg/storage/unified/sql/sqltemplate/dialect_postgresql.go
  6. 26
      pkg/storage/unified/sql/sqltemplate/dialect_sqlite.go
  7. 46
      pkg/storage/unified/sql/sqltemplate/dialect_test.go

@ -43,21 +43,7 @@ func ProvideResourceDB(grafanaDB infraDB.DB, cfg *setting.Cfg, tracer trace.Trac
if err != nil {
return nil, fmt.Errorf("provide Resource DB: %w", err)
}
var once sync.Once
var resourceDB db.DB
return dbProviderFunc(func(ctx context.Context) (db.DB, error) {
once.Do(func() {
resourceDB, err = p.init(ctx)
})
return resourceDB, err
}), nil
}
type dbProviderFunc func(context.Context) (db.DB, error)
func (f dbProviderFunc) Init(ctx context.Context) (db.DB, error) {
return f(ctx)
return p, nil
}
type resourceDBProvider struct {
@ -68,6 +54,10 @@ type resourceDBProvider struct {
tracer trace.Tracer
registerMetrics bool
logQueries bool
once sync.Once
resourceDB db.DB
initErr error
}
func newResourceDBProvider(grafanaDB infraDB.DB, cfg *setting.Cfg, tracer trace.Tracer) (p *resourceDBProvider, err error) {
@ -124,7 +114,14 @@ func newResourceDBProvider(grafanaDB infraDB.DB, cfg *setting.Cfg, tracer trace.
}
}
func (p *resourceDBProvider) init(ctx context.Context) (db.DB, error) {
func (p *resourceDBProvider) Init(ctx context.Context) (db.DB, error) {
p.once.Do(func() {
p.resourceDB, p.initErr = p.initDB(ctx)
})
return p.resourceDB, p.initErr
}
func (p *resourceDBProvider) initDB(ctx context.Context) (db.DB, error) {
p.log.Info("Initializing Resource DB",
"db_type",
p.engine.Dialect().DriverName(),

@ -71,7 +71,7 @@ func TestArg_ArgList(t *testing.T) {
}
var a args
a.d = argFmtSQL92
a.d = MySQL
for i, tc := range testCases {
a.Reset()

@ -3,7 +3,6 @@ package sqltemplate
import (
"bytes"
"errors"
"strconv"
"strings"
)
@ -92,7 +91,6 @@ func ParseRowLockingClause(s ...string) (RowLockingClause, error) {
return opt, nil
}
// Row-locking clause options.
const (
SelectForShare RowLockingClause = "SHARE"
SelectForShareNoWait RowLockingClause = "SHARE NOWAIT"
@ -129,9 +127,6 @@ var rowLockingClauseAll = rowLockingClauseMap{
SelectForUpdateSkipLocked: SelectForUpdateSkipLocked,
}
// standardIdent provides standard SQL escaping of identifiers.
type standardIdent struct{}
func escapeIdentity(s string, quote rune, clean func(string) string) (string, error) {
if s == "" {
return "", ErrEmptyIdent
@ -154,31 +149,11 @@ func escapeIdentity(s string, quote rune, clean func(string) string) (string, er
return buffer.String(), nil
}
func (standardIdent) Ident(s string) (string, error) {
// standardIdent provides standard SQL escaping of identifiers.
func standardIdent(s string) (string, error) {
return escapeIdentity(s, '"', func(s string) string {
// not sure we should support escaping quotes in table/column names,
// but it is valid so we will support it for now
return strings.ReplaceAll(s, `"`, `""`)
})
}
type argPlaceholderFunc func(int) string
func (f argPlaceholderFunc) ArgPlaceholder(argNum int) string {
return f(argNum)
}
var (
argFmtSQL92 = argPlaceholderFunc(func(int) string {
return "?"
})
argFmtPositional = argPlaceholderFunc(func(argNum int) string {
return "$" + strconv.Itoa(argNum)
})
)
type name string
func (n name) DialectName() string {
return string(n)
}

@ -6,26 +6,17 @@ import (
// MySQL is the default implementation of Dialect for the MySQL DMBS,
// currently supporting MySQL-8.x.
var MySQL = mysql{
rowLockingClauseMap: rowLockingClauseAll,
argPlaceholderFunc: argFmtSQL92,
name: "mysql",
}
var MySQL = mysql{}
var _ Dialect = MySQL
type mysql struct{}
type mysql struct {
backtickIdent
rowLockingClauseMap
argPlaceholderFunc
name
func (m mysql) DialectName() string {
return "mysql"
}
// MySQL always supports backticks for identifiers
// https://dev.mysql.com/doc/refman/8.4/en/identifiers.html
type backtickIdent struct{}
func (backtickIdent) Ident(s string) (string, error) {
func (m mysql) Ident(s string) (string, error) {
// MySQL always supports backticks for identifiers
// https://dev.mysql.com/doc/refman/8.4/en/identifiers.html
if strings.ContainsRune(s, '`') {
return "", ErrInvalidIdentInput
}
@ -34,6 +25,14 @@ func (backtickIdent) Ident(s string) (string, error) {
})
}
func (m mysql) ArgPlaceholder(argNum int) string {
return "?"
}
func (m mysql) SelectFor(s ...string) (string, error) {
return rowLockingClauseAll.SelectFor(s...)
}
func (mysql) CurrentEpoch() string {
return "CAST(FLOOR(UNIX_TIMESTAMP(NOW(6)) * 1000000) AS SIGNED)"
}

@ -2,28 +2,29 @@ package sqltemplate
import (
"errors"
"fmt"
"strings"
)
// PostgreSQL is an implementation of Dialect for the PostgreSQL DMBS.
var PostgreSQL = postgresql{
rowLockingClauseMap: rowLockingClauseAll,
argPlaceholderFunc: argFmtPositional,
name: "postgres",
}
var _ Dialect = PostgreSQL
var PostgreSQL = postgresql{}
// PostgreSQL-specific errors.
var (
ErrPostgreSQLUnsupportedIdent = errors.New("identifiers in PostgreSQL cannot contain the character with code zero")
)
type postgresql struct {
standardIdent
rowLockingClauseMap
argPlaceholderFunc
name
type postgresql struct{}
func (p postgresql) DialectName() string {
return "postgres"
}
func (p postgresql) ArgPlaceholder(argNum int) string {
return fmt.Sprintf("$%d", argNum)
}
func (p postgresql) SelectFor(s ...string) (string, error) {
return rowLockingClauseAll.SelectFor(s...)
}
func (p postgresql) Ident(s string) (string, error) {
@ -33,7 +34,7 @@ func (p postgresql) Ident(s string) (string, error) {
return "", ErrPostgreSQLUnsupportedIdent
}
return p.standardIdent.Ident(s)
return standardIdent(s)
}
func (postgresql) CurrentEpoch() string {

@ -1,20 +1,26 @@
package sqltemplate
// SQLite is an implementation of Dialect for the SQLite DMBS.
var SQLite = sqlite{
argPlaceholderFunc: argFmtSQL92,
name: "sqlite",
}
var SQLite = sqlite{}
type sqlite struct{}
var _ Dialect = SQLite
func (s sqlite) DialectName() string {
return "sqlite"
}
type sqlite struct {
func (s sqlite) Ident(i string) (string, error) {
// See:
// https://www.sqlite.org/lang_keywords.html
standardIdent
rowLockingClauseMap
argPlaceholderFunc
name
return standardIdent(i)
}
func (s sqlite) ArgPlaceholder(argNum int) string {
return "?"
}
func (s sqlite) SelectFor(s2 ...string) (string, error) {
return rowLockingClauseMap(nil).SelectFor(s2...)
}
func (sqlite) CurrentEpoch() string {

@ -6,6 +6,10 @@ import (
"testing"
)
var _ Dialect = MySQL
var _ Dialect = SQLite
var _ Dialect = PostgreSQL
func TestSelectForOption_Valid(t *testing.T) {
t.Parallel()
@ -133,7 +137,7 @@ func TestStandardIdent_Ident(t *testing.T) {
}
for i, tc := range testCases {
gotOutput, gotErr := standardIdent{}.Ident(tc.input)
gotOutput, gotErr := standardIdent(tc.input)
if !errors.Is(gotErr, tc.err) {
t.Fatalf("unexpected error %v in test case %d", gotErr, i)
}
@ -142,43 +146,3 @@ func TestStandardIdent_Ident(t *testing.T) {
}
}
}
func TestArgPlaceholderFunc(t *testing.T) {
t.Parallel()
testCases := []struct {
input int
valuePositional string
}{
{
input: 1,
valuePositional: "$1",
},
{
input: 16,
valuePositional: "$16",
},
}
for i, tc := range testCases {
got := argFmtSQL92(tc.input)
if got != "?" {
t.Fatalf("[argFmtSQL92] unexpected value %q in test case %d", got, i)
}
got = argFmtPositional(tc.input)
if got != tc.valuePositional {
t.Fatalf("[argFmtPositional] unexpected value %q in test case %d", got, i)
}
}
}
func TestName_Name(t *testing.T) {
t.Parallel()
const v = "some dialect name"
n := name(v)
if n.DialectName() != v {
t.Fatalf("unexpected dialect name %q", n.DialectName())
}
}

Loading…
Cancel
Save