Remove global variable from user (#46696)

* Remove global variable from user

* Remove missed x
pull/46283/head
idafurjes 3 years ago committed by GitHub
parent 93390b5a1e
commit 52bd7618dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      pkg/services/sqlstore/org_test.go
  2. 286
      pkg/services/sqlstore/user.go
  3. 24
      pkg/services/sqlstore/user_test.go

@ -164,7 +164,7 @@ func TestAccountDataAccess(t *testing.T) {
t.Run("Can search users", func(t *testing.T) { t.Run("Can search users", func(t *testing.T) {
query := models.SearchUsersQuery{Query: ""} query := models.SearchUsersQuery{Query: ""}
err := SearchUsers(context.Background(), &query) err := sqlStore.SearchUsers(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, query.Result.Users, 2) require.Len(t, query.Result.Users, 2)

@ -27,7 +27,7 @@ func (ss *SQLStore) addUserQueryAndCommandHandlers() {
bus.AddHandler("sql", ss.SetUsingOrg) bus.AddHandler("sql", ss.SetUsingOrg)
bus.AddHandler("sql", ss.UpdateUserLastSeenAt) bus.AddHandler("sql", ss.UpdateUserLastSeenAt)
bus.AddHandler("sql", ss.GetUserProfile) bus.AddHandler("sql", ss.GetUserProfile)
bus.AddHandler("sql", SearchUsers) bus.AddHandler("sql", ss.SearchUsers)
bus.AddHandler("sql", ss.GetUserOrgList) bus.AddHandler("sql", ss.GetUserOrgList)
bus.AddHandler("sql", ss.DisableUser) bus.AddHandler("sql", ss.DisableUser)
bus.AddHandler("sql", ss.BatchDisableUsers) bus.AddHandler("sql", ss.BatchDisableUsers)
@ -495,17 +495,19 @@ func (o byOrgName) Less(i, j int) bool {
} }
func (ss *SQLStore) GetUserOrgList(ctx context.Context, query *models.GetUserOrgListQuery) error { func (ss *SQLStore) GetUserOrgList(ctx context.Context, query *models.GetUserOrgListQuery) error {
query.Result = make([]*models.UserOrgDTO, 0) return ss.WithDbSession(ctx, func(dbSess *DBSession) error {
sess := x.Table("org_user") query.Result = make([]*models.UserOrgDTO, 0)
sess.Join("INNER", "org", "org_user.org_id=org.id") sess := dbSess.Table("org_user")
sess.Join("INNER", x.Dialect().Quote("user"), fmt.Sprintf("org_user.user_id=%s.id", x.Dialect().Quote("user"))) sess.Join("INNER", "org", "org_user.org_id=org.id")
sess.Where("org_user.user_id=?", query.UserId) sess.Join("INNER", x.Dialect().Quote("user"), fmt.Sprintf("org_user.user_id=%s.id", x.Dialect().Quote("user")))
sess.Where(notServiceAccountFilter(ss)) sess.Where("org_user.user_id=?", query.UserId)
sess.Cols("org.name", "org_user.role", "org_user.org_id") sess.Where(notServiceAccountFilter(ss))
sess.OrderBy("org.name") sess.Cols("org.name", "org_user.role", "org_user.org_id")
err := sess.Find(&query.Result) sess.OrderBy("org.name")
sort.Sort(byOrgName(query.Result)) err := sess.Find(&query.Result)
return err sort.Sort(byOrgName(query.Result))
return err
})
} }
func newSignedInUserCacheKey(orgID, userID int64) string { func newSignedInUserCacheKey(orgID, userID int64) string {
@ -531,12 +533,13 @@ func (ss *SQLStore) GetSignedInUserWithCacheCtx(ctx context.Context, query *mode
} }
func (ss *SQLStore) GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) error { func (ss *SQLStore) GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) error {
orgId := "u.org_id" return ss.WithDbSession(ctx, func(dbSess *DBSession) error {
if query.OrgId > 0 { orgId := "u.org_id"
orgId = strconv.FormatInt(query.OrgId, 10) if query.OrgId > 0 {
} orgId = strconv.FormatInt(query.OrgId, 10)
}
var rawSQL = `SELECT var rawSQL = `SELECT
u.id as user_id, u.id as user_id,
u.is_admin as is_grafana_admin, u.is_admin as is_grafana_admin,
u.email as email, u.email as email,
@ -552,167 +555,168 @@ func (ss *SQLStore) GetSignedInUser(ctx context.Context, query *models.GetSigned
LEFT OUTER JOIN org_user on org_user.org_id = ` + orgId + ` and org_user.user_id = u.id LEFT OUTER JOIN org_user on org_user.org_id = ` + orgId + ` and org_user.user_id = u.id
LEFT OUTER JOIN org on org.id = org_user.org_id ` LEFT OUTER JOIN org on org.id = org_user.org_id `
sess := x.Table("user") sess := dbSess.Table("user")
sess = sess.Context(ctx) sess = sess.Context(ctx)
switch { switch {
case query.UserId > 0: case query.UserId > 0:
sess.SQL(rawSQL+"WHERE u.id=?", query.UserId) sess.SQL(rawSQL+"WHERE u.id=?", query.UserId)
case query.Login != "": case query.Login != "":
sess.SQL(rawSQL+"WHERE u.login=?", query.Login) sess.SQL(rawSQL+"WHERE u.login=?", query.Login)
case query.Email != "": case query.Email != "":
sess.SQL(rawSQL+"WHERE u.email=?", query.Email) sess.SQL(rawSQL+"WHERE u.email=?", query.Email)
} }
var user models.SignedInUser var user models.SignedInUser
has, err := sess.Get(&user) has, err := sess.Get(&user)
if err != nil { if err != nil {
return err return err
} else if !has { } else if !has {
return models.ErrUserNotFound return models.ErrUserNotFound
} }
if user.OrgRole == "" { if user.OrgRole == "" {
user.OrgId = -1 user.OrgId = -1
user.OrgName = "Org missing" user.OrgName = "Org missing"
} }
getTeamsByUserQuery := &models.GetTeamsByUserQuery{OrgId: user.OrgId, UserId: user.UserId} getTeamsByUserQuery := &models.GetTeamsByUserQuery{OrgId: user.OrgId, UserId: user.UserId}
err = ss.GetTeamsByUser(ctx, getTeamsByUserQuery) err = ss.GetTeamsByUser(ctx, getTeamsByUserQuery)
if err != nil { if err != nil {
return err return err
} }
user.Teams = make([]int64, len(getTeamsByUserQuery.Result)) user.Teams = make([]int64, len(getTeamsByUserQuery.Result))
for i, t := range getTeamsByUserQuery.Result { for i, t := range getTeamsByUserQuery.Result {
user.Teams[i] = t.Id user.Teams[i] = t.Id
} }
query.Result = &user query.Result = &user
return err return err
})
} }
func (ss *SQLStore) SearchUsers(ctx context.Context, query *models.SearchUsersQuery) error { func (ss *SQLStore) SearchUsers(ctx context.Context, query *models.SearchUsersQuery) error {
return SearchUsers(ctx, query) return ss.WithDbSession(ctx, func(dbSess *DBSession) error {
} query.Result = models.SearchUserQueryResult{
Users: make([]*models.UserSearchHitDTO, 0),
func SearchUsers(ctx context.Context, query *models.SearchUsersQuery) error { }
query.Result = models.SearchUserQueryResult{
Users: make([]*models.UserSearchHitDTO, 0),
}
queryWithWildcards := "%" + query.Query + "%" queryWithWildcards := "%" + query.Query + "%"
whereConditions := make([]string, 0) whereConditions := make([]string, 0)
whereParams := make([]interface{}, 0) whereParams := make([]interface{}, 0)
sess := x.Table("user").Alias("u") sess := dbSess.Table("user").Alias("u")
whereConditions = append(whereConditions, "u.is_service_account = ?") whereConditions = append(whereConditions, "u.is_service_account = ?")
whereParams = append(whereParams, dialect.BooleanStr(false)) whereParams = append(whereParams, dialect.BooleanStr(false))
// Join with only most recent auth module // Join with only most recent auth module
joinCondition := `( joinCondition := `(
SELECT id from user_auth SELECT id from user_auth
WHERE user_auth.user_id = u.id WHERE user_auth.user_id = u.id
ORDER BY user_auth.created DESC ` ORDER BY user_auth.created DESC `
joinCondition = "user_auth.id=" + joinCondition + dialect.Limit(1) + ")" joinCondition = "user_auth.id=" + joinCondition + dialect.Limit(1) + ")"
sess.Join("LEFT", "user_auth", joinCondition) sess.Join("LEFT", "user_auth", joinCondition)
if query.OrgId > 0 { if query.OrgId > 0 {
whereConditions = append(whereConditions, "org_id = ?") whereConditions = append(whereConditions, "org_id = ?")
whereParams = append(whereParams, query.OrgId) whereParams = append(whereParams, query.OrgId)
} }
if query.Query != "" {
whereConditions = append(whereConditions, "(email "+dialect.LikeStr()+" ? OR name "+dialect.LikeStr()+" ? OR login "+dialect.LikeStr()+" ?)")
whereParams = append(whereParams, queryWithWildcards, queryWithWildcards, queryWithWildcards)
}
if query.IsDisabled != nil {
whereConditions = append(whereConditions, "is_disabled = ?")
whereParams = append(whereParams, query.IsDisabled)
}
if query.AuthModule != "" {
whereConditions = append(whereConditions, `auth_module=?`)
whereParams = append(whereParams, query.AuthModule)
}
if len(whereConditions) > 0 {
sess.Where(strings.Join(whereConditions, " AND "), whereParams...)
}
for _, filter := range query.Filters { if query.Query != "" {
if jc := filter.JoinCondition(); jc != nil { whereConditions = append(whereConditions, "(email "+dialect.LikeStr()+" ? OR name "+dialect.LikeStr()+" ? OR login "+dialect.LikeStr()+" ?)")
sess.Join(jc.Operator, jc.Table, jc.Params) whereParams = append(whereParams, queryWithWildcards, queryWithWildcards, queryWithWildcards)
} }
if ic := filter.InCondition(); ic != nil {
sess.In(ic.Condition, ic.Params) if query.IsDisabled != nil {
whereConditions = append(whereConditions, "is_disabled = ?")
whereParams = append(whereParams, query.IsDisabled)
} }
if wc := filter.WhereCondition(); wc != nil {
sess.Where(wc.Condition, wc.Params) if query.AuthModule != "" {
whereConditions = append(whereConditions, `auth_module=?`)
whereParams = append(whereParams, query.AuthModule)
} }
}
if query.Limit > 0 { if len(whereConditions) > 0 {
offset := query.Limit * (query.Page - 1) sess.Where(strings.Join(whereConditions, " AND "), whereParams...)
sess.Limit(query.Limit, offset) }
}
sess.Cols("u.id", "u.email", "u.name", "u.login", "u.is_admin", "u.is_disabled", "u.last_seen_at", "user_auth.auth_module") for _, filter := range query.Filters {
sess.Asc("u.login", "u.email") if jc := filter.JoinCondition(); jc != nil {
if err := sess.Find(&query.Result.Users); err != nil { sess.Join(jc.Operator, jc.Table, jc.Params)
return err }
} if ic := filter.InCondition(); ic != nil {
sess.In(ic.Condition, ic.Params)
}
if wc := filter.WhereCondition(); wc != nil {
sess.Where(wc.Condition, wc.Params)
}
}
// get total if query.Limit > 0 {
user := models.User{} offset := query.Limit * (query.Page - 1)
countSess := x.Table("user").Alias("u") sess.Limit(query.Limit, offset)
}
// Join with user_auth table if users filtered by auth_module sess.Cols("u.id", "u.email", "u.name", "u.login", "u.is_admin", "u.is_disabled", "u.last_seen_at", "user_auth.auth_module")
if query.AuthModule != "" { sess.Asc("u.login", "u.email")
countSess.Join("LEFT", "user_auth", joinCondition) if err := sess.Find(&query.Result.Users); err != nil {
} return err
}
if len(whereConditions) > 0 { // get total
countSess.Where(strings.Join(whereConditions, " AND "), whereParams...) user := models.User{}
} countSess := dbSess.Table("user").Alias("u")
for _, filter := range query.Filters { // Join with user_auth table if users filtered by auth_module
if jc := filter.JoinCondition(); jc != nil { if query.AuthModule != "" {
countSess.Join(jc.Operator, jc.Table, jc.Params) countSess.Join("LEFT", "user_auth", joinCondition)
} }
if ic := filter.InCondition(); ic != nil {
countSess.In(ic.Condition, ic.Params) if len(whereConditions) > 0 {
countSess.Where(strings.Join(whereConditions, " AND "), whereParams...)
} }
if wc := filter.WhereCondition(); wc != nil {
countSess.Where(wc.Condition, wc.Params) for _, filter := range query.Filters {
if jc := filter.JoinCondition(); jc != nil {
countSess.Join(jc.Operator, jc.Table, jc.Params)
}
if ic := filter.InCondition(); ic != nil {
countSess.In(ic.Condition, ic.Params)
}
if wc := filter.WhereCondition(); wc != nil {
countSess.Where(wc.Condition, wc.Params)
}
} }
}
count, err := countSess.Count(&user) count, err := countSess.Count(&user)
query.Result.TotalCount = count query.Result.TotalCount = count
for _, user := range query.Result.Users { for _, user := range query.Result.Users {
user.LastSeenAtAge = util.GetAgeString(user.LastSeenAt) user.LastSeenAtAge = util.GetAgeString(user.LastSeenAt)
} }
return err return err
})
} }
func (ss *SQLStore) DisableUser(ctx context.Context, cmd *models.DisableUserCommand) error { func (ss *SQLStore) DisableUser(ctx context.Context, cmd *models.DisableUserCommand) error {
user := models.User{} return ss.WithDbSession(ctx, func(dbSess *DBSession) error {
sess := x.Table("user") user := models.User{}
sess := dbSess.Table("user")
if has, err := sess.ID(cmd.UserId).Where(notServiceAccountFilter(ss)).Get(&user); err != nil { if has, err := sess.ID(cmd.UserId).Where(notServiceAccountFilter(ss)).Get(&user); err != nil {
return err return err
} else if !has { } else if !has {
return models.ErrUserNotFound return models.ErrUserNotFound
} }
user.IsDisabled = cmd.IsDisabled user.IsDisabled = cmd.IsDisabled
sess.UseBool("is_disabled") sess.UseBool("is_disabled")
_, err := sess.ID(cmd.UserId).Update(&user) _, err := sess.ID(cmd.UserId).Update(&user)
return err return err
})
} }
func (ss *SQLStore) BatchDisableUsers(ctx context.Context, cmd *models.BatchDisableUsersCommand) error { func (ss *SQLStore) BatchDisableUsers(ctx context.Context, cmd *models.BatchDisableUsersCommand) error {

@ -130,7 +130,7 @@ func TestUserDataAccess(t *testing.T) {
// Return the first page of users and a total count // Return the first page of users and a total count
query := models.SearchUsersQuery{Query: "", Page: 1, Limit: 3} query := models.SearchUsersQuery{Query: "", Page: 1, Limit: 3}
err := SearchUsers(context.Background(), &query) err := ss.SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 3) require.Len(t, query.Result.Users, 3)
@ -138,7 +138,7 @@ func TestUserDataAccess(t *testing.T) {
// Return the second page of users and a total count // Return the second page of users and a total count
query = models.SearchUsersQuery{Query: "", Page: 2, Limit: 3} query = models.SearchUsersQuery{Query: "", Page: 2, Limit: 3}
err = SearchUsers(context.Background(), &query) err = ss.SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 2) require.Len(t, query.Result.Users, 2)
@ -146,28 +146,28 @@ func TestUserDataAccess(t *testing.T) {
// Return list of users matching query on user name // Return list of users matching query on user name
query = models.SearchUsersQuery{Query: "use", Page: 1, Limit: 3} query = models.SearchUsersQuery{Query: "use", Page: 1, Limit: 3}
err = SearchUsers(context.Background(), &query) err = ss.SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 3) require.Len(t, query.Result.Users, 3)
require.EqualValues(t, query.Result.TotalCount, 5) require.EqualValues(t, query.Result.TotalCount, 5)
query = models.SearchUsersQuery{Query: "ser1", Page: 1, Limit: 3} query = models.SearchUsersQuery{Query: "ser1", Page: 1, Limit: 3}
err = SearchUsers(context.Background(), &query) err = ss.SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 1) require.Len(t, query.Result.Users, 1)
require.EqualValues(t, query.Result.TotalCount, 1) require.EqualValues(t, query.Result.TotalCount, 1)
query = models.SearchUsersQuery{Query: "USER1", Page: 1, Limit: 3} query = models.SearchUsersQuery{Query: "USER1", Page: 1, Limit: 3}
err = SearchUsers(context.Background(), &query) err = ss.SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 1) require.Len(t, query.Result.Users, 1)
require.EqualValues(t, query.Result.TotalCount, 1) require.EqualValues(t, query.Result.TotalCount, 1)
query = models.SearchUsersQuery{Query: "idontexist", Page: 1, Limit: 3} query = models.SearchUsersQuery{Query: "idontexist", Page: 1, Limit: 3}
err = SearchUsers(context.Background(), &query) err = ss.SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 0) require.Len(t, query.Result.Users, 0)
@ -175,7 +175,7 @@ func TestUserDataAccess(t *testing.T) {
// Return list of users matching query on email // Return list of users matching query on email
query = models.SearchUsersQuery{Query: "ser1@test.com", Page: 1, Limit: 3} query = models.SearchUsersQuery{Query: "ser1@test.com", Page: 1, Limit: 3}
err = SearchUsers(context.Background(), &query) err = ss.SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 1) require.Len(t, query.Result.Users, 1)
@ -183,7 +183,7 @@ func TestUserDataAccess(t *testing.T) {
// Return list of users matching query on login name // Return list of users matching query on login name
query = models.SearchUsersQuery{Query: "loginuser1", Page: 1, Limit: 3} query = models.SearchUsersQuery{Query: "loginuser1", Page: 1, Limit: 3}
err = SearchUsers(context.Background(), &query) err = ss.SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 1) require.Len(t, query.Result.Users, 1)
@ -203,7 +203,7 @@ func TestUserDataAccess(t *testing.T) {
isDisabled := false isDisabled := false
query := models.SearchUsersQuery{IsDisabled: &isDisabled} query := models.SearchUsersQuery{IsDisabled: &isDisabled}
err := SearchUsers(context.Background(), &query) err := ss.SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 2) require.Len(t, query.Result.Users, 2)
@ -330,7 +330,7 @@ func TestUserDataAccess(t *testing.T) {
isDisabled = true isDisabled = true
query5 := &models.SearchUsersQuery{IsDisabled: &isDisabled} query5 := &models.SearchUsersQuery{IsDisabled: &isDisabled}
err = SearchUsers(context.Background(), query5) err = ss.SearchUsers(context.Background(), query5)
require.Nil(t, err) require.Nil(t, err)
require.EqualValues(t, query5.Result.TotalCount, 5) require.EqualValues(t, query5.Result.TotalCount, 5)
@ -383,7 +383,7 @@ func TestUserDataAccess(t *testing.T) {
isDisabled := false isDisabled := false
query := &models.SearchUsersQuery{IsDisabled: &isDisabled} query := &models.SearchUsersQuery{IsDisabled: &isDisabled}
err = SearchUsers(context.Background(), query) err = ss.SearchUsers(context.Background(), query)
require.Nil(t, err) require.Nil(t, err)
require.EqualValues(t, query.Result.TotalCount, 5) require.EqualValues(t, query.Result.TotalCount, 5)
@ -414,7 +414,7 @@ func TestUserDataAccess(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
query := models.SearchUsersQuery{} query := models.SearchUsersQuery{}
err = SearchUsers(context.Background(), &query) err = ss.SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.EqualValues(t, query.Result.TotalCount, 5) require.EqualValues(t, query.Result.TotalCount, 5)

Loading…
Cancel
Save