mirror of https://github.com/grafana/grafana
commit
0ad6336634
@ -0,0 +1,184 @@ |
||||
package login |
||||
|
||||
import ( |
||||
"github.com/grafana/grafana/pkg/bus" |
||||
"github.com/grafana/grafana/pkg/log" |
||||
m "github.com/grafana/grafana/pkg/models" |
||||
"github.com/grafana/grafana/pkg/services/quota" |
||||
) |
||||
|
||||
func init() { |
||||
bus.AddHandler("auth", UpsertUser) |
||||
} |
||||
|
||||
func UpsertUser(cmd *m.UpsertUserCommand) error { |
||||
extUser := cmd.ExternalUser |
||||
|
||||
userQuery := &m.GetUserByAuthInfoQuery{ |
||||
AuthModule: extUser.AuthModule, |
||||
AuthId: extUser.AuthId, |
||||
UserId: extUser.UserId, |
||||
Email: extUser.Email, |
||||
Login: extUser.Login, |
||||
} |
||||
err := bus.Dispatch(userQuery) |
||||
if err != m.ErrUserNotFound && err != nil { |
||||
return err |
||||
} |
||||
|
||||
if err != nil { |
||||
if !cmd.SignupAllowed { |
||||
log.Warn("Not allowing %s login, user not found in internal user database and allow signup = false", extUser.AuthModule) |
||||
return ErrInvalidCredentials |
||||
} |
||||
|
||||
limitReached, err := quota.QuotaReached(cmd.ReqContext, "user") |
||||
if err != nil { |
||||
log.Warn("Error getting user quota", "err", err) |
||||
return ErrGettingUserQuota |
||||
} |
||||
if limitReached { |
||||
return ErrUsersQuotaReached |
||||
} |
||||
|
||||
cmd.Result, err = createUser(extUser) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if extUser.AuthModule != "" && extUser.AuthId != "" { |
||||
cmd2 := &m.SetAuthInfoCommand{ |
||||
UserId: cmd.Result.Id, |
||||
AuthModule: extUser.AuthModule, |
||||
AuthId: extUser.AuthId, |
||||
} |
||||
if err := bus.Dispatch(cmd2); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
} else { |
||||
cmd.Result = userQuery.Result |
||||
|
||||
err = updateUser(cmd.Result, extUser) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
return syncOrgRoles(cmd.Result, extUser) |
||||
} |
||||
|
||||
func createUser(extUser *m.ExternalUserInfo) (*m.User, error) { |
||||
cmd := &m.CreateUserCommand{ |
||||
Login: extUser.Login, |
||||
Email: extUser.Email, |
||||
Name: extUser.Name, |
||||
SkipOrgSetup: len(extUser.OrgRoles) > 0, |
||||
} |
||||
if err := bus.Dispatch(cmd); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &cmd.Result, nil |
||||
} |
||||
|
||||
func updateUser(user *m.User, extUser *m.ExternalUserInfo) error { |
||||
// sync user info
|
||||
updateCmd := &m.UpdateUserCommand{ |
||||
UserId: user.Id, |
||||
} |
||||
|
||||
needsUpdate := false |
||||
if extUser.Login != "" && extUser.Login != user.Login { |
||||
updateCmd.Login = extUser.Login |
||||
user.Login = extUser.Login |
||||
needsUpdate = true |
||||
} |
||||
|
||||
if extUser.Email != "" && extUser.Email != user.Email { |
||||
updateCmd.Email = extUser.Email |
||||
user.Email = extUser.Email |
||||
needsUpdate = true |
||||
} |
||||
|
||||
if extUser.Name != "" && extUser.Name != user.Name { |
||||
updateCmd.Name = extUser.Name |
||||
user.Name = extUser.Name |
||||
needsUpdate = true |
||||
} |
||||
|
||||
if !needsUpdate { |
||||
return nil |
||||
} |
||||
|
||||
log.Debug("Syncing user info", "id", user.Id, "update", updateCmd) |
||||
return bus.Dispatch(updateCmd) |
||||
} |
||||
|
||||
func syncOrgRoles(user *m.User, extUser *m.ExternalUserInfo) error { |
||||
// don't sync org roles if none are specified
|
||||
if len(extUser.OrgRoles) == 0 { |
||||
return nil |
||||
} |
||||
|
||||
orgsQuery := &m.GetUserOrgListQuery{UserId: user.Id} |
||||
if err := bus.Dispatch(orgsQuery); err != nil { |
||||
return err |
||||
} |
||||
|
||||
handledOrgIds := map[int64]bool{} |
||||
deleteOrgIds := []int64{} |
||||
|
||||
// update existing org roles
|
||||
for _, org := range orgsQuery.Result { |
||||
handledOrgIds[org.OrgId] = true |
||||
|
||||
if extUser.OrgRoles[org.OrgId] == "" { |
||||
deleteOrgIds = append(deleteOrgIds, org.OrgId) |
||||
} else if extUser.OrgRoles[org.OrgId] != org.Role { |
||||
// update role
|
||||
cmd := &m.UpdateOrgUserCommand{OrgId: org.OrgId, UserId: user.Id, Role: extUser.OrgRoles[org.OrgId]} |
||||
if err := bus.Dispatch(cmd); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
} |
||||
|
||||
// add any new org roles
|
||||
for orgId, orgRole := range extUser.OrgRoles { |
||||
if _, exists := handledOrgIds[orgId]; exists { |
||||
continue |
||||
} |
||||
|
||||
// add role
|
||||
cmd := &m.AddOrgUserCommand{UserId: user.Id, Role: orgRole, OrgId: orgId} |
||||
err := bus.Dispatch(cmd) |
||||
if err != nil && err != m.ErrOrgNotFound { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
// delete any removed org roles
|
||||
for _, orgId := range deleteOrgIds { |
||||
cmd := &m.RemoveOrgUserCommand{OrgId: orgId, UserId: user.Id} |
||||
if err := bus.Dispatch(cmd); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
// update user's default org if needed
|
||||
if _, ok := extUser.OrgRoles[user.OrgId]; !ok { |
||||
for orgId := range extUser.OrgRoles { |
||||
user.OrgId = orgId |
||||
break |
||||
} |
||||
|
||||
return bus.Dispatch(&m.SetUsingOrgCommand{ |
||||
UserId: user.Id, |
||||
OrgId: user.OrgId, |
||||
}) |
||||
} |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,72 @@ |
||||
package models |
||||
|
||||
import ( |
||||
"time" |
||||
) |
||||
|
||||
type UserAuth struct { |
||||
Id int64 |
||||
UserId int64 |
||||
AuthModule string |
||||
AuthId string |
||||
Created time.Time |
||||
} |
||||
|
||||
type ExternalUserInfo struct { |
||||
AuthModule string |
||||
AuthId string |
||||
UserId int64 |
||||
Email string |
||||
Login string |
||||
Name string |
||||
OrgRoles map[int64]RoleType |
||||
} |
||||
|
||||
// ---------------------
|
||||
// COMMANDS
|
||||
|
||||
type UpsertUserCommand struct { |
||||
ReqContext *ReqContext |
||||
ExternalUser *ExternalUserInfo |
||||
SignupAllowed bool |
||||
|
||||
Result *User |
||||
} |
||||
|
||||
type SetAuthInfoCommand struct { |
||||
AuthModule string |
||||
AuthId string |
||||
UserId int64 |
||||
} |
||||
|
||||
type DeleteAuthInfoCommand struct { |
||||
UserAuth *UserAuth |
||||
} |
||||
|
||||
// ----------------------
|
||||
// QUERIES
|
||||
|
||||
type LoginUserQuery struct { |
||||
ReqContext *ReqContext |
||||
Username string |
||||
Password string |
||||
User *User |
||||
IpAddress string |
||||
} |
||||
|
||||
type GetUserByAuthInfoQuery struct { |
||||
AuthModule string |
||||
AuthId string |
||||
UserId int64 |
||||
Email string |
||||
Login string |
||||
|
||||
Result *User |
||||
} |
||||
|
||||
type GetAuthInfoQuery struct { |
||||
AuthModule string |
||||
AuthId string |
||||
|
||||
Result *UserAuth |
||||
} |
@ -0,0 +1,24 @@ |
||||
package migrations |
||||
|
||||
import . "github.com/grafana/grafana/pkg/services/sqlstore/migrator" |
||||
|
||||
func addUserAuthMigrations(mg *Migrator) { |
||||
userAuthV1 := Table{ |
||||
Name: "user_auth", |
||||
Columns: []*Column{ |
||||
{Name: "id", Type: DB_BigInt, IsPrimaryKey: true, IsAutoIncrement: true}, |
||||
{Name: "user_id", Type: DB_BigInt, Nullable: false}, |
||||
{Name: "auth_module", Type: DB_NVarchar, Length: 190, Nullable: false}, |
||||
{Name: "auth_id", Type: DB_NVarchar, Length: 100, Nullable: false}, |
||||
{Name: "created", Type: DB_DateTime, Nullable: false}, |
||||
}, |
||||
Indices: []*Index{ |
||||
{Cols: []string{"auth_module", "auth_id"}}, |
||||
}, |
||||
} |
||||
|
||||
// create table
|
||||
mg.AddMigration("create user auth table", NewAddTableMigration(userAuthV1)) |
||||
// add indices
|
||||
addTableIndicesMigrations(mg, "v1", userAuthV1) |
||||
} |
@ -0,0 +1,148 @@ |
||||
package sqlstore |
||||
|
||||
import ( |
||||
"time" |
||||
|
||||
"github.com/grafana/grafana/pkg/bus" |
||||
m "github.com/grafana/grafana/pkg/models" |
||||
) |
||||
|
||||
func init() { |
||||
bus.AddHandler("sql", GetUserByAuthInfo) |
||||
bus.AddHandler("sql", GetAuthInfo) |
||||
bus.AddHandler("sql", SetAuthInfo) |
||||
bus.AddHandler("sql", DeleteAuthInfo) |
||||
} |
||||
|
||||
func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error { |
||||
user := &m.User{} |
||||
has := false |
||||
var err error |
||||
authQuery := &m.GetAuthInfoQuery{} |
||||
|
||||
// Try to find the user by auth module and id first
|
||||
if query.AuthModule != "" && query.AuthId != "" { |
||||
authQuery.AuthModule = query.AuthModule |
||||
authQuery.AuthId = query.AuthId |
||||
|
||||
err = GetAuthInfo(authQuery) |
||||
if err != m.ErrUserNotFound { |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
// if user id was specified and doesn't match the user_auth entry, remove it
|
||||
if query.UserId != 0 && query.UserId != authQuery.Result.UserId { |
||||
err = DeleteAuthInfo(&m.DeleteAuthInfoCommand{ |
||||
UserAuth: authQuery.Result, |
||||
}) |
||||
if err != nil { |
||||
sqlog.Error("Error removing user_auth entry", "error", err) |
||||
} |
||||
|
||||
authQuery.Result = nil |
||||
} else { |
||||
has, err = x.Id(authQuery.Result.UserId).Get(user) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if !has { |
||||
// if the user has been deleted then remove the entry
|
||||
err = DeleteAuthInfo(&m.DeleteAuthInfoCommand{ |
||||
UserAuth: authQuery.Result, |
||||
}) |
||||
if err != nil { |
||||
sqlog.Error("Error removing user_auth entry", "error", err) |
||||
} |
||||
|
||||
authQuery.Result = nil |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
// If not found, try to find the user by id
|
||||
if !has && query.UserId != 0 { |
||||
has, err = x.Id(query.UserId).Get(user) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
// If not found, try to find the user by email address
|
||||
if !has && query.Email != "" { |
||||
user = &m.User{Email: query.Email} |
||||
has, err = x.Get(user) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
// If not found, try to find the user by login
|
||||
if !has && query.Login != "" { |
||||
user = &m.User{Login: query.Login} |
||||
has, err = x.Get(user) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
// No user found
|
||||
if !has { |
||||
return m.ErrUserNotFound |
||||
} |
||||
|
||||
// create authInfo record to link accounts
|
||||
if authQuery.Result == nil && query.AuthModule != "" && query.AuthId != "" { |
||||
cmd2 := &m.SetAuthInfoCommand{ |
||||
UserId: user.Id, |
||||
AuthModule: query.AuthModule, |
||||
AuthId: query.AuthId, |
||||
} |
||||
if err := SetAuthInfo(cmd2); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
query.Result = user |
||||
return nil |
||||
} |
||||
|
||||
func GetAuthInfo(query *m.GetAuthInfoQuery) error { |
||||
userAuth := &m.UserAuth{ |
||||
AuthModule: query.AuthModule, |
||||
AuthId: query.AuthId, |
||||
} |
||||
has, err := x.Get(userAuth) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if !has { |
||||
return m.ErrUserNotFound |
||||
} |
||||
|
||||
query.Result = userAuth |
||||
return nil |
||||
} |
||||
|
||||
func SetAuthInfo(cmd *m.SetAuthInfoCommand) error { |
||||
return inTransaction(func(sess *DBSession) error { |
||||
authUser := &m.UserAuth{ |
||||
UserId: cmd.UserId, |
||||
AuthModule: cmd.AuthModule, |
||||
AuthId: cmd.AuthId, |
||||
Created: time.Now(), |
||||
} |
||||
|
||||
_, err := sess.Insert(authUser) |
||||
return err |
||||
}) |
||||
} |
||||
|
||||
func DeleteAuthInfo(cmd *m.DeleteAuthInfoCommand) error { |
||||
return inTransaction(func(sess *DBSession) error { |
||||
_, err := sess.Delete(cmd.UserAuth) |
||||
return err |
||||
}) |
||||
} |
@ -0,0 +1,131 @@ |
||||
package sqlstore |
||||
|
||||
import ( |
||||
"fmt" |
||||
"testing" |
||||
|
||||
. "github.com/smartystreets/goconvey/convey" |
||||
|
||||
m "github.com/grafana/grafana/pkg/models" |
||||
) |
||||
|
||||
func TestUserAuth(t *testing.T) { |
||||
InitTestDB(t) |
||||
|
||||
Convey("Given 5 users", t, func() { |
||||
var err error |
||||
var cmd *m.CreateUserCommand |
||||
users := []m.User{} |
||||
for i := 0; i < 5; i++ { |
||||
cmd = &m.CreateUserCommand{ |
||||
Email: fmt.Sprint("user", i, "@test.com"), |
||||
Name: fmt.Sprint("user", i), |
||||
Login: fmt.Sprint("loginuser", i), |
||||
} |
||||
err = CreateUser(cmd) |
||||
So(err, ShouldBeNil) |
||||
users = append(users, cmd.Result) |
||||
} |
||||
|
||||
Reset(func() { |
||||
_, err := x.Exec("DELETE FROM org_user WHERE 1=1") |
||||
So(err, ShouldBeNil) |
||||
_, err = x.Exec("DELETE FROM org WHERE 1=1") |
||||
So(err, ShouldBeNil) |
||||
_, err = x.Exec("DELETE FROM user WHERE 1=1") |
||||
So(err, ShouldBeNil) |
||||
_, err = x.Exec("DELETE FROM user_auth WHERE 1=1") |
||||
So(err, ShouldBeNil) |
||||
}) |
||||
|
||||
Convey("Can find existing user", func() { |
||||
// By Login
|
||||
login := "loginuser0" |
||||
|
||||
query := &m.GetUserByAuthInfoQuery{Login: login} |
||||
err = GetUserByAuthInfo(query) |
||||
|
||||
So(err, ShouldBeNil) |
||||
So(query.Result.Login, ShouldEqual, login) |
||||
|
||||
// By ID
|
||||
id := query.Result.Id |
||||
|
||||
query = &m.GetUserByAuthInfoQuery{UserId: id} |
||||
err = GetUserByAuthInfo(query) |
||||
|
||||
So(err, ShouldBeNil) |
||||
So(query.Result.Id, ShouldEqual, id) |
||||
|
||||
// By Email
|
||||
email := "user1@test.com" |
||||
|
||||
query = &m.GetUserByAuthInfoQuery{Email: email} |
||||
err = GetUserByAuthInfo(query) |
||||
|
||||
So(err, ShouldBeNil) |
||||
So(query.Result.Email, ShouldEqual, email) |
||||
|
||||
// Don't find nonexistent user
|
||||
email = "nonexistent@test.com" |
||||
|
||||
query = &m.GetUserByAuthInfoQuery{Email: email} |
||||
err = GetUserByAuthInfo(query) |
||||
|
||||
So(err, ShouldEqual, m.ErrUserNotFound) |
||||
So(query.Result, ShouldBeNil) |
||||
}) |
||||
|
||||
Convey("Can set & locate by AuthModule and AuthId", func() { |
||||
// get nonexistent user_auth entry
|
||||
query := &m.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"} |
||||
err = GetUserByAuthInfo(query) |
||||
|
||||
So(err, ShouldEqual, m.ErrUserNotFound) |
||||
So(query.Result, ShouldBeNil) |
||||
|
||||
// create user_auth entry
|
||||
login := "loginuser0" |
||||
|
||||
query.Login = login |
||||
err = GetUserByAuthInfo(query) |
||||
|
||||
So(err, ShouldBeNil) |
||||
So(query.Result.Login, ShouldEqual, login) |
||||
|
||||
// get via user_auth
|
||||
query = &m.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"} |
||||
err = GetUserByAuthInfo(query) |
||||
|
||||
So(err, ShouldBeNil) |
||||
So(query.Result.Login, ShouldEqual, login) |
||||
|
||||
// get with non-matching id
|
||||
id := query.Result.Id |
||||
|
||||
query.UserId = id + 1 |
||||
err = GetUserByAuthInfo(query) |
||||
|
||||
So(err, ShouldBeNil) |
||||
So(query.Result.Login, ShouldEqual, "loginuser1") |
||||
|
||||
// get via user_auth
|
||||
query = &m.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"} |
||||
err = GetUserByAuthInfo(query) |
||||
|
||||
So(err, ShouldBeNil) |
||||
So(query.Result.Login, ShouldEqual, "loginuser1") |
||||
|
||||
// remove user
|
||||
_, err = x.Exec("DELETE FROM user WHERE id=?", query.Result.Id) |
||||
So(err, ShouldBeNil) |
||||
|
||||
// get via user_auth for deleted user
|
||||
query = &m.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"} |
||||
err = GetUserByAuthInfo(query) |
||||
|
||||
So(err, ShouldEqual, m.ErrUserNotFound) |
||||
So(query.Result, ShouldBeNil) |
||||
}) |
||||
}) |
||||
} |
Loading…
Reference in new issue