diff --git a/pkg/login/ext_user.go b/pkg/login/ext_user.go index e8234e505e3..f60d5d1e22d 100644 --- a/pkg/login/ext_user.go +++ b/pkg/login/ext_user.go @@ -16,14 +16,14 @@ func init() { func UpsertUser(cmd *m.UpsertUserCommand) error { extUser := cmd.ExternalUser - userQuery := m.GetUserByAuthInfoQuery{ + userQuery := &m.GetUserByAuthInfoQuery{ AuthModule: extUser.AuthModule, AuthId: extUser.AuthId, UserId: extUser.UserId, Email: extUser.Email, Login: extUser.Login, } - err := bus.Dispatch(&userQuery) + err := bus.Dispatch(userQuery) if err != nil { if err != m.ErrUserNotFound { return err @@ -47,8 +47,19 @@ func UpsertUser(cmd *m.UpsertUserCommand) error { if err != nil { return err } + + if extUser.AuthModule != "" && extUser.AuthId != "" { + cmd2 := &m.SetAuthInfoCommand{ + UserId: cmd.Result.Id, + AuthModule: extUser.AuthModule, + AuthId: extUser.AuthId, + } + if err := bus.Dispatch(cmd2); err != nil { + return err + } + } } else { - cmd.Result = userQuery.User + cmd.Result = userQuery.Result // sync user info err = updateUser(cmd.Result, extUser) @@ -57,17 +68,6 @@ func UpsertUser(cmd *m.UpsertUserCommand) error { } } - if userQuery.UserAuth == nil && extUser.AuthModule != "" && extUser.AuthId != "" { - cmd2 := m.SetAuthInfoCommand{ - UserId: cmd.Result.Id, - AuthModule: extUser.AuthModule, - AuthId: extUser.AuthId, - } - if err := bus.Dispatch(&cmd2); err != nil { - return err - } - } - err = syncOrgRoles(cmd.Result, extUser) if err != nil { return err @@ -77,12 +77,12 @@ func UpsertUser(cmd *m.UpsertUserCommand) error { } func createUser(extUser *m.ExternalUserInfo) (*m.User, error) { - cmd := m.CreateUserCommand{ + cmd := &m.CreateUserCommand{ Login: extUser.Login, Email: extUser.Email, Name: extUser.Name, } - if err := bus.Dispatch(&cmd); err != nil { + if err := bus.Dispatch(cmd); err != nil { return nil, err } @@ -91,7 +91,7 @@ func createUser(extUser *m.ExternalUserInfo) (*m.User, error) { func updateUser(user *m.User, extUser *m.ExternalUserInfo) error { // sync user info - updateCmd := m.UpdateUserCommand{ + updateCmd := &m.UpdateUserCommand{ UserId: user.Id, } needsUpdate := false @@ -111,7 +111,7 @@ func updateUser(user *m.User, extUser *m.ExternalUserInfo) error { if needsUpdate { log.Debug("Syncing user info", "id", user.Id, "update", updateCmd) - err := bus.Dispatch(&updateCmd) + err := bus.Dispatch(updateCmd) if err != nil { return err } @@ -126,8 +126,8 @@ func syncOrgRoles(user *m.User, extUser *m.ExternalUserInfo) error { return nil } - orgsQuery := m.GetUserOrgListQuery{UserId: user.Id} - if err := bus.Dispatch(&orgsQuery); err != nil { + orgsQuery := &m.GetUserOrgListQuery{UserId: user.Id} + if err := bus.Dispatch(orgsQuery); err != nil { return err } @@ -142,8 +142,8 @@ func syncOrgRoles(user *m.User, extUser *m.ExternalUserInfo) error { deleteOrgIds = append(deleteOrgIds, org.OrgId) } else if extUser.OrgRoles[org.OrgId] != org.Role { // update role - cmd := m.UpdateOrgUserCommand{OrgId: org.OrgId, UserId: user.Id, Role: extUser.OrgRoles[org.OrgId]} - if err := bus.Dispatch(&cmd); err != nil { + cmd := &m.UpdateOrgUserCommand{OrgId: org.OrgId, UserId: user.Id, Role: extUser.OrgRoles[org.OrgId]} + if err := bus.Dispatch(cmd); err != nil { return err } } @@ -156,8 +156,8 @@ func syncOrgRoles(user *m.User, extUser *m.ExternalUserInfo) error { } // add role - cmd := m.AddOrgUserCommand{UserId: user.Id, Role: orgRole, OrgId: orgId} - err := bus.Dispatch(&cmd) + cmd := &m.AddOrgUserCommand{UserId: user.Id, Role: orgRole, OrgId: orgId} + err := bus.Dispatch(cmd) if err != nil && err != m.ErrOrgNotFound { return err } @@ -165,8 +165,8 @@ func syncOrgRoles(user *m.User, extUser *m.ExternalUserInfo) error { // delete any removed org roles for _, orgId := range deleteOrgIds { - cmd := m.RemoveOrgUserCommand{OrgId: orgId, UserId: user.Id} - if err := bus.Dispatch(&cmd); err != nil { + cmd := &m.RemoveOrgUserCommand{OrgId: orgId, UserId: user.Id} + if err := bus.Dispatch(cmd); err != nil { return err } } diff --git a/pkg/models/user_auth.go b/pkg/models/user_auth.go index 85a063775c9..0ecd144d52c 100644 --- a/pkg/models/user_auth.go +++ b/pkg/models/user_auth.go @@ -61,8 +61,7 @@ type GetUserByAuthInfoQuery struct { Email string Login string - User *User - UserAuth *UserAuth + Result *User } type GetAuthInfoQuery struct { diff --git a/pkg/services/sqlstore/user_auth.go b/pkg/services/sqlstore/user_auth.go index 32d9e246dd7..ca26791c440 100644 --- a/pkg/services/sqlstore/user_auth.go +++ b/pkg/services/sqlstore/user_auth.go @@ -18,13 +18,12 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error { user := &m.User{} has := false var err error + authQuery := &m.GetAuthInfoQuery{} // Try to find the user by auth module and id first if query.AuthModule != "" && query.AuthId != "" { - authQuery := &m.GetAuthInfoQuery{ - AuthModule: query.AuthModule, - AuthId: query.AuthId, - } + authQuery.AuthModule = query.AuthModule + authQuery.AuthId = query.AuthId err = GetAuthInfo(authQuery) // if user id was specified and doesn't match the user_auth entry, remove it @@ -35,15 +34,15 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error { if err != nil { sqlog.Error("Error removing user_auth entry", "error", err) } + + authQuery.Result = nil } else if err == nil { has, err = x.Id(authQuery.Result.UserId).Get(user) if err != nil { return err } - if has { - query.UserAuth = authQuery.Result - } else { + if !has { // if the user has been deleted then remove the entry err = DeleteAuthInfo(&m.DeleteAuthInfoCommand{ UserAuth: authQuery.Result, @@ -51,6 +50,8 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error { if err != nil { sqlog.Error("Error removing user_auth entry", "error", err) } + + authQuery.Result = nil } } else if err != m.ErrUserNotFound { return err @@ -88,7 +89,19 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error { return m.ErrUserNotFound } - query.User = user + // create authInfo record to link accounts + if authQuery.Result == nil && query.AuthModule != "" && query.AuthId != "" { + cmd2 := &m.SetAuthInfoCommand{ + UserId: user.Id, + AuthModule: query.AuthModule, + AuthId: query.AuthId, + } + if err := SetAuthInfo(cmd2); err != nil { + return err + } + } + + query.Result = user return nil } @@ -111,14 +124,14 @@ func GetAuthInfo(query *m.GetAuthInfoQuery) error { func SetAuthInfo(cmd *m.SetAuthInfoCommand) error { return inTransaction(func(sess *DBSession) error { - authUser := m.UserAuth{ + authUser := &m.UserAuth{ UserId: cmd.UserId, AuthModule: cmd.AuthModule, AuthId: cmd.AuthId, Created: time.Now(), } - _, err := sess.Insert(&authUser) + _, err := sess.Insert(authUser) return err }) }