diff --git a/pkg/api/ldap_debug.go b/pkg/api/ldap_debug.go index cc4ef546405..48015002395 100644 --- a/pkg/api/ldap_debug.go +++ b/pkg/api/ldap_debug.go @@ -220,6 +220,11 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon ReqContext: c, ExternalUser: user, SignupAllowed: hs.Cfg.LDAPAllowSignup, + UserLookupParams: models.UserLookupParams{ + UserID: &query.Result.ID, // Upsert by ID only + Email: nil, + Login: nil, + }, } err = hs.Login.UpsertUser(c.Req.Context(), upsertCmd) diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index 7beff805990..8495f08de71 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -305,6 +305,11 @@ func (hs *HTTPServer) SyncUser( ReqContext: ctx, ExternalUser: extUser, SignupAllowed: connect.IsSignupAllowed(), + UserLookupParams: models.UserLookupParams{ + Email: &extUser.Email, + UserID: nil, + Login: nil, + }, } if err := hs.Login.UpsertUser(ctx.Req.Context(), cmd); err != nil { diff --git a/pkg/api/user_test.go b/pkg/api/user_test.go index bdd5cd6e071..c7ec173f524 100644 --- a/pkg/api/user_test.go +++ b/pkg/api/user_test.go @@ -76,7 +76,8 @@ func TestUserAPIEndpoint_userLoggedIn(t *testing.T) { } idToken := "testidtoken" token = token.WithExtra(map[string]interface{}{"id_token": idToken}) - query := &models.GetUserByAuthInfoQuery{Login: "loginuser", AuthModule: "test", AuthId: "test"} + login := "loginuser" + query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test", UserLookupParams: models.UserLookupParams{Login: &login}} cmd := &models.UpdateAuthInfoCommand{ UserId: user.ID, AuthId: query.AuthId, diff --git a/pkg/login/ldap_login.go b/pkg/login/ldap_login.go index 23d77224423..4e06ddc72ce 100644 --- a/pkg/login/ldap_login.go +++ b/pkg/login/ldap_login.go @@ -57,9 +57,13 @@ var loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery, log ReqContext: query.ReqContext, ExternalUser: externalUser, SignupAllowed: setting.LDAPAllowSignup, + UserLookupParams: models.UserLookupParams{ + Login: &externalUser.Login, + Email: &externalUser.Email, + UserID: nil, + }, } - err = loginService.UpsertUser(ctx, upsert) - if err != nil { + if err = loginService.UpsertUser(ctx, upsert); err != nil { return true, err } query.User = upsert.Result diff --git a/pkg/models/user_auth.go b/pkg/models/user_auth.go index c9bd9ab0859..0732c81f47d 100644 --- a/pkg/models/user_auth.go +++ b/pkg/models/user_auth.go @@ -57,8 +57,9 @@ type RequestURIKey struct{} // COMMANDS type UpsertUserCommand struct { - ReqContext *ReqContext - ExternalUser *ExternalUserInfo + ReqContext *ReqContext + ExternalUser *ExternalUserInfo + UserLookupParams SignupAllowed bool Result *user.User @@ -98,9 +99,14 @@ type LoginUserQuery struct { type GetUserByAuthInfoQuery struct { AuthModule string AuthId string - UserId int64 - Email string - Login string + UserLookupParams +} + +type UserLookupParams struct { + // Describes lookup order as well + UserID *int64 // if set, will try to find the user by id + Email *string // if set, will try to find the user by email + Login *string // if set, will try to find the user by login } type GetExternalUserInfoByLoginQuery struct { diff --git a/pkg/services/contexthandler/auth_jwt.go b/pkg/services/contexthandler/auth_jwt.go index f0bddf8d6f2..1b1856426ba 100644 --- a/pkg/services/contexthandler/auth_jwt.go +++ b/pkg/services/contexthandler/auth_jwt.go @@ -66,6 +66,11 @@ func (h *ContextHandler) initContextWithJWT(ctx *models.ReqContext, orgId int64) ReqContext: ctx, SignupAllowed: h.Cfg.JWTAuthAutoSignUp, ExternalUser: extUser, + UserLookupParams: models.UserLookupParams{ + UserID: nil, + Login: &query.Login, + Email: &query.Email, + }, } if err := h.loginService.UpsertUser(ctx.Req.Context(), upsert); err != nil { ctx.Logger.Error("Failed to upsert JWT user", "error", err) diff --git a/pkg/services/contexthandler/authproxy/authproxy.go b/pkg/services/contexthandler/authproxy/authproxy.go index 79eb395b004..b9c8cefd89a 100644 --- a/pkg/services/contexthandler/authproxy/authproxy.go +++ b/pkg/services/contexthandler/authproxy/authproxy.go @@ -241,6 +241,11 @@ func (auth *AuthProxy) LoginViaLDAP(reqCtx *models.ReqContext) (int64, error) { ReqContext: reqCtx, SignupAllowed: auth.cfg.LDAPAllowSignup, ExternalUser: extUser, + UserLookupParams: models.UserLookupParams{ + Login: &extUser.Login, + Email: &extUser.Email, + UserID: nil, + }, } if err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert); err != nil { return 0, err @@ -298,6 +303,11 @@ func (auth *AuthProxy) loginViaHeader(reqCtx *models.ReqContext) (int64, error) ReqContext: reqCtx, SignupAllowed: auth.cfg.AuthProxyAutoSignUp, ExternalUser: extUser, + UserLookupParams: models.UserLookupParams{ + UserID: nil, + Login: &extUser.Login, + Email: &extUser.Email, + }, } err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert) diff --git a/pkg/services/login/authinfoservice/service.go b/pkg/services/login/authinfoservice/service.go index 1e8c79bc39a..ddcc1385352 100644 --- a/pkg/services/login/authinfoservice/service.go +++ b/pkg/services/login/authinfoservice/service.go @@ -44,11 +44,12 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *models.GetUser } // if user id was specified and doesn't match the user_auth entry, remove it - if query.UserId != 0 && query.UserId != authQuery.Result.UserId { - err := s.authInfoStore.DeleteAuthInfo(ctx, &models.DeleteAuthInfoCommand{ + if query.UserLookupParams.UserID != nil && + *query.UserLookupParams.UserID != 0 && + *query.UserLookupParams.UserID != authQuery.Result.UserId { + if err := s.authInfoStore.DeleteAuthInfo(ctx, &models.DeleteAuthInfoCommand{ UserAuth: authQuery.Result, - }) - if err != nil { + }); err != nil { s.logger.Error("Error removing user_auth entry", "error", err) } @@ -78,29 +79,29 @@ func (s *Implementation) LookupAndFix(ctx context.Context, query *models.GetUser return false, nil, nil, models.ErrUserNotFound } -func (s *Implementation) LookupByOneOf(ctx context.Context, userId int64, email string, login string) (*user.User, error) { +func (s *Implementation) LookupByOneOf(ctx context.Context, params *models.UserLookupParams) (*user.User, error) { var user *user.User var err error // If not found, try to find the user by id - if userId != 0 { - user, err = s.authInfoStore.GetUserById(ctx, userId) + if params.UserID != nil && *params.UserID != 0 { + user, err = s.authInfoStore.GetUserById(ctx, *params.UserID) if err != nil && !errors.Is(err, models.ErrUserNotFound) { return nil, err } } // If not found, try to find the user by email address - if user == nil && email != "" { - user, err = s.authInfoStore.GetUserByEmail(ctx, email) + if user == nil && params.Email != nil && *params.Email != "" { + user, err = s.authInfoStore.GetUserByEmail(ctx, *params.Email) if err != nil && !errors.Is(err, models.ErrUserNotFound) { return nil, err } } // If not found, try to find the user by login - if user == nil && login != "" { - user, err = s.authInfoStore.GetUserByLogin(ctx, login) + if user == nil && params.Login != nil && *params.Login != "" { + user, err = s.authInfoStore.GetUserByLogin(ctx, *params.Login) if err != nil && !errors.Is(err, models.ErrUserNotFound) { return nil, err } @@ -139,7 +140,7 @@ func (s *Implementation) LookupAndUpdate(ctx context.Context, query *models.GetU // 2. FindByUserDetails if !foundUser { - user, err = s.LookupByOneOf(ctx, query.UserId, query.Email, query.Login) + user, err = s.LookupByOneOf(ctx, &query.UserLookupParams) if err != nil { return nil, err } diff --git a/pkg/services/login/authinfoservice/user_auth_test.go b/pkg/services/login/authinfoservice/user_auth_test.go index 0c96d819e7b..c67ea0f5b4d 100644 --- a/pkg/services/login/authinfoservice/user_auth_test.go +++ b/pkg/services/login/authinfoservice/user_auth_test.go @@ -43,7 +43,7 @@ func TestUserAuth(t *testing.T) { // By Login login := "loginuser0" - query := &models.GetUserByAuthInfoQuery{Login: login} + query := &models.GetUserByAuthInfoQuery{UserLookupParams: models.UserLookupParams{Login: &login}} user, err := srv.LookupAndUpdate(context.Background(), query) require.Nil(t, err) @@ -52,7 +52,9 @@ func TestUserAuth(t *testing.T) { // By ID id := user.ID - user, err = srv.LookupByOneOf(context.Background(), id, "", "") + user, err = srv.LookupByOneOf(context.Background(), &models.UserLookupParams{ + UserID: &id, + }) require.Nil(t, err) require.Equal(t, user.ID, id) @@ -60,7 +62,9 @@ func TestUserAuth(t *testing.T) { // By Email email := "user1@test.com" - user, err = srv.LookupByOneOf(context.Background(), 0, email, "") + user, err = srv.LookupByOneOf(context.Background(), &models.UserLookupParams{ + Email: &email, + }) require.Nil(t, err) require.Equal(t, user.Email, email) @@ -68,7 +72,9 @@ func TestUserAuth(t *testing.T) { // Don't find nonexistent user email = "nonexistent@test.com" - user, err = srv.LookupByOneOf(context.Background(), 0, email, "") + user, err = srv.LookupByOneOf(context.Background(), &models.UserLookupParams{ + Email: &email, + }) require.Equal(t, models.ErrUserNotFound, err) require.Nil(t, user) @@ -85,7 +91,7 @@ func TestUserAuth(t *testing.T) { // create user_auth entry login := "loginuser0" - query.Login = login + query.UserLookupParams.Login = &login user, err = srv.LookupAndUpdate(context.Background(), query) require.Nil(t, err) @@ -99,9 +105,9 @@ func TestUserAuth(t *testing.T) { require.Equal(t, user.Login, login) // get with non-matching id - id := user.ID + idPlusOne := user.ID + 1 - query.UserId = id + 1 + query.UserLookupParams.UserID = &idPlusOne user, err = srv.LookupAndUpdate(context.Background(), query) require.Nil(t, err) @@ -143,7 +149,9 @@ func TestUserAuth(t *testing.T) { login := "loginuser0" // Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table - query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"} + query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test", UserLookupParams: models.UserLookupParams{ + Login: &login, + }} user, err := srv.LookupAndUpdate(context.Background(), query) require.Nil(t, err) @@ -192,7 +200,9 @@ func TestUserAuth(t *testing.T) { // Calling srv.LookupAndUpdateQuery on an existing user will populate an entry in the user_auth table // Make the first log-in during the past database.GetTime = func() time.Time { return time.Now().AddDate(0, 0, -2) } - query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test1", AuthId: "test1"} + query := &models.GetUserByAuthInfoQuery{AuthModule: "test1", AuthId: "test1", UserLookupParams: models.UserLookupParams{ + Login: &login, + }} user, err := srv.LookupAndUpdate(context.Background(), query) database.GetTime = time.Now @@ -202,7 +212,9 @@ func TestUserAuth(t *testing.T) { // Add a second auth module for this user // Have this module's last log-in be more recent database.GetTime = func() time.Time { return time.Now().AddDate(0, 0, -1) } - query = &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test2", AuthId: "test2"} + query = &models.GetUserByAuthInfoQuery{AuthModule: "test2", AuthId: "test2", UserLookupParams: models.UserLookupParams{ + Login: &login, + }} user, err = srv.LookupAndUpdate(context.Background(), query) database.GetTime = time.Now @@ -257,7 +269,9 @@ func TestUserAuth(t *testing.T) { // Calling srv.LookupAndUpdateQuery on an existing user will populate an entry in the user_auth table // Make the first log-in during the past database.GetTime = func() time.Time { return fixedTime.AddDate(0, 0, -2) } - queryOne := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test1", AuthId: "test1"} + queryOne := &models.GetUserByAuthInfoQuery{AuthModule: "test1", AuthId: "test1", UserLookupParams: models.UserLookupParams{ + Login: &login, + }} user, err := srv.LookupAndUpdate(context.Background(), queryOne) database.GetTime = time.Now @@ -267,7 +281,9 @@ func TestUserAuth(t *testing.T) { // Add a second auth module for this user // Have this module's last log-in be more recent database.GetTime = func() time.Time { return fixedTime.AddDate(0, 0, -1) } - queryTwo := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: "test2", AuthId: "test2"} + queryTwo := &models.GetUserByAuthInfoQuery{AuthModule: "test2", AuthId: "test2", UserLookupParams: models.UserLookupParams{ + Login: &login, + }} user, err = srv.LookupAndUpdate(context.Background(), queryTwo) require.Nil(t, err) require.Equal(t, user.Login, login) @@ -333,16 +349,21 @@ func TestUserAuth(t *testing.T) { // Expect to pass since there's a matching login user database.GetTime = func() time.Time { return time.Now().AddDate(0, 0, -2) } - query := &models.GetUserByAuthInfoQuery{Login: login, AuthModule: genericOAuthModule, AuthId: ""} + query := &models.GetUserByAuthInfoQuery{AuthModule: genericOAuthModule, AuthId: "", UserLookupParams: models.UserLookupParams{ + Login: &login, + }} user, err := srv.LookupAndUpdate(context.Background(), query) database.GetTime = time.Now require.Nil(t, err) require.Equal(t, user.Login, login) + otherLoginUser := "aloginuser" // Should throw a "user not found" error since there's no matching login user database.GetTime = func() time.Time { return time.Now().AddDate(0, 0, -2) } - query = &models.GetUserByAuthInfoQuery{Login: "aloginuser", AuthModule: genericOAuthModule, AuthId: ""} + query = &models.GetUserByAuthInfoQuery{AuthModule: genericOAuthModule, AuthId: "", UserLookupParams: models.UserLookupParams{ + Login: &otherLoginUser, + }} user, err = srv.LookupAndUpdate(context.Background(), query) database.GetTime = time.Now diff --git a/pkg/services/login/loginservice/loginservice.go b/pkg/services/login/loginservice/loginservice.go index f49532c64d0..b138c6e5393 100644 --- a/pkg/services/login/loginservice/loginservice.go +++ b/pkg/services/login/loginservice/loginservice.go @@ -49,11 +49,9 @@ func (ls *Implementation) UpsertUser(ctx context.Context, cmd *models.UpsertUser extUser := cmd.ExternalUser usr, err := ls.AuthInfoService.LookupAndUpdate(ctx, &models.GetUserByAuthInfoQuery{ - AuthModule: extUser.AuthModule, - AuthId: extUser.AuthId, - UserId: extUser.UserId, - Email: extUser.Email, - Login: extUser.Login, + AuthModule: extUser.AuthModule, + AuthId: extUser.AuthId, + UserLookupParams: cmd.UserLookupParams, }) if err != nil { if !errors.Is(err, models.ErrUserNotFound) { diff --git a/pkg/services/login/loginservice/loginservice_test.go b/pkg/services/login/loginservice/loginservice_test.go index dd9328b2d91..1bd5be21c7b 100644 --- a/pkg/services/login/loginservice/loginservice_test.go +++ b/pkg/services/login/loginservice/loginservice_test.go @@ -69,10 +69,12 @@ func Test_teamSync(t *testing.T) { AuthInfoService: authInfoMock, } - upserCmd := &models.UpsertUserCommand{ExternalUser: &models.ExternalUserInfo{Email: "test_user@example.org"}} + email := "test_user@example.org" + upserCmd := &models.UpsertUserCommand{ExternalUser: &models.ExternalUserInfo{Email: email}, + UserLookupParams: models.UserLookupParams{Email: &email}} expectedUser := &user.User{ ID: 1, - Email: "test_user@example.org", + Email: email, Name: "test_user", Login: "test_user", } diff --git a/pkg/services/login/logintest/logintest.go b/pkg/services/login/logintest/logintest.go index 16691823323..d4a9e37c3c6 100644 --- a/pkg/services/login/logintest/logintest.go +++ b/pkg/services/login/logintest/logintest.go @@ -29,7 +29,11 @@ type AuthInfoServiceFake struct { } func (a *AuthInfoServiceFake) LookupAndUpdate(ctx context.Context, query *models.GetUserByAuthInfoQuery) (*user.User, error) { - a.LatestUserID = query.UserId + if query.UserLookupParams.UserID != nil { + a.LatestUserID = *query.UserLookupParams.UserID + } else { + a.LatestUserID = 0 + } return a.ExpectedUser, a.ExpectedError }