accesscontrol service read replica (#89963)

* accesscontrol service read replica
* now using the ReplDB interface
* ReadReplica for GetUser
pull/90212/head
Kristin Laemmert 12 months ago committed by GitHub
parent e9876749d4
commit 77a4869fca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 20
      pkg/api/folder_bench_test.go
  2. 14
      pkg/cmd/grafana-cli/commands/conflict_user_command.go
  3. 2
      pkg/server/wire.go
  4. 4
      pkg/services/accesscontrol/acimpl/service.go
  5. 2
      pkg/services/accesscontrol/acimpl/service_bench_test.go
  6. 4
      pkg/services/accesscontrol/acimpl/service_test.go
  7. 20
      pkg/services/accesscontrol/database/database.go
  8. 4
      pkg/services/accesscontrol/database/database_test.go
  9. 4
      pkg/services/accesscontrol/database/externalservices.go
  10. 17
      pkg/services/accesscontrol/database/externalservices_test.go
  11. 6
      pkg/services/accesscontrol/migrator/migrator.go
  12. 2
      pkg/services/accesscontrol/migrator/migrator_bench_test.go
  13. 2
      pkg/services/accesscontrol/migrator/migrator_test.go
  14. 12
      pkg/services/sqlstore/replstore.go

@ -69,7 +69,7 @@ const (
) )
type benchScenario struct { type benchScenario struct {
db db.DB db db.ReplDB
// signedInUser is the user that is signed in to the server // signedInUser is the user that is signed in to the server
cfg *setting.Cfg cfg *setting.Cfg
signedInUser *user.SignedInUser signedInUser *user.SignedInUser
@ -202,7 +202,7 @@ func BenchmarkFolderListAndSearch(b *testing.B) {
func setupDB(b testing.TB) benchScenario { func setupDB(b testing.TB) benchScenario {
b.Helper() b.Helper()
db, cfg := sqlstore.InitTestDB(b) db, cfg := sqlstore.InitTestReplDB(b)
IDs := map[int64]struct{}{} IDs := map[int64]struct{}{}
opts := sqlstore.NativeSettingsForDialect(db.GetDialect()) opts := sqlstore.NativeSettingsForDialect(db.GetDialect())
@ -451,26 +451,26 @@ func setupServer(b testing.TB, sc benchScenario, features featuremgmt.FeatureTog
quotaSrv := quotatest.New(false, nil) quotaSrv := quotatest.New(false, nil)
dashStore, err := database.ProvideDashboardStore(sc.db, sc.cfg, features, tagimpl.ProvideService(sc.db), quotaSrv) dashStore, err := database.ProvideDashboardStore(sc.db.DB(), sc.cfg, features, tagimpl.ProvideService(sc.db.DB()), quotaSrv)
require.NoError(b, err) require.NoError(b, err)
folderStore := folderimpl.ProvideDashboardFolderStore(sc.db) folderStore := folderimpl.ProvideDashboardFolderStore(sc.db.DB())
ac := acimpl.ProvideAccessControl(featuremgmt.WithFeatures(), zanzana.NewNoopClient()) ac := acimpl.ProvideAccessControl(featuremgmt.WithFeatures(), zanzana.NewNoopClient())
folderServiceWithFlagOn := folderimpl.ProvideService(ac, bus.ProvideBus(tracing.InitializeTracerForTest()), dashStore, folderStore, sc.db, features, supportbundlestest.NewFakeBundleService(), nil) folderServiceWithFlagOn := folderimpl.ProvideService(ac, bus.ProvideBus(tracing.InitializeTracerForTest()), dashStore, folderStore, sc.db.DB(), features, supportbundlestest.NewFakeBundleService(), nil)
cfg := setting.NewCfg() cfg := setting.NewCfg()
actionSets := resourcepermissions.NewActionSetService() actionSets := resourcepermissions.NewActionSetService()
acSvc := acimpl.ProvideOSSService( acSvc := acimpl.ProvideOSSService(
sc.cfg, acdb.ProvideService(sc.db), actionSets, localcache.ProvideService(), sc.cfg, acdb.ProvideService(sc.db), actionSets, localcache.ProvideService(),
features, tracing.InitializeTracerForTest(), zanzana.NewNoopClient(), sc.db, features, tracing.InitializeTracerForTest(), zanzana.NewNoopClient(), sc.db.DB(),
) )
folderPermissions, err := ossaccesscontrol.ProvideFolderPermissions( folderPermissions, err := ossaccesscontrol.ProvideFolderPermissions(
cfg, features, routing.NewRouteRegister(), sc.db, ac, license, &dashboards.FakeDashboardStore{}, folderServiceWithFlagOn, acSvc, sc.teamSvc, sc.userSvc, actionSets) cfg, features, routing.NewRouteRegister(), sc.db.DB(), ac, license, &dashboards.FakeDashboardStore{}, folderServiceWithFlagOn, acSvc, sc.teamSvc, sc.userSvc, actionSets)
require.NoError(b, err) require.NoError(b, err)
dashboardPermissions, err := ossaccesscontrol.ProvideDashboardPermissions( dashboardPermissions, err := ossaccesscontrol.ProvideDashboardPermissions(
cfg, features, routing.NewRouteRegister(), sc.db, ac, license, &dashboards.FakeDashboardStore{}, folderServiceWithFlagOn, acSvc, sc.teamSvc, sc.userSvc, actionSets) cfg, features, routing.NewRouteRegister(), sc.db.DB(), ac, license, &dashboards.FakeDashboardStore{}, folderServiceWithFlagOn, acSvc, sc.teamSvc, sc.userSvc, actionSets)
require.NoError(b, err) require.NoError(b, err)
dashboardSvc, err := dashboardservice.ProvideDashboardServiceImpl( dashboardSvc, err := dashboardservice.ProvideDashboardServiceImpl(
@ -486,10 +486,10 @@ func setupServer(b testing.TB, sc benchScenario, features featuremgmt.FeatureTog
hs := &HTTPServer{ hs := &HTTPServer{
CacheService: localcache.New(5*time.Minute, 10*time.Minute), CacheService: localcache.New(5*time.Minute, 10*time.Minute),
Cfg: sc.cfg, Cfg: sc.cfg,
SQLStore: sc.db, SQLStore: sc.db.DB(),
Features: features, Features: features,
QuotaService: quotaSrv, QuotaService: quotaSrv,
SearchService: search.ProvideService(sc.cfg, sc.db, starSvc, dashboardSvc), SearchService: search.ProvideService(sc.cfg, sc.db.DB(), starSvc, dashboardSvc),
folderService: folderServiceWithFlagOn, folderService: folderServiceWithFlagOn,
DashboardService: dashboardSvc, DashboardService: dashboardSvc,
} }

@ -70,7 +70,7 @@ func initializeConflictResolver(cmd *utils.ContextCommandLine, f Formatter, ctx
if err != nil { if err != nil {
return nil, fmt.Errorf("%v: %w", "failed to load configuration", err) return nil, fmt.Errorf("%v: %w", "failed to load configuration", err)
} }
s, err := getSqlStore(cfg, tracer, features) s, replstore, err := getSqlStore(cfg, tracer, features)
if err != nil { if err != nil {
return nil, fmt.Errorf("%v: %w", "failed to get to sql", err) return nil, fmt.Errorf("%v: %w", "failed to get to sql", err)
} }
@ -90,7 +90,7 @@ func initializeConflictResolver(cmd *utils.ContextCommandLine, f Formatter, ctx
if err != nil { if err != nil {
return nil, fmt.Errorf("%v: %w", "failed to initialize tracer service", err) return nil, fmt.Errorf("%v: %w", "failed to initialize tracer service", err)
} }
acService, err := acimpl.ProvideService(cfg, s, routing, nil, nil, nil, features, tracer, zanzana.NewNoopClient()) acService, err := acimpl.ProvideService(cfg, replstore, routing, nil, nil, nil, features, tracer, zanzana.NewNoopClient())
if err != nil { if err != nil {
return nil, fmt.Errorf("%v: %w", "failed to get access control", err) return nil, fmt.Errorf("%v: %w", "failed to get access control", err)
} }
@ -99,9 +99,15 @@ func initializeConflictResolver(cmd *utils.ContextCommandLine, f Formatter, ctx
return &resolver, nil return &resolver, nil
} }
func getSqlStore(cfg *setting.Cfg, tracer tracing.Tracer, features featuremgmt.FeatureToggles) (*sqlstore.SQLStore, error) { func getSqlStore(cfg *setting.Cfg, tracer tracing.Tracer, features featuremgmt.FeatureToggles) (*sqlstore.SQLStore, *sqlstore.ReplStore, error) {
bus := bus.ProvideBus(tracer) bus := bus.ProvideBus(tracer)
return sqlstore.ProvideService(cfg, features, &migrations.OSSMigrations{}, bus, tracer) ss, err := sqlstore.ProvideService(cfg, features, &migrations.OSSMigrations{}, bus, tracer)
if err != nil {
return nil, nil, err
}
replStore, err := sqlstore.ProvideServiceWithReadReplica(ss, cfg, features, &migrations.OSSMigrations{}, bus, tracer)
return ss, replStore, err
} }
func runListConflictUsers() func(context *cli.Context) error { func runListConflictUsers() func(context *cli.Context) error {

@ -396,6 +396,7 @@ var wireSet = wire.NewSet(
wire.Bind(new(notifications.WebhookSender), new(*notifications.NotificationService)), wire.Bind(new(notifications.WebhookSender), new(*notifications.NotificationService)),
wire.Bind(new(notifications.EmailSender), new(*notifications.NotificationService)), wire.Bind(new(notifications.EmailSender), new(*notifications.NotificationService)),
wire.Bind(new(db.DB), new(*sqlstore.SQLStore)), wire.Bind(new(db.DB), new(*sqlstore.SQLStore)),
wire.Bind(new(db.ReplDB), new(*sqlstore.ReplStore)),
prefimpl.ProvideService, prefimpl.ProvideService,
oauthtoken.ProvideService, oauthtoken.ProvideService,
wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtoken.Service)), wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtoken.Service)),
@ -412,6 +413,7 @@ var wireCLISet = wire.NewSet(
wire.Bind(new(notifications.WebhookSender), new(*notifications.NotificationService)), wire.Bind(new(notifications.WebhookSender), new(*notifications.NotificationService)),
wire.Bind(new(notifications.EmailSender), new(*notifications.NotificationService)), wire.Bind(new(notifications.EmailSender), new(*notifications.NotificationService)),
wire.Bind(new(db.DB), new(*sqlstore.SQLStore)), wire.Bind(new(db.DB), new(*sqlstore.SQLStore)),
wire.Bind(new(db.ReplDB), new(*sqlstore.ReplStore)),
prefimpl.ProvideService, prefimpl.ProvideService,
oauthtoken.ProvideService, oauthtoken.ProvideService,
wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtoken.Service)), wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtoken.Service)),

@ -48,11 +48,11 @@ var SharedWithMeFolderPermission = accesscontrol.Permission{
var OSSRolesPrefixes = []string{accesscontrol.ManagedRolePrefix, accesscontrol.ExternalServiceRolePrefix} var OSSRolesPrefixes = []string{accesscontrol.ManagedRolePrefix, accesscontrol.ExternalServiceRolePrefix}
func ProvideService( func ProvideService(
cfg *setting.Cfg, db db.DB, routeRegister routing.RouteRegister, cache *localcache.CacheService, cfg *setting.Cfg, db db.ReplDB, routeRegister routing.RouteRegister, cache *localcache.CacheService,
accessControl accesscontrol.AccessControl, actionResolver accesscontrol.ActionResolver, accessControl accesscontrol.AccessControl, actionResolver accesscontrol.ActionResolver,
features featuremgmt.FeatureToggles, tracer tracing.Tracer, zclient zanzana.Client, features featuremgmt.FeatureToggles, tracer tracing.Tracer, zclient zanzana.Client,
) (*Service, error) { ) (*Service, error) {
service := ProvideOSSService(cfg, database.ProvideService(db), actionResolver, cache, features, tracer, zclient, db) service := ProvideOSSService(cfg, database.ProvideService(db), actionResolver, cache, features, tracer, zclient, db.DB())
api.NewAccessControlAPI(routeRegister, accessControl, service, features).RegisterAPIEndpoints() api.NewAccessControlAPI(routeRegister, accessControl, service, features).RegisterAPIEndpoints()
if err := accesscontrol.DeclareFixedRoles(service, cfg); err != nil { if err := accesscontrol.DeclareFixedRoles(service, cfg); err != nil {

@ -25,7 +25,7 @@ import (
// - each managed role will have 3 permissions {"resources:action2", "resources:id:x"} where x belongs to [1, 3] // - each managed role will have 3 permissions {"resources:action2", "resources:id:x"} where x belongs to [1, 3]
func setupBenchEnv(b *testing.B, usersCount, resourceCount int) (accesscontrol.Service, *user.SignedInUser) { func setupBenchEnv(b *testing.B, usersCount, resourceCount int) (accesscontrol.Service, *user.SignedInUser) {
now := time.Now() now := time.Now()
sqlStore := db.InitTestDB(b) sqlStore := db.InitTestReplDB(b)
store := database.ProvideService(sqlStore) store := database.ProvideService(sqlStore)
acService := &Service{ acService := &Service{
cfg: setting.NewCfg(), cfg: setting.NewCfg(),

@ -41,8 +41,8 @@ func setupTestEnv(t testing.TB) *Service {
log: log.New("accesscontrol"), log: log.New("accesscontrol"),
registrations: accesscontrol.RegistrationList{}, registrations: accesscontrol.RegistrationList{},
roles: accesscontrol.BuildBasicRoleDefinitions(), roles: accesscontrol.BuildBasicRoleDefinitions(),
store: database.ProvideService(db.InitTestDB(t)),
tracer: tracing.InitializeTracerForTest(), tracer: tracing.InitializeTracerForTest(),
store: database.ProvideService(db.InitTestReplDB(t)),
} }
require.NoError(t, ac.RegisterFixedRoles(context.Background())) require.NoError(t, ac.RegisterFixedRoles(context.Background()))
return ac return ac
@ -65,7 +65,7 @@ func TestUsageMetrics(t *testing.T) {
s := ProvideOSSService( s := ProvideOSSService(
cfg, cfg,
database.ProvideService(db.InitTestDB(t)), database.ProvideService(db.InitTestReplDB(t)),
&resourcepermissions.FakeActionSetSvc{}, &resourcepermissions.FakeActionSetSvc{},
localcache.ProvideService(), localcache.ProvideService(),
featuremgmt.WithFeatures(), featuremgmt.WithFeatures(),

@ -36,17 +36,17 @@ const (
WHERE br.role = ?` WHERE br.role = ?`
) )
func ProvideService(sql db.DB) *AccessControlStore { func ProvideService(sql db.ReplDB) *AccessControlStore {
return &AccessControlStore{sql} return &AccessControlStore{sql}
} }
type AccessControlStore struct { type AccessControlStore struct {
sql db.DB sql db.ReplDB
} }
func (s *AccessControlStore) GetUserPermissions(ctx context.Context, query accesscontrol.GetUserPermissionsQuery) ([]accesscontrol.Permission, error) { func (s *AccessControlStore) GetUserPermissions(ctx context.Context, query accesscontrol.GetUserPermissionsQuery) ([]accesscontrol.Permission, error) {
result := make([]accesscontrol.Permission, 0) result := make([]accesscontrol.Permission, 0)
err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { err := s.sql.ReadReplica().WithDbSession(ctx, func(sess *db.Session) error {
if query.UserID == 0 && len(query.TeamIDs) == 0 && len(query.Roles) == 0 { if query.UserID == 0 && len(query.TeamIDs) == 0 && len(query.Roles) == 0 {
// no permission to fetch // no permission to fetch
return nil return nil
@ -104,7 +104,7 @@ func (s *AccessControlStore) GetTeamsPermissions(ctx context.Context, query acce
orgID := query.OrgID orgID := query.OrgID
rolePrefixes := query.RolePrefixes rolePrefixes := query.RolePrefixes
result := make([]teamPermission, 0) result := make([]teamPermission, 0)
err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { err := s.sql.ReadReplica().WithDbSession(ctx, func(sess *db.Session) error {
if len(teams) == 0 { if len(teams) == 0 {
// no permission to fetch // no permission to fetch
return nil return nil
@ -172,7 +172,7 @@ func (s *AccessControlStore) SearchUsersPermissions(ctx context.Context, orgID i
} }
} }
if err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { if err := s.sql.ReadReplica().WithDbSession(ctx, func(sess *db.Session) error {
roleNameFilterJoin := "" roleNameFilterJoin := ""
if len(options.RolePrefixes) > 0 { if len(options.RolePrefixes) > 0 {
roleNameFilterJoin = "INNER JOIN role AS r ON up.role_id = r.id" roleNameFilterJoin = "INNER JOIN role AS r ON up.role_id = r.id"
@ -198,7 +198,7 @@ func (s *AccessControlStore) SearchUsersPermissions(ctx context.Context, orgID i
params = append(params, userID) params = append(params, userID)
} }
grafanaAdmin := fmt.Sprintf(grafanaAdminAssignsSQL, s.sql.Quote("user")) grafanaAdmin := fmt.Sprintf(grafanaAdminAssignsSQL, s.sql.ReadReplica().Quote("user"))
params = append(params, accesscontrol.RoleGrafanaAdmin) params = append(params, accesscontrol.RoleGrafanaAdmin)
if options.NamespacedID != "" { if options.NamespacedID != "" {
grafanaAdmin += " AND sa.user_id = ?" grafanaAdmin += " AND sa.user_id = ?"
@ -284,11 +284,11 @@ func (s *AccessControlStore) GetUsersBasicRoles(ctx context.Context, userFilter
IsAdmin bool `xorm:"is_admin"` IsAdmin bool `xorm:"is_admin"`
} }
dbRoles := make([]UserOrgRole, 0) dbRoles := make([]UserOrgRole, 0)
if err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { if err := s.sql.ReadReplica().WithDbSession(ctx, func(sess *db.Session) error {
// Find roles // Find roles
q := ` q := `
SELECT u.id, ou.role, u.is_admin SELECT u.id, ou.role, u.is_admin
FROM ` + s.sql.GetDialect().Quote("user") + ` AS u FROM ` + s.sql.ReadReplica().GetDialect().Quote("user") + ` AS u
LEFT JOIN org_user AS ou ON u.id = ou.user_id LEFT JOIN org_user AS ou ON u.id = ou.user_id
WHERE (u.is_admin OR ou.org_id = ?) WHERE (u.is_admin OR ou.org_id = ?)
` `
@ -318,7 +318,7 @@ func (s *AccessControlStore) GetUsersBasicRoles(ctx context.Context, userFilter
} }
func (s *AccessControlStore) DeleteUserPermissions(ctx context.Context, orgID, userID int64) error { func (s *AccessControlStore) DeleteUserPermissions(ctx context.Context, orgID, userID int64) error {
err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { err := s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error {
roleDeleteQuery := "DELETE FROM user_role WHERE user_id = ?" roleDeleteQuery := "DELETE FROM user_role WHERE user_id = ?"
roleDeleteParams := []any{roleDeleteQuery, userID} roleDeleteParams := []any{roleDeleteQuery, userID}
if orgID != accesscontrol.GlobalOrgID { if orgID != accesscontrol.GlobalOrgID {
@ -383,7 +383,7 @@ func (s *AccessControlStore) DeleteUserPermissions(ctx context.Context, orgID, u
} }
func (s *AccessControlStore) DeleteTeamPermissions(ctx context.Context, orgID, teamID int64) error { func (s *AccessControlStore) DeleteTeamPermissions(ctx context.Context, orgID, teamID int64) error {
err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { err := s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error {
roleDeleteQuery := "DELETE FROM team_role WHERE team_id = ? AND org_id = ?" roleDeleteQuery := "DELETE FROM team_role WHERE team_id = ? AND org_id = ?"
roleDeleteParams := []any{roleDeleteQuery, teamID, orgID} roleDeleteParams := []any{roleDeleteQuery, teamID, orgID}

@ -470,8 +470,8 @@ func createUsersAndTeams(t *testing.T, store db.DB, svcs helperServices, orgID i
return res return res
} }
func setupTestEnv(t testing.TB) (*database.AccessControlStore, rs.Store, user.Service, team.Service, org.Service, *sqlstore.SQLStore) { func setupTestEnv(t testing.TB) (*database.AccessControlStore, rs.Store, user.Service, team.Service, org.Service, *sqlstore.ReplStore) {
sql, cfg := db.InitTestDBWithCfg(t) sql, cfg := db.InitTestReplDBWithCfg(t)
cfg.AutoAssignOrg = true cfg.AutoAssignOrg = true
cfg.AutoAssignOrgRole = "Viewer" cfg.AutoAssignOrgRole = "Viewer"
cfg.AutoAssignOrgId = 1 cfg.AutoAssignOrgId = 1

@ -18,7 +18,7 @@ func extServiceRoleName(externalServiceID string) string {
func (s *AccessControlStore) DeleteExternalServiceRole(ctx context.Context, externalServiceID string) error { func (s *AccessControlStore) DeleteExternalServiceRole(ctx context.Context, externalServiceID string) error {
uid := accesscontrol.PrefixedRoleUID(extServiceRoleName(externalServiceID)) uid := accesscontrol.PrefixedRoleUID(extServiceRoleName(externalServiceID))
return s.sql.WithDbSession(ctx, func(sess *db.Session) error { return s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error {
stored, errGet := getRoleByUID(ctx, sess, uid) stored, errGet := getRoleByUID(ctx, sess, uid)
if errGet != nil { if errGet != nil {
// Role not found, nothing to do // Role not found, nothing to do
@ -55,7 +55,7 @@ func (s *AccessControlStore) SaveExternalServiceRole(ctx context.Context, cmd ac
role := genExternalServiceRole(cmd) role := genExternalServiceRole(cmd)
assignment := genExternalServiceAssignment(cmd) assignment := genExternalServiceAssignment(cmd)
return s.sql.WithDbSession(ctx, func(sess *db.Session) error { return s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error {
// Create or update the role // Create or update the role
existingRole, errSaveRole := s.saveRole(ctx, sess, &role) existingRole, errSaveRole := s.saveRole(ctx, sess, &role)
if errSaveRole != nil { if errSaveRole != nil {

@ -7,9 +7,10 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/accesscontrol"
"github.com/stretchr/testify/require"
) )
func TestAccessControlStore_SaveExternalServiceRole(t *testing.T) { func TestAccessControlStore_SaveExternalServiceRole(t *testing.T) {
@ -114,7 +115,7 @@ func TestAccessControlStore_SaveExternalServiceRole(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := context.Background() ctx := context.Background()
s := &AccessControlStore{ s := &AccessControlStore{
sql: db.InitTestDB(t), sql: db.InitTestReplDB(t),
} }
for i := range tt.runs { for i := range tt.runs {
@ -125,7 +126,7 @@ func TestAccessControlStore_SaveExternalServiceRole(t *testing.T) {
} }
require.NoError(t, err) require.NoError(t, err)
errDBSession := s.sql.WithDbSession(ctx, func(sess *db.Session) error { errDBSession := s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error {
storedRole, err := getRoleByUID(ctx, sess, accesscontrol.PrefixedRoleUID(extServiceRoleName(tt.runs[i].cmd.ExternalServiceID))) storedRole, err := getRoleByUID(ctx, sess, accesscontrol.PrefixedRoleUID(extServiceRoleName(tt.runs[i].cmd.ExternalServiceID)))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, storedRole) require.NotNil(t, storedRole)
@ -187,13 +188,13 @@ func TestAccessControlStore_DeleteExternalServiceRole(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := context.Background() ctx := context.Background()
s := &AccessControlStore{ s := &AccessControlStore{
sql: db.InitTestDB(t), sql: db.InitTestReplDB(t),
} }
if tt.init != nil { if tt.init != nil {
tt.init(t, ctx, s) tt.init(t, ctx, s)
} }
roleID := int64(-1) roleID := int64(-1)
err := s.sql.WithDbSession(ctx, func(sess *db.Session) error { err := s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error {
role, err := getRoleByUID(ctx, sess, accesscontrol.PrefixedRoleUID(extServiceRoleName(tt.id))) role, err := getRoleByUID(ctx, sess, accesscontrol.PrefixedRoleUID(extServiceRoleName(tt.id)))
if err != nil && !errors.Is(err, accesscontrol.ErrRoleNotFound) { if err != nil && !errors.Is(err, accesscontrol.ErrRoleNotFound) {
return err return err
@ -217,7 +218,7 @@ func TestAccessControlStore_DeleteExternalServiceRole(t *testing.T) {
} }
// Assignments should be deleted // Assignments should be deleted
_ = s.sql.WithDbSession(ctx, func(sess *db.Session) error { _ = s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error {
var assignment accesscontrol.UserRole var assignment accesscontrol.UserRole
count, err := sess.Where("role_id = ?", roleID).Count(&assignment) count, err := sess.Where("role_id = ?", roleID).Count(&assignment)
require.NoError(t, err) require.NoError(t, err)
@ -226,7 +227,7 @@ func TestAccessControlStore_DeleteExternalServiceRole(t *testing.T) {
}) })
// Permissions should be deleted // Permissions should be deleted
_ = s.sql.WithDbSession(ctx, func(sess *db.Session) error { _ = s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error {
var permission accesscontrol.Permission var permission accesscontrol.Permission
count, err := sess.Where("role_id = ?", roleID).Count(&permission) count, err := sess.Where("role_id = ?", roleID).Count(&permission)
require.NoError(t, err) require.NoError(t, err)
@ -235,7 +236,7 @@ func TestAccessControlStore_DeleteExternalServiceRole(t *testing.T) {
}) })
// Role should be deleted // Role should be deleted
_ = s.sql.WithDbSession(ctx, func(sess *db.Session) error { _ = s.sql.DB().WithDbSession(ctx, func(sess *db.Session) error {
storedRole, err := getRoleByUID(ctx, sess, accesscontrol.PrefixedRoleUID(extServiceRoleName(tt.id))) storedRole, err := getRoleByUID(ctx, sess, accesscontrol.PrefixedRoleUID(extServiceRoleName(tt.id)))
require.ErrorIs(t, err, accesscontrol.ErrRoleNotFound) require.ErrorIs(t, err, accesscontrol.ErrRoleNotFound)
require.Nil(t, storedRole) require.Nil(t, storedRole)

@ -19,14 +19,14 @@ const (
maxLen = 40 maxLen = 40
) )
func MigrateScopeSplit(db db.DB, log log.Logger) error { func MigrateScopeSplit(db db.ReplDB, log log.Logger) error {
t := time.Now() t := time.Now()
ctx := context.Background() ctx := context.Background()
cnt := 0 cnt := 0
// Search for the permissions to update // Search for the permissions to update
var permissions []ac.Permission var permissions []ac.Permission
if errFind := db.WithTransactionalDbSession(ctx, func(sess *sqlstore.DBSession) error { if errFind := db.DB().WithTransactionalDbSession(ctx, func(sess *sqlstore.DBSession) error {
return sess.SQL("SELECT * FROM permission WHERE NOT scope = '' AND identifier = ''").Find(&permissions) return sess.SQL("SELECT * FROM permission WHERE NOT scope = '' AND identifier = ''").Find(&permissions)
}); errFind != nil { }); errFind != nil {
log.Error("Could not search for permissions to update", "migration", "scopeSplit", "error", errFind) log.Error("Could not search for permissions to update", "migration", "scopeSplit", "error", errFind)
@ -76,7 +76,7 @@ func MigrateScopeSplit(db db.DB, log log.Logger) error {
delQuery = delQuery[:len(delQuery)-1] + ")" delQuery = delQuery[:len(delQuery)-1] + ")"
// Batch update the permissions // Batch update the permissions
if errBatchUpdate := db.GetSqlxSession().WithTransaction(ctx, func(tx *session.SessionTx) error { if errBatchUpdate := db.DB().GetSqlxSession().WithTransaction(ctx, func(tx *session.SessionTx) error {
if _, errDel := tx.Exec(ctx, delQuery, delArgs...); errDel != nil { if _, errDel := tx.Exec(ctx, delQuery, delArgs...); errDel != nil {
log.Error("Error deleting permissions", "migration", "scopeSplit", "error", errDel) log.Error("Error deleting permissions", "migration", "scopeSplit", "error", errDel)
return errDel return errDel

@ -10,7 +10,7 @@ import (
) )
func benchScopeSplitConcurrent(b *testing.B, count int) { func benchScopeSplitConcurrent(b *testing.B, count int) {
store := db.InitTestDB(b) store := db.InitTestReplDB(b)
// Populate permissions // Populate permissions
require.NoError(b, batchInsertPermissions(count, store), "could not insert permissions") require.NoError(b, batchInsertPermissions(count, store), "could not insert permissions")
logger := log.New("migrator.test") logger := log.New("migrator.test")

@ -46,7 +46,7 @@ func batchInsertPermissions(cnt int, sqlStore db.DB) error {
// TestIntegrationMigrateScopeSplit tests the scope split migration // TestIntegrationMigrateScopeSplit tests the scope split migration
// also tests the scope split truncation logic // also tests the scope split truncation logic
func TestIntegrationMigrateScopeSplitTruncation(t *testing.T) { func TestIntegrationMigrateScopeSplitTruncation(t *testing.T) {
sqlStore := db.InitTestDB(t) sqlStore := db.InitTestReplDB(t)
logger := log.New("accesscontrol.migrator.test") logger := log.New("accesscontrol.migrator.test")
batchSize = 20 batchSize = 20

@ -192,3 +192,15 @@ func InitTestReplDB(t sqlutil.ITestDB, opts ...InitTestDBOpt) (*ReplStore, *sett
} }
return &ReplStore{ss, ss}, cfg return &ReplStore{ss, ss}, cfg
} }
// InitTestReplDBWithMigration initializes the test DB given custom migrations.
func InitTestReplDBWithMigration(t sqlutil.ITestDB, migration registry.DatabaseMigrator, opts ...InitTestDBOpt) *ReplStore {
t.Helper()
features := getFeaturesForTesting(opts...)
cfg := getCfgForTesting(opts...)
ss, err := initTestDB(t, cfg, features, migration, opts...)
if err != nil {
t.Fatalf("failed to initialize sql store: %s", err)
}
return &ReplStore{ss, ss}
}

Loading…
Cancel
Save