diff --git a/pkg/api/admin_users.go b/pkg/api/admin_users.go index 4ad8a2b84ab..76193771eb9 100644 --- a/pkg/api/admin_users.go +++ b/pkg/api/admin_users.go @@ -119,7 +119,7 @@ func (server *HTTPServer) AdminLogoutUser(c *m.ReqContext) Response { return Error(400, "You cannot logout yourself", nil) } - return server.logoutUserFromAllDevicesInternal(userID) + return server.logoutUserFromAllDevicesInternal(c.Req.Context(), userID) } // GET /api/admin/users/:id/auth-tokens diff --git a/pkg/api/login.go b/pkg/api/login.go index 65ace1b2b83..ebf4cc8db07 100644 --- a/pkg/api/login.go +++ b/pkg/api/login.go @@ -131,7 +131,7 @@ func (hs *HTTPServer) loginUserWithUser(user *m.User, c *m.ReqContext) { hs.log.Error("user login with nil user") } - userToken, err := hs.AuthTokenService.CreateToken(user.Id, c.RemoteAddr(), c.Req.UserAgent()) + userToken, err := hs.AuthTokenService.CreateToken(c.Req.Context(), user.Id, c.RemoteAddr(), c.Req.UserAgent()) if err != nil { hs.log.Error("failed to create auth token", "error", err) } @@ -140,7 +140,7 @@ func (hs *HTTPServer) loginUserWithUser(user *m.User, c *m.ReqContext) { } func (hs *HTTPServer) Logout(c *m.ReqContext) { - if err := hs.AuthTokenService.RevokeToken(c.UserToken); err != nil && err != m.ErrUserTokenNotFound { + if err := hs.AuthTokenService.RevokeToken(c.Req.Context(), c.UserToken); err != nil && err != m.ErrUserTokenNotFound { hs.log.Error("failed to revoke auth token", "error", err) } diff --git a/pkg/api/user_token.go b/pkg/api/user_token.go index 2f74eedea5d..3e53a003bd8 100644 --- a/pkg/api/user_token.go +++ b/pkg/api/user_token.go @@ -1,6 +1,7 @@ package api import ( + "context" "time" "github.com/grafana/grafana/pkg/api/dtos" @@ -19,7 +20,7 @@ func (server *HTTPServer) RevokeUserAuthToken(c *models.ReqContext, cmd models.R return server.revokeUserAuthTokenInternal(c, c.UserId, cmd) } -func (server *HTTPServer) logoutUserFromAllDevicesInternal(userID int64) Response { +func (server *HTTPServer) logoutUserFromAllDevicesInternal(ctx context.Context, userID int64) Response { userQuery := models.GetUserByIdQuery{Id: userID} if err := bus.Dispatch(&userQuery); err != nil { @@ -29,7 +30,7 @@ func (server *HTTPServer) logoutUserFromAllDevicesInternal(userID int64) Respons return Error(500, "Could not read user from database", err) } - err := server.AuthTokenService.RevokeAllUserTokens(userID) + err := server.AuthTokenService.RevokeAllUserTokens(ctx, userID) if err != nil { return Error(500, "Failed to logout user", err) } @@ -49,7 +50,7 @@ func (server *HTTPServer) getUserAuthTokensInternal(c *models.ReqContext, userID return Error(500, "Failed to get user", err) } - tokens, err := server.AuthTokenService.GetUserTokens(userID) + tokens, err := server.AuthTokenService.GetUserTokens(c.Req.Context(), userID) if err != nil { return Error(500, "Failed to get user auth tokens", err) } @@ -84,7 +85,7 @@ func (server *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, user return Error(500, "Failed to get user", err) } - token, err := server.AuthTokenService.GetUserToken(userID, cmd.AuthTokenId) + token, err := server.AuthTokenService.GetUserToken(c.Req.Context(), userID, cmd.AuthTokenId) if err != nil { if err == models.ErrUserTokenNotFound { return Error(404, "User auth token not found", err) @@ -96,7 +97,7 @@ func (server *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, user return Error(400, "Cannot revoke active user auth token", nil) } - err = server.AuthTokenService.RevokeToken(token) + err = server.AuthTokenService.RevokeToken(c.Req.Context(), token) if err != nil { if err == models.ErrUserTokenNotFound { return Error(404, "User auth token not found", err) diff --git a/pkg/api/user_token_test.go b/pkg/api/user_token_test.go index 111070dca92..aa5bc47f93e 100644 --- a/pkg/api/user_token_test.go +++ b/pkg/api/user_token_test.go @@ -1,6 +1,7 @@ package api import ( + "context" "testing" "time" @@ -75,7 +76,7 @@ func TestUserTokenApiEndpoint(t *testing.T) { token := &m.UserToken{Id: 1} revokeUserAuthTokenInternalScenario("Should be successful", cmd, 200, token, func(sc *scenarioContext) { - sc.userAuthTokenService.GetUserTokenProvider = func(userId, userTokenId int64) (*m.UserToken, error) { + sc.userAuthTokenService.GetUserTokenProvider = func(ctx context.Context, userId, userTokenId int64) (*m.UserToken, error) { return &m.UserToken{Id: 2}, nil } sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() @@ -93,7 +94,7 @@ func TestUserTokenApiEndpoint(t *testing.T) { token := &m.UserToken{Id: 2} revokeUserAuthTokenInternalScenario("Should not be successful", cmd, TestUserID, token, func(sc *scenarioContext) { - sc.userAuthTokenService.GetUserTokenProvider = func(userId, userTokenId int64) (*m.UserToken, error) { + sc.userAuthTokenService.GetUserTokenProvider = func(ctx context.Context, userId, userTokenId int64) (*m.UserToken, error) { return token, nil } sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() @@ -126,7 +127,7 @@ func TestUserTokenApiEndpoint(t *testing.T) { SeenAt: time.Now().Unix(), }, } - sc.userAuthTokenService.GetUserTokensProvider = func(userId int64) ([]*m.UserToken, error) { + sc.userAuthTokenService.GetUserTokensProvider = func(ctx context.Context, userId int64) ([]*m.UserToken, error) { return tokens, nil } sc.fakeReqWithParams("GET", sc.url, map[string]string{}).exec() @@ -226,7 +227,7 @@ func logoutUserFromAllDevicesInternalScenario(desc string, userId int64, fn scen sc.context.OrgId = TestOrgID sc.context.OrgRole = m.ROLE_ADMIN - return hs.logoutUserFromAllDevicesInternal(userId) + return hs.logoutUserFromAllDevicesInternal(context.Background(), userId) }) sc.m.Post("/", sc.defaultHandler) diff --git a/pkg/middleware/middleware.go b/pkg/middleware/middleware.go index 61f86824118..be76c45d89d 100644 --- a/pkg/middleware/middleware.go +++ b/pkg/middleware/middleware.go @@ -173,7 +173,7 @@ func initContextWithToken(authTokenService m.UserTokenService, ctx *m.ReqContext return false } - token, err := authTokenService.LookupToken(rawToken) + token, err := authTokenService.LookupToken(ctx.Req.Context(), rawToken) if err != nil { ctx.Logger.Error("failed to look up user based on cookie", "error", err) WriteSessionCookie(ctx, "", -1) @@ -190,7 +190,7 @@ func initContextWithToken(authTokenService m.UserTokenService, ctx *m.ReqContext ctx.IsSignedIn = true ctx.UserToken = token - rotated, err := authTokenService.TryRotateToken(token, ctx.RemoteAddr(), ctx.Req.UserAgent()) + rotated, err := authTokenService.TryRotateToken(ctx.Req.Context(), token, ctx.RemoteAddr(), ctx.Req.UserAgent()) if err != nil { ctx.Logger.Error("failed to rotate token", "error", err) return true diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index 92d3da0896f..e59e017398a 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "encoding/json" "fmt" "net/http" @@ -156,7 +157,7 @@ func TestMiddlewareContext(t *testing.T) { return nil }) - sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) { return &m.UserToken{ UserId: 12, UnhashedToken: unhashedToken, @@ -185,14 +186,14 @@ func TestMiddlewareContext(t *testing.T) { return nil }) - sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) { return &m.UserToken{ UserId: 12, UnhashedToken: "", }, nil } - sc.userAuthTokenService.TryRotateTokenProvider = func(userToken *m.UserToken, clientIP, userAgent string) (bool, error) { + sc.userAuthTokenService.TryRotateTokenProvider = func(ctx context.Context, userToken *m.UserToken, clientIP, userAgent string) (bool, error) { userToken.UnhashedToken = "rotated" return true, nil } @@ -227,7 +228,7 @@ func TestMiddlewareContext(t *testing.T) { middlewareScenario(t, "Invalid/expired auth token in cookie", func(sc *scenarioContext) { sc.withTokenSessionCookie("token") - sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) { return nil, m.ErrUserTokenNotFound } diff --git a/pkg/middleware/org_redirect_test.go b/pkg/middleware/org_redirect_test.go index f307376331e..e74b6e8451c 100644 --- a/pkg/middleware/org_redirect_test.go +++ b/pkg/middleware/org_redirect_test.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "fmt" "testing" @@ -23,7 +24,7 @@ func TestOrgRedirectMiddleware(t *testing.T) { return nil }) - sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) { return &m.UserToken{ UserId: 0, UnhashedToken: "", @@ -49,7 +50,7 @@ func TestOrgRedirectMiddleware(t *testing.T) { return nil }) - sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) { return &m.UserToken{ UserId: 12, UnhashedToken: "", diff --git a/pkg/middleware/quota_test.go b/pkg/middleware/quota_test.go index c3b448c9f6a..c6c8a1fd4d3 100644 --- a/pkg/middleware/quota_test.go +++ b/pkg/middleware/quota_test.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "testing" "github.com/grafana/grafana/pkg/bus" @@ -87,7 +88,7 @@ func TestMiddlewareQuota(t *testing.T) { return nil }) - sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) { return &m.UserToken{ UserId: 12, UnhashedToken: "", diff --git a/pkg/models/user_token.go b/pkg/models/user_token.go index 8c3e7985995..b07bd508114 100644 --- a/pkg/models/user_token.go +++ b/pkg/models/user_token.go @@ -1,6 +1,7 @@ package models import ( + "context" "errors" ) @@ -31,12 +32,12 @@ type RevokeAuthTokenCmd struct { // UserTokenService are used for generating and validating user tokens type UserTokenService interface { - CreateToken(userId int64, clientIP, userAgent string) (*UserToken, error) - LookupToken(unhashedToken string) (*UserToken, error) - TryRotateToken(token *UserToken, clientIP, userAgent string) (bool, error) - RevokeToken(token *UserToken) error - RevokeAllUserTokens(userId int64) error - ActiveTokenCount() (int64, error) - GetUserToken(userId, userTokenId int64) (*UserToken, error) - GetUserTokens(userId int64) ([]*UserToken, error) + CreateToken(ctx context.Context, userId int64, clientIP, userAgent string) (*UserToken, error) + LookupToken(ctx context.Context, unhashedToken string) (*UserToken, error) + TryRotateToken(ctx context.Context, token *UserToken, clientIP, userAgent string) (bool, error) + RevokeToken(ctx context.Context, token *UserToken) error + RevokeAllUserTokens(ctx context.Context, userId int64) error + ActiveTokenCount(ctx context.Context) (int64, error) + GetUserToken(ctx context.Context, userId, userTokenId int64) (*UserToken, error) + GetUserTokens(ctx context.Context, userId int64) ([]*UserToken, error) } diff --git a/pkg/services/auth/auth_token.go b/pkg/services/auth/auth_token.go index 740e5081668..dc9936f2f3f 100644 --- a/pkg/services/auth/auth_token.go +++ b/pkg/services/auth/auth_token.go @@ -1,6 +1,7 @@ package auth import ( + "context" "crypto/sha256" "encoding/hex" "time" @@ -35,14 +36,24 @@ func (s *UserAuthTokenService) Init() error { return nil } -func (s *UserAuthTokenService) ActiveTokenCount() (int64, error) { - var model userAuthToken - count, err := s.SQLStore.NewSession().Where(`created_at > ? AND rotated_at > ?`, s.createdAfterParam(), s.rotatedAfterParam()).Count(&model) +func (s *UserAuthTokenService) ActiveTokenCount(ctx context.Context) (int64, error) { + + var count int64 + var err error + err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error { + var model userAuthToken + count, err = dbSession.Where(`created_at > ? AND rotated_at > ?`, + s.createdAfterParam(), + s.rotatedAfterParam()). + Count(&model) + + return err + }) return count, err } -func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) { +func (s *UserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP, userAgent string) (*models.UserToken, error) { clientIP = util.ParseIPAddress(clientIP) token, err := util.RandomHex(16) if err != nil { @@ -65,7 +76,12 @@ func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent str SeenAt: 0, AuthTokenSeen: false, } - _, err = s.SQLStore.NewSession().Insert(&userAuthToken) + + err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error { + _, err = dbSession.Insert(&userAuthToken) + return err + }) + if err != nil { return nil, err } @@ -80,14 +96,27 @@ func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent str return &userToken, err } -func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserToken, error) { +func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken string) (*models.UserToken, error) { hashedToken := hashToken(unhashedToken) if setting.Env == setting.DEV { s.log.Debug("looking up token", "unhashed", unhashedToken, "hashed", hashedToken) } var model userAuthToken - exists, err := s.SQLStore.NewSession().Where("(auth_token = ? OR prev_auth_token = ?) AND created_at > ? AND rotated_at > ?", hashedToken, hashedToken, s.createdAfterParam(), s.rotatedAfterParam()).Get(&model) + var exists bool + var err error + err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error { + exists, err = dbSession.Where("(auth_token = ? OR prev_auth_token = ?) AND created_at > ? AND rotated_at > ?", + hashedToken, + hashedToken, + s.createdAfterParam(), + s.rotatedAfterParam()). + Get(&model) + + return err + + }) + if err != nil { return nil, err } @@ -100,7 +129,18 @@ func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserTo modelCopy := model modelCopy.AuthTokenSeen = false expireBefore := getTime().Add(-urgentRotateTime).Unix() - affectedRows, err := s.SQLStore.NewSession().Where("id = ? AND prev_auth_token = ? AND rotated_at < ?", modelCopy.Id, modelCopy.PrevAuthToken, expireBefore).AllCols().Update(&modelCopy) + + var affectedRows int64 + err = s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *sqlstore.DBSession) error { + affectedRows, err = dbSession.Where("id = ? AND prev_auth_token = ? AND rotated_at < ?", + modelCopy.Id, + modelCopy.PrevAuthToken, + expireBefore). + AllCols().Update(&modelCopy) + + return err + }) + if err != nil { return nil, err } @@ -116,7 +156,17 @@ func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserTo modelCopy := model modelCopy.AuthTokenSeen = true modelCopy.SeenAt = getTime().Unix() - affectedRows, err := s.SQLStore.NewSession().Where("id = ? AND auth_token = ?", modelCopy.Id, modelCopy.AuthToken).AllCols().Update(&modelCopy) + + var affectedRows int64 + err = s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *sqlstore.DBSession) error { + affectedRows, err = dbSession.Where("id = ? AND auth_token = ?", + modelCopy.Id, + modelCopy.AuthToken). + AllCols().Update(&modelCopy) + + return err + }) + if err != nil { return nil, err } @@ -140,7 +190,7 @@ func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserTo return &userToken, err } -func (s *UserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP, userAgent string) (bool, error) { +func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error) { if token == nil { return false, nil } @@ -183,12 +233,21 @@ func (s *UserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP, rotated_at = ? WHERE id = ? AND (auth_token_seen = ? OR rotated_at < ?)` - res, err := s.SQLStore.NewSession().Exec(sql, userAgent, clientIP, s.SQLStore.Dialect.BooleanStr(true), hashedToken, s.SQLStore.Dialect.BooleanStr(false), now.Unix(), model.Id, s.SQLStore.Dialect.BooleanStr(true), now.Add(-30*time.Second).Unix()) + var affected int64 + err = s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *sqlstore.DBSession) error { + res, err := dbSession.Exec(sql, userAgent, clientIP, s.SQLStore.Dialect.BooleanStr(true), hashedToken, s.SQLStore.Dialect.BooleanStr(false), now.Unix(), model.Id, s.SQLStore.Dialect.BooleanStr(true), now.Add(-30*time.Second).Unix()) + if err != nil { + return err + } + + affected, err = res.RowsAffected() + return err + }) + if err != nil { return false, err } - affected, _ := res.RowsAffected() s.log.Debug("auth token rotated", "affected", affected, "auth_token_id", model.Id, "userId", model.UserId) if affected > 0 { model.UnhashedToken = newToken @@ -199,14 +258,20 @@ func (s *UserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP, return false, nil } -func (s *UserAuthTokenService) RevokeToken(token *models.UserToken) error { +func (s *UserAuthTokenService) RevokeToken(ctx context.Context, token *models.UserToken) error { if token == nil { return models.ErrUserTokenNotFound } model := userAuthTokenFromUserToken(token) - rowsAffected, err := s.SQLStore.NewSession().Delete(model) + var rowsAffected int64 + var err error + err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error { + rowsAffected, err = dbSession.Delete(model) + return err + }) + if err != nil { return err } @@ -221,55 +286,71 @@ func (s *UserAuthTokenService) RevokeToken(token *models.UserToken) error { return nil } -func (s *UserAuthTokenService) RevokeAllUserTokens(userId int64) error { - sql := `DELETE from user_auth_token WHERE user_id = ?` - res, err := s.SQLStore.NewSession().Exec(sql, userId) - if err != nil { - return err - } +func (s *UserAuthTokenService) RevokeAllUserTokens(ctx context.Context, userId int64) error { + return s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error { + sql := `DELETE from user_auth_token WHERE user_id = ?` + res, err := dbSession.Exec(sql, userId) + if err != nil { + return err + } - affected, err := res.RowsAffected() - if err != nil { - return err - } + affected, err := res.RowsAffected() + if err != nil { + return err + } - s.log.Debug("all user tokens for user revoked", "userId", userId, "count", affected) + s.log.Debug("all user tokens for user revoked", "userId", userId, "count", affected) - return nil + return err + }) } -func (s *UserAuthTokenService) GetUserToken(userId, userTokenId int64) (*models.UserToken, error) { - var token userAuthToken - exists, err := s.SQLStore.NewSession().Where("id = ? AND user_id = ?", userTokenId, userId).Get(&token) - if err != nil { - return nil, err - } - - if !exists { - return nil, models.ErrUserTokenNotFound - } +func (s *UserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) { var result models.UserToken - token.toUserToken(&result) + err := s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error { + var token userAuthToken + exists, err := dbSession.Where("id = ? AND user_id = ?", userTokenId, userId).Get(&token) + if err != nil { + return err + } + + if !exists { + return models.ErrUserTokenNotFound + } + + token.toUserToken(&result) + return nil + }) - return &result, nil + return &result, err } -func (s *UserAuthTokenService) GetUserTokens(userId int64) ([]*models.UserToken, error) { - var tokens []*userAuthToken - err := s.SQLStore.NewSession().Where("user_id = ? AND created_at > ? AND rotated_at > ?", userId, s.createdAfterParam(), s.rotatedAfterParam()).Find(&tokens) - if err != nil { - return nil, err - } +func (s *UserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) ([]*models.UserToken, error) { result := []*models.UserToken{} - for _, token := range tokens { - var userToken models.UserToken - token.toUserToken(&userToken) - result = append(result, &userToken) - } + err := s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error { + var tokens []*userAuthToken + err := dbSession.Where("user_id = ? AND created_at > ? AND rotated_at > ?", + userId, + s.createdAfterParam(), + s.rotatedAfterParam()). + Find(&tokens) + + if err != nil { + return err + } + + for _, token := range tokens { + var userToken models.UserToken + token.toUserToken(&userToken) + result = append(result, &userToken) + } + + return nil + }) - return result, nil + return result, err } func (s *UserAuthTokenService) createdAfterParam() int64 { diff --git a/pkg/services/auth/auth_token_test.go b/pkg/services/auth/auth_token_test.go index 33eb309ad18..b1398834bdc 100644 --- a/pkg/services/auth/auth_token_test.go +++ b/pkg/services/auth/auth_token_test.go @@ -1,6 +1,7 @@ package auth import ( + "context" "encoding/json" "testing" "time" @@ -26,19 +27,19 @@ func TestUserAuthToken(t *testing.T) { } Convey("When creating token", func() { - userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") + userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent") So(err, ShouldBeNil) So(userToken, ShouldNotBeNil) So(userToken.AuthTokenSeen, ShouldBeFalse) Convey("Can count active tokens", func() { - count, err := userAuthTokenService.ActiveTokenCount() + count, err := userAuthTokenService.ActiveTokenCount(context.Background()) So(err, ShouldBeNil) So(count, ShouldEqual, 1) }) Convey("When lookup unhashed token should return user auth token", func() { - userToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) + userToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken) So(err, ShouldBeNil) So(userToken, ShouldNotBeNil) So(userToken.UserId, ShouldEqual, userID) @@ -51,13 +52,13 @@ func TestUserAuthToken(t *testing.T) { }) Convey("When lookup hashed token should return user auth token not found error", func() { - userToken, err := userAuthTokenService.LookupToken(userToken.AuthToken) + userToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.AuthToken) So(err, ShouldEqual, models.ErrUserTokenNotFound) So(userToken, ShouldBeNil) }) Convey("revoking existing token should delete token", func() { - err = userAuthTokenService.RevokeToken(userToken) + err = userAuthTokenService.RevokeToken(context.Background(), userToken) So(err, ShouldBeNil) model, err := ctx.getAuthTokenByID(userToken.Id) @@ -66,37 +67,37 @@ func TestUserAuthToken(t *testing.T) { }) Convey("revoking nil token should return error", func() { - err = userAuthTokenService.RevokeToken(nil) + err = userAuthTokenService.RevokeToken(context.Background(), nil) So(err, ShouldEqual, models.ErrUserTokenNotFound) }) Convey("revoking non-existing token should return error", func() { userToken.Id = 1000 - err = userAuthTokenService.RevokeToken(userToken) + err = userAuthTokenService.RevokeToken(context.Background(), userToken) So(err, ShouldEqual, models.ErrUserTokenNotFound) }) Convey("When creating an additional token", func() { - userToken2, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") + userToken2, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent") So(err, ShouldBeNil) So(userToken2, ShouldNotBeNil) Convey("Can get first user token", func() { - token, err := userAuthTokenService.GetUserToken(userID, userToken.Id) + token, err := userAuthTokenService.GetUserToken(context.Background(), userID, userToken.Id) So(err, ShouldBeNil) So(token, ShouldNotBeNil) So(token.Id, ShouldEqual, userToken.Id) }) Convey("Can get second user token", func() { - token, err := userAuthTokenService.GetUserToken(userID, userToken2.Id) + token, err := userAuthTokenService.GetUserToken(context.Background(), userID, userToken2.Id) So(err, ShouldBeNil) So(token, ShouldNotBeNil) So(token.Id, ShouldEqual, userToken2.Id) }) Convey("Can get user tokens", func() { - tokens, err := userAuthTokenService.GetUserTokens(userID) + tokens, err := userAuthTokenService.GetUserTokens(context.Background(), userID) So(err, ShouldBeNil) So(tokens, ShouldHaveLength, 2) So(tokens[0].Id, ShouldEqual, userToken.Id) @@ -104,7 +105,7 @@ func TestUserAuthToken(t *testing.T) { }) Convey("Can revoke all user tokens", func() { - err := userAuthTokenService.RevokeAllUserTokens(userID) + err := userAuthTokenService.RevokeAllUserTokens(context.Background(), userID) So(err, ShouldBeNil) model, err := ctx.getAuthTokenByID(userToken.Id) @@ -119,24 +120,24 @@ func TestUserAuthToken(t *testing.T) { }) Convey("expires correctly", func() { - userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") + userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent") So(err, ShouldBeNil) - userToken, err = userAuthTokenService.LookupToken(userToken.UnhashedToken) + userToken, err = userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken) So(err, ShouldBeNil) getTime = func() time.Time { return t.Add(time.Hour) } - rotated, err := userAuthTokenService.TryRotateToken(userToken, "192.168.10.11:1234", "some user agent") + rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "192.168.10.11:1234", "some user agent") So(err, ShouldBeNil) So(rotated, ShouldBeTrue) - userToken, err = userAuthTokenService.LookupToken(userToken.UnhashedToken) + userToken, err = userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken) So(err, ShouldBeNil) - stillGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) + stillGood, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken) So(err, ShouldBeNil) So(stillGood, ShouldNotBeNil) @@ -148,7 +149,7 @@ func TestUserAuthToken(t *testing.T) { return time.Unix(model.RotatedAt, 0).Add(24 * 7 * time.Hour).Add(-time.Second) } - stillGood, err = userAuthTokenService.LookupToken(stillGood.UnhashedToken) + stillGood, err = userAuthTokenService.LookupToken(context.Background(), stillGood.UnhashedToken) So(err, ShouldBeNil) So(stillGood, ShouldNotBeNil) }) @@ -158,12 +159,12 @@ func TestUserAuthToken(t *testing.T) { return time.Unix(model.RotatedAt, 0).Add(24 * 7 * time.Hour) } - notGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) + notGood, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken) So(err, ShouldEqual, models.ErrUserTokenNotFound) So(notGood, ShouldBeNil) Convey("should not find active token when expired", func() { - count, err := userAuthTokenService.ActiveTokenCount() + count, err := userAuthTokenService.ActiveTokenCount(context.Background()) So(err, ShouldBeNil) So(count, ShouldEqual, 0) }) @@ -178,7 +179,7 @@ func TestUserAuthToken(t *testing.T) { return time.Unix(model.CreatedAt, 0).Add(24 * 30 * time.Hour).Add(-time.Second) } - stillGood, err = userAuthTokenService.LookupToken(stillGood.UnhashedToken) + stillGood, err = userAuthTokenService.LookupToken(context.Background(), stillGood.UnhashedToken) So(err, ShouldBeNil) So(stillGood, ShouldNotBeNil) }) @@ -192,20 +193,20 @@ func TestUserAuthToken(t *testing.T) { return time.Unix(model.CreatedAt, 0).Add(24 * 30 * time.Hour) } - notGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) + notGood, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken) So(err, ShouldEqual, models.ErrUserTokenNotFound) So(notGood, ShouldBeNil) }) }) Convey("can properly rotate tokens", func() { - userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") + userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent") So(err, ShouldBeNil) prevToken := userToken.AuthToken unhashedPrev := userToken.UnhashedToken - rotated, err := userAuthTokenService.TryRotateToken(userToken, "192.168.10.12:1234", "a new user agent") + rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "192.168.10.12:1234", "a new user agent") So(err, ShouldBeNil) So(rotated, ShouldBeFalse) @@ -224,7 +225,7 @@ func TestUserAuthToken(t *testing.T) { return t.Add(time.Hour) } - rotated, err = userAuthTokenService.TryRotateToken(&tok, "192.168.10.12:1234", "a new user agent") + rotated, err = userAuthTokenService.TryRotateToken(context.Background(), &tok, "192.168.10.12:1234", "a new user agent") So(err, ShouldBeNil) So(rotated, ShouldBeTrue) @@ -243,13 +244,13 @@ func TestUserAuthToken(t *testing.T) { // ability to auth using an old token - lookedUpUserToken, err := userAuthTokenService.LookupToken(model.UnhashedToken) + lookedUpUserToken, err := userAuthTokenService.LookupToken(context.Background(), model.UnhashedToken) So(err, ShouldBeNil) So(lookedUpUserToken, ShouldNotBeNil) So(lookedUpUserToken.AuthTokenSeen, ShouldBeTrue) So(lookedUpUserToken.SeenAt, ShouldEqual, getTime().Unix()) - lookedUpUserToken, err = userAuthTokenService.LookupToken(unhashedPrev) + lookedUpUserToken, err = userAuthTokenService.LookupToken(context.Background(), unhashedPrev) So(err, ShouldBeNil) So(lookedUpUserToken, ShouldNotBeNil) So(lookedUpUserToken.Id, ShouldEqual, model.Id) @@ -259,7 +260,7 @@ func TestUserAuthToken(t *testing.T) { return t.Add(time.Hour + (2 * time.Minute)) } - lookedUpUserToken, err = userAuthTokenService.LookupToken(unhashedPrev) + lookedUpUserToken, err = userAuthTokenService.LookupToken(context.Background(), unhashedPrev) So(err, ShouldBeNil) So(lookedUpUserToken, ShouldNotBeNil) So(lookedUpUserToken.AuthTokenSeen, ShouldBeTrue) @@ -269,7 +270,7 @@ func TestUserAuthToken(t *testing.T) { So(lookedUpModel, ShouldNotBeNil) So(lookedUpModel.AuthTokenSeen, ShouldBeFalse) - rotated, err = userAuthTokenService.TryRotateToken(userToken, "192.168.10.12:1234", "a new user agent") + rotated, err = userAuthTokenService.TryRotateToken(context.Background(), userToken, "192.168.10.12:1234", "a new user agent") So(err, ShouldBeNil) So(rotated, ShouldBeTrue) @@ -280,11 +281,11 @@ func TestUserAuthToken(t *testing.T) { }) Convey("keeps prev token valid for 1 minute after it is confirmed", func() { - userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") + userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent") So(err, ShouldBeNil) So(userToken, ShouldNotBeNil) - lookedUpUserToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) + lookedUpUserToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken) So(err, ShouldBeNil) So(lookedUpUserToken, ShouldNotBeNil) @@ -293,7 +294,7 @@ func TestUserAuthToken(t *testing.T) { } prevToken := userToken.UnhashedToken - rotated, err := userAuthTokenService.TryRotateToken(userToken, "1.1.1.1", "firefox") + rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "1.1.1.1", "firefox") So(err, ShouldBeNil) So(rotated, ShouldBeTrue) @@ -301,25 +302,25 @@ func TestUserAuthToken(t *testing.T) { return t.Add(20 * time.Minute) } - currentUserToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) + currentUserToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken) So(err, ShouldBeNil) So(currentUserToken, ShouldNotBeNil) - prevUserToken, err := userAuthTokenService.LookupToken(prevToken) + prevUserToken, err := userAuthTokenService.LookupToken(context.Background(), prevToken) So(err, ShouldBeNil) So(prevUserToken, ShouldNotBeNil) }) Convey("will not mark token unseen when prev and current are the same", func() { - userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") + userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent") So(err, ShouldBeNil) So(userToken, ShouldNotBeNil) - lookedUpUserToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) + lookedUpUserToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken) So(err, ShouldBeNil) So(lookedUpUserToken, ShouldNotBeNil) - lookedUpUserToken, err = userAuthTokenService.LookupToken(userToken.UnhashedToken) + lookedUpUserToken, err = userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken) So(err, ShouldBeNil) So(lookedUpUserToken, ShouldNotBeNil) @@ -330,7 +331,7 @@ func TestUserAuthToken(t *testing.T) { }) Convey("Rotate token", func() { - userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") + userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent") So(err, ShouldBeNil) So(userToken, ShouldNotBeNil) @@ -345,7 +346,7 @@ func TestUserAuthToken(t *testing.T) { return t.Add(10 * time.Minute) } - rotated, err := userAuthTokenService.TryRotateToken(userToken, "1.1.1.1", "firefox") + rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "1.1.1.1", "firefox") So(err, ShouldBeNil) So(rotated, ShouldBeTrue) @@ -366,7 +367,7 @@ func TestUserAuthToken(t *testing.T) { return t.Add(20 * time.Minute) } - rotated, err = userAuthTokenService.TryRotateToken(userToken, "1.1.1.1", "firefox") + rotated, err = userAuthTokenService.TryRotateToken(context.Background(), userToken, "1.1.1.1", "firefox") So(err, ShouldBeNil) So(rotated, ShouldBeTrue) @@ -385,7 +386,7 @@ func TestUserAuthToken(t *testing.T) { return t.Add(2 * time.Minute) } - rotated, err := userAuthTokenService.TryRotateToken(userToken, "1.1.1.1", "firefox") + rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "1.1.1.1", "firefox") So(err, ShouldBeNil) So(rotated, ShouldBeTrue) diff --git a/pkg/services/auth/testing.go b/pkg/services/auth/testing.go index 68e65466c3d..378a68b053c 100644 --- a/pkg/services/auth/testing.go +++ b/pkg/services/auth/testing.go @@ -1,81 +1,85 @@ package auth -import "github.com/grafana/grafana/pkg/models" +import ( + "context" + + "github.com/grafana/grafana/pkg/models" +) type FakeUserAuthTokenService struct { - CreateTokenProvider func(userId int64, clientIP, userAgent string) (*models.UserToken, error) - TryRotateTokenProvider func(token *models.UserToken, clientIP, userAgent string) (bool, error) - LookupTokenProvider func(unhashedToken string) (*models.UserToken, error) - RevokeTokenProvider func(token *models.UserToken) error - RevokeAllUserTokensProvider func(userId int64) error - ActiveAuthTokenCount func() (int64, error) - GetUserTokenProvider func(userId, userTokenId int64) (*models.UserToken, error) - GetUserTokensProvider func(userId int64) ([]*models.UserToken, error) + CreateTokenProvider func(ctx context.Context, userId int64, clientIP, userAgent string) (*models.UserToken, error) + TryRotateTokenProvider func(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error) + LookupTokenProvider func(ctx context.Context, unhashedToken string) (*models.UserToken, error) + RevokeTokenProvider func(ctx context.Context, token *models.UserToken) error + RevokeAllUserTokensProvider func(ctx context.Context, userId int64) error + ActiveAuthTokenCount func(ctx context.Context) (int64, error) + GetUserTokenProvider func(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) + GetUserTokensProvider func(ctx context.Context, userId int64) ([]*models.UserToken, error) } func NewFakeUserAuthTokenService() *FakeUserAuthTokenService { return &FakeUserAuthTokenService{ - CreateTokenProvider: func(userId int64, clientIP, userAgent string) (*models.UserToken, error) { + CreateTokenProvider: func(ctx context.Context, userId int64, clientIP, userAgent string) (*models.UserToken, error) { return &models.UserToken{ UserId: 0, UnhashedToken: "", }, nil }, - TryRotateTokenProvider: func(token *models.UserToken, clientIP, userAgent string) (bool, error) { + TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error) { return false, nil }, - LookupTokenProvider: func(unhashedToken string) (*models.UserToken, error) { + LookupTokenProvider: func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { return &models.UserToken{ UserId: 0, UnhashedToken: "", }, nil }, - RevokeTokenProvider: func(token *models.UserToken) error { + RevokeTokenProvider: func(ctx context.Context, token *models.UserToken) error { return nil }, - RevokeAllUserTokensProvider: func(userId int64) error { + RevokeAllUserTokensProvider: func(ctx context.Context, userId int64) error { return nil }, - ActiveAuthTokenCount: func() (int64, error) { + ActiveAuthTokenCount: func(ctx context.Context) (int64, error) { return 10, nil }, - GetUserTokenProvider: func(userId, userTokenId int64) (*models.UserToken, error) { + GetUserTokenProvider: func(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) { return nil, nil }, - GetUserTokensProvider: func(userId int64) ([]*models.UserToken, error) { + GetUserTokensProvider: func(ctx context.Context, userId int64) ([]*models.UserToken, error) { return nil, nil }, } } -func (s *FakeUserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) { - return s.CreateTokenProvider(userId, clientIP, userAgent) +func (s *FakeUserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP, userAgent string) (*models.UserToken, error) { + return s.CreateTokenProvider(context.Background(), userId, clientIP, userAgent) } -func (s *FakeUserAuthTokenService) LookupToken(unhashedToken string) (*models.UserToken, error) { - return s.LookupTokenProvider(unhashedToken) +func (s *FakeUserAuthTokenService) LookupToken(ctx context.Context, unhashedToken string) (*models.UserToken, error) { + return s.LookupTokenProvider(context.Background(), unhashedToken) } -func (s *FakeUserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP, userAgent string) (bool, error) { - return s.TryRotateTokenProvider(token, clientIP, userAgent) +func (s *FakeUserAuthTokenService) TryRotateToken(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error) { + return s.TryRotateTokenProvider(context.Background(), token, clientIP, userAgent) } -func (s *FakeUserAuthTokenService) RevokeToken(token *models.UserToken) error { - return s.RevokeTokenProvider(token) +func (s *FakeUserAuthTokenService) RevokeToken(ctx context.Context, token *models.UserToken) error { + return s.RevokeTokenProvider(context.Background(), token) } -func (s *FakeUserAuthTokenService) RevokeAllUserTokens(userId int64) error { - return s.RevokeAllUserTokensProvider(userId) +func (s *FakeUserAuthTokenService) RevokeAllUserTokens(ctx context.Context, userId int64) error { + return s.RevokeAllUserTokensProvider(context.Background(), userId) } -func (s *FakeUserAuthTokenService) ActiveTokenCount() (int64, error) { - return s.ActiveAuthTokenCount() +func (s *FakeUserAuthTokenService) ActiveTokenCount(ctx context.Context) (int64, error) { + return s.ActiveAuthTokenCount(context.Background()) } -func (s *FakeUserAuthTokenService) GetUserToken(userId, userTokenId int64) (*models.UserToken, error) { - return s.GetUserTokenProvider(userId, userTokenId) +func (s *FakeUserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) { + return s.GetUserTokenProvider(context.Background(), userId, userTokenId) } -func (s *FakeUserAuthTokenService) GetUserTokens(userId int64) ([]*models.UserToken, error) { - return s.GetUserTokensProvider(userId) +func (s *FakeUserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) ([]*models.UserToken, error) { + return s.GetUserTokensProvider(context.Background(), userId) } diff --git a/pkg/services/auth/token_cleanup.go b/pkg/services/auth/token_cleanup.go index 1fe0996aa4c..671d3d7f5b7 100644 --- a/pkg/services/auth/token_cleanup.go +++ b/pkg/services/auth/token_cleanup.go @@ -3,6 +3,8 @@ package auth import ( "context" "time" + + "github.com/grafana/grafana/pkg/services/sqlstore" ) func (srv *UserAuthTokenService) Run(ctx context.Context) error { @@ -11,21 +13,22 @@ func (srv *UserAuthTokenService) Run(ctx context.Context) error { maxLifetime := time.Duration(srv.Cfg.LoginMaxLifetimeDays) * 24 * time.Hour err := srv.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() { - srv.deleteExpiredTokens(maxInactiveLifetime, maxLifetime) + srv.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime) }) + if err != nil { - srv.log.Error("failed to lock and execite cleanup of expired auth token", "erro", err) + srv.log.Error("failed to lock and execute cleanup of expired auth token", "error", err) } for { select { case <-ticker.C: err := srv.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() { - srv.deleteExpiredTokens(maxInactiveLifetime, maxLifetime) + srv.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime) }) if err != nil { - srv.log.Error("failed to lock and execite cleanup of expired auth token", "erro", err) + srv.log.Error("failed to lock and execute cleanup of expired auth token", "error", err) } case <-ctx.Done(): @@ -34,24 +37,30 @@ func (srv *UserAuthTokenService) Run(ctx context.Context) error { } } -func (srv *UserAuthTokenService) deleteExpiredTokens(maxInactiveLifetime, maxLifetime time.Duration) (int64, error) { +func (srv *UserAuthTokenService) deleteExpiredTokens(ctx context.Context, maxInactiveLifetime, maxLifetime time.Duration) (int64, error) { createdBefore := getTime().Add(-maxLifetime) rotatedBefore := getTime().Add(-maxInactiveLifetime) srv.log.Debug("starting cleanup of expired auth tokens", "createdBefore", createdBefore, "rotatedBefore", rotatedBefore) - sql := `DELETE from user_auth_token WHERE created_at <= ? OR rotated_at <= ?` - res, err := srv.SQLStore.NewSession().Exec(sql, createdBefore.Unix(), rotatedBefore.Unix()) - if err != nil { - return 0, err - } + var affected int64 + err := srv.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error { + sql := `DELETE from user_auth_token WHERE created_at <= ? OR rotated_at <= ?` + res, err := dbSession.Exec(sql, createdBefore.Unix(), rotatedBefore.Unix()) + if err != nil { + return err + } - affected, err := res.RowsAffected() - if err != nil { - srv.log.Error("failed to cleanup expired auth tokens", "error", err) - return 0, nil - } + affected, err = res.RowsAffected() + if err != nil { + srv.log.Error("failed to cleanup expired auth tokens", "error", err) + return nil + } + + srv.log.Debug("cleanup of expired auth tokens done", "count", affected) + + return nil + }) - srv.log.Debug("cleanup of expired auth tokens done", "count", affected) return affected, err } diff --git a/pkg/services/auth/token_cleanup_test.go b/pkg/services/auth/token_cleanup_test.go index 410764d3f8d..2df42eb724c 100644 --- a/pkg/services/auth/token_cleanup_test.go +++ b/pkg/services/auth/token_cleanup_test.go @@ -1,6 +1,7 @@ package auth import ( + "context" "fmt" "testing" "time" @@ -40,7 +41,7 @@ func TestUserAuthTokenCleanup(t *testing.T) { insertToken(fmt.Sprintf("newA%d", i), fmt.Sprintf("newB%d", i), from.Unix(), from.Unix()) } - affected, err := ctx.tokenService.deleteExpiredTokens(7*24*time.Hour, 30*24*time.Hour) + affected, err := ctx.tokenService.deleteExpiredTokens(context.Background(), 7*24*time.Hour, 30*24*time.Hour) So(err, ShouldBeNil) So(affected, ShouldEqual, 3) }) @@ -60,7 +61,7 @@ func TestUserAuthTokenCleanup(t *testing.T) { insertToken(fmt.Sprintf("newA%d", i), fmt.Sprintf("newB%d", i), from.Unix(), fromRotate.Unix()) } - affected, err := ctx.tokenService.deleteExpiredTokens(7*24*time.Hour, 30*24*time.Hour) + affected, err := ctx.tokenService.deleteExpiredTokens(context.Background(), 7*24*time.Hour, 30*24*time.Hour) So(err, ShouldBeNil) So(affected, ShouldEqual, 3) }) diff --git a/pkg/services/quota/quota.go b/pkg/services/quota/quota.go index b65ad932699..7e1e62fc52a 100644 --- a/pkg/services/quota/quota.go +++ b/pkg/services/quota/quota.go @@ -43,7 +43,7 @@ func (qs *QuotaService) QuotaReached(c *m.ReqContext, target string) (bool, erro } if target == "session" { - usedSessions, err := qs.AuthTokenService.ActiveTokenCount() + usedSessions, err := qs.AuthTokenService.ActiveTokenCount(c.Req.Context()) if err != nil { return false, err }