mirror of https://github.com/grafana/grafana
parent
c1798320d2
commit
8f44e1a349
@ -1 +0,0 @@ |
|||||||
package auth |
|
||||||
@ -0,0 +1,176 @@ |
|||||||
|
package identity |
||||||
|
|
||||||
|
import "fmt" |
||||||
|
|
||||||
|
var _ Requester = &StaticRequester{} |
||||||
|
|
||||||
|
// StaticRequester is helpful for tests
|
||||||
|
// This is mostly copied from:
|
||||||
|
// https://github.com/grafana/grafana/blob/v11.0.0/pkg/services/user/identity.go#L16
|
||||||
|
type StaticRequester struct { |
||||||
|
Namespace Namespace |
||||||
|
UserID int64 |
||||||
|
UserUID string |
||||||
|
OrgID int64 |
||||||
|
OrgName string |
||||||
|
OrgRole RoleType |
||||||
|
Login string |
||||||
|
Name string |
||||||
|
DisplayName string |
||||||
|
Email string |
||||||
|
EmailVerified bool |
||||||
|
AuthID string |
||||||
|
AuthenticatedBy string |
||||||
|
IsGrafanaAdmin bool |
||||||
|
IsAnonymous bool |
||||||
|
IsDisabled bool |
||||||
|
// Permissions grouped by orgID and actions
|
||||||
|
Permissions map[int64]map[string][]string |
||||||
|
IDToken string |
||||||
|
CacheKey string |
||||||
|
} |
||||||
|
|
||||||
|
func (u *StaticRequester) HasRole(role RoleType) bool { |
||||||
|
if u.IsGrafanaAdmin { |
||||||
|
return true |
||||||
|
} |
||||||
|
|
||||||
|
return u.OrgRole.Includes(role) |
||||||
|
} |
||||||
|
|
||||||
|
// GetIsGrafanaAdmin returns true if the user is a server admin
|
||||||
|
func (u *StaticRequester) GetIsGrafanaAdmin() bool { |
||||||
|
return u.IsGrafanaAdmin |
||||||
|
} |
||||||
|
|
||||||
|
// GetLogin returns the login of the active entity
|
||||||
|
// Can be empty if the user is anonymous
|
||||||
|
func (u *StaticRequester) GetLogin() string { |
||||||
|
return u.Login |
||||||
|
} |
||||||
|
|
||||||
|
// GetOrgID returns the ID of the active organization
|
||||||
|
func (u *StaticRequester) GetOrgID() int64 { |
||||||
|
return u.OrgID |
||||||
|
} |
||||||
|
|
||||||
|
// DEPRECATED: GetOrgName returns the name of the active organization
|
||||||
|
// Retrieve the organization name from the organization service instead of using this method.
|
||||||
|
func (u *StaticRequester) GetOrgName() string { |
||||||
|
return u.OrgName |
||||||
|
} |
||||||
|
|
||||||
|
// GetPermissions returns the permissions of the active entity
|
||||||
|
func (u *StaticRequester) GetPermissions() map[string][]string { |
||||||
|
if u.Permissions == nil { |
||||||
|
return make(map[string][]string) |
||||||
|
} |
||||||
|
|
||||||
|
if u.Permissions[u.GetOrgID()] == nil { |
||||||
|
return make(map[string][]string) |
||||||
|
} |
||||||
|
|
||||||
|
return u.Permissions[u.GetOrgID()] |
||||||
|
} |
||||||
|
|
||||||
|
// GetGlobalPermissions returns the permissions of the active entity that are available across all organizations
|
||||||
|
func (u *StaticRequester) GetGlobalPermissions() map[string][]string { |
||||||
|
if u.Permissions == nil { |
||||||
|
return make(map[string][]string) |
||||||
|
} |
||||||
|
|
||||||
|
const globalOrgID = 0 |
||||||
|
|
||||||
|
if u.Permissions[globalOrgID] == nil { |
||||||
|
return make(map[string][]string) |
||||||
|
} |
||||||
|
|
||||||
|
return u.Permissions[globalOrgID] |
||||||
|
} |
||||||
|
|
||||||
|
// DEPRECATED: GetTeams returns the teams the entity is a member of
|
||||||
|
// Retrieve the teams from the team service instead of using this method.
|
||||||
|
func (u *StaticRequester) GetTeams() []int64 { |
||||||
|
return []int64{} // Not implemented
|
||||||
|
} |
||||||
|
|
||||||
|
// GetOrgRole returns the role of the active entity in the active organization
|
||||||
|
func (u *StaticRequester) GetOrgRole() RoleType { |
||||||
|
return u.OrgRole |
||||||
|
} |
||||||
|
|
||||||
|
// HasUniqueId returns true if the entity has a unique id
|
||||||
|
func (u *StaticRequester) HasUniqueId() bool { |
||||||
|
return u.UserID > 0 |
||||||
|
} |
||||||
|
|
||||||
|
// GetID returns namespaced id for the entity
|
||||||
|
func (u *StaticRequester) GetID() NamespaceID { |
||||||
|
return NewNamespaceIDString(u.Namespace, fmt.Sprintf("%d", u.UserID)) |
||||||
|
} |
||||||
|
|
||||||
|
// GetUID returns namespaced uid for the entity
|
||||||
|
func (u *StaticRequester) GetUID() NamespaceID { |
||||||
|
return NewNamespaceIDString(u.Namespace, u.UserUID) |
||||||
|
} |
||||||
|
|
||||||
|
// GetNamespacedID returns the namespace and ID of the active entity
|
||||||
|
// The namespace is one of the constants defined in pkg/apimachinery/identity
|
||||||
|
func (u *StaticRequester) GetNamespacedID() (Namespace, string) { |
||||||
|
return u.Namespace, fmt.Sprintf("%d", u.UserID) |
||||||
|
} |
||||||
|
|
||||||
|
func (u *StaticRequester) GetAuthID() string { |
||||||
|
return u.AuthID |
||||||
|
} |
||||||
|
|
||||||
|
func (u *StaticRequester) GetAuthenticatedBy() string { |
||||||
|
return u.AuthenticatedBy |
||||||
|
} |
||||||
|
|
||||||
|
func (u *StaticRequester) IsAuthenticatedBy(providers ...string) bool { |
||||||
|
for _, p := range providers { |
||||||
|
if u.AuthenticatedBy == p { |
||||||
|
return true |
||||||
|
} |
||||||
|
} |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
// FIXME: remove this method once all services are using an interface
|
||||||
|
func (u *StaticRequester) IsNil() bool { |
||||||
|
return u == nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetEmail returns the email of the active entity
|
||||||
|
// Can be empty.
|
||||||
|
func (u *StaticRequester) GetEmail() string { |
||||||
|
return u.Email |
||||||
|
} |
||||||
|
|
||||||
|
func (u *StaticRequester) IsEmailVerified() bool { |
||||||
|
return u.EmailVerified |
||||||
|
} |
||||||
|
|
||||||
|
func (u *StaticRequester) GetCacheKey() string { |
||||||
|
return u.CacheKey |
||||||
|
} |
||||||
|
|
||||||
|
// GetDisplayName returns the display name of the active entity
|
||||||
|
// The display name is the name if it is set, otherwise the login or email
|
||||||
|
func (u *StaticRequester) GetDisplayName() string { |
||||||
|
if u.DisplayName != "" { |
||||||
|
return u.DisplayName |
||||||
|
} |
||||||
|
if u.Name != "" { |
||||||
|
return u.Name |
||||||
|
} |
||||||
|
if u.Login != "" { |
||||||
|
return u.Login |
||||||
|
} |
||||||
|
return u.Email |
||||||
|
} |
||||||
|
|
||||||
|
func (u *StaticRequester) GetIDToken() string { |
||||||
|
return u.IDToken |
||||||
|
} |
||||||
@ -1,222 +0,0 @@ |
|||||||
package sqlstash |
|
||||||
|
|
||||||
import ( |
|
||||||
"context" |
|
||||||
"database/sql" |
|
||||||
"embed" |
|
||||||
"errors" |
|
||||||
"fmt" |
|
||||||
"strings" |
|
||||||
"text/template" |
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/services/store/entity/db" |
|
||||||
"github.com/grafana/grafana/pkg/services/store/entity/sqlstash/sqltemplate" |
|
||||||
) |
|
||||||
|
|
||||||
// Templates setup.
|
|
||||||
var ( |
|
||||||
//go:embed data/*.sql
|
|
||||||
sqlTemplatesFS embed.FS |
|
||||||
|
|
||||||
// all templates
|
|
||||||
helpers = template.FuncMap{ |
|
||||||
"listSep": helperListSep, |
|
||||||
"join": helperJoin, |
|
||||||
} |
|
||||||
sqlTemplates = template.Must(template.New("sql").Funcs(helpers).ParseFS(sqlTemplatesFS, `data/*.sql`)) |
|
||||||
) |
|
||||||
|
|
||||||
func mustTemplate(filename string) *template.Template { |
|
||||||
if t := sqlTemplates.Lookup(filename); t != nil { |
|
||||||
return t |
|
||||||
} |
|
||||||
panic(fmt.Sprintf("template file not found: %s", filename)) |
|
||||||
} |
|
||||||
|
|
||||||
// Templates.
|
|
||||||
var ( |
|
||||||
sqlResourceVersionGet = mustTemplate("rv_get.sql") |
|
||||||
sqlResourceVersionInc = mustTemplate("rv_inc.sql") |
|
||||||
sqlResourceVersionInsert = mustTemplate("rv_insert.sql") |
|
||||||
sqlResourceVersionLock = mustTemplate("rv_lock.sql") |
|
||||||
|
|
||||||
sqlResourceInsert = mustTemplate("resource_insert.sql") |
|
||||||
sqlResourceGet = mustTemplate("resource_get.sql") |
|
||||||
) |
|
||||||
|
|
||||||
// TxOptions.
|
|
||||||
var ( |
|
||||||
ReadCommitted = &sql.TxOptions{ |
|
||||||
Isolation: sql.LevelReadCommitted, |
|
||||||
} |
|
||||||
ReadCommittedRO = &sql.TxOptions{ |
|
||||||
Isolation: sql.LevelReadCommitted, |
|
||||||
ReadOnly: true, |
|
||||||
} |
|
||||||
) |
|
||||||
|
|
||||||
// SQLError is an error returned by the database, which includes additionally
|
|
||||||
// debugging information about what was sent to the database.
|
|
||||||
type SQLError struct { |
|
||||||
Err error |
|
||||||
CallType string // either Query, QueryRow or Exec
|
|
||||||
TemplateName string |
|
||||||
Query string |
|
||||||
RawQuery string |
|
||||||
ScanDest []any |
|
||||||
|
|
||||||
// potentially regulated information is not exported and only directly
|
|
||||||
// available for local testing and local debugging purposes, making sure it
|
|
||||||
// is never marshaled to JSON or any other serialization.
|
|
||||||
|
|
||||||
arguments []any |
|
||||||
} |
|
||||||
|
|
||||||
func (e SQLError) Unwrap() error { |
|
||||||
return e.Err |
|
||||||
} |
|
||||||
|
|
||||||
func (e SQLError) Error() string { |
|
||||||
return fmt.Sprintf("%s: %s with %d input arguments and %d output "+ |
|
||||||
"destination arguments: %v", e.TemplateName, e.CallType, |
|
||||||
len(e.arguments), len(e.ScanDest), e.Err) |
|
||||||
} |
|
||||||
|
|
||||||
//------------------------------------------------------------------------
|
|
||||||
// Resource Version table support
|
|
||||||
//------------------------------------------------------------------------
|
|
||||||
|
|
||||||
type returnsResourceVersion struct { |
|
||||||
ResourceVersion int64 |
|
||||||
} |
|
||||||
|
|
||||||
func (r *returnsResourceVersion) Results() (*returnsResourceVersion, error) { |
|
||||||
return r, nil |
|
||||||
} |
|
||||||
|
|
||||||
type sqlResourceVersionGetRequest struct { |
|
||||||
*sqltemplate.SQLTemplate |
|
||||||
Group string |
|
||||||
Resource string |
|
||||||
*returnsResourceVersion |
|
||||||
} |
|
||||||
|
|
||||||
func (r sqlResourceVersionGetRequest) Validate() error { |
|
||||||
return nil // TODO
|
|
||||||
} |
|
||||||
|
|
||||||
type sqlResourceVersionLockRequest struct { |
|
||||||
*sqltemplate.SQLTemplate |
|
||||||
Group string |
|
||||||
Resource string |
|
||||||
*returnsResourceVersion |
|
||||||
} |
|
||||||
|
|
||||||
func (r sqlResourceVersionLockRequest) Validate() error { |
|
||||||
return nil // TODO
|
|
||||||
} |
|
||||||
|
|
||||||
type sqlResourceVersionIncRequest struct { |
|
||||||
*sqltemplate.SQLTemplate |
|
||||||
Group string |
|
||||||
Resource string |
|
||||||
ResourceVersion int64 |
|
||||||
} |
|
||||||
|
|
||||||
func (r sqlResourceVersionIncRequest) Validate() error { |
|
||||||
return nil // TODO
|
|
||||||
} |
|
||||||
|
|
||||||
type sqlResourceVersionInsertRequest struct { |
|
||||||
*sqltemplate.SQLTemplate |
|
||||||
Group string |
|
||||||
Resource string |
|
||||||
} |
|
||||||
|
|
||||||
func (r sqlResourceVersionInsertRequest) Validate() error { |
|
||||||
return nil // TODO
|
|
||||||
} |
|
||||||
|
|
||||||
// resourceVersionAtomicInc atomically increases the version of a kind within a
|
|
||||||
// transaction.
|
|
||||||
func resourceVersionAtomicInc(ctx context.Context, x db.ContextExecer, d sqltemplate.Dialect, group, resource string) (newVersion int64, err error) { |
|
||||||
// 1. Lock the kind and get the latest version
|
|
||||||
lockReq := sqlResourceVersionLockRequest{ |
|
||||||
SQLTemplate: sqltemplate.New(d), |
|
||||||
Group: group, |
|
||||||
Resource: resource, |
|
||||||
returnsResourceVersion: new(returnsResourceVersion), |
|
||||||
} |
|
||||||
kindv, err := queryRow(ctx, x, sqlResourceVersionLock, lockReq) |
|
||||||
|
|
||||||
// if there wasn't a row associated with the given kind, we create one with
|
|
||||||
// version 1
|
|
||||||
if errors.Is(err, sql.ErrNoRows) { |
|
||||||
// NOTE: there is a marginal chance that we race with another writer
|
|
||||||
// trying to create the same row. This is only possible when onboarding
|
|
||||||
// a new (Group, Resource) to the cell, which should be very unlikely,
|
|
||||||
// and the workaround is simply retrying. The alternative would be to
|
|
||||||
// use INSERT ... ON CONFLICT DO UPDATE ..., but that creates a
|
|
||||||
// requirement for support in Dialect only for this marginal case, and
|
|
||||||
// we would rather keep Dialect as small as possible. Another
|
|
||||||
// alternative is to simply check if the INSERT returns a DUPLICATE KEY
|
|
||||||
// error and then retry the original SELECT, but that also adds some
|
|
||||||
// complexity to the code. That would be preferrable to changing
|
|
||||||
// Dialect, though. The current alternative, just retrying, seems to be
|
|
||||||
// enough for now.
|
|
||||||
insReq := sqlResourceVersionInsertRequest{ |
|
||||||
SQLTemplate: sqltemplate.New(d), |
|
||||||
Group: group, |
|
||||||
Resource: resource, |
|
||||||
} |
|
||||||
if _, err = exec(ctx, x, sqlResourceVersionInsert, insReq); err != nil { |
|
||||||
return 0, fmt.Errorf("insert into kind_version: %w", err) |
|
||||||
} |
|
||||||
|
|
||||||
return 1, nil |
|
||||||
} |
|
||||||
|
|
||||||
if err != nil { |
|
||||||
return 0, fmt.Errorf("lock kind: %w", err) |
|
||||||
} |
|
||||||
|
|
||||||
incReq := sqlResourceVersionIncRequest{ |
|
||||||
SQLTemplate: sqltemplate.New(d), |
|
||||||
Group: group, |
|
||||||
Resource: resource, |
|
||||||
ResourceVersion: kindv.ResourceVersion, |
|
||||||
} |
|
||||||
if _, err = exec(ctx, x, sqlResourceVersionInc, incReq); err != nil { |
|
||||||
return 0, fmt.Errorf("increase kind version: %w", err) |
|
||||||
} |
|
||||||
|
|
||||||
return kindv.ResourceVersion + 1, nil |
|
||||||
} |
|
||||||
|
|
||||||
// Template helpers.
|
|
||||||
|
|
||||||
// helperListSep is a helper that helps writing simpler loops in SQL templates.
|
|
||||||
// Example usage:
|
|
||||||
//
|
|
||||||
// {{ $comma := listSep ", " }}
|
|
||||||
// {{ range .Values }}
|
|
||||||
// {{/* here we put "-" on each end to remove extra white space */}}
|
|
||||||
// {{- call $comma -}}
|
|
||||||
// {{ .Value }}
|
|
||||||
// {{ end }}
|
|
||||||
func helperListSep(sep string) func() string { |
|
||||||
var addSep bool |
|
||||||
|
|
||||||
return func() string { |
|
||||||
if addSep { |
|
||||||
return sep |
|
||||||
} |
|
||||||
addSep = true |
|
||||||
|
|
||||||
return "" |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func helperJoin(sep string, elems ...string) string { |
|
||||||
return strings.Join(elems, sep) |
|
||||||
} |
|
||||||
@ -1,25 +0,0 @@ |
|||||||
{ |
|
||||||
"apiVersion": "playlist.grafana.app/v0alpha1", |
|
||||||
"kind": "Playlist", |
|
||||||
"metadata": { |
|
||||||
"name": "fdgsv37qslr0ga", |
|
||||||
"namespace": "default", |
|
||||||
"annotations": { |
|
||||||
"grafana.app/originName": "elsewhere", |
|
||||||
"grafana.app/originPath": "path/to/item", |
|
||||||
"grafana.app/originTimestamp": "2024-02-02T00:00:00Z" |
|
||||||
}, |
|
||||||
"creationTimestamp": "2024-03-03T00:00:00Z", |
|
||||||
"uid": "8tGrXJgGbFI0" |
|
||||||
}, |
|
||||||
"spec": { |
|
||||||
"title": "hello", |
|
||||||
"interval": "5m", |
|
||||||
"items": [ |
|
||||||
{ |
|
||||||
"type": "dashboard_by_uid", |
|
||||||
"value": "vmie2cmWz" |
|
||||||
} |
|
||||||
] |
|
||||||
} |
|
||||||
} |
|
||||||
@ -1,503 +0,0 @@ |
|||||||
package sqlstash |
|
||||||
|
|
||||||
import ( |
|
||||||
"database/sql" |
|
||||||
"database/sql/driver" |
|
||||||
"errors" |
|
||||||
"fmt" |
|
||||||
"io" |
|
||||||
"regexp" |
|
||||||
"strings" |
|
||||||
"testing" |
|
||||||
"text/template" |
|
||||||
|
|
||||||
sqlmock "github.com/DATA-DOG/go-sqlmock" |
|
||||||
"github.com/stretchr/testify/require" |
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/services/store/entity/db" |
|
||||||
"github.com/grafana/grafana/pkg/services/store/entity/db/dbimpl" |
|
||||||
"github.com/grafana/grafana/pkg/services/store/entity/sqlstash/sqltemplate" |
|
||||||
sqltemplateMocks "github.com/grafana/grafana/pkg/services/store/entity/sqlstash/sqltemplate/mocks" |
|
||||||
"github.com/grafana/grafana/pkg/util/testutil" |
|
||||||
) |
|
||||||
|
|
||||||
// newMockDBNopSQL returns a db.DB and a sqlmock.Sqlmock that doesn't validates
|
|
||||||
// SQL. This is only meant to be used to test wrapping utilities exec, query and
|
|
||||||
// queryRow, where the actual SQL is not relevant to the unit tests, but rather
|
|
||||||
// how the possible derived error conditions handled.
|
|
||||||
func newMockDBNopSQL(t *testing.T) (db.DB, sqlmock.Sqlmock) { |
|
||||||
t.Helper() |
|
||||||
|
|
||||||
db, mock, err := sqlmock.New( |
|
||||||
sqlmock.MonitorPingsOption(true), |
|
||||||
sqlmock.QueryMatcherOption(sqlmock.QueryMatcherFunc( |
|
||||||
func(expectedSQL, actualSQL string) error { |
|
||||||
return nil |
|
||||||
}, |
|
||||||
)), |
|
||||||
) |
|
||||||
|
|
||||||
return newUnitTestDB(t, db, mock, err) |
|
||||||
} |
|
||||||
|
|
||||||
// newMockDBMatchWords returns a db.DB and a sqlmock.Sqlmock that will match SQL
|
|
||||||
// by splitting the expected SQL string into words, and then try to find all of
|
|
||||||
// them in the actual SQL, in the given order, case insensitively. Prepend a
|
|
||||||
// word with a `!` to say that word should not be found.
|
|
||||||
func newMockDBMatchWords(t *testing.T) (db.DB, sqlmock.Sqlmock) { |
|
||||||
t.Helper() |
|
||||||
|
|
||||||
db, mock, err := sqlmock.New( |
|
||||||
sqlmock.MonitorPingsOption(true), |
|
||||||
sqlmock.QueryMatcherOption( |
|
||||||
sqlmock.QueryMatcherFunc(func(expectedSQL, actualSQL string) error { |
|
||||||
actualSQL = strings.ToLower(sqltemplate.FormatSQL(actualSQL)) |
|
||||||
expectedSQL = strings.ToLower(expectedSQL) |
|
||||||
|
|
||||||
var offset int |
|
||||||
for _, vv := range mockDBMatchWordsRE.FindAllStringSubmatch(expectedSQL, -1) { |
|
||||||
v := vv[1] |
|
||||||
|
|
||||||
var shouldNotMatch bool |
|
||||||
if v != "" && v[0] == '!' { |
|
||||||
v = v[1:] |
|
||||||
shouldNotMatch = true |
|
||||||
} |
|
||||||
if v == "" { |
|
||||||
return fmt.Errorf("invalid expected word %q in %q", v, |
|
||||||
expectedSQL) |
|
||||||
} |
|
||||||
|
|
||||||
reWord, err := regexp.Compile(`\b` + regexp.QuoteMeta(v) + `\b`) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("compile word %q from expected SQL: %s", v, |
|
||||||
expectedSQL) |
|
||||||
} |
|
||||||
|
|
||||||
if shouldNotMatch { |
|
||||||
if reWord.MatchString(actualSQL[offset:]) { |
|
||||||
return fmt.Errorf("actual SQL fragent should not cont"+ |
|
||||||
"ain %q but it does\n\tFragment: %s\n\tFull SQL: %s", |
|
||||||
v, actualSQL[offset:], actualSQL) |
|
||||||
} |
|
||||||
} else { |
|
||||||
loc := reWord.FindStringIndex(actualSQL[offset:]) |
|
||||||
if len(loc) == 0 { |
|
||||||
return fmt.Errorf("actual SQL fragment should contain "+ |
|
||||||
"%q but it doesn't\n\tFragment: %s\n\tFull SQL: %s", |
|
||||||
v, actualSQL[offset:], actualSQL) |
|
||||||
} |
|
||||||
offset = loc[1] // advance the offset
|
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return nil |
|
||||||
}, |
|
||||||
), |
|
||||||
), |
|
||||||
) |
|
||||||
|
|
||||||
return newUnitTestDB(t, db, mock, err) |
|
||||||
} |
|
||||||
|
|
||||||
var mockDBMatchWordsRE = regexp.MustCompile(`(?:\W|\A)(!?\w+)\b`) |
|
||||||
|
|
||||||
func newUnitTestDB(t *testing.T, db *sql.DB, mock sqlmock.Sqlmock, err error) (db.DB, sqlmock.Sqlmock) { |
|
||||||
t.Helper() |
|
||||||
|
|
||||||
require.NoError(t, err) |
|
||||||
|
|
||||||
return dbimpl.NewDB(db, "sqlmock"), mock |
|
||||||
} |
|
||||||
|
|
||||||
// mockResults aids in testing code paths with queries returning large number of
|
|
||||||
// values, like those returning *entity.Entity. This is because we want to
|
|
||||||
// emulate returning the same row columns and row values the same as a real
|
|
||||||
// database would do. This utility the same template SQL that is expected to be
|
|
||||||
// used to help populate all the expected fields.
|
|
||||||
// fileds
|
|
||||||
type mockResults[T any] struct { |
|
||||||
t *testing.T |
|
||||||
tmpl *template.Template |
|
||||||
data sqltemplate.WithResults[T] |
|
||||||
rows *sqlmock.Rows |
|
||||||
} |
|
||||||
|
|
||||||
// newMockResults returns a new *mockResults. If you want to emulate a call
|
|
||||||
// returning zero rows, then immediately call the Row method afterward.
|
|
||||||
func newMockResults[T any](t *testing.T, mock sqlmock.Sqlmock, tmpl *template.Template, data sqltemplate.WithResults[T]) *mockResults[T] { |
|
||||||
t.Helper() |
|
||||||
|
|
||||||
data.Reset() |
|
||||||
err := tmpl.Execute(io.Discard, data) |
|
||||||
require.NoError(t, err) |
|
||||||
rows := mock.NewRows(data.GetColNames()) |
|
||||||
|
|
||||||
return &mockResults[T]{ |
|
||||||
t: t, |
|
||||||
tmpl: tmpl, |
|
||||||
data: data, |
|
||||||
rows: rows, |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// AddCurrentData uses the values contained in the `data` argument used during
|
|
||||||
// creation to populate a new expected row. It will access `data` with pointers,
|
|
||||||
// so you should replace the internal values of `data` with freshly allocated
|
|
||||||
// results to return different rows.
|
|
||||||
func (r *mockResults[T]) AddCurrentData() *mockResults[T] { |
|
||||||
r.t.Helper() |
|
||||||
|
|
||||||
r.data.Reset() |
|
||||||
err := r.tmpl.Execute(io.Discard, r.data) |
|
||||||
require.NoError(r.t, err) |
|
||||||
|
|
||||||
d := r.data.GetScanDest() |
|
||||||
dv := make([]driver.Value, len(d)) |
|
||||||
for i, v := range d { |
|
||||||
dv[i] = v |
|
||||||
} |
|
||||||
r.rows.AddRow(dv...) |
|
||||||
|
|
||||||
return r |
|
||||||
} |
|
||||||
|
|
||||||
// Rows returns the *sqlmock.Rows object built.
|
|
||||||
func (r *mockResults[T]) Rows() *sqlmock.Rows { |
|
||||||
return r.rows |
|
||||||
} |
|
||||||
|
|
||||||
func TestPtrOr(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
p := ptrOr[*int]() |
|
||||||
require.NotNil(t, p) |
|
||||||
require.Zero(t, *p) |
|
||||||
|
|
||||||
p = ptrOr[*int](nil, nil, nil, nil, nil, nil) |
|
||||||
require.NotNil(t, p) |
|
||||||
require.Zero(t, *p) |
|
||||||
|
|
||||||
v := 42 |
|
||||||
v2 := 5 |
|
||||||
p = ptrOr(nil, nil, nil, &v, nil, &v2, nil, nil) |
|
||||||
require.NotNil(t, p) |
|
||||||
require.Equal(t, v, *p) |
|
||||||
|
|
||||||
p = ptrOr(nil, nil, nil, &v) |
|
||||||
require.NotNil(t, p) |
|
||||||
require.Equal(t, v, *p) |
|
||||||
} |
|
||||||
|
|
||||||
func TestSliceOr(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
p := sliceOr[[]int]() |
|
||||||
require.NotNil(t, p) |
|
||||||
require.Len(t, p, 0) |
|
||||||
|
|
||||||
p = sliceOr[[]int](nil, nil, nil, nil) |
|
||||||
require.NotNil(t, p) |
|
||||||
require.Len(t, p, 0) |
|
||||||
|
|
||||||
p = sliceOr([]int{}, []int{}, []int{}, []int{}) |
|
||||||
require.NotNil(t, p) |
|
||||||
require.Len(t, p, 0) |
|
||||||
|
|
||||||
v := []int{1, 2} |
|
||||||
p = sliceOr([]int{}, nil, []int{}, v, nil, []int{}, []int{10}, nil) |
|
||||||
require.NotNil(t, p) |
|
||||||
require.Equal(t, v, p) |
|
||||||
|
|
||||||
p = sliceOr([]int{}, nil, []int{}, v) |
|
||||||
require.NotNil(t, p) |
|
||||||
require.Equal(t, v, p) |
|
||||||
} |
|
||||||
|
|
||||||
func TestMapOr(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
p := mapOr[map[string]int]() |
|
||||||
require.NotNil(t, p) |
|
||||||
require.Len(t, p, 0) |
|
||||||
|
|
||||||
p = mapOr(nil, map[string]int(nil), nil, map[string]int{}, nil) |
|
||||||
require.NotNil(t, p) |
|
||||||
require.Len(t, p, 0) |
|
||||||
|
|
||||||
v := map[string]int{"a": 0, "b": 1} |
|
||||||
v2 := map[string]int{"c": 2, "d": 3} |
|
||||||
|
|
||||||
p = mapOr(nil, map[string]int(nil), v, v2, nil, map[string]int{}, nil) |
|
||||||
require.NotNil(t, p) |
|
||||||
require.Equal(t, v, p) |
|
||||||
|
|
||||||
p = mapOr(nil, map[string]int(nil), v) |
|
||||||
require.NotNil(t, p) |
|
||||||
require.Equal(t, v, p) |
|
||||||
} |
|
||||||
|
|
||||||
var ( |
|
||||||
validTestTmpl = template.Must(template.New("test").Parse("nothing special")) |
|
||||||
invalidTestTmpl = template.New("no definition should fail to exec") |
|
||||||
errTest = errors.New("because of reasons") |
|
||||||
) |
|
||||||
|
|
||||||
// expectRows is a testing helper to keep mocks in sync when adding rows to a
|
|
||||||
// mocked SQL result. This is a helper to test `query` and `queryRow`.
|
|
||||||
type expectRows[T any] struct { |
|
||||||
*sqlmock.Rows |
|
||||||
ExpectedResults []T |
|
||||||
|
|
||||||
req *sqltemplateMocks.WithResults[T] |
|
||||||
} |
|
||||||
|
|
||||||
func newReturnsRow[T any](dbmock sqlmock.Sqlmock, req *sqltemplateMocks.WithResults[T]) *expectRows[T] { |
|
||||||
return &expectRows[T]{ |
|
||||||
Rows: dbmock.NewRows(nil), |
|
||||||
req: req, |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// Add adds a new value that should be returned by the `query` or `queryRow`
|
|
||||||
// operation.
|
|
||||||
func (r *expectRows[T]) Add(value T, err error) *expectRows[T] { |
|
||||||
r.req.EXPECT().GetScanDest().Return(nil).Once() |
|
||||||
r.req.EXPECT().Results().Return(value, err).Once() |
|
||||||
r.Rows.AddRow() |
|
||||||
r.ExpectedResults = append(r.ExpectedResults, value) |
|
||||||
|
|
||||||
return r |
|
||||||
} |
|
||||||
|
|
||||||
func TestQueryRow(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
t.Run("happy path", func(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
// test declarations
|
|
||||||
ctx := testutil.NewDefaultTestContext(t) |
|
||||||
req := sqltemplateMocks.NewWithResults[int64](t) |
|
||||||
db, dbmock := newMockDBNopSQL(t) |
|
||||||
rows := newReturnsRow(dbmock, req) |
|
||||||
|
|
||||||
// setup expectations
|
|
||||||
req.EXPECT().Validate().Return(nil).Once() |
|
||||||
req.EXPECT().GetArgs().Return(nil).Once() |
|
||||||
rows.Add(1, nil) |
|
||||||
dbmock.ExpectQuery("").WillReturnRows(rows.Rows) |
|
||||||
|
|
||||||
// execute and assert
|
|
||||||
res, err := queryRow(ctx, db, validTestTmpl, req) |
|
||||||
require.NoError(t, err) |
|
||||||
require.Equal(t, rows.ExpectedResults[0], res) |
|
||||||
}) |
|
||||||
|
|
||||||
t.Run("invalid request", func(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
// test declarations
|
|
||||||
ctx := testutil.NewDefaultTestContext(t) |
|
||||||
req := sqltemplateMocks.NewWithResults[int64](t) |
|
||||||
db, _ := newMockDBNopSQL(t) |
|
||||||
|
|
||||||
// setup expectations
|
|
||||||
req.EXPECT().Validate().Return(errTest).Once() |
|
||||||
|
|
||||||
// execute and assert
|
|
||||||
res, err := queryRow(ctx, db, invalidTestTmpl, req) |
|
||||||
require.Zero(t, res) |
|
||||||
require.Error(t, err) |
|
||||||
require.ErrorContains(t, err, "invalid request") |
|
||||||
}) |
|
||||||
|
|
||||||
t.Run("error executing template", func(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
// test declarations
|
|
||||||
ctx := testutil.NewDefaultTestContext(t) |
|
||||||
req := sqltemplateMocks.NewWithResults[int64](t) |
|
||||||
db, _ := newMockDBNopSQL(t) |
|
||||||
|
|
||||||
// setup expectations
|
|
||||||
req.EXPECT().Validate().Return(nil).Once() |
|
||||||
|
|
||||||
// execute and assert
|
|
||||||
res, err := queryRow(ctx, db, invalidTestTmpl, req) |
|
||||||
require.Zero(t, res) |
|
||||||
require.Error(t, err) |
|
||||||
require.ErrorContains(t, err, "execute template") |
|
||||||
}) |
|
||||||
|
|
||||||
t.Run("error executing query", func(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
// test declarations
|
|
||||||
ctx := testutil.NewDefaultTestContext(t) |
|
||||||
req := sqltemplateMocks.NewWithResults[int64](t) |
|
||||||
db, dbmock := newMockDBNopSQL(t) |
|
||||||
|
|
||||||
// setup expectations
|
|
||||||
req.EXPECT().Validate().Return(nil).Once() |
|
||||||
req.EXPECT().GetArgs().Return(nil) |
|
||||||
req.EXPECT().GetScanDest().Return(nil).Maybe() |
|
||||||
dbmock.ExpectQuery("").WillReturnError(errTest) |
|
||||||
|
|
||||||
// execute and assert
|
|
||||||
res, err := queryRow(ctx, db, validTestTmpl, req) |
|
||||||
require.Zero(t, res) |
|
||||||
require.Error(t, err) |
|
||||||
require.ErrorAs(t, err, new(SQLError)) |
|
||||||
}) |
|
||||||
} |
|
||||||
|
|
||||||
// scannerFunc is an adapter for the `scanner` interface.
|
|
||||||
type scannerFunc func(dest ...any) error |
|
||||||
|
|
||||||
func (f scannerFunc) Scan(dest ...any) error { |
|
||||||
return f(dest...) |
|
||||||
} |
|
||||||
|
|
||||||
func TestScanRow(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
const value int64 = 1 |
|
||||||
|
|
||||||
t.Run("happy path", func(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
// test declarations
|
|
||||||
req := sqltemplateMocks.NewWithResults[int64](t) |
|
||||||
sc := scannerFunc(func(dest ...any) error { |
|
||||||
return nil |
|
||||||
}) |
|
||||||
|
|
||||||
// setup expectations
|
|
||||||
req.EXPECT().GetScanDest().Return(nil).Once() |
|
||||||
req.EXPECT().Results().Return(value, nil).Once() |
|
||||||
|
|
||||||
// execute and assert
|
|
||||||
res, err := scanRow(sc, req) |
|
||||||
require.NoError(t, err) |
|
||||||
require.Equal(t, value, res) |
|
||||||
}) |
|
||||||
|
|
||||||
t.Run("scan error", func(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
// test declarations
|
|
||||||
req := sqltemplateMocks.NewWithResults[int64](t) |
|
||||||
sc := scannerFunc(func(dest ...any) error { |
|
||||||
return errTest |
|
||||||
}) |
|
||||||
|
|
||||||
// setup expectations
|
|
||||||
req.EXPECT().GetScanDest().Return(nil).Once() |
|
||||||
|
|
||||||
// execute and assert
|
|
||||||
res, err := scanRow(sc, req) |
|
||||||
require.Zero(t, res) |
|
||||||
require.Error(t, err) |
|
||||||
require.ErrorIs(t, err, errTest) |
|
||||||
}) |
|
||||||
|
|
||||||
t.Run("results error", func(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
// test declarations
|
|
||||||
req := sqltemplateMocks.NewWithResults[int64](t) |
|
||||||
sc := scannerFunc(func(dest ...any) error { |
|
||||||
return nil |
|
||||||
}) |
|
||||||
|
|
||||||
// setup expectations
|
|
||||||
req.EXPECT().GetScanDest().Return(nil).Once() |
|
||||||
req.EXPECT().Results().Return(0, errTest).Once() |
|
||||||
|
|
||||||
// execute and assert
|
|
||||||
res, err := scanRow(sc, req) |
|
||||||
require.Zero(t, res) |
|
||||||
require.Error(t, err) |
|
||||||
require.ErrorIs(t, err, errTest) |
|
||||||
}) |
|
||||||
} |
|
||||||
|
|
||||||
func TestExec(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
t.Run("happy path", func(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
// test declarations
|
|
||||||
ctx := testutil.NewDefaultTestContext(t) |
|
||||||
req := sqltemplateMocks.NewSQLTemplateIface(t) |
|
||||||
db, dbmock := newMockDBNopSQL(t) |
|
||||||
|
|
||||||
// setup expectations
|
|
||||||
req.EXPECT().Validate().Return(nil).Once() |
|
||||||
req.EXPECT().GetArgs().Return(nil).Once() |
|
||||||
dbmock.ExpectExec("").WillReturnResult(sqlmock.NewResult(0, 0)) |
|
||||||
|
|
||||||
// execute and assert
|
|
||||||
res, err := exec(ctx, db, validTestTmpl, req) |
|
||||||
require.NoError(t, err) |
|
||||||
require.NotNil(t, res) |
|
||||||
}) |
|
||||||
|
|
||||||
t.Run("invalid request", func(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
// test declarations
|
|
||||||
ctx := testutil.NewDefaultTestContext(t) |
|
||||||
req := sqltemplateMocks.NewSQLTemplateIface(t) |
|
||||||
db, _ := newMockDBNopSQL(t) |
|
||||||
|
|
||||||
// setup expectations
|
|
||||||
req.EXPECT().Validate().Return(errTest).Once() |
|
||||||
|
|
||||||
// execute and assert
|
|
||||||
res, err := exec(ctx, db, invalidTestTmpl, req) |
|
||||||
require.Nil(t, res) |
|
||||||
require.Error(t, err) |
|
||||||
require.ErrorContains(t, err, "invalid request") |
|
||||||
}) |
|
||||||
|
|
||||||
t.Run("error executing template", func(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
// test declarations
|
|
||||||
ctx := testutil.NewDefaultTestContext(t) |
|
||||||
req := sqltemplateMocks.NewSQLTemplateIface(t) |
|
||||||
db, _ := newMockDBNopSQL(t) |
|
||||||
|
|
||||||
// setup expectations
|
|
||||||
req.EXPECT().Validate().Return(nil).Once() |
|
||||||
|
|
||||||
// execute and assert
|
|
||||||
res, err := exec(ctx, db, invalidTestTmpl, req) |
|
||||||
require.Nil(t, res) |
|
||||||
require.Error(t, err) |
|
||||||
require.ErrorContains(t, err, "execute template") |
|
||||||
}) |
|
||||||
|
|
||||||
t.Run("error executing SQL", func(t *testing.T) { |
|
||||||
t.Parallel() |
|
||||||
|
|
||||||
// test declarations
|
|
||||||
ctx := testutil.NewDefaultTestContext(t) |
|
||||||
req := sqltemplateMocks.NewSQLTemplateIface(t) |
|
||||||
db, dbmock := newMockDBNopSQL(t) |
|
||||||
|
|
||||||
// setup expectations
|
|
||||||
req.EXPECT().Validate().Return(nil).Once() |
|
||||||
req.EXPECT().GetArgs().Return(nil) |
|
||||||
dbmock.ExpectExec("").WillReturnError(errTest) |
|
||||||
|
|
||||||
// execute and assert
|
|
||||||
res, err := exec(ctx, db, validTestTmpl, req) |
|
||||||
require.Nil(t, res) |
|
||||||
require.Error(t, err) |
|
||||||
require.ErrorAs(t, err, new(SQLError)) |
|
||||||
}) |
|
||||||
} |
|
||||||
Loading…
Reference in new issue