diff --git a/pkg/api/org_users_test.go b/pkg/api/org_users_test.go index 74820816843..eb108229481 100644 --- a/pkg/api/org_users_test.go +++ b/pkg/api/org_users_test.go @@ -716,7 +716,7 @@ func TestOrgUsersAPIEndpointWithSetPerms_AccessControl(t *testing.T) { sc := setupHTTPServer(t, true, func(hs *HTTPServer) { hs.tempUserService = tempuserimpl.ProvideService(hs.SQLStore) hs.userService = userimpl.ProvideService( - hs.SQLStore, nil, nil, hs.SQLStore.(*sqlstore.SQLStore), + hs.SQLStore, nil, setting.NewCfg(), hs.SQLStore.(*sqlstore.SQLStore), ) }) setInitCtxSignedInViewer(sc.initCtx) diff --git a/pkg/api/user.go b/pkg/api/user.go index 09e05ee0e13..22326ae2de3 100644 --- a/pkg/api/user.go +++ b/pkg/api/user.go @@ -51,7 +51,7 @@ func (hs *HTTPServer) GetUserByID(c *models.ReqContext) response.Response { func (hs *HTTPServer) getUserUserProfile(c *models.ReqContext, userID int64) response.Response { query := user.GetUserProfileQuery{UserID: userID} - userProfile, err := hs.userService.GetUserProfile(c.Req.Context(), &query) + userProfile, err := hs.userService.GetProfile(c.Req.Context(), &query) if err != nil { if errors.Is(err, user.ErrUserNotFound) { return response.Error(404, user.ErrUserNotFound.Error(), nil) diff --git a/pkg/services/sqlstore/db/db.go b/pkg/services/sqlstore/db/db.go index f35a27c5c98..8af1321bd00 100644 --- a/pkg/services/sqlstore/db/db.go +++ b/pkg/services/sqlstore/db/db.go @@ -6,6 +6,7 @@ import ( "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore/migrator" "github.com/grafana/grafana/pkg/services/sqlstore/session" + "xorm.io/core" ) type DB interface { @@ -13,6 +14,7 @@ type DB interface { WithDbSession(ctx context.Context, callback sqlstore.DBTransactionFunc) error NewSession(ctx context.Context) *sqlstore.DBSession GetDialect() migrator.Dialect + GetDBType() core.DbType GetSqlxSession() *session.SessionDB InTransaction(ctx context.Context, fn func(ctx context.Context) error) error } diff --git a/pkg/services/sqlstore/mockstore/mockstore.go b/pkg/services/sqlstore/mockstore/mockstore.go index 2437d8ab8a9..48941e14e13 100644 --- a/pkg/services/sqlstore/mockstore/mockstore.go +++ b/pkg/services/sqlstore/mockstore/mockstore.go @@ -9,6 +9,7 @@ import ( "github.com/grafana/grafana/pkg/services/sqlstore/migrator" "github.com/grafana/grafana/pkg/services/sqlstore/session" "github.com/grafana/grafana/pkg/services/user" + "xorm.io/core" ) type OrgListResponse []struct { @@ -76,6 +77,10 @@ func (m *SQLStoreMock) GetDialect() migrator.Dialect { return nil } +func (m *SQLStoreMock) GetDBType() core.DbType { + return "" +} + func (m *SQLStoreMock) HasEditPermissionInFolders(ctx context.Context, query *models.HasEditPermissionInFoldersQuery) error { return m.ExpectedError } diff --git a/pkg/services/sqlstore/sqlstore.go b/pkg/services/sqlstore/sqlstore.go index 0ab1b8bc68c..ef50821c462 100644 --- a/pkg/services/sqlstore/sqlstore.go +++ b/pkg/services/sqlstore/sqlstore.go @@ -16,6 +16,7 @@ import ( "github.com/jmoiron/sqlx" _ "github.com/lib/pq" "github.com/prometheus/client_golang/prometheus" + "xorm.io/core" "xorm.io/xorm" "github.com/grafana/grafana/pkg/bus" @@ -171,6 +172,10 @@ func (ss *SQLStore) GetDialect() migrator.Dialect { return ss.Dialect } +func (ss *SQLStore) GetDBType() core.DbType { + return ss.engine.Dialect().DBType() +} + func (ss *SQLStore) Bus() bus.Bus { return ss.bus } diff --git a/pkg/services/sqlstore/store.go b/pkg/services/sqlstore/store.go index aada82826e4..23412cf1d25 100644 --- a/pkg/services/sqlstore/store.go +++ b/pkg/services/sqlstore/store.go @@ -7,6 +7,7 @@ import ( "github.com/grafana/grafana/pkg/services/sqlstore/migrator" "github.com/grafana/grafana/pkg/services/sqlstore/session" "github.com/grafana/grafana/pkg/services/user" + "xorm.io/core" ) type Store interface { @@ -15,6 +16,7 @@ type Store interface { GetDataSourceStats(ctx context.Context, query *models.GetDataSourceStatsQuery) error GetDataSourceAccessStats(ctx context.Context, query *models.GetDataSourceAccessStatsQuery) error GetDialect() migrator.Dialect + GetDBType() core.DbType GetSystemStats(ctx context.Context, query *models.GetSystemStatsQuery) error GetOrgByName(name string) (*models.Org, error) CreateOrg(ctx context.Context, cmd *models.CreateOrgCommand) error diff --git a/pkg/services/sqlstore/user.go b/pkg/services/sqlstore/user.go index 9c564cf97db..eacaa039dc3 100644 --- a/pkg/services/sqlstore/user.go +++ b/pkg/services/sqlstore/user.go @@ -201,87 +201,6 @@ func (ss *SQLStore) GetUserById(ctx context.Context, query *models.GetUserByIdQu }) } -func (ss *SQLStore) GetUserByLogin(ctx context.Context, query *models.GetUserByLoginQuery) error { - return ss.WithDbSession(ctx, func(sess *DBSession) error { - if query.LoginOrEmail == "" { - return user.ErrUserNotFound - } - - // Try and find the user by login first. - // It's not sufficient to assume that a LoginOrEmail with an "@" is an email. - usr := &user.User{} - where := "login=?" - if ss.Cfg.CaseInsensitiveLogin { - where = "LOWER(login)=LOWER(?)" - } - - has, err := sess.Where(notServiceAccountFilter(ss)).Where(where, query.LoginOrEmail).Get(usr) - if err != nil { - return err - } - - if !has && strings.Contains(query.LoginOrEmail, "@") { - // If the user wasn't found, and it contains an "@" fallback to finding the - // user by email. - - where = "email=?" - if ss.Cfg.CaseInsensitiveLogin { - where = "LOWER(email)=LOWER(?)" - } - usr = &user.User{} - has, err = sess.Where(notServiceAccountFilter(ss)).Where(where, query.LoginOrEmail).Get(usr) - } - - if err != nil { - return err - } else if !has { - return user.ErrUserNotFound - } - - if ss.Cfg.CaseInsensitiveLogin { - if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil { - return err - } - } - - query.Result = usr - - return nil - }) -} - -func (ss *SQLStore) GetUserByEmail(ctx context.Context, query *models.GetUserByEmailQuery) error { - return ss.WithDbSession(ctx, func(sess *DBSession) error { - if query.Email == "" { - return user.ErrUserNotFound - } - - usr := &user.User{} - where := "email=?" - if ss.Cfg.CaseInsensitiveLogin { - where = "LOWER(email)=LOWER(?)" - } - - has, err := sess.Where(notServiceAccountFilter(ss)).Where(where, query.Email).Get(usr) - - if err != nil { - return err - } else if !has { - return user.ErrUserNotFound - } - - if ss.Cfg.CaseInsensitiveLogin { - if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil { - return err - } - } - - query.Result = usr - - return nil - }) -} - func (ss *SQLStore) UpdateUser(ctx context.Context, cmd *models.UpdateUserCommand) error { if ss.Cfg.CaseInsensitiveLogin { cmd.Login = strings.ToLower(cmd.Login) diff --git a/pkg/services/sqlstore/user_test.go b/pkg/services/sqlstore/user_test.go index bc80e36dac9..51e1d56993c 100644 --- a/pkg/services/sqlstore/user_test.go +++ b/pkg/services/sqlstore/user_test.go @@ -7,7 +7,6 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/org" - "github.com/grafana/grafana/pkg/services/sqlstore/migrator" "github.com/grafana/grafana/pkg/services/user" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -90,82 +89,6 @@ func TestIntegrationUserDataAccess(t *testing.T) { Permissions: map[int64]map[string][]string{1: {"users:read": {"global.users:*"}}}, } - t.Run("Testing DB - creates and loads user", func(t *testing.T) { - cmd := user.CreateUserCommand{ - Email: "usertest@test.com", - Name: "user name", - Login: "user_test_login", - } - user, err := ss.CreateUser(context.Background(), cmd) - require.NoError(t, err) - - query := models.GetUserByIdQuery{Id: user.ID} - err = ss.GetUserById(context.Background(), &query) - require.Nil(t, err) - - require.Equal(t, query.Result.Email, "usertest@test.com") - require.Equal(t, query.Result.Password, "") - require.Len(t, query.Result.Rands, 10) - require.Len(t, query.Result.Salt, 10) - require.False(t, query.Result.IsDisabled) - - query = models.GetUserByIdQuery{Id: user.ID} - err = ss.GetUserById(context.Background(), &query) - require.Nil(t, err) - - require.Equal(t, query.Result.Email, "usertest@test.com") - require.Equal(t, query.Result.Password, "") - require.Len(t, query.Result.Rands, 10) - require.Len(t, query.Result.Salt, 10) - require.False(t, query.Result.IsDisabled) - - t.Run("Get User by email case insensitive", func(t *testing.T) { - ss.Cfg.CaseInsensitiveLogin = true - query := models.GetUserByEmailQuery{Email: "USERtest@TEST.COM"} - err = ss.GetUserByEmail(context.Background(), &query) - require.Nil(t, err) - - require.Equal(t, query.Result.Email, "usertest@test.com") - require.Equal(t, query.Result.Password, "") - require.Len(t, query.Result.Rands, 10) - require.Len(t, query.Result.Salt, 10) - require.False(t, query.Result.IsDisabled) - - ss.Cfg.CaseInsensitiveLogin = false - }) - - t.Run("Get User by login - case insensitive", func(t *testing.T) { - ss.Cfg.CaseInsensitiveLogin = true - - query := models.GetUserByLoginQuery{LoginOrEmail: "USER_test_login"} - err = ss.GetUserByLogin(context.Background(), &query) - require.Nil(t, err) - - require.Equal(t, query.Result.Email, "usertest@test.com") - require.Equal(t, query.Result.Password, "") - require.Len(t, query.Result.Rands, 10) - require.Len(t, query.Result.Salt, 10) - require.False(t, query.Result.IsDisabled) - - ss.Cfg.CaseInsensitiveLogin = false - }) - - t.Run("Get User by login - email fallback case insensitive", func(t *testing.T) { - ss.Cfg.CaseInsensitiveLogin = true - query := models.GetUserByLoginQuery{LoginOrEmail: "USERtest@TEST.COM"} - err = ss.GetUserByLogin(context.Background(), &query) - require.Nil(t, err) - - require.Equal(t, query.Result.Email, "usertest@test.com") - require.Equal(t, query.Result.Password, "") - require.Len(t, query.Result.Rands, 10) - require.Len(t, query.Result.Salt, 10) - require.False(t, query.Result.IsDisabled) - - ss.Cfg.CaseInsensitiveLogin = false - }) - }) - t.Run("Testing DB - creates and loads disabled user", func(t *testing.T) { ss = InitTestDB(t) cmd := user.CreateUserCommand{ @@ -475,90 +398,6 @@ func TestIntegrationUserDataAccess(t *testing.T) { assert.Len(t, query.Result.Users, 2) }) - t.Run("Testing DB - error on case insensitive conflict", func(t *testing.T) { - if ss.engine.Dialect().DBType() == migrator.MySQL { - t.Skip("Skipping on MySQL due to case insensitive indexes") - } - - cmd := user.CreateUserCommand{ - Email: "confusertest@test.com", - Name: "user name", - Login: "user_email_conflict", - } - userEmailConflict, err := ss.CreateUser(context.Background(), cmd) - require.NoError(t, err) - - cmd = user.CreateUserCommand{ - Email: "confusertest@TEST.COM", - Name: "user name", - Login: "user_email_conflict_two", - } - _, err = ss.CreateUser(context.Background(), cmd) - require.NoError(t, err) - - cmd = user.CreateUserCommand{ - Email: "user_test_login_conflict@test.com", - Name: "user name", - Login: "user_test_login_conflict", - } - userLoginConflict, err := ss.CreateUser(context.Background(), cmd) - require.NoError(t, err) - - cmd = user.CreateUserCommand{ - Email: "user_test_login_conflict_two@test.com", - Name: "user name", - Login: "user_test_login_CONFLICT", - } - _, err = ss.CreateUser(context.Background(), cmd) - require.NoError(t, err) - - ss.Cfg.CaseInsensitiveLogin = true - - t.Run("GetUserByEmail - email conflict", func(t *testing.T) { - query := models.GetUserByEmailQuery{Email: "confusertest@test.com"} - err = ss.GetUserByEmail(context.Background(), &query) - require.Error(t, err) - }) - - t.Run("GetUserByEmail - login conflict", func(t *testing.T) { - query := models.GetUserByEmailQuery{Email: "user_test_login_conflict@test.com"} - err = ss.GetUserByEmail(context.Background(), &query) - require.Error(t, err) - }) - - t.Run("GetUserByID - email conflict", func(t *testing.T) { - query := models.GetUserByIdQuery{Id: userEmailConflict.ID} - err = ss.GetUserById(context.Background(), &query) - require.Error(t, err) - }) - - t.Run("GetUserByID - login conflict", func(t *testing.T) { - query := models.GetUserByIdQuery{Id: userLoginConflict.ID} - err = ss.GetUserById(context.Background(), &query) - require.Error(t, err) - }) - - t.Run("GetUserByLogin - email conflict", func(t *testing.T) { - query := models.GetUserByLoginQuery{LoginOrEmail: "user_email_conflict_two"} - err = ss.GetUserByLogin(context.Background(), &query) - require.Error(t, err) - }) - - t.Run("GetUserByLogin - login conflict", func(t *testing.T) { - query := models.GetUserByLoginQuery{LoginOrEmail: "user_test_login_conflict"} - err = ss.GetUserByLogin(context.Background(), &query) - require.Error(t, err) - }) - - t.Run("GetUserByLogin - login conflict by email", func(t *testing.T) { - query := models.GetUserByLoginQuery{LoginOrEmail: "user_test_login_conflict@test.com"} - err = ss.GetUserByLogin(context.Background(), &query) - require.Error(t, err) - }) - - ss.Cfg.CaseInsensitiveLogin = false - }) - ss = InitTestDB(t) t.Run("Testing DB - enable all users", func(t *testing.T) { diff --git a/pkg/services/user/user.go b/pkg/services/user/user.go index 7df9b493e67..c718018783c 100644 --- a/pkg/services/user/user.go +++ b/pkg/services/user/user.go @@ -21,5 +21,5 @@ type Service interface { BatchDisableUsers(context.Context, *BatchDisableUsersCommand) error UpdatePermissions(int64, bool) error SetUserHelpFlag(context.Context, *SetUserHelpFlagCommand) error - GetUserProfile(context.Context, *GetUserProfileQuery) (UserProfileDTO, error) + GetProfile(context.Context, *GetUserProfileQuery) (UserProfileDTO, error) } diff --git a/pkg/services/user/userimpl/store.go b/pkg/services/user/userimpl/store.go index f215789a8da..e986352254b 100644 --- a/pkg/services/user/userimpl/store.go +++ b/pkg/services/user/userimpl/store.go @@ -3,6 +3,7 @@ package userimpl import ( "context" "fmt" + "strings" "github.com/grafana/grafana/pkg/events" "github.com/grafana/grafana/pkg/infra/log" @@ -20,6 +21,8 @@ type store interface { GetNotServiceAccount(context.Context, int64) (*user.User, error) Delete(context.Context, int64) error CaseInsensitiveLoginConflict(context.Context, string, string) error + GetByLogin(context.Context, *user.GetUserByLoginQuery) (*user.User, error) + GetByEmail(context.Context, *user.GetUserByEmailQuery) (*user.User, error) } type sqlStore struct { @@ -145,3 +148,101 @@ func (ss *sqlStore) CaseInsensitiveLoginConflict(ctx context.Context, login, ema }) return err } + +func (ss *sqlStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQuery) (*user.User, error) { + usr := &user.User{} + err := ss.db.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + if query.LoginOrEmail == "" { + return user.ErrUserNotFound + } + + // Try and find the user by login first. + // It's not sufficient to assume that a LoginOrEmail with an "@" is an email. + where := "login=?" + if ss.cfg.CaseInsensitiveLogin { + where = "LOWER(login)=LOWER(?)" + } + + has, err := sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr) + if err != nil { + return err + } + + if !has && strings.Contains(query.LoginOrEmail, "@") { + // If the user wasn't found, and it contains an "@" fallback to finding the + // user by email. + + where = "email=?" + if ss.cfg.CaseInsensitiveLogin { + where = "LOWER(email)=LOWER(?)" + } + usr = &user.User{} + has, err = sess.Where(ss.notServiceAccountFilter()).Where(where, query.LoginOrEmail).Get(usr) + } + + if err != nil { + return err + } else if !has { + return user.ErrUserNotFound + } + + if ss.cfg.CaseInsensitiveLogin { + if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil { + return err + } + } + return nil + }) + if err != nil { + return nil, err + } + return usr, nil +} + +func (ss *sqlStore) GetByEmail(ctx context.Context, query *user.GetUserByEmailQuery) (*user.User, error) { + usr := &user.User{} + err := ss.db.WithDbSession(ctx, func(sess *sqlstore.DBSession) error { + if query.Email == "" { + return user.ErrUserNotFound + } + + where := "email=?" + if ss.cfg.CaseInsensitiveLogin { + where = "LOWER(email)=LOWER(?)" + } + + has, err := sess.Where(ss.notServiceAccountFilter()).Where(where, query.Email).Get(usr) + + if err != nil { + return err + } else if !has { + return user.ErrUserNotFound + } + + if ss.cfg.CaseInsensitiveLogin { + if err := ss.userCaseInsensitiveLoginConflict(ctx, sess, usr.Login, usr.Email); err != nil { + return err + } + } + return nil + }) + if err != nil { + return nil, err + } + return usr, nil +} + +func (ss *sqlStore) userCaseInsensitiveLoginConflict(ctx context.Context, sess *sqlstore.DBSession, login, email string) error { + users := make([]user.User, 0) + + if err := sess.Where("LOWER(email)=LOWER(?) OR LOWER(login)=LOWER(?)", + email, login).Find(&users); err != nil { + return err + } + + if len(users) > 1 { + return &user.ErrCaseInsensitiveLoginConflict{Users: users} + } + + return nil +} diff --git a/pkg/services/user/userimpl/store_test.go b/pkg/services/user/userimpl/store_test.go index 3df8e61c9b8..9e865038c2b 100644 --- a/pkg/services/user/userimpl/store_test.go +++ b/pkg/services/user/userimpl/store_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/services/sqlstore" + "github.com/grafana/grafana/pkg/services/sqlstore/migrator" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) @@ -54,4 +55,143 @@ func TestIntegrationUserDataAccess(t *testing.T) { ) require.NoError(t, err) }) + + t.Run("Testing DB - creates and loads user", func(t *testing.T) { + ss := sqlstore.InitTestDB(t) + cmd := user.CreateUserCommand{ + Email: "usertest@test.com", + Name: "user name", + Login: "user_test_login", + } + usr, err := ss.CreateUser(context.Background(), cmd) + require.NoError(t, err) + + result, err := userStore.GetByID(context.Background(), usr.ID) + require.Nil(t, err) + + require.Equal(t, result.Email, "usertest@test.com") + require.Equal(t, result.Password, "") + require.Len(t, result.Rands, 10) + require.Len(t, result.Salt, 10) + require.False(t, result.IsDisabled) + + result, err = userStore.GetByID(context.Background(), usr.ID) + require.Nil(t, err) + + require.Equal(t, result.Email, "usertest@test.com") + require.Equal(t, result.Password, "") + require.Len(t, result.Rands, 10) + require.Len(t, result.Salt, 10) + require.False(t, result.IsDisabled) + + t.Run("Get User by email case insensitive", func(t *testing.T) { + userStore.cfg.CaseInsensitiveLogin = true + query := user.GetUserByEmailQuery{Email: "USERtest@TEST.COM"} + result, err := userStore.GetByEmail(context.Background(), &query) + require.Nil(t, err) + + require.Equal(t, result.Email, "usertest@test.com") + require.Equal(t, result.Password, "") + require.Len(t, result.Rands, 10) + require.Len(t, result.Salt, 10) + require.False(t, result.IsDisabled) + + userStore.cfg.CaseInsensitiveLogin = false + }) + + t.Run("Testing DB - creates and loads user", func(t *testing.T) { + result, err = userStore.GetByID(context.Background(), usr.ID) + require.Nil(t, err) + + require.Equal(t, result.Email, "usertest@test.com") + require.Equal(t, result.Password, "") + require.Len(t, result.Rands, 10) + require.Len(t, result.Salt, 10) + require.False(t, result.IsDisabled) + + result, err = userStore.GetByID(context.Background(), usr.ID) + require.Nil(t, err) + + require.Equal(t, result.Email, "usertest@test.com") + require.Equal(t, result.Password, "") + require.Len(t, result.Rands, 10) + require.Len(t, result.Salt, 10) + require.False(t, result.IsDisabled) + ss.Cfg.CaseInsensitiveLogin = false + }) + }) + + t.Run("Testing DB - error on case insensitive conflict", func(t *testing.T) { + if ss.GetDBType() == migrator.MySQL { + t.Skip("Skipping on MySQL due to case insensitive indexes") + } + userStore.cfg.CaseInsensitiveLogin = true + cmd := user.CreateUserCommand{ + Email: "confusertest@test.com", + Name: "user name", + Login: "user_email_conflict", + } + // userEmailConflict + _, err := ss.CreateUser(context.Background(), cmd) + require.NoError(t, err) + + cmd = user.CreateUserCommand{ + Email: "confusertest@TEST.COM", + Name: "user name", + Login: "user_email_conflict_two", + } + _, err = ss.CreateUser(context.Background(), cmd) + require.NoError(t, err) + + cmd = user.CreateUserCommand{ + Email: "user_test_login_conflict@test.com", + Name: "user name", + Login: "user_test_login_conflict", + } + // userLoginConflict + _, err = ss.CreateUser(context.Background(), cmd) + require.NoError(t, err) + + cmd = user.CreateUserCommand{ + Email: "user_test_login_conflict_two@test.com", + Name: "user name", + Login: "user_test_login_CONFLICT", + } + _, err = ss.CreateUser(context.Background(), cmd) + require.NoError(t, err) + + ss.Cfg.CaseInsensitiveLogin = true + + t.Run("GetByEmail - email conflict", func(t *testing.T) { + query := user.GetUserByEmailQuery{Email: "confusertest@test.com"} + _, err = userStore.GetByEmail(context.Background(), &query) + require.Error(t, err) + }) + + t.Run("GetByEmail - login conflict", func(t *testing.T) { + query := user.GetUserByEmailQuery{Email: "user_test_login_conflict@test.com"} + _, err = userStore.GetByEmail(context.Background(), &query) + require.Error(t, err) + }) + + t.Run("GetByLogin - email conflict", func(t *testing.T) { + query := user.GetUserByLoginQuery{LoginOrEmail: "user_email_conflict_two"} + _, err = userStore.GetByLogin(context.Background(), &query) + require.Error(t, err) + }) + + t.Run("GetByLogin - login conflict", func(t *testing.T) { + query := user.GetUserByLoginQuery{LoginOrEmail: "user_test_login_conflict"} + _, err = userStore.GetByLogin(context.Background(), &query) + require.Error(t, err) + }) + + t.Run("GetByLogin - login conflict by email", func(t *testing.T) { + query := user.GetUserByLoginQuery{LoginOrEmail: "user_test_login_conflict@test.com"} + _, err = userStore.GetByLogin(context.Background(), &query) + require.Error(t, err) + }) + + ss.Cfg.CaseInsensitiveLogin = false + }) } diff --git a/pkg/services/user/userimpl/user.go b/pkg/services/user/userimpl/user.go index 47824d9421e..e713e1be78b 100644 --- a/pkg/services/user/userimpl/user.go +++ b/pkg/services/user/userimpl/user.go @@ -152,24 +152,12 @@ func (s *Service) GetByID(ctx context.Context, query *user.GetUserByIDQuery) (*u return user, nil } -// TODO: remove wrapper around sqlstore func (s *Service) GetByLogin(ctx context.Context, query *user.GetUserByLoginQuery) (*user.User, error) { - q := models.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail} - err := s.sqlStore.GetUserByLogin(ctx, &q) - if err != nil { - return nil, err - } - return q.Result, nil + return s.store.GetByLogin(ctx, query) } -// TODO: remove wrapper around sqlstore func (s *Service) GetByEmail(ctx context.Context, query *user.GetUserByEmailQuery) (*user.User, error) { - q := models.GetUserByEmailQuery{Email: query.Email} - err := s.sqlStore.GetUserByEmail(ctx, &q) - if err != nil { - return nil, err - } - return q.Result, nil + return s.store.GetByEmail(ctx, query) } // TODO: remove wrapper around sqlstore @@ -313,7 +301,7 @@ func (s *Service) SetUserHelpFlag(ctx context.Context, cmd *user.SetUserHelpFlag } // TODO: remove wrapper around sqlstore -func (s *Service) GetUserProfile(ctx context.Context, query *user.GetUserProfileQuery) (user.UserProfileDTO, error) { +func (s *Service) GetProfile(ctx context.Context, query *user.GetUserProfileQuery) (user.UserProfileDTO, error) { q := &models.GetUserProfileQuery{ UserId: query.UserID, } diff --git a/pkg/services/user/userimpl/user_test.go b/pkg/services/user/userimpl/user_test.go index f8a634d96bb..d69329a0d64 100644 --- a/pkg/services/user/userimpl/user_test.go +++ b/pkg/services/user/userimpl/user_test.go @@ -2,6 +2,7 @@ package userimpl import ( "context" + "errors" "testing" "github.com/grafana/grafana/pkg/services/org/orgtest" @@ -77,6 +78,14 @@ func TestUserService(t *testing.T) { err := userService.Delete(context.Background(), &user.DeleteUserCommand{UserID: 1}) require.NoError(t, err) }) + + t.Run("GetByID - email conflict", func(t *testing.T) { + userService.cfg.CaseInsensitiveLogin = true + userStore.ExpectedError = errors.New("email conflict") + query := user.GetUserByIDQuery{} + _, err := userService.GetByID(context.Background(), &query) + require.Error(t, err) + }) } type FakeUserStore struct { @@ -112,3 +121,11 @@ func (f *FakeUserStore) GetByID(context.Context, int64) (*user.User, error) { func (f *FakeUserStore) CaseInsensitiveLoginConflict(context.Context, string, string) error { return f.ExpectedError } + +func (f *FakeUserStore) GetByLogin(ctx context.Context, query *user.GetUserByLoginQuery) (*user.User, error) { + return f.ExpectedUser, f.ExpectedError +} + +func (f *FakeUserStore) GetByEmail(ctx context.Context, query *user.GetUserByEmailQuery) (*user.User, error) { + return f.ExpectedUser, f.ExpectedError +} diff --git a/pkg/services/user/usertest/fake.go b/pkg/services/user/usertest/fake.go index 47ce819f436..862b4ac4e88 100644 --- a/pkg/services/user/usertest/fake.go +++ b/pkg/services/user/usertest/fake.go @@ -91,6 +91,6 @@ func (f *FakeUserService) SetUserHelpFlag(ctx context.Context, cmd *user.SetUser return f.ExpectedError } -func (f *FakeUserService) GetUserProfile(ctx context.Context, query *user.GetUserProfileQuery) (user.UserProfileDTO, error) { +func (f *FakeUserService) GetProfile(ctx context.Context, query *user.GetUserProfileQuery) (user.UserProfileDTO, error) { return f.ExpectedUSerProfileDTO, f.ExpectedError }