mirror of https://github.com/grafana/grafana
parent
c1798320d2
commit
8f44e1a349
@ -1 +0,0 @@ |
||||
package auth |
||||
@ -0,0 +1,176 @@ |
||||
package identity |
||||
|
||||
import "fmt" |
||||
|
||||
var _ Requester = &StaticRequester{} |
||||
|
||||
// StaticRequester is helpful for tests
|
||||
// This is mostly copied from:
|
||||
// https://github.com/grafana/grafana/blob/v11.0.0/pkg/services/user/identity.go#L16
|
||||
type StaticRequester struct { |
||||
Namespace Namespace |
||||
UserID int64 |
||||
UserUID string |
||||
OrgID int64 |
||||
OrgName string |
||||
OrgRole RoleType |
||||
Login string |
||||
Name string |
||||
DisplayName string |
||||
Email string |
||||
EmailVerified bool |
||||
AuthID string |
||||
AuthenticatedBy string |
||||
IsGrafanaAdmin bool |
||||
IsAnonymous bool |
||||
IsDisabled bool |
||||
// Permissions grouped by orgID and actions
|
||||
Permissions map[int64]map[string][]string |
||||
IDToken string |
||||
CacheKey string |
||||
} |
||||
|
||||
func (u *StaticRequester) HasRole(role RoleType) bool { |
||||
if u.IsGrafanaAdmin { |
||||
return true |
||||
} |
||||
|
||||
return u.OrgRole.Includes(role) |
||||
} |
||||
|
||||
// GetIsGrafanaAdmin returns true if the user is a server admin
|
||||
func (u *StaticRequester) GetIsGrafanaAdmin() bool { |
||||
return u.IsGrafanaAdmin |
||||
} |
||||
|
||||
// GetLogin returns the login of the active entity
|
||||
// Can be empty if the user is anonymous
|
||||
func (u *StaticRequester) GetLogin() string { |
||||
return u.Login |
||||
} |
||||
|
||||
// GetOrgID returns the ID of the active organization
|
||||
func (u *StaticRequester) GetOrgID() int64 { |
||||
return u.OrgID |
||||
} |
||||
|
||||
// DEPRECATED: GetOrgName returns the name of the active organization
|
||||
// Retrieve the organization name from the organization service instead of using this method.
|
||||
func (u *StaticRequester) GetOrgName() string { |
||||
return u.OrgName |
||||
} |
||||
|
||||
// GetPermissions returns the permissions of the active entity
|
||||
func (u *StaticRequester) GetPermissions() map[string][]string { |
||||
if u.Permissions == nil { |
||||
return make(map[string][]string) |
||||
} |
||||
|
||||
if u.Permissions[u.GetOrgID()] == nil { |
||||
return make(map[string][]string) |
||||
} |
||||
|
||||
return u.Permissions[u.GetOrgID()] |
||||
} |
||||
|
||||
// GetGlobalPermissions returns the permissions of the active entity that are available across all organizations
|
||||
func (u *StaticRequester) GetGlobalPermissions() map[string][]string { |
||||
if u.Permissions == nil { |
||||
return make(map[string][]string) |
||||
} |
||||
|
||||
const globalOrgID = 0 |
||||
|
||||
if u.Permissions[globalOrgID] == nil { |
||||
return make(map[string][]string) |
||||
} |
||||
|
||||
return u.Permissions[globalOrgID] |
||||
} |
||||
|
||||
// DEPRECATED: GetTeams returns the teams the entity is a member of
|
||||
// Retrieve the teams from the team service instead of using this method.
|
||||
func (u *StaticRequester) GetTeams() []int64 { |
||||
return []int64{} // Not implemented
|
||||
} |
||||
|
||||
// GetOrgRole returns the role of the active entity in the active organization
|
||||
func (u *StaticRequester) GetOrgRole() RoleType { |
||||
return u.OrgRole |
||||
} |
||||
|
||||
// HasUniqueId returns true if the entity has a unique id
|
||||
func (u *StaticRequester) HasUniqueId() bool { |
||||
return u.UserID > 0 |
||||
} |
||||
|
||||
// GetID returns namespaced id for the entity
|
||||
func (u *StaticRequester) GetID() NamespaceID { |
||||
return NewNamespaceIDString(u.Namespace, fmt.Sprintf("%d", u.UserID)) |
||||
} |
||||
|
||||
// GetUID returns namespaced uid for the entity
|
||||
func (u *StaticRequester) GetUID() NamespaceID { |
||||
return NewNamespaceIDString(u.Namespace, u.UserUID) |
||||
} |
||||
|
||||
// GetNamespacedID returns the namespace and ID of the active entity
|
||||
// The namespace is one of the constants defined in pkg/apimachinery/identity
|
||||
func (u *StaticRequester) GetNamespacedID() (Namespace, string) { |
||||
return u.Namespace, fmt.Sprintf("%d", u.UserID) |
||||
} |
||||
|
||||
func (u *StaticRequester) GetAuthID() string { |
||||
return u.AuthID |
||||
} |
||||
|
||||
func (u *StaticRequester) GetAuthenticatedBy() string { |
||||
return u.AuthenticatedBy |
||||
} |
||||
|
||||
func (u *StaticRequester) IsAuthenticatedBy(providers ...string) bool { |
||||
for _, p := range providers { |
||||
if u.AuthenticatedBy == p { |
||||
return true |
||||
} |
||||
} |
||||
return false |
||||
} |
||||
|
||||
// FIXME: remove this method once all services are using an interface
|
||||
func (u *StaticRequester) IsNil() bool { |
||||
return u == nil |
||||
} |
||||
|
||||
// GetEmail returns the email of the active entity
|
||||
// Can be empty.
|
||||
func (u *StaticRequester) GetEmail() string { |
||||
return u.Email |
||||
} |
||||
|
||||
func (u *StaticRequester) IsEmailVerified() bool { |
||||
return u.EmailVerified |
||||
} |
||||
|
||||
func (u *StaticRequester) GetCacheKey() string { |
||||
return u.CacheKey |
||||
} |
||||
|
||||
// GetDisplayName returns the display name of the active entity
|
||||
// The display name is the name if it is set, otherwise the login or email
|
||||
func (u *StaticRequester) GetDisplayName() string { |
||||
if u.DisplayName != "" { |
||||
return u.DisplayName |
||||
} |
||||
if u.Name != "" { |
||||
return u.Name |
||||
} |
||||
if u.Login != "" { |
||||
return u.Login |
||||
} |
||||
return u.Email |
||||
} |
||||
|
||||
func (u *StaticRequester) GetIDToken() string { |
||||
return u.IDToken |
||||
} |
||||
@ -1,222 +0,0 @@ |
||||
package sqlstash |
||||
|
||||
import ( |
||||
"context" |
||||
"database/sql" |
||||
"embed" |
||||
"errors" |
||||
"fmt" |
||||
"strings" |
||||
"text/template" |
||||
|
||||
"github.com/grafana/grafana/pkg/services/store/entity/db" |
||||
"github.com/grafana/grafana/pkg/services/store/entity/sqlstash/sqltemplate" |
||||
) |
||||
|
||||
// Templates setup.
|
||||
var ( |
||||
//go:embed data/*.sql
|
||||
sqlTemplatesFS embed.FS |
||||
|
||||
// all templates
|
||||
helpers = template.FuncMap{ |
||||
"listSep": helperListSep, |
||||
"join": helperJoin, |
||||
} |
||||
sqlTemplates = template.Must(template.New("sql").Funcs(helpers).ParseFS(sqlTemplatesFS, `data/*.sql`)) |
||||
) |
||||
|
||||
func mustTemplate(filename string) *template.Template { |
||||
if t := sqlTemplates.Lookup(filename); t != nil { |
||||
return t |
||||
} |
||||
panic(fmt.Sprintf("template file not found: %s", filename)) |
||||
} |
||||
|
||||
// Templates.
|
||||
var ( |
||||
sqlResourceVersionGet = mustTemplate("rv_get.sql") |
||||
sqlResourceVersionInc = mustTemplate("rv_inc.sql") |
||||
sqlResourceVersionInsert = mustTemplate("rv_insert.sql") |
||||
sqlResourceVersionLock = mustTemplate("rv_lock.sql") |
||||
|
||||
sqlResourceInsert = mustTemplate("resource_insert.sql") |
||||
sqlResourceGet = mustTemplate("resource_get.sql") |
||||
) |
||||
|
||||
// TxOptions.
|
||||
var ( |
||||
ReadCommitted = &sql.TxOptions{ |
||||
Isolation: sql.LevelReadCommitted, |
||||
} |
||||
ReadCommittedRO = &sql.TxOptions{ |
||||
Isolation: sql.LevelReadCommitted, |
||||
ReadOnly: true, |
||||
} |
||||
) |
||||
|
||||
// SQLError is an error returned by the database, which includes additionally
|
||||
// debugging information about what was sent to the database.
|
||||
type SQLError struct { |
||||
Err error |
||||
CallType string // either Query, QueryRow or Exec
|
||||
TemplateName string |
||||
Query string |
||||
RawQuery string |
||||
ScanDest []any |
||||
|
||||
// potentially regulated information is not exported and only directly
|
||||
// available for local testing and local debugging purposes, making sure it
|
||||
// is never marshaled to JSON or any other serialization.
|
||||
|
||||
arguments []any |
||||
} |
||||
|
||||
func (e SQLError) Unwrap() error { |
||||
return e.Err |
||||
} |
||||
|
||||
func (e SQLError) Error() string { |
||||
return fmt.Sprintf("%s: %s with %d input arguments and %d output "+ |
||||
"destination arguments: %v", e.TemplateName, e.CallType, |
||||
len(e.arguments), len(e.ScanDest), e.Err) |
||||
} |
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Resource Version table support
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
type returnsResourceVersion struct { |
||||
ResourceVersion int64 |
||||
} |
||||
|
||||
func (r *returnsResourceVersion) Results() (*returnsResourceVersion, error) { |
||||
return r, nil |
||||
} |
||||
|
||||
type sqlResourceVersionGetRequest struct { |
||||
*sqltemplate.SQLTemplate |
||||
Group string |
||||
Resource string |
||||
*returnsResourceVersion |
||||
} |
||||
|
||||
func (r sqlResourceVersionGetRequest) Validate() error { |
||||
return nil // TODO
|
||||
} |
||||
|
||||
type sqlResourceVersionLockRequest struct { |
||||
*sqltemplate.SQLTemplate |
||||
Group string |
||||
Resource string |
||||
*returnsResourceVersion |
||||
} |
||||
|
||||
func (r sqlResourceVersionLockRequest) Validate() error { |
||||
return nil // TODO
|
||||
} |
||||
|
||||
type sqlResourceVersionIncRequest struct { |
||||
*sqltemplate.SQLTemplate |
||||
Group string |
||||
Resource string |
||||
ResourceVersion int64 |
||||
} |
||||
|
||||
func (r sqlResourceVersionIncRequest) Validate() error { |
||||
return nil // TODO
|
||||
} |
||||
|
||||
type sqlResourceVersionInsertRequest struct { |
||||
*sqltemplate.SQLTemplate |
||||
Group string |
||||
Resource string |
||||
} |
||||
|
||||
func (r sqlResourceVersionInsertRequest) Validate() error { |
||||
return nil // TODO
|
||||
} |
||||
|
||||
// resourceVersionAtomicInc atomically increases the version of a kind within a
|
||||
// transaction.
|
||||
func resourceVersionAtomicInc(ctx context.Context, x db.ContextExecer, d sqltemplate.Dialect, group, resource string) (newVersion int64, err error) { |
||||
// 1. Lock the kind and get the latest version
|
||||
lockReq := sqlResourceVersionLockRequest{ |
||||
SQLTemplate: sqltemplate.New(d), |
||||
Group: group, |
||||
Resource: resource, |
||||
returnsResourceVersion: new(returnsResourceVersion), |
||||
} |
||||
kindv, err := queryRow(ctx, x, sqlResourceVersionLock, lockReq) |
||||
|
||||
// if there wasn't a row associated with the given kind, we create one with
|
||||
// version 1
|
||||
if errors.Is(err, sql.ErrNoRows) { |
||||
// NOTE: there is a marginal chance that we race with another writer
|
||||
// trying to create the same row. This is only possible when onboarding
|
||||
// a new (Group, Resource) to the cell, which should be very unlikely,
|
||||
// and the workaround is simply retrying. The alternative would be to
|
||||
// use INSERT ... ON CONFLICT DO UPDATE ..., but that creates a
|
||||
// requirement for support in Dialect only for this marginal case, and
|
||||
// we would rather keep Dialect as small as possible. Another
|
||||
// alternative is to simply check if the INSERT returns a DUPLICATE KEY
|
||||
// error and then retry the original SELECT, but that also adds some
|
||||
// complexity to the code. That would be preferrable to changing
|
||||
// Dialect, though. The current alternative, just retrying, seems to be
|
||||
// enough for now.
|
||||
insReq := sqlResourceVersionInsertRequest{ |
||||
SQLTemplate: sqltemplate.New(d), |
||||
Group: group, |
||||
Resource: resource, |
||||
} |
||||
if _, err = exec(ctx, x, sqlResourceVersionInsert, insReq); err != nil { |
||||
return 0, fmt.Errorf("insert into kind_version: %w", err) |
||||
} |
||||
|
||||
return 1, nil |
||||
} |
||||
|
||||
if err != nil { |
||||
return 0, fmt.Errorf("lock kind: %w", err) |
||||
} |
||||
|
||||
incReq := sqlResourceVersionIncRequest{ |
||||
SQLTemplate: sqltemplate.New(d), |
||||
Group: group, |
||||
Resource: resource, |
||||
ResourceVersion: kindv.ResourceVersion, |
||||
} |
||||
if _, err = exec(ctx, x, sqlResourceVersionInc, incReq); err != nil { |
||||
return 0, fmt.Errorf("increase kind version: %w", err) |
||||
} |
||||
|
||||
return kindv.ResourceVersion + 1, nil |
||||
} |
||||
|
||||
// Template helpers.
|
||||
|
||||
// helperListSep is a helper that helps writing simpler loops in SQL templates.
|
||||
// Example usage:
|
||||
//
|
||||
// {{ $comma := listSep ", " }}
|
||||
// {{ range .Values }}
|
||||
// {{/* here we put "-" on each end to remove extra white space */}}
|
||||
// {{- call $comma -}}
|
||||
// {{ .Value }}
|
||||
// {{ end }}
|
||||
func helperListSep(sep string) func() string { |
||||
var addSep bool |
||||
|
||||
return func() string { |
||||
if addSep { |
||||
return sep |
||||
} |
||||
addSep = true |
||||
|
||||
return "" |
||||
} |
||||
} |
||||
|
||||
func helperJoin(sep string, elems ...string) string { |
||||
return strings.Join(elems, sep) |
||||
} |
||||
@ -1,25 +0,0 @@ |
||||
{ |
||||
"apiVersion": "playlist.grafana.app/v0alpha1", |
||||
"kind": "Playlist", |
||||
"metadata": { |
||||
"name": "fdgsv37qslr0ga", |
||||
"namespace": "default", |
||||
"annotations": { |
||||
"grafana.app/originName": "elsewhere", |
||||
"grafana.app/originPath": "path/to/item", |
||||
"grafana.app/originTimestamp": "2024-02-02T00:00:00Z" |
||||
}, |
||||
"creationTimestamp": "2024-03-03T00:00:00Z", |
||||
"uid": "8tGrXJgGbFI0" |
||||
}, |
||||
"spec": { |
||||
"title": "hello", |
||||
"interval": "5m", |
||||
"items": [ |
||||
{ |
||||
"type": "dashboard_by_uid", |
||||
"value": "vmie2cmWz" |
||||
} |
||||
] |
||||
} |
||||
} |
||||
@ -1,503 +0,0 @@ |
||||
package sqlstash |
||||
|
||||
import ( |
||||
"database/sql" |
||||
"database/sql/driver" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"regexp" |
||||
"strings" |
||||
"testing" |
||||
"text/template" |
||||
|
||||
sqlmock "github.com/DATA-DOG/go-sqlmock" |
||||
"github.com/stretchr/testify/require" |
||||
|
||||
"github.com/grafana/grafana/pkg/services/store/entity/db" |
||||
"github.com/grafana/grafana/pkg/services/store/entity/db/dbimpl" |
||||
"github.com/grafana/grafana/pkg/services/store/entity/sqlstash/sqltemplate" |
||||
sqltemplateMocks "github.com/grafana/grafana/pkg/services/store/entity/sqlstash/sqltemplate/mocks" |
||||
"github.com/grafana/grafana/pkg/util/testutil" |
||||
) |
||||
|
||||
// newMockDBNopSQL returns a db.DB and a sqlmock.Sqlmock that doesn't validates
|
||||
// SQL. This is only meant to be used to test wrapping utilities exec, query and
|
||||
// queryRow, where the actual SQL is not relevant to the unit tests, but rather
|
||||
// how the possible derived error conditions handled.
|
||||
func newMockDBNopSQL(t *testing.T) (db.DB, sqlmock.Sqlmock) { |
||||
t.Helper() |
||||
|
||||
db, mock, err := sqlmock.New( |
||||
sqlmock.MonitorPingsOption(true), |
||||
sqlmock.QueryMatcherOption(sqlmock.QueryMatcherFunc( |
||||
func(expectedSQL, actualSQL string) error { |
||||
return nil |
||||
}, |
||||
)), |
||||
) |
||||
|
||||
return newUnitTestDB(t, db, mock, err) |
||||
} |
||||
|
||||
// newMockDBMatchWords returns a db.DB and a sqlmock.Sqlmock that will match SQL
|
||||
// by splitting the expected SQL string into words, and then try to find all of
|
||||
// them in the actual SQL, in the given order, case insensitively. Prepend a
|
||||
// word with a `!` to say that word should not be found.
|
||||
func newMockDBMatchWords(t *testing.T) (db.DB, sqlmock.Sqlmock) { |
||||
t.Helper() |
||||
|
||||
db, mock, err := sqlmock.New( |
||||
sqlmock.MonitorPingsOption(true), |
||||
sqlmock.QueryMatcherOption( |
||||
sqlmock.QueryMatcherFunc(func(expectedSQL, actualSQL string) error { |
||||
actualSQL = strings.ToLower(sqltemplate.FormatSQL(actualSQL)) |
||||
expectedSQL = strings.ToLower(expectedSQL) |
||||
|
||||
var offset int |
||||
for _, vv := range mockDBMatchWordsRE.FindAllStringSubmatch(expectedSQL, -1) { |
||||
v := vv[1] |
||||
|
||||
var shouldNotMatch bool |
||||
if v != "" && v[0] == '!' { |
||||
v = v[1:] |
||||
shouldNotMatch = true |
||||
} |
||||
if v == "" { |
||||
return fmt.Errorf("invalid expected word %q in %q", v, |
||||
expectedSQL) |
||||
} |
||||
|
||||
reWord, err := regexp.Compile(`\b` + regexp.QuoteMeta(v) + `\b`) |
||||
if err != nil { |
||||
return fmt.Errorf("compile word %q from expected SQL: %s", v, |
||||
expectedSQL) |
||||
} |
||||
|
||||
if shouldNotMatch { |
||||
if reWord.MatchString(actualSQL[offset:]) { |
||||
return fmt.Errorf("actual SQL fragent should not cont"+ |
||||
"ain %q but it does\n\tFragment: %s\n\tFull SQL: %s", |
||||
v, actualSQL[offset:], actualSQL) |
||||
} |
||||
} else { |
||||
loc := reWord.FindStringIndex(actualSQL[offset:]) |
||||
if len(loc) == 0 { |
||||
return fmt.Errorf("actual SQL fragment should contain "+ |
||||
"%q but it doesn't\n\tFragment: %s\n\tFull SQL: %s", |
||||
v, actualSQL[offset:], actualSQL) |
||||
} |
||||
offset = loc[1] // advance the offset
|
||||
} |
||||
} |
||||
|
||||
return nil |
||||
}, |
||||
), |
||||
), |
||||
) |
||||
|
||||
return newUnitTestDB(t, db, mock, err) |
||||
} |
||||
|
||||
var mockDBMatchWordsRE = regexp.MustCompile(`(?:\W|\A)(!?\w+)\b`) |
||||
|
||||
func newUnitTestDB(t *testing.T, db *sql.DB, mock sqlmock.Sqlmock, err error) (db.DB, sqlmock.Sqlmock) { |
||||
t.Helper() |
||||
|
||||
require.NoError(t, err) |
||||
|
||||
return dbimpl.NewDB(db, "sqlmock"), mock |
||||
} |
||||
|
||||
// mockResults aids in testing code paths with queries returning large number of
|
||||
// values, like those returning *entity.Entity. This is because we want to
|
||||
// emulate returning the same row columns and row values the same as a real
|
||||
// database would do. This utility the same template SQL that is expected to be
|
||||
// used to help populate all the expected fields.
|
||||
// fileds
|
||||
type mockResults[T any] struct { |
||||
t *testing.T |
||||
tmpl *template.Template |
||||
data sqltemplate.WithResults[T] |
||||
rows *sqlmock.Rows |
||||
} |
||||
|
||||
// newMockResults returns a new *mockResults. If you want to emulate a call
|
||||
// returning zero rows, then immediately call the Row method afterward.
|
||||
func newMockResults[T any](t *testing.T, mock sqlmock.Sqlmock, tmpl *template.Template, data sqltemplate.WithResults[T]) *mockResults[T] { |
||||
t.Helper() |
||||
|
||||
data.Reset() |
||||
err := tmpl.Execute(io.Discard, data) |
||||
require.NoError(t, err) |
||||
rows := mock.NewRows(data.GetColNames()) |
||||
|
||||
return &mockResults[T]{ |
||||
t: t, |
||||
tmpl: tmpl, |
||||
data: data, |
||||
rows: rows, |
||||
} |
||||
} |
||||
|
||||
// AddCurrentData uses the values contained in the `data` argument used during
|
||||
// creation to populate a new expected row. It will access `data` with pointers,
|
||||
// so you should replace the internal values of `data` with freshly allocated
|
||||
// results to return different rows.
|
||||
func (r *mockResults[T]) AddCurrentData() *mockResults[T] { |
||||
r.t.Helper() |
||||
|
||||
r.data.Reset() |
||||
err := r.tmpl.Execute(io.Discard, r.data) |
||||
require.NoError(r.t, err) |
||||
|
||||
d := r.data.GetScanDest() |
||||
dv := make([]driver.Value, len(d)) |
||||
for i, v := range d { |
||||
dv[i] = v |
||||
} |
||||
r.rows.AddRow(dv...) |
||||
|
||||
return r |
||||
} |
||||
|
||||
// Rows returns the *sqlmock.Rows object built.
|
||||
func (r *mockResults[T]) Rows() *sqlmock.Rows { |
||||
return r.rows |
||||
} |
||||
|
||||
func TestPtrOr(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
p := ptrOr[*int]() |
||||
require.NotNil(t, p) |
||||
require.Zero(t, *p) |
||||
|
||||
p = ptrOr[*int](nil, nil, nil, nil, nil, nil) |
||||
require.NotNil(t, p) |
||||
require.Zero(t, *p) |
||||
|
||||
v := 42 |
||||
v2 := 5 |
||||
p = ptrOr(nil, nil, nil, &v, nil, &v2, nil, nil) |
||||
require.NotNil(t, p) |
||||
require.Equal(t, v, *p) |
||||
|
||||
p = ptrOr(nil, nil, nil, &v) |
||||
require.NotNil(t, p) |
||||
require.Equal(t, v, *p) |
||||
} |
||||
|
||||
func TestSliceOr(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
p := sliceOr[[]int]() |
||||
require.NotNil(t, p) |
||||
require.Len(t, p, 0) |
||||
|
||||
p = sliceOr[[]int](nil, nil, nil, nil) |
||||
require.NotNil(t, p) |
||||
require.Len(t, p, 0) |
||||
|
||||
p = sliceOr([]int{}, []int{}, []int{}, []int{}) |
||||
require.NotNil(t, p) |
||||
require.Len(t, p, 0) |
||||
|
||||
v := []int{1, 2} |
||||
p = sliceOr([]int{}, nil, []int{}, v, nil, []int{}, []int{10}, nil) |
||||
require.NotNil(t, p) |
||||
require.Equal(t, v, p) |
||||
|
||||
p = sliceOr([]int{}, nil, []int{}, v) |
||||
require.NotNil(t, p) |
||||
require.Equal(t, v, p) |
||||
} |
||||
|
||||
func TestMapOr(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
p := mapOr[map[string]int]() |
||||
require.NotNil(t, p) |
||||
require.Len(t, p, 0) |
||||
|
||||
p = mapOr(nil, map[string]int(nil), nil, map[string]int{}, nil) |
||||
require.NotNil(t, p) |
||||
require.Len(t, p, 0) |
||||
|
||||
v := map[string]int{"a": 0, "b": 1} |
||||
v2 := map[string]int{"c": 2, "d": 3} |
||||
|
||||
p = mapOr(nil, map[string]int(nil), v, v2, nil, map[string]int{}, nil) |
||||
require.NotNil(t, p) |
||||
require.Equal(t, v, p) |
||||
|
||||
p = mapOr(nil, map[string]int(nil), v) |
||||
require.NotNil(t, p) |
||||
require.Equal(t, v, p) |
||||
} |
||||
|
||||
var ( |
||||
validTestTmpl = template.Must(template.New("test").Parse("nothing special")) |
||||
invalidTestTmpl = template.New("no definition should fail to exec") |
||||
errTest = errors.New("because of reasons") |
||||
) |
||||
|
||||
// expectRows is a testing helper to keep mocks in sync when adding rows to a
|
||||
// mocked SQL result. This is a helper to test `query` and `queryRow`.
|
||||
type expectRows[T any] struct { |
||||
*sqlmock.Rows |
||||
ExpectedResults []T |
||||
|
||||
req *sqltemplateMocks.WithResults[T] |
||||
} |
||||
|
||||
func newReturnsRow[T any](dbmock sqlmock.Sqlmock, req *sqltemplateMocks.WithResults[T]) *expectRows[T] { |
||||
return &expectRows[T]{ |
||||
Rows: dbmock.NewRows(nil), |
||||
req: req, |
||||
} |
||||
} |
||||
|
||||
// Add adds a new value that should be returned by the `query` or `queryRow`
|
||||
// operation.
|
||||
func (r *expectRows[T]) Add(value T, err error) *expectRows[T] { |
||||
r.req.EXPECT().GetScanDest().Return(nil).Once() |
||||
r.req.EXPECT().Results().Return(value, err).Once() |
||||
r.Rows.AddRow() |
||||
r.ExpectedResults = append(r.ExpectedResults, value) |
||||
|
||||
return r |
||||
} |
||||
|
||||
func TestQueryRow(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
t.Run("happy path", func(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
// test declarations
|
||||
ctx := testutil.NewDefaultTestContext(t) |
||||
req := sqltemplateMocks.NewWithResults[int64](t) |
||||
db, dbmock := newMockDBNopSQL(t) |
||||
rows := newReturnsRow(dbmock, req) |
||||
|
||||
// setup expectations
|
||||
req.EXPECT().Validate().Return(nil).Once() |
||||
req.EXPECT().GetArgs().Return(nil).Once() |
||||
rows.Add(1, nil) |
||||
dbmock.ExpectQuery("").WillReturnRows(rows.Rows) |
||||
|
||||
// execute and assert
|
||||
res, err := queryRow(ctx, db, validTestTmpl, req) |
||||
require.NoError(t, err) |
||||
require.Equal(t, rows.ExpectedResults[0], res) |
||||
}) |
||||
|
||||
t.Run("invalid request", func(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
// test declarations
|
||||
ctx := testutil.NewDefaultTestContext(t) |
||||
req := sqltemplateMocks.NewWithResults[int64](t) |
||||
db, _ := newMockDBNopSQL(t) |
||||
|
||||
// setup expectations
|
||||
req.EXPECT().Validate().Return(errTest).Once() |
||||
|
||||
// execute and assert
|
||||
res, err := queryRow(ctx, db, invalidTestTmpl, req) |
||||
require.Zero(t, res) |
||||
require.Error(t, err) |
||||
require.ErrorContains(t, err, "invalid request") |
||||
}) |
||||
|
||||
t.Run("error executing template", func(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
// test declarations
|
||||
ctx := testutil.NewDefaultTestContext(t) |
||||
req := sqltemplateMocks.NewWithResults[int64](t) |
||||
db, _ := newMockDBNopSQL(t) |
||||
|
||||
// setup expectations
|
||||
req.EXPECT().Validate().Return(nil).Once() |
||||
|
||||
// execute and assert
|
||||
res, err := queryRow(ctx, db, invalidTestTmpl, req) |
||||
require.Zero(t, res) |
||||
require.Error(t, err) |
||||
require.ErrorContains(t, err, "execute template") |
||||
}) |
||||
|
||||
t.Run("error executing query", func(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
// test declarations
|
||||
ctx := testutil.NewDefaultTestContext(t) |
||||
req := sqltemplateMocks.NewWithResults[int64](t) |
||||
db, dbmock := newMockDBNopSQL(t) |
||||
|
||||
// setup expectations
|
||||
req.EXPECT().Validate().Return(nil).Once() |
||||
req.EXPECT().GetArgs().Return(nil) |
||||
req.EXPECT().GetScanDest().Return(nil).Maybe() |
||||
dbmock.ExpectQuery("").WillReturnError(errTest) |
||||
|
||||
// execute and assert
|
||||
res, err := queryRow(ctx, db, validTestTmpl, req) |
||||
require.Zero(t, res) |
||||
require.Error(t, err) |
||||
require.ErrorAs(t, err, new(SQLError)) |
||||
}) |
||||
} |
||||
|
||||
// scannerFunc is an adapter for the `scanner` interface.
|
||||
type scannerFunc func(dest ...any) error |
||||
|
||||
func (f scannerFunc) Scan(dest ...any) error { |
||||
return f(dest...) |
||||
} |
||||
|
||||
func TestScanRow(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
const value int64 = 1 |
||||
|
||||
t.Run("happy path", func(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
// test declarations
|
||||
req := sqltemplateMocks.NewWithResults[int64](t) |
||||
sc := scannerFunc(func(dest ...any) error { |
||||
return nil |
||||
}) |
||||
|
||||
// setup expectations
|
||||
req.EXPECT().GetScanDest().Return(nil).Once() |
||||
req.EXPECT().Results().Return(value, nil).Once() |
||||
|
||||
// execute and assert
|
||||
res, err := scanRow(sc, req) |
||||
require.NoError(t, err) |
||||
require.Equal(t, value, res) |
||||
}) |
||||
|
||||
t.Run("scan error", func(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
// test declarations
|
||||
req := sqltemplateMocks.NewWithResults[int64](t) |
||||
sc := scannerFunc(func(dest ...any) error { |
||||
return errTest |
||||
}) |
||||
|
||||
// setup expectations
|
||||
req.EXPECT().GetScanDest().Return(nil).Once() |
||||
|
||||
// execute and assert
|
||||
res, err := scanRow(sc, req) |
||||
require.Zero(t, res) |
||||
require.Error(t, err) |
||||
require.ErrorIs(t, err, errTest) |
||||
}) |
||||
|
||||
t.Run("results error", func(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
// test declarations
|
||||
req := sqltemplateMocks.NewWithResults[int64](t) |
||||
sc := scannerFunc(func(dest ...any) error { |
||||
return nil |
||||
}) |
||||
|
||||
// setup expectations
|
||||
req.EXPECT().GetScanDest().Return(nil).Once() |
||||
req.EXPECT().Results().Return(0, errTest).Once() |
||||
|
||||
// execute and assert
|
||||
res, err := scanRow(sc, req) |
||||
require.Zero(t, res) |
||||
require.Error(t, err) |
||||
require.ErrorIs(t, err, errTest) |
||||
}) |
||||
} |
||||
|
||||
func TestExec(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
t.Run("happy path", func(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
// test declarations
|
||||
ctx := testutil.NewDefaultTestContext(t) |
||||
req := sqltemplateMocks.NewSQLTemplateIface(t) |
||||
db, dbmock := newMockDBNopSQL(t) |
||||
|
||||
// setup expectations
|
||||
req.EXPECT().Validate().Return(nil).Once() |
||||
req.EXPECT().GetArgs().Return(nil).Once() |
||||
dbmock.ExpectExec("").WillReturnResult(sqlmock.NewResult(0, 0)) |
||||
|
||||
// execute and assert
|
||||
res, err := exec(ctx, db, validTestTmpl, req) |
||||
require.NoError(t, err) |
||||
require.NotNil(t, res) |
||||
}) |
||||
|
||||
t.Run("invalid request", func(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
// test declarations
|
||||
ctx := testutil.NewDefaultTestContext(t) |
||||
req := sqltemplateMocks.NewSQLTemplateIface(t) |
||||
db, _ := newMockDBNopSQL(t) |
||||
|
||||
// setup expectations
|
||||
req.EXPECT().Validate().Return(errTest).Once() |
||||
|
||||
// execute and assert
|
||||
res, err := exec(ctx, db, invalidTestTmpl, req) |
||||
require.Nil(t, res) |
||||
require.Error(t, err) |
||||
require.ErrorContains(t, err, "invalid request") |
||||
}) |
||||
|
||||
t.Run("error executing template", func(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
// test declarations
|
||||
ctx := testutil.NewDefaultTestContext(t) |
||||
req := sqltemplateMocks.NewSQLTemplateIface(t) |
||||
db, _ := newMockDBNopSQL(t) |
||||
|
||||
// setup expectations
|
||||
req.EXPECT().Validate().Return(nil).Once() |
||||
|
||||
// execute and assert
|
||||
res, err := exec(ctx, db, invalidTestTmpl, req) |
||||
require.Nil(t, res) |
||||
require.Error(t, err) |
||||
require.ErrorContains(t, err, "execute template") |
||||
}) |
||||
|
||||
t.Run("error executing SQL", func(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
// test declarations
|
||||
ctx := testutil.NewDefaultTestContext(t) |
||||
req := sqltemplateMocks.NewSQLTemplateIface(t) |
||||
db, dbmock := newMockDBNopSQL(t) |
||||
|
||||
// setup expectations
|
||||
req.EXPECT().Validate().Return(nil).Once() |
||||
req.EXPECT().GetArgs().Return(nil) |
||||
dbmock.ExpectExec("").WillReturnError(errTest) |
||||
|
||||
// execute and assert
|
||||
res, err := exec(ctx, db, validTestTmpl, req) |
||||
require.Nil(t, res) |
||||
require.Error(t, err) |
||||
require.ErrorAs(t, err, new(SQLError)) |
||||
}) |
||||
} |
||||
Loading…
Reference in new issue