Chore: Add context to temp user (#41284)

* Add context to temp user

* Remove xorm and InTransaction
pull/41310/head
idafurjes 4 years ago committed by GitHub
parent b82797d1b0
commit da5033f3fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 27
      pkg/api/org_invite.go
  2. 15
      pkg/api/signup.go
  3. 2
      pkg/services/searchusers/searchusers.go
  4. 1
      pkg/services/sqlstore/sqlstore.go
  5. 92
      pkg/services/sqlstore/temp_user.go
  6. 20
      pkg/services/sqlstore/temp_user_test.go

@ -1,6 +1,7 @@
package api
import (
"context"
"errors"
"fmt"
@ -18,7 +19,7 @@ import (
func GetPendingOrgInvites(c *models.ReqContext) response.Response {
query := models.GetTempUsersQuery{OrgId: c.OrgId, Status: models.TmpUserInvitePending}
if err := bus.Dispatch(&query); err != nil {
if err := bus.DispatchCtx(c.Req.Context(), &query); err != nil {
return response.Error(500, "Failed to get invites from db", err)
}
@ -62,7 +63,7 @@ func AddOrgInvite(c *models.ReqContext, inviteDto dtos.AddInviteForm) response.R
cmd.Role = inviteDto.Role
cmd.RemoteAddr = c.Req.RemoteAddr
if err := bus.Dispatch(&cmd); err != nil {
if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
return response.Error(500, "Failed to save invite to database", err)
}
@ -102,7 +103,7 @@ func AddOrgInvite(c *models.ReqContext, inviteDto dtos.AddInviteForm) response.R
func inviteExistingUserToOrg(c *models.ReqContext, user *models.User, inviteDto *dtos.AddInviteForm) response.Response {
// user exists, add org role
createOrgUserCmd := models.AddOrgUserCommand{OrgId: c.OrgId, UserId: user.Id, Role: inviteDto.Role}
if err := bus.Dispatch(&createOrgUserCmd); err != nil {
if err := bus.DispatchCtx(c.Req.Context(), &createOrgUserCmd); err != nil {
if errors.Is(err, models.ErrOrgUserAlreadyAdded) {
return response.Error(412, fmt.Sprintf("User %s is already added to organization", inviteDto.LoginOrEmail), err)
}
@ -132,7 +133,7 @@ func inviteExistingUserToOrg(c *models.ReqContext, user *models.User, inviteDto
}
func RevokeInvite(c *models.ReqContext) response.Response {
if ok, rsp := updateTempUserStatus(web.Params(c.Req)[":code"], models.TmpUserRevoked); !ok {
if ok, rsp := updateTempUserStatus(c.Req.Context(), web.Params(c.Req)[":code"], models.TmpUserRevoked); !ok {
return rsp
}
@ -144,7 +145,7 @@ func RevokeInvite(c *models.ReqContext) response.Response {
// If a (pending) invite is not found, 404 is returned.
func GetInviteInfoByCode(c *models.ReqContext) response.Response {
query := models.GetTempUserByCodeQuery{Code: web.Params(c.Req)[":code"]}
if err := bus.Dispatch(&query); err != nil {
if err := bus.DispatchCtx(c.Req.Context(), &query); err != nil {
if errors.Is(err, models.ErrTempUserNotFound) {
return response.Error(404, "Invite not found", nil)
}
@ -167,7 +168,7 @@ func GetInviteInfoByCode(c *models.ReqContext) response.Response {
func (hs *HTTPServer) CompleteInvite(c *models.ReqContext, completeInvite dtos.CompleteInviteForm) response.Response {
query := models.GetTempUserByCodeQuery{Code: completeInvite.InviteCode}
if err := bus.Dispatch(&query); err != nil {
if err := bus.DispatchCtx(c.Req.Context(), &query); err != nil {
if errors.Is(err, models.ErrTempUserNotFound) {
return response.Error(404, "Invite not found", nil)
}
@ -203,7 +204,7 @@ func (hs *HTTPServer) CompleteInvite(c *models.ReqContext, completeInvite dtos.C
return response.Error(500, "failed to publish event", err)
}
if ok, rsp := applyUserInvite(user, invite, true); !ok {
if ok, rsp := applyUserInvite(c.Req.Context(), user, invite, true); !ok {
return rsp
}
@ -221,33 +222,33 @@ func (hs *HTTPServer) CompleteInvite(c *models.ReqContext, completeInvite dtos.C
})
}
func updateTempUserStatus(code string, status models.TempUserStatus) (bool, response.Response) {
func updateTempUserStatus(ctx context.Context, code string, status models.TempUserStatus) (bool, response.Response) {
// update temp user status
updateTmpUserCmd := models.UpdateTempUserStatusCommand{Code: code, Status: status}
if err := bus.Dispatch(&updateTmpUserCmd); err != nil {
if err := bus.DispatchCtx(ctx, &updateTmpUserCmd); err != nil {
return false, response.Error(500, "Failed to update invite status", err)
}
return true, nil
}
func applyUserInvite(user *models.User, invite *models.TempUserDTO, setActive bool) (bool, response.Response) {
func applyUserInvite(ctx context.Context, user *models.User, invite *models.TempUserDTO, setActive bool) (bool, response.Response) {
// add to org
addOrgUserCmd := models.AddOrgUserCommand{OrgId: invite.OrgId, UserId: user.Id, Role: invite.Role}
if err := bus.Dispatch(&addOrgUserCmd); err != nil {
if err := bus.DispatchCtx(ctx, &addOrgUserCmd); err != nil {
if !errors.Is(err, models.ErrOrgUserAlreadyAdded) {
return false, response.Error(500, "Error while trying to create org user", err)
}
}
// update temp user status
if ok, rsp := updateTempUserStatus(invite.Code, models.TmpUserCompleted); !ok {
if ok, rsp := updateTempUserStatus(ctx, invite.Code, models.TmpUserCompleted); !ok {
return false, rsp
}
if setActive {
// set org to active
if err := bus.Dispatch(&models.SetUsingOrgCommand{OrgId: invite.OrgId, UserId: user.Id}); err != nil {
if err := bus.DispatchCtx(ctx, &models.SetUsingOrgCommand{OrgId: invite.OrgId, UserId: user.Id}); err != nil {
return false, response.Error(500, "Failed to set org as active", err)
}
}

@ -1,6 +1,7 @@
package api
import (
"context"
"errors"
"github.com/grafana/grafana/pkg/api/dtos"
@ -44,7 +45,7 @@ func SignUp(c *models.ReqContext, form dtos.SignUpForm) response.Response {
}
cmd.RemoteAddr = c.Req.RemoteAddr
if err := bus.Dispatch(&cmd); err != nil {
if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
return response.Error(500, "Failed to create signup", err)
}
@ -75,7 +76,7 @@ func (hs *HTTPServer) SignUpStep2(c *models.ReqContext, form dtos.SignUpStep2For
// verify email
if setting.VerifyEmailEnabled {
if ok, rsp := verifyUserSignUpEmail(form.Email, form.Code); !ok {
if ok, rsp := verifyUserSignUpEmail(c.Req.Context(), form.Email, form.Code); !ok {
return rsp
}
createUserCmd.EmailVerified = true
@ -99,19 +100,19 @@ func (hs *HTTPServer) SignUpStep2(c *models.ReqContext, form dtos.SignUpStep2For
}
// mark temp user as completed
if ok, rsp := updateTempUserStatus(form.Code, models.TmpUserCompleted); !ok {
if ok, rsp := updateTempUserStatus(c.Req.Context(), form.Code, models.TmpUserCompleted); !ok {
return rsp
}
// check for pending invites
invitesQuery := models.GetTempUsersQuery{Email: form.Email, Status: models.TmpUserInvitePending}
if err := bus.Dispatch(&invitesQuery); err != nil {
if err := bus.DispatchCtx(c.Req.Context(), &invitesQuery); err != nil {
return response.Error(500, "Failed to query database for invites", err)
}
apiResponse := util.DynMap{"message": "User sign up completed successfully", "code": "redirect-to-landing-page"}
for _, invite := range invitesQuery.Result {
if ok, rsp := applyUserInvite(user, invite, false); !ok {
if ok, rsp := applyUserInvite(c.Req.Context(), user, invite, false); !ok {
return rsp
}
apiResponse["code"] = "redirect-to-select-org"
@ -127,10 +128,10 @@ func (hs *HTTPServer) SignUpStep2(c *models.ReqContext, form dtos.SignUpStep2For
return response.JSON(200, apiResponse)
}
func verifyUserSignUpEmail(email string, code string) (bool, response.Response) {
func verifyUserSignUpEmail(ctx context.Context, email string, code string) (bool, response.Response) {
query := models.GetTempUserByCodeQuery{Code: code}
if err := bus.Dispatch(&query); err != nil {
if err := bus.DispatchCtx(ctx, &query); err != nil {
if errors.Is(err, models.ErrTempUserNotFound) {
return false, response.Error(404, "Invalid email verification code", nil)
}

@ -60,7 +60,7 @@ func (s *OSSService) SearchUser(c *models.ReqContext) (*models.SearchUsersQuery,
}
query := &models.SearchUsersQuery{Query: searchQuery, Filters: filters, Page: page, Limit: perPage}
if err := s.bus.Dispatch(query); err != nil {
if err := s.bus.DispatchCtx(c.Req.Context(), query); err != nil {
return nil, err
}

@ -118,6 +118,7 @@ func newSQLStore(cfg *setting.Cfg, cacheService *localcache.CacheService, bus bu
ss.addOrgUsersQueryAndCommandHandlers()
ss.addStarQueryAndCommandHandlers()
ss.addAlertQueryAndCommandHandlers()
ss.addTempUserQueryAndCommandHandlers()
// if err := ss.Reset(); err != nil {
// return nil, err

@ -1,31 +1,32 @@
package sqlstore
import (
"context"
"time"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/models"
)
func init() {
bus.AddHandler("sql", CreateTempUser)
bus.AddHandler("sql", GetTempUsersQuery)
bus.AddHandler("sql", UpdateTempUserStatus)
bus.AddHandler("sql", GetTempUserByCode)
bus.AddHandler("sql", UpdateTempUserWithEmailSent)
bus.AddHandler("sql", ExpireOldUserInvites)
func (ss *SQLStore) addTempUserQueryAndCommandHandlers() {
bus.AddHandlerCtx("sql", ss.CreateTempUser)
bus.AddHandlerCtx("sql", ss.GetTempUsersQuery)
bus.AddHandlerCtx("sql", ss.UpdateTempUserStatus)
bus.AddHandlerCtx("sql", ss.GetTempUserByCode)
bus.AddHandlerCtx("sql", ss.UpdateTempUserWithEmailSent)
bus.AddHandlerCtx("sql", ss.ExpireOldUserInvites)
}
func UpdateTempUserStatus(cmd *models.UpdateTempUserStatusCommand) error {
return inTransaction(func(sess *DBSession) error {
func (ss *SQLStore) UpdateTempUserStatus(ctx context.Context, cmd *models.UpdateTempUserStatusCommand) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
var rawSQL = "UPDATE temp_user SET status=? WHERE code=?"
_, err := sess.Exec(rawSQL, string(cmd.Status), cmd.Code)
return err
})
}
func CreateTempUser(cmd *models.CreateTempUserCommand) error {
return inTransaction(func(sess *DBSession) error {
func (ss *SQLStore) CreateTempUser(ctx context.Context, cmd *models.CreateTempUserCommand) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
// create user
user := &models.TempUser{
Email: cmd.Email,
@ -46,12 +47,13 @@ func CreateTempUser(cmd *models.CreateTempUserCommand) error {
}
cmd.Result = user
return nil
})
}
func UpdateTempUserWithEmailSent(cmd *models.UpdateTempUserWithEmailSentCommand) error {
return inTransaction(func(sess *DBSession) error {
func (ss *SQLStore) UpdateTempUserWithEmailSent(ctx context.Context, cmd *models.UpdateTempUserWithEmailSentCommand) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
user := &models.TempUser{
EmailSent: true,
EmailSentOn: time.Now(),
@ -63,8 +65,9 @@ func UpdateTempUserWithEmailSent(cmd *models.UpdateTempUserWithEmailSentCommand)
})
}
func GetTempUsersQuery(query *models.GetTempUsersQuery) error {
rawSQL := `SELECT
func (ss *SQLStore) GetTempUsersQuery(ctx context.Context, query *models.GetTempUsersQuery) error {
return ss.WithDbSession(ctx, func(dbSess *DBSession) error {
rawSQL := `SELECT
tu.id as id,
tu.org_id as org_id,
tu.email as email,
@ -81,28 +84,30 @@ func GetTempUsersQuery(query *models.GetTempUsersQuery) error {
FROM ` + dialect.Quote("temp_user") + ` as tu
LEFT OUTER JOIN ` + dialect.Quote("user") + ` as u on u.id = tu.invited_by_user_id
WHERE tu.status=?`
params := []interface{}{string(query.Status)}
params := []interface{}{string(query.Status)}
if query.OrgId > 0 {
rawSQL += ` AND tu.org_id=?`
params = append(params, query.OrgId)
}
if query.OrgId > 0 {
rawSQL += ` AND tu.org_id=?`
params = append(params, query.OrgId)
}
if query.Email != "" {
rawSQL += ` AND tu.email=?`
params = append(params, query.Email)
}
if query.Email != "" {
rawSQL += ` AND tu.email=?`
params = append(params, query.Email)
}
rawSQL += " ORDER BY tu.created desc"
rawSQL += " ORDER BY tu.created desc"
query.Result = make([]*models.TempUserDTO, 0)
sess := x.SQL(rawSQL, params...)
err := sess.Find(&query.Result)
return err
query.Result = make([]*models.TempUserDTO, 0)
sess := dbSess.SQL(rawSQL, params...)
err := sess.Find(&query.Result)
return err
})
}
func GetTempUserByCode(query *models.GetTempUserByCodeQuery) error {
var rawSQL = `SELECT
func (ss *SQLStore) GetTempUserByCode(ctx context.Context, query *models.GetTempUserByCodeQuery) error {
return ss.WithDbSession(ctx, func(dbSess *DBSession) error {
var rawSQL = `SELECT
tu.id as id,
tu.org_id as org_id,
tu.email as email,
@ -120,22 +125,23 @@ func GetTempUserByCode(query *models.GetTempUserByCodeQuery) error {
LEFT OUTER JOIN ` + dialect.Quote("user") + ` as u on u.id = tu.invited_by_user_id
WHERE tu.code=?`
var tempUser models.TempUserDTO
sess := x.SQL(rawSQL, query.Code)
has, err := sess.Get(&tempUser)
var tempUser models.TempUserDTO
sess := dbSess.SQL(rawSQL, query.Code)
has, err := sess.Get(&tempUser)
if err != nil {
return err
} else if !has {
return models.ErrTempUserNotFound
}
if err != nil {
return err
} else if !has {
return models.ErrTempUserNotFound
}
query.Result = &tempUser
return err
query.Result = &tempUser
return err
})
}
func ExpireOldUserInvites(cmd *models.ExpireTempUsersCommand) error {
return inTransaction(func(sess *DBSession) error {
func (ss *SQLStore) ExpireOldUserInvites(ctx context.Context, cmd *models.ExpireTempUsersCommand) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
var rawSQL = "UPDATE temp_user SET status = ?, updated = ? WHERE created <= ? AND status in (?, ?)"
if result, err := sess.Exec(rawSQL, string(models.TmpUserExpired), time.Now().Unix(), cmd.OlderThan.Unix(), string(models.TmpUserSignUpStarted), string(models.TmpUserInvitePending)); err != nil {
return err

@ -4,6 +4,7 @@
package sqlstore
import (
"context"
"testing"
"time"
@ -13,6 +14,7 @@ import (
)
func TestTempUserCommandsAndQueries(t *testing.T) {
ss := InitTestDB(t)
cmd := models.CreateTempUserCommand{
OrgId: 2256,
Name: "hello",
@ -22,14 +24,14 @@ func TestTempUserCommandsAndQueries(t *testing.T) {
}
setup := func(t *testing.T) {
InitTestDB(t)
err := CreateTempUser(&cmd)
err := ss.CreateTempUser(context.Background(), &cmd)
require.Nil(t, err)
}
t.Run("Should be able to get temp users by org id", func(t *testing.T) {
setup(t)
query := models.GetTempUsersQuery{OrgId: 2256, Status: models.TmpUserInvitePending}
err := GetTempUsersQuery(&query)
err := ss.GetTempUsersQuery(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, 1, len(query.Result))
@ -38,7 +40,7 @@ func TestTempUserCommandsAndQueries(t *testing.T) {
t.Run("Should be able to get temp users by email", func(t *testing.T) {
setup(t)
query := models.GetTempUsersQuery{Email: "e@as.co", Status: models.TmpUserInvitePending}
err := GetTempUsersQuery(&query)
err := ss.GetTempUsersQuery(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, 1, len(query.Result))
@ -47,7 +49,7 @@ func TestTempUserCommandsAndQueries(t *testing.T) {
t.Run("Should be able to get temp users by code", func(t *testing.T) {
setup(t)
query := models.GetTempUserByCodeQuery{Code: "asd"}
err := GetTempUserByCode(&query)
err := ss.GetTempUserByCode(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, "hello", query.Result.Name)
@ -56,18 +58,18 @@ func TestTempUserCommandsAndQueries(t *testing.T) {
t.Run("Should be able update status", func(t *testing.T) {
setup(t)
cmd2 := models.UpdateTempUserStatusCommand{Code: "asd", Status: models.TmpUserRevoked}
err := UpdateTempUserStatus(&cmd2)
err := ss.UpdateTempUserStatus(context.Background(), &cmd2)
require.Nil(t, err)
})
t.Run("Should be able update email sent and email sent on", func(t *testing.T) {
setup(t)
cmd2 := models.UpdateTempUserWithEmailSentCommand{Code: cmd.Result.Code}
err := UpdateTempUserWithEmailSent(&cmd2)
err := ss.UpdateTempUserWithEmailSent(context.Background(), &cmd2)
require.Nil(t, err)
query := models.GetTempUsersQuery{OrgId: 2256, Status: models.TmpUserInvitePending}
err = GetTempUsersQuery(&query)
err = ss.GetTempUsersQuery(context.Background(), &query)
require.Nil(t, err)
require.True(t, query.Result[0].EmailSent)
@ -78,14 +80,14 @@ func TestTempUserCommandsAndQueries(t *testing.T) {
setup(t)
createdAt := time.Unix(cmd.Result.Created, 0)
cmd2 := models.ExpireTempUsersCommand{OlderThan: createdAt.Add(1 * time.Second)}
err := ExpireOldUserInvites(&cmd2)
err := ss.ExpireOldUserInvites(context.Background(), &cmd2)
require.Nil(t, err)
require.Equal(t, int64(1), cmd2.NumExpired)
t.Run("Should do nothing when no temp users to expire", func(t *testing.T) {
createdAt := time.Unix(cmd.Result.Created, 0)
cmd2 := models.ExpireTempUsersCommand{OlderThan: createdAt.Add(1 * time.Second)}
err := ExpireOldUserInvites(&cmd2)
err := ss.ExpireOldUserInvites(context.Background(), &cmd2)
require.Nil(t, err)
require.Equal(t, int64(0), cmd2.NumExpired)
})

Loading…
Cancel
Save