Auth: Enable retries and transaction for some db calls for auth tokens (#16785)

the WithSession wrapper handles retries and connection
management so the caller dont have to worry about it.
pull/16813/head
Carl Bergquist 6 years ago committed by GitHub
parent eb82a75668
commit 9660356638
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      pkg/api/admin_users.go
  2. 4
      pkg/api/login.go
  3. 11
      pkg/api/user_token.go
  4. 9
      pkg/api/user_token_test.go
  5. 4
      pkg/middleware/middleware.go
  6. 9
      pkg/middleware/middleware_test.go
  7. 5
      pkg/middleware/org_redirect_test.go
  8. 3
      pkg/middleware/quota_test.go
  9. 17
      pkg/models/user_token.go
  10. 181
      pkg/services/auth/auth_token.go
  11. 83
      pkg/services/auth/auth_token_test.go
  12. 70
      pkg/services/auth/testing.go
  13. 41
      pkg/services/auth/token_cleanup.go
  14. 5
      pkg/services/auth/token_cleanup_test.go
  15. 2
      pkg/services/quota/quota.go

@ -119,7 +119,7 @@ func (server *HTTPServer) AdminLogoutUser(c *m.ReqContext) Response {
return Error(400, "You cannot logout yourself", nil) return Error(400, "You cannot logout yourself", nil)
} }
return server.logoutUserFromAllDevicesInternal(userID) return server.logoutUserFromAllDevicesInternal(c.Req.Context(), userID)
} }
// GET /api/admin/users/:id/auth-tokens // GET /api/admin/users/:id/auth-tokens

@ -131,7 +131,7 @@ func (hs *HTTPServer) loginUserWithUser(user *m.User, c *m.ReqContext) {
hs.log.Error("user login with nil user") hs.log.Error("user login with nil user")
} }
userToken, err := hs.AuthTokenService.CreateToken(user.Id, c.RemoteAddr(), c.Req.UserAgent()) userToken, err := hs.AuthTokenService.CreateToken(c.Req.Context(), user.Id, c.RemoteAddr(), c.Req.UserAgent())
if err != nil { if err != nil {
hs.log.Error("failed to create auth token", "error", err) hs.log.Error("failed to create auth token", "error", err)
} }
@ -140,7 +140,7 @@ func (hs *HTTPServer) loginUserWithUser(user *m.User, c *m.ReqContext) {
} }
func (hs *HTTPServer) Logout(c *m.ReqContext) { func (hs *HTTPServer) Logout(c *m.ReqContext) {
if err := hs.AuthTokenService.RevokeToken(c.UserToken); err != nil && err != m.ErrUserTokenNotFound { if err := hs.AuthTokenService.RevokeToken(c.Req.Context(), c.UserToken); err != nil && err != m.ErrUserTokenNotFound {
hs.log.Error("failed to revoke auth token", "error", err) hs.log.Error("failed to revoke auth token", "error", err)
} }

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"time" "time"
"github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/dtos"
@ -19,7 +20,7 @@ func (server *HTTPServer) RevokeUserAuthToken(c *models.ReqContext, cmd models.R
return server.revokeUserAuthTokenInternal(c, c.UserId, cmd) return server.revokeUserAuthTokenInternal(c, c.UserId, cmd)
} }
func (server *HTTPServer) logoutUserFromAllDevicesInternal(userID int64) Response { func (server *HTTPServer) logoutUserFromAllDevicesInternal(ctx context.Context, userID int64) Response {
userQuery := models.GetUserByIdQuery{Id: userID} userQuery := models.GetUserByIdQuery{Id: userID}
if err := bus.Dispatch(&userQuery); err != nil { if err := bus.Dispatch(&userQuery); err != nil {
@ -29,7 +30,7 @@ func (server *HTTPServer) logoutUserFromAllDevicesInternal(userID int64) Respons
return Error(500, "Could not read user from database", err) return Error(500, "Could not read user from database", err)
} }
err := server.AuthTokenService.RevokeAllUserTokens(userID) err := server.AuthTokenService.RevokeAllUserTokens(ctx, userID)
if err != nil { if err != nil {
return Error(500, "Failed to logout user", err) return Error(500, "Failed to logout user", err)
} }
@ -49,7 +50,7 @@ func (server *HTTPServer) getUserAuthTokensInternal(c *models.ReqContext, userID
return Error(500, "Failed to get user", err) return Error(500, "Failed to get user", err)
} }
tokens, err := server.AuthTokenService.GetUserTokens(userID) tokens, err := server.AuthTokenService.GetUserTokens(c.Req.Context(), userID)
if err != nil { if err != nil {
return Error(500, "Failed to get user auth tokens", err) return Error(500, "Failed to get user auth tokens", err)
} }
@ -84,7 +85,7 @@ func (server *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, user
return Error(500, "Failed to get user", err) return Error(500, "Failed to get user", err)
} }
token, err := server.AuthTokenService.GetUserToken(userID, cmd.AuthTokenId) token, err := server.AuthTokenService.GetUserToken(c.Req.Context(), userID, cmd.AuthTokenId)
if err != nil { if err != nil {
if err == models.ErrUserTokenNotFound { if err == models.ErrUserTokenNotFound {
return Error(404, "User auth token not found", err) return Error(404, "User auth token not found", err)
@ -96,7 +97,7 @@ func (server *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, user
return Error(400, "Cannot revoke active user auth token", nil) return Error(400, "Cannot revoke active user auth token", nil)
} }
err = server.AuthTokenService.RevokeToken(token) err = server.AuthTokenService.RevokeToken(c.Req.Context(), token)
if err != nil { if err != nil {
if err == models.ErrUserTokenNotFound { if err == models.ErrUserTokenNotFound {
return Error(404, "User auth token not found", err) return Error(404, "User auth token not found", err)

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"testing" "testing"
"time" "time"
@ -75,7 +76,7 @@ func TestUserTokenApiEndpoint(t *testing.T) {
token := &m.UserToken{Id: 1} token := &m.UserToken{Id: 1}
revokeUserAuthTokenInternalScenario("Should be successful", cmd, 200, token, func(sc *scenarioContext) { revokeUserAuthTokenInternalScenario("Should be successful", cmd, 200, token, func(sc *scenarioContext) {
sc.userAuthTokenService.GetUserTokenProvider = func(userId, userTokenId int64) (*m.UserToken, error) { sc.userAuthTokenService.GetUserTokenProvider = func(ctx context.Context, userId, userTokenId int64) (*m.UserToken, error) {
return &m.UserToken{Id: 2}, nil return &m.UserToken{Id: 2}, nil
} }
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec()
@ -93,7 +94,7 @@ func TestUserTokenApiEndpoint(t *testing.T) {
token := &m.UserToken{Id: 2} token := &m.UserToken{Id: 2}
revokeUserAuthTokenInternalScenario("Should not be successful", cmd, TestUserID, token, func(sc *scenarioContext) { revokeUserAuthTokenInternalScenario("Should not be successful", cmd, TestUserID, token, func(sc *scenarioContext) {
sc.userAuthTokenService.GetUserTokenProvider = func(userId, userTokenId int64) (*m.UserToken, error) { sc.userAuthTokenService.GetUserTokenProvider = func(ctx context.Context, userId, userTokenId int64) (*m.UserToken, error) {
return token, nil return token, nil
} }
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec()
@ -126,7 +127,7 @@ func TestUserTokenApiEndpoint(t *testing.T) {
SeenAt: time.Now().Unix(), SeenAt: time.Now().Unix(),
}, },
} }
sc.userAuthTokenService.GetUserTokensProvider = func(userId int64) ([]*m.UserToken, error) { sc.userAuthTokenService.GetUserTokensProvider = func(ctx context.Context, userId int64) ([]*m.UserToken, error) {
return tokens, nil return tokens, nil
} }
sc.fakeReqWithParams("GET", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("GET", sc.url, map[string]string{}).exec()
@ -226,7 +227,7 @@ func logoutUserFromAllDevicesInternalScenario(desc string, userId int64, fn scen
sc.context.OrgId = TestOrgID sc.context.OrgId = TestOrgID
sc.context.OrgRole = m.ROLE_ADMIN sc.context.OrgRole = m.ROLE_ADMIN
return hs.logoutUserFromAllDevicesInternal(userId) return hs.logoutUserFromAllDevicesInternal(context.Background(), userId)
}) })
sc.m.Post("/", sc.defaultHandler) sc.m.Post("/", sc.defaultHandler)

@ -173,7 +173,7 @@ func initContextWithToken(authTokenService m.UserTokenService, ctx *m.ReqContext
return false return false
} }
token, err := authTokenService.LookupToken(rawToken) token, err := authTokenService.LookupToken(ctx.Req.Context(), rawToken)
if err != nil { if err != nil {
ctx.Logger.Error("failed to look up user based on cookie", "error", err) ctx.Logger.Error("failed to look up user based on cookie", "error", err)
WriteSessionCookie(ctx, "", -1) WriteSessionCookie(ctx, "", -1)
@ -190,7 +190,7 @@ func initContextWithToken(authTokenService m.UserTokenService, ctx *m.ReqContext
ctx.IsSignedIn = true ctx.IsSignedIn = true
ctx.UserToken = token ctx.UserToken = token
rotated, err := authTokenService.TryRotateToken(token, ctx.RemoteAddr(), ctx.Req.UserAgent()) rotated, err := authTokenService.TryRotateToken(ctx.Req.Context(), token, ctx.RemoteAddr(), ctx.Req.UserAgent())
if err != nil { if err != nil {
ctx.Logger.Error("failed to rotate token", "error", err) ctx.Logger.Error("failed to rotate token", "error", err)
return true return true

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -156,7 +157,7 @@ func TestMiddlewareContext(t *testing.T) {
return nil return nil
}) })
sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) {
return &m.UserToken{ return &m.UserToken{
UserId: 12, UserId: 12,
UnhashedToken: unhashedToken, UnhashedToken: unhashedToken,
@ -185,14 +186,14 @@ func TestMiddlewareContext(t *testing.T) {
return nil return nil
}) })
sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) {
return &m.UserToken{ return &m.UserToken{
UserId: 12, UserId: 12,
UnhashedToken: "", UnhashedToken: "",
}, nil }, nil
} }
sc.userAuthTokenService.TryRotateTokenProvider = func(userToken *m.UserToken, clientIP, userAgent string) (bool, error) { sc.userAuthTokenService.TryRotateTokenProvider = func(ctx context.Context, userToken *m.UserToken, clientIP, userAgent string) (bool, error) {
userToken.UnhashedToken = "rotated" userToken.UnhashedToken = "rotated"
return true, nil return true, nil
} }
@ -227,7 +228,7 @@ func TestMiddlewareContext(t *testing.T) {
middlewareScenario(t, "Invalid/expired auth token in cookie", func(sc *scenarioContext) { middlewareScenario(t, "Invalid/expired auth token in cookie", func(sc *scenarioContext) {
sc.withTokenSessionCookie("token") sc.withTokenSessionCookie("token")
sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) {
return nil, m.ErrUserTokenNotFound return nil, m.ErrUserTokenNotFound
} }

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
@ -23,7 +24,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
return nil return nil
}) })
sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) {
return &m.UserToken{ return &m.UserToken{
UserId: 0, UserId: 0,
UnhashedToken: "", UnhashedToken: "",
@ -49,7 +50,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
return nil return nil
}) })
sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) {
return &m.UserToken{ return &m.UserToken{
UserId: 12, UserId: 12,
UnhashedToken: "", UnhashedToken: "",

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"context"
"testing" "testing"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
@ -87,7 +88,7 @@ func TestMiddlewareQuota(t *testing.T) {
return nil return nil
}) })
sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*m.UserToken, error) {
return &m.UserToken{ return &m.UserToken{
UserId: 12, UserId: 12,
UnhashedToken: "", UnhashedToken: "",

@ -1,6 +1,7 @@
package models package models
import ( import (
"context"
"errors" "errors"
) )
@ -31,12 +32,12 @@ type RevokeAuthTokenCmd struct {
// UserTokenService are used for generating and validating user tokens // UserTokenService are used for generating and validating user tokens
type UserTokenService interface { type UserTokenService interface {
CreateToken(userId int64, clientIP, userAgent string) (*UserToken, error) CreateToken(ctx context.Context, userId int64, clientIP, userAgent string) (*UserToken, error)
LookupToken(unhashedToken string) (*UserToken, error) LookupToken(ctx context.Context, unhashedToken string) (*UserToken, error)
TryRotateToken(token *UserToken, clientIP, userAgent string) (bool, error) TryRotateToken(ctx context.Context, token *UserToken, clientIP, userAgent string) (bool, error)
RevokeToken(token *UserToken) error RevokeToken(ctx context.Context, token *UserToken) error
RevokeAllUserTokens(userId int64) error RevokeAllUserTokens(ctx context.Context, userId int64) error
ActiveTokenCount() (int64, error) ActiveTokenCount(ctx context.Context) (int64, error)
GetUserToken(userId, userTokenId int64) (*UserToken, error) GetUserToken(ctx context.Context, userId, userTokenId int64) (*UserToken, error)
GetUserTokens(userId int64) ([]*UserToken, error) GetUserTokens(ctx context.Context, userId int64) ([]*UserToken, error)
} }

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"time" "time"
@ -35,14 +36,24 @@ func (s *UserAuthTokenService) Init() error {
return nil return nil
} }
func (s *UserAuthTokenService) ActiveTokenCount() (int64, error) { func (s *UserAuthTokenService) ActiveTokenCount(ctx context.Context) (int64, error) {
var model userAuthToken
count, err := s.SQLStore.NewSession().Where(`created_at > ? AND rotated_at > ?`, s.createdAfterParam(), s.rotatedAfterParam()).Count(&model) var count int64
var err error
err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
var model userAuthToken
count, err = dbSession.Where(`created_at > ? AND rotated_at > ?`,
s.createdAfterParam(),
s.rotatedAfterParam()).
Count(&model)
return err
})
return count, err return count, err
} }
func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) { func (s *UserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP, userAgent string) (*models.UserToken, error) {
clientIP = util.ParseIPAddress(clientIP) clientIP = util.ParseIPAddress(clientIP)
token, err := util.RandomHex(16) token, err := util.RandomHex(16)
if err != nil { if err != nil {
@ -65,7 +76,12 @@ func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent str
SeenAt: 0, SeenAt: 0,
AuthTokenSeen: false, AuthTokenSeen: false,
} }
_, err = s.SQLStore.NewSession().Insert(&userAuthToken)
err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
_, err = dbSession.Insert(&userAuthToken)
return err
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -80,14 +96,27 @@ func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent str
return &userToken, err return &userToken, err
} }
func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserToken, error) { func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
hashedToken := hashToken(unhashedToken) hashedToken := hashToken(unhashedToken)
if setting.Env == setting.DEV { if setting.Env == setting.DEV {
s.log.Debug("looking up token", "unhashed", unhashedToken, "hashed", hashedToken) s.log.Debug("looking up token", "unhashed", unhashedToken, "hashed", hashedToken)
} }
var model userAuthToken var model userAuthToken
exists, err := s.SQLStore.NewSession().Where("(auth_token = ? OR prev_auth_token = ?) AND created_at > ? AND rotated_at > ?", hashedToken, hashedToken, s.createdAfterParam(), s.rotatedAfterParam()).Get(&model) var exists bool
var err error
err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
exists, err = dbSession.Where("(auth_token = ? OR prev_auth_token = ?) AND created_at > ? AND rotated_at > ?",
hashedToken,
hashedToken,
s.createdAfterParam(),
s.rotatedAfterParam()).
Get(&model)
return err
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -100,7 +129,18 @@ func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserTo
modelCopy := model modelCopy := model
modelCopy.AuthTokenSeen = false modelCopy.AuthTokenSeen = false
expireBefore := getTime().Add(-urgentRotateTime).Unix() expireBefore := getTime().Add(-urgentRotateTime).Unix()
affectedRows, err := s.SQLStore.NewSession().Where("id = ? AND prev_auth_token = ? AND rotated_at < ?", modelCopy.Id, modelCopy.PrevAuthToken, expireBefore).AllCols().Update(&modelCopy)
var affectedRows int64
err = s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
affectedRows, err = dbSession.Where("id = ? AND prev_auth_token = ? AND rotated_at < ?",
modelCopy.Id,
modelCopy.PrevAuthToken,
expireBefore).
AllCols().Update(&modelCopy)
return err
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -116,7 +156,17 @@ func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserTo
modelCopy := model modelCopy := model
modelCopy.AuthTokenSeen = true modelCopy.AuthTokenSeen = true
modelCopy.SeenAt = getTime().Unix() modelCopy.SeenAt = getTime().Unix()
affectedRows, err := s.SQLStore.NewSession().Where("id = ? AND auth_token = ?", modelCopy.Id, modelCopy.AuthToken).AllCols().Update(&modelCopy)
var affectedRows int64
err = s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
affectedRows, err = dbSession.Where("id = ? AND auth_token = ?",
modelCopy.Id,
modelCopy.AuthToken).
AllCols().Update(&modelCopy)
return err
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -140,7 +190,7 @@ func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserTo
return &userToken, err return &userToken, err
} }
func (s *UserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP, userAgent string) (bool, error) { func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error) {
if token == nil { if token == nil {
return false, nil return false, nil
} }
@ -183,12 +233,21 @@ func (s *UserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP,
rotated_at = ? rotated_at = ?
WHERE id = ? AND (auth_token_seen = ? OR rotated_at < ?)` WHERE id = ? AND (auth_token_seen = ? OR rotated_at < ?)`
res, err := s.SQLStore.NewSession().Exec(sql, userAgent, clientIP, s.SQLStore.Dialect.BooleanStr(true), hashedToken, s.SQLStore.Dialect.BooleanStr(false), now.Unix(), model.Id, s.SQLStore.Dialect.BooleanStr(true), now.Add(-30*time.Second).Unix()) var affected int64
err = s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
res, err := dbSession.Exec(sql, userAgent, clientIP, s.SQLStore.Dialect.BooleanStr(true), hashedToken, s.SQLStore.Dialect.BooleanStr(false), now.Unix(), model.Id, s.SQLStore.Dialect.BooleanStr(true), now.Add(-30*time.Second).Unix())
if err != nil {
return err
}
affected, err = res.RowsAffected()
return err
})
if err != nil { if err != nil {
return false, err return false, err
} }
affected, _ := res.RowsAffected()
s.log.Debug("auth token rotated", "affected", affected, "auth_token_id", model.Id, "userId", model.UserId) s.log.Debug("auth token rotated", "affected", affected, "auth_token_id", model.Id, "userId", model.UserId)
if affected > 0 { if affected > 0 {
model.UnhashedToken = newToken model.UnhashedToken = newToken
@ -199,14 +258,20 @@ func (s *UserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP,
return false, nil return false, nil
} }
func (s *UserAuthTokenService) RevokeToken(token *models.UserToken) error { func (s *UserAuthTokenService) RevokeToken(ctx context.Context, token *models.UserToken) error {
if token == nil { if token == nil {
return models.ErrUserTokenNotFound return models.ErrUserTokenNotFound
} }
model := userAuthTokenFromUserToken(token) model := userAuthTokenFromUserToken(token)
rowsAffected, err := s.SQLStore.NewSession().Delete(model) var rowsAffected int64
var err error
err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
rowsAffected, err = dbSession.Delete(model)
return err
})
if err != nil { if err != nil {
return err return err
} }
@ -221,55 +286,71 @@ func (s *UserAuthTokenService) RevokeToken(token *models.UserToken) error {
return nil return nil
} }
func (s *UserAuthTokenService) RevokeAllUserTokens(userId int64) error { func (s *UserAuthTokenService) RevokeAllUserTokens(ctx context.Context, userId int64) error {
sql := `DELETE from user_auth_token WHERE user_id = ?` return s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
res, err := s.SQLStore.NewSession().Exec(sql, userId) sql := `DELETE from user_auth_token WHERE user_id = ?`
if err != nil { res, err := dbSession.Exec(sql, userId)
return err if err != nil {
} return err
}
affected, err := res.RowsAffected() affected, err := res.RowsAffected()
if err != nil { if err != nil {
return err return err
} }
s.log.Debug("all user tokens for user revoked", "userId", userId, "count", affected) s.log.Debug("all user tokens for user revoked", "userId", userId, "count", affected)
return nil return err
})
} }
func (s *UserAuthTokenService) GetUserToken(userId, userTokenId int64) (*models.UserToken, error) { func (s *UserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) {
var token userAuthToken
exists, err := s.SQLStore.NewSession().Where("id = ? AND user_id = ?", userTokenId, userId).Get(&token)
if err != nil {
return nil, err
}
if !exists {
return nil, models.ErrUserTokenNotFound
}
var result models.UserToken var result models.UserToken
token.toUserToken(&result) err := s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
var token userAuthToken
exists, err := dbSession.Where("id = ? AND user_id = ?", userTokenId, userId).Get(&token)
if err != nil {
return err
}
if !exists {
return models.ErrUserTokenNotFound
}
token.toUserToken(&result)
return nil
})
return &result, nil return &result, err
} }
func (s *UserAuthTokenService) GetUserTokens(userId int64) ([]*models.UserToken, error) { func (s *UserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) ([]*models.UserToken, error) {
var tokens []*userAuthToken
err := s.SQLStore.NewSession().Where("user_id = ? AND created_at > ? AND rotated_at > ?", userId, s.createdAfterParam(), s.rotatedAfterParam()).Find(&tokens)
if err != nil {
return nil, err
}
result := []*models.UserToken{} result := []*models.UserToken{}
for _, token := range tokens { err := s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
var userToken models.UserToken var tokens []*userAuthToken
token.toUserToken(&userToken) err := dbSession.Where("user_id = ? AND created_at > ? AND rotated_at > ?",
result = append(result, &userToken) userId,
} s.createdAfterParam(),
s.rotatedAfterParam()).
Find(&tokens)
if err != nil {
return err
}
for _, token := range tokens {
var userToken models.UserToken
token.toUserToken(&userToken)
result = append(result, &userToken)
}
return nil
})
return result, nil return result, err
} }
func (s *UserAuthTokenService) createdAfterParam() int64 { func (s *UserAuthTokenService) createdAfterParam() int64 {

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"context"
"encoding/json" "encoding/json"
"testing" "testing"
"time" "time"
@ -26,19 +27,19 @@ func TestUserAuthToken(t *testing.T) {
} }
Convey("When creating token", func() { Convey("When creating token", func() {
userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(userToken, ShouldNotBeNil) So(userToken, ShouldNotBeNil)
So(userToken.AuthTokenSeen, ShouldBeFalse) So(userToken.AuthTokenSeen, ShouldBeFalse)
Convey("Can count active tokens", func() { Convey("Can count active tokens", func() {
count, err := userAuthTokenService.ActiveTokenCount() count, err := userAuthTokenService.ActiveTokenCount(context.Background())
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(count, ShouldEqual, 1) So(count, ShouldEqual, 1)
}) })
Convey("When lookup unhashed token should return user auth token", func() { Convey("When lookup unhashed token should return user auth token", func() {
userToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) userToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(userToken, ShouldNotBeNil) So(userToken, ShouldNotBeNil)
So(userToken.UserId, ShouldEqual, userID) So(userToken.UserId, ShouldEqual, userID)
@ -51,13 +52,13 @@ func TestUserAuthToken(t *testing.T) {
}) })
Convey("When lookup hashed token should return user auth token not found error", func() { Convey("When lookup hashed token should return user auth token not found error", func() {
userToken, err := userAuthTokenService.LookupToken(userToken.AuthToken) userToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.AuthToken)
So(err, ShouldEqual, models.ErrUserTokenNotFound) So(err, ShouldEqual, models.ErrUserTokenNotFound)
So(userToken, ShouldBeNil) So(userToken, ShouldBeNil)
}) })
Convey("revoking existing token should delete token", func() { Convey("revoking existing token should delete token", func() {
err = userAuthTokenService.RevokeToken(userToken) err = userAuthTokenService.RevokeToken(context.Background(), userToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
model, err := ctx.getAuthTokenByID(userToken.Id) model, err := ctx.getAuthTokenByID(userToken.Id)
@ -66,37 +67,37 @@ func TestUserAuthToken(t *testing.T) {
}) })
Convey("revoking nil token should return error", func() { Convey("revoking nil token should return error", func() {
err = userAuthTokenService.RevokeToken(nil) err = userAuthTokenService.RevokeToken(context.Background(), nil)
So(err, ShouldEqual, models.ErrUserTokenNotFound) So(err, ShouldEqual, models.ErrUserTokenNotFound)
}) })
Convey("revoking non-existing token should return error", func() { Convey("revoking non-existing token should return error", func() {
userToken.Id = 1000 userToken.Id = 1000
err = userAuthTokenService.RevokeToken(userToken) err = userAuthTokenService.RevokeToken(context.Background(), userToken)
So(err, ShouldEqual, models.ErrUserTokenNotFound) So(err, ShouldEqual, models.ErrUserTokenNotFound)
}) })
Convey("When creating an additional token", func() { Convey("When creating an additional token", func() {
userToken2, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") userToken2, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(userToken2, ShouldNotBeNil) So(userToken2, ShouldNotBeNil)
Convey("Can get first user token", func() { Convey("Can get first user token", func() {
token, err := userAuthTokenService.GetUserToken(userID, userToken.Id) token, err := userAuthTokenService.GetUserToken(context.Background(), userID, userToken.Id)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(token, ShouldNotBeNil) So(token, ShouldNotBeNil)
So(token.Id, ShouldEqual, userToken.Id) So(token.Id, ShouldEqual, userToken.Id)
}) })
Convey("Can get second user token", func() { Convey("Can get second user token", func() {
token, err := userAuthTokenService.GetUserToken(userID, userToken2.Id) token, err := userAuthTokenService.GetUserToken(context.Background(), userID, userToken2.Id)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(token, ShouldNotBeNil) So(token, ShouldNotBeNil)
So(token.Id, ShouldEqual, userToken2.Id) So(token.Id, ShouldEqual, userToken2.Id)
}) })
Convey("Can get user tokens", func() { Convey("Can get user tokens", func() {
tokens, err := userAuthTokenService.GetUserTokens(userID) tokens, err := userAuthTokenService.GetUserTokens(context.Background(), userID)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(tokens, ShouldHaveLength, 2) So(tokens, ShouldHaveLength, 2)
So(tokens[0].Id, ShouldEqual, userToken.Id) So(tokens[0].Id, ShouldEqual, userToken.Id)
@ -104,7 +105,7 @@ func TestUserAuthToken(t *testing.T) {
}) })
Convey("Can revoke all user tokens", func() { Convey("Can revoke all user tokens", func() {
err := userAuthTokenService.RevokeAllUserTokens(userID) err := userAuthTokenService.RevokeAllUserTokens(context.Background(), userID)
So(err, ShouldBeNil) So(err, ShouldBeNil)
model, err := ctx.getAuthTokenByID(userToken.Id) model, err := ctx.getAuthTokenByID(userToken.Id)
@ -119,24 +120,24 @@ func TestUserAuthToken(t *testing.T) {
}) })
Convey("expires correctly", func() { Convey("expires correctly", func() {
userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
userToken, err = userAuthTokenService.LookupToken(userToken.UnhashedToken) userToken, err = userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
getTime = func() time.Time { getTime = func() time.Time {
return t.Add(time.Hour) return t.Add(time.Hour)
} }
rotated, err := userAuthTokenService.TryRotateToken(userToken, "192.168.10.11:1234", "some user agent") rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "192.168.10.11:1234", "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(rotated, ShouldBeTrue) So(rotated, ShouldBeTrue)
userToken, err = userAuthTokenService.LookupToken(userToken.UnhashedToken) userToken, err = userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
stillGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) stillGood, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(stillGood, ShouldNotBeNil) So(stillGood, ShouldNotBeNil)
@ -148,7 +149,7 @@ func TestUserAuthToken(t *testing.T) {
return time.Unix(model.RotatedAt, 0).Add(24 * 7 * time.Hour).Add(-time.Second) return time.Unix(model.RotatedAt, 0).Add(24 * 7 * time.Hour).Add(-time.Second)
} }
stillGood, err = userAuthTokenService.LookupToken(stillGood.UnhashedToken) stillGood, err = userAuthTokenService.LookupToken(context.Background(), stillGood.UnhashedToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(stillGood, ShouldNotBeNil) So(stillGood, ShouldNotBeNil)
}) })
@ -158,12 +159,12 @@ func TestUserAuthToken(t *testing.T) {
return time.Unix(model.RotatedAt, 0).Add(24 * 7 * time.Hour) return time.Unix(model.RotatedAt, 0).Add(24 * 7 * time.Hour)
} }
notGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) notGood, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
So(err, ShouldEqual, models.ErrUserTokenNotFound) So(err, ShouldEqual, models.ErrUserTokenNotFound)
So(notGood, ShouldBeNil) So(notGood, ShouldBeNil)
Convey("should not find active token when expired", func() { Convey("should not find active token when expired", func() {
count, err := userAuthTokenService.ActiveTokenCount() count, err := userAuthTokenService.ActiveTokenCount(context.Background())
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(count, ShouldEqual, 0) So(count, ShouldEqual, 0)
}) })
@ -178,7 +179,7 @@ func TestUserAuthToken(t *testing.T) {
return time.Unix(model.CreatedAt, 0).Add(24 * 30 * time.Hour).Add(-time.Second) return time.Unix(model.CreatedAt, 0).Add(24 * 30 * time.Hour).Add(-time.Second)
} }
stillGood, err = userAuthTokenService.LookupToken(stillGood.UnhashedToken) stillGood, err = userAuthTokenService.LookupToken(context.Background(), stillGood.UnhashedToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(stillGood, ShouldNotBeNil) So(stillGood, ShouldNotBeNil)
}) })
@ -192,20 +193,20 @@ func TestUserAuthToken(t *testing.T) {
return time.Unix(model.CreatedAt, 0).Add(24 * 30 * time.Hour) return time.Unix(model.CreatedAt, 0).Add(24 * 30 * time.Hour)
} }
notGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) notGood, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
So(err, ShouldEqual, models.ErrUserTokenNotFound) So(err, ShouldEqual, models.ErrUserTokenNotFound)
So(notGood, ShouldBeNil) So(notGood, ShouldBeNil)
}) })
}) })
Convey("can properly rotate tokens", func() { Convey("can properly rotate tokens", func() {
userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
prevToken := userToken.AuthToken prevToken := userToken.AuthToken
unhashedPrev := userToken.UnhashedToken unhashedPrev := userToken.UnhashedToken
rotated, err := userAuthTokenService.TryRotateToken(userToken, "192.168.10.12:1234", "a new user agent") rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "192.168.10.12:1234", "a new user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(rotated, ShouldBeFalse) So(rotated, ShouldBeFalse)
@ -224,7 +225,7 @@ func TestUserAuthToken(t *testing.T) {
return t.Add(time.Hour) return t.Add(time.Hour)
} }
rotated, err = userAuthTokenService.TryRotateToken(&tok, "192.168.10.12:1234", "a new user agent") rotated, err = userAuthTokenService.TryRotateToken(context.Background(), &tok, "192.168.10.12:1234", "a new user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(rotated, ShouldBeTrue) So(rotated, ShouldBeTrue)
@ -243,13 +244,13 @@ func TestUserAuthToken(t *testing.T) {
// ability to auth using an old token // ability to auth using an old token
lookedUpUserToken, err := userAuthTokenService.LookupToken(model.UnhashedToken) lookedUpUserToken, err := userAuthTokenService.LookupToken(context.Background(), model.UnhashedToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(lookedUpUserToken, ShouldNotBeNil) So(lookedUpUserToken, ShouldNotBeNil)
So(lookedUpUserToken.AuthTokenSeen, ShouldBeTrue) So(lookedUpUserToken.AuthTokenSeen, ShouldBeTrue)
So(lookedUpUserToken.SeenAt, ShouldEqual, getTime().Unix()) So(lookedUpUserToken.SeenAt, ShouldEqual, getTime().Unix())
lookedUpUserToken, err = userAuthTokenService.LookupToken(unhashedPrev) lookedUpUserToken, err = userAuthTokenService.LookupToken(context.Background(), unhashedPrev)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(lookedUpUserToken, ShouldNotBeNil) So(lookedUpUserToken, ShouldNotBeNil)
So(lookedUpUserToken.Id, ShouldEqual, model.Id) So(lookedUpUserToken.Id, ShouldEqual, model.Id)
@ -259,7 +260,7 @@ func TestUserAuthToken(t *testing.T) {
return t.Add(time.Hour + (2 * time.Minute)) return t.Add(time.Hour + (2 * time.Minute))
} }
lookedUpUserToken, err = userAuthTokenService.LookupToken(unhashedPrev) lookedUpUserToken, err = userAuthTokenService.LookupToken(context.Background(), unhashedPrev)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(lookedUpUserToken, ShouldNotBeNil) So(lookedUpUserToken, ShouldNotBeNil)
So(lookedUpUserToken.AuthTokenSeen, ShouldBeTrue) So(lookedUpUserToken.AuthTokenSeen, ShouldBeTrue)
@ -269,7 +270,7 @@ func TestUserAuthToken(t *testing.T) {
So(lookedUpModel, ShouldNotBeNil) So(lookedUpModel, ShouldNotBeNil)
So(lookedUpModel.AuthTokenSeen, ShouldBeFalse) So(lookedUpModel.AuthTokenSeen, ShouldBeFalse)
rotated, err = userAuthTokenService.TryRotateToken(userToken, "192.168.10.12:1234", "a new user agent") rotated, err = userAuthTokenService.TryRotateToken(context.Background(), userToken, "192.168.10.12:1234", "a new user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(rotated, ShouldBeTrue) So(rotated, ShouldBeTrue)
@ -280,11 +281,11 @@ func TestUserAuthToken(t *testing.T) {
}) })
Convey("keeps prev token valid for 1 minute after it is confirmed", func() { Convey("keeps prev token valid for 1 minute after it is confirmed", func() {
userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(userToken, ShouldNotBeNil) So(userToken, ShouldNotBeNil)
lookedUpUserToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) lookedUpUserToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(lookedUpUserToken, ShouldNotBeNil) So(lookedUpUserToken, ShouldNotBeNil)
@ -293,7 +294,7 @@ func TestUserAuthToken(t *testing.T) {
} }
prevToken := userToken.UnhashedToken prevToken := userToken.UnhashedToken
rotated, err := userAuthTokenService.TryRotateToken(userToken, "1.1.1.1", "firefox") rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "1.1.1.1", "firefox")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(rotated, ShouldBeTrue) So(rotated, ShouldBeTrue)
@ -301,25 +302,25 @@ func TestUserAuthToken(t *testing.T) {
return t.Add(20 * time.Minute) return t.Add(20 * time.Minute)
} }
currentUserToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) currentUserToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(currentUserToken, ShouldNotBeNil) So(currentUserToken, ShouldNotBeNil)
prevUserToken, err := userAuthTokenService.LookupToken(prevToken) prevUserToken, err := userAuthTokenService.LookupToken(context.Background(), prevToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(prevUserToken, ShouldNotBeNil) So(prevUserToken, ShouldNotBeNil)
}) })
Convey("will not mark token unseen when prev and current are the same", func() { Convey("will not mark token unseen when prev and current are the same", func() {
userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(userToken, ShouldNotBeNil) So(userToken, ShouldNotBeNil)
lookedUpUserToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) lookedUpUserToken, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(lookedUpUserToken, ShouldNotBeNil) So(lookedUpUserToken, ShouldNotBeNil)
lookedUpUserToken, err = userAuthTokenService.LookupToken(userToken.UnhashedToken) lookedUpUserToken, err = userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(lookedUpUserToken, ShouldNotBeNil) So(lookedUpUserToken, ShouldNotBeNil)
@ -330,7 +331,7 @@ func TestUserAuthToken(t *testing.T) {
}) })
Convey("Rotate token", func() { Convey("Rotate token", func() {
userToken, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent") userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, "192.168.10.11:1234", "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(userToken, ShouldNotBeNil) So(userToken, ShouldNotBeNil)
@ -345,7 +346,7 @@ func TestUserAuthToken(t *testing.T) {
return t.Add(10 * time.Minute) return t.Add(10 * time.Minute)
} }
rotated, err := userAuthTokenService.TryRotateToken(userToken, "1.1.1.1", "firefox") rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "1.1.1.1", "firefox")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(rotated, ShouldBeTrue) So(rotated, ShouldBeTrue)
@ -366,7 +367,7 @@ func TestUserAuthToken(t *testing.T) {
return t.Add(20 * time.Minute) return t.Add(20 * time.Minute)
} }
rotated, err = userAuthTokenService.TryRotateToken(userToken, "1.1.1.1", "firefox") rotated, err = userAuthTokenService.TryRotateToken(context.Background(), userToken, "1.1.1.1", "firefox")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(rotated, ShouldBeTrue) So(rotated, ShouldBeTrue)
@ -385,7 +386,7 @@ func TestUserAuthToken(t *testing.T) {
return t.Add(2 * time.Minute) return t.Add(2 * time.Minute)
} }
rotated, err := userAuthTokenService.TryRotateToken(userToken, "1.1.1.1", "firefox") rotated, err := userAuthTokenService.TryRotateToken(context.Background(), userToken, "1.1.1.1", "firefox")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(rotated, ShouldBeTrue) So(rotated, ShouldBeTrue)

@ -1,81 +1,85 @@
package auth package auth
import "github.com/grafana/grafana/pkg/models" import (
"context"
"github.com/grafana/grafana/pkg/models"
)
type FakeUserAuthTokenService struct { type FakeUserAuthTokenService struct {
CreateTokenProvider func(userId int64, clientIP, userAgent string) (*models.UserToken, error) CreateTokenProvider func(ctx context.Context, userId int64, clientIP, userAgent string) (*models.UserToken, error)
TryRotateTokenProvider func(token *models.UserToken, clientIP, userAgent string) (bool, error) TryRotateTokenProvider func(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error)
LookupTokenProvider func(unhashedToken string) (*models.UserToken, error) LookupTokenProvider func(ctx context.Context, unhashedToken string) (*models.UserToken, error)
RevokeTokenProvider func(token *models.UserToken) error RevokeTokenProvider func(ctx context.Context, token *models.UserToken) error
RevokeAllUserTokensProvider func(userId int64) error RevokeAllUserTokensProvider func(ctx context.Context, userId int64) error
ActiveAuthTokenCount func() (int64, error) ActiveAuthTokenCount func(ctx context.Context) (int64, error)
GetUserTokenProvider func(userId, userTokenId int64) (*models.UserToken, error) GetUserTokenProvider func(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error)
GetUserTokensProvider func(userId int64) ([]*models.UserToken, error) GetUserTokensProvider func(ctx context.Context, userId int64) ([]*models.UserToken, error)
} }
func NewFakeUserAuthTokenService() *FakeUserAuthTokenService { func NewFakeUserAuthTokenService() *FakeUserAuthTokenService {
return &FakeUserAuthTokenService{ return &FakeUserAuthTokenService{
CreateTokenProvider: func(userId int64, clientIP, userAgent string) (*models.UserToken, error) { CreateTokenProvider: func(ctx context.Context, userId int64, clientIP, userAgent string) (*models.UserToken, error) {
return &models.UserToken{ return &models.UserToken{
UserId: 0, UserId: 0,
UnhashedToken: "", UnhashedToken: "",
}, nil }, nil
}, },
TryRotateTokenProvider: func(token *models.UserToken, clientIP, userAgent string) (bool, error) { TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error) {
return false, nil return false, nil
}, },
LookupTokenProvider: func(unhashedToken string) (*models.UserToken, error) { LookupTokenProvider: func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
return &models.UserToken{ return &models.UserToken{
UserId: 0, UserId: 0,
UnhashedToken: "", UnhashedToken: "",
}, nil }, nil
}, },
RevokeTokenProvider: func(token *models.UserToken) error { RevokeTokenProvider: func(ctx context.Context, token *models.UserToken) error {
return nil return nil
}, },
RevokeAllUserTokensProvider: func(userId int64) error { RevokeAllUserTokensProvider: func(ctx context.Context, userId int64) error {
return nil return nil
}, },
ActiveAuthTokenCount: func() (int64, error) { ActiveAuthTokenCount: func(ctx context.Context) (int64, error) {
return 10, nil return 10, nil
}, },
GetUserTokenProvider: func(userId, userTokenId int64) (*models.UserToken, error) { GetUserTokenProvider: func(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) {
return nil, nil return nil, nil
}, },
GetUserTokensProvider: func(userId int64) ([]*models.UserToken, error) { GetUserTokensProvider: func(ctx context.Context, userId int64) ([]*models.UserToken, error) {
return nil, nil return nil, nil
}, },
} }
} }
func (s *FakeUserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) { func (s *FakeUserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP, userAgent string) (*models.UserToken, error) {
return s.CreateTokenProvider(userId, clientIP, userAgent) return s.CreateTokenProvider(context.Background(), userId, clientIP, userAgent)
} }
func (s *FakeUserAuthTokenService) LookupToken(unhashedToken string) (*models.UserToken, error) { func (s *FakeUserAuthTokenService) LookupToken(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
return s.LookupTokenProvider(unhashedToken) return s.LookupTokenProvider(context.Background(), unhashedToken)
} }
func (s *FakeUserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP, userAgent string) (bool, error) { func (s *FakeUserAuthTokenService) TryRotateToken(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error) {
return s.TryRotateTokenProvider(token, clientIP, userAgent) return s.TryRotateTokenProvider(context.Background(), token, clientIP, userAgent)
} }
func (s *FakeUserAuthTokenService) RevokeToken(token *models.UserToken) error { func (s *FakeUserAuthTokenService) RevokeToken(ctx context.Context, token *models.UserToken) error {
return s.RevokeTokenProvider(token) return s.RevokeTokenProvider(context.Background(), token)
} }
func (s *FakeUserAuthTokenService) RevokeAllUserTokens(userId int64) error { func (s *FakeUserAuthTokenService) RevokeAllUserTokens(ctx context.Context, userId int64) error {
return s.RevokeAllUserTokensProvider(userId) return s.RevokeAllUserTokensProvider(context.Background(), userId)
} }
func (s *FakeUserAuthTokenService) ActiveTokenCount() (int64, error) { func (s *FakeUserAuthTokenService) ActiveTokenCount(ctx context.Context) (int64, error) {
return s.ActiveAuthTokenCount() return s.ActiveAuthTokenCount(context.Background())
} }
func (s *FakeUserAuthTokenService) GetUserToken(userId, userTokenId int64) (*models.UserToken, error) { func (s *FakeUserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) {
return s.GetUserTokenProvider(userId, userTokenId) return s.GetUserTokenProvider(context.Background(), userId, userTokenId)
} }
func (s *FakeUserAuthTokenService) GetUserTokens(userId int64) ([]*models.UserToken, error) { func (s *FakeUserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) ([]*models.UserToken, error) {
return s.GetUserTokensProvider(userId) return s.GetUserTokensProvider(context.Background(), userId)
} }

@ -3,6 +3,8 @@ package auth
import ( import (
"context" "context"
"time" "time"
"github.com/grafana/grafana/pkg/services/sqlstore"
) )
func (srv *UserAuthTokenService) Run(ctx context.Context) error { func (srv *UserAuthTokenService) Run(ctx context.Context) error {
@ -11,21 +13,22 @@ func (srv *UserAuthTokenService) Run(ctx context.Context) error {
maxLifetime := time.Duration(srv.Cfg.LoginMaxLifetimeDays) * 24 * time.Hour maxLifetime := time.Duration(srv.Cfg.LoginMaxLifetimeDays) * 24 * time.Hour
err := srv.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() { err := srv.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() {
srv.deleteExpiredTokens(maxInactiveLifetime, maxLifetime) srv.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime)
}) })
if err != nil { if err != nil {
srv.log.Error("failed to lock and execite cleanup of expired auth token", "erro", err) srv.log.Error("failed to lock and execute cleanup of expired auth token", "error", err)
} }
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
err := srv.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() { err := srv.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() {
srv.deleteExpiredTokens(maxInactiveLifetime, maxLifetime) srv.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime)
}) })
if err != nil { if err != nil {
srv.log.Error("failed to lock and execite cleanup of expired auth token", "erro", err) srv.log.Error("failed to lock and execute cleanup of expired auth token", "error", err)
} }
case <-ctx.Done(): case <-ctx.Done():
@ -34,24 +37,30 @@ func (srv *UserAuthTokenService) Run(ctx context.Context) error {
} }
} }
func (srv *UserAuthTokenService) deleteExpiredTokens(maxInactiveLifetime, maxLifetime time.Duration) (int64, error) { func (srv *UserAuthTokenService) deleteExpiredTokens(ctx context.Context, maxInactiveLifetime, maxLifetime time.Duration) (int64, error) {
createdBefore := getTime().Add(-maxLifetime) createdBefore := getTime().Add(-maxLifetime)
rotatedBefore := getTime().Add(-maxInactiveLifetime) rotatedBefore := getTime().Add(-maxInactiveLifetime)
srv.log.Debug("starting cleanup of expired auth tokens", "createdBefore", createdBefore, "rotatedBefore", rotatedBefore) srv.log.Debug("starting cleanup of expired auth tokens", "createdBefore", createdBefore, "rotatedBefore", rotatedBefore)
sql := `DELETE from user_auth_token WHERE created_at <= ? OR rotated_at <= ?` var affected int64
res, err := srv.SQLStore.NewSession().Exec(sql, createdBefore.Unix(), rotatedBefore.Unix()) err := srv.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
if err != nil { sql := `DELETE from user_auth_token WHERE created_at <= ? OR rotated_at <= ?`
return 0, err res, err := dbSession.Exec(sql, createdBefore.Unix(), rotatedBefore.Unix())
} if err != nil {
return err
}
affected, err := res.RowsAffected() affected, err = res.RowsAffected()
if err != nil { if err != nil {
srv.log.Error("failed to cleanup expired auth tokens", "error", err) srv.log.Error("failed to cleanup expired auth tokens", "error", err)
return 0, nil return nil
} }
srv.log.Debug("cleanup of expired auth tokens done", "count", affected)
return nil
})
srv.log.Debug("cleanup of expired auth tokens done", "count", affected)
return affected, err return affected, err
} }

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
"time" "time"
@ -40,7 +41,7 @@ func TestUserAuthTokenCleanup(t *testing.T) {
insertToken(fmt.Sprintf("newA%d", i), fmt.Sprintf("newB%d", i), from.Unix(), from.Unix()) insertToken(fmt.Sprintf("newA%d", i), fmt.Sprintf("newB%d", i), from.Unix(), from.Unix())
} }
affected, err := ctx.tokenService.deleteExpiredTokens(7*24*time.Hour, 30*24*time.Hour) affected, err := ctx.tokenService.deleteExpiredTokens(context.Background(), 7*24*time.Hour, 30*24*time.Hour)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(affected, ShouldEqual, 3) So(affected, ShouldEqual, 3)
}) })
@ -60,7 +61,7 @@ func TestUserAuthTokenCleanup(t *testing.T) {
insertToken(fmt.Sprintf("newA%d", i), fmt.Sprintf("newB%d", i), from.Unix(), fromRotate.Unix()) insertToken(fmt.Sprintf("newA%d", i), fmt.Sprintf("newB%d", i), from.Unix(), fromRotate.Unix())
} }
affected, err := ctx.tokenService.deleteExpiredTokens(7*24*time.Hour, 30*24*time.Hour) affected, err := ctx.tokenService.deleteExpiredTokens(context.Background(), 7*24*time.Hour, 30*24*time.Hour)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(affected, ShouldEqual, 3) So(affected, ShouldEqual, 3)
}) })

@ -43,7 +43,7 @@ func (qs *QuotaService) QuotaReached(c *m.ReqContext, target string) (bool, erro
} }
if target == "session" { if target == "session" {
usedSessions, err := qs.AuthTokenService.ActiveTokenCount() usedSessions, err := qs.AuthTokenService.ActiveTokenCount(c.Req.Context())
if err != nil { if err != nil {
return false, err return false, err
} }

Loading…
Cancel
Save