mirror of https://github.com/grafana/grafana
commit
3f9a19dcd5
@ -0,0 +1,12 @@ |
||||
package dtos |
||||
|
||||
import "time" |
||||
|
||||
type UserToken struct { |
||||
Id int64 `json:"id"` |
||||
IsActive bool `json:"isActive"` |
||||
ClientIp string `json:"clientIp"` |
||||
UserAgent string `json:"userAgent"` |
||||
CreatedAt time.Time `json:"createdAt"` |
||||
SeenAt time.Time `json:"seenAt"` |
||||
} |
||||
@ -0,0 +1,110 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"time" |
||||
|
||||
"github.com/grafana/grafana/pkg/api/dtos" |
||||
"github.com/grafana/grafana/pkg/bus" |
||||
"github.com/grafana/grafana/pkg/models" |
||||
"github.com/grafana/grafana/pkg/util" |
||||
) |
||||
|
||||
// GET /api/user/auth-tokens
|
||||
func (server *HTTPServer) GetUserAuthTokens(c *models.ReqContext) Response { |
||||
return server.getUserAuthTokensInternal(c, c.UserId) |
||||
} |
||||
|
||||
// POST /api/user/revoke-auth-token
|
||||
func (server *HTTPServer) RevokeUserAuthToken(c *models.ReqContext, cmd models.RevokeAuthTokenCmd) Response { |
||||
return server.revokeUserAuthTokenInternal(c, c.UserId, cmd) |
||||
} |
||||
|
||||
func (server *HTTPServer) logoutUserFromAllDevicesInternal(userID int64) Response { |
||||
userQuery := models.GetUserByIdQuery{Id: userID} |
||||
|
||||
if err := bus.Dispatch(&userQuery); err != nil { |
||||
if err == models.ErrUserNotFound { |
||||
return Error(404, "User not found", err) |
||||
} |
||||
return Error(500, "Could not read user from database", err) |
||||
} |
||||
|
||||
err := server.AuthTokenService.RevokeAllUserTokens(userID) |
||||
if err != nil { |
||||
return Error(500, "Failed to logout user", err) |
||||
} |
||||
|
||||
return JSON(200, util.DynMap{ |
||||
"message": "User logged out", |
||||
}) |
||||
} |
||||
|
||||
func (server *HTTPServer) getUserAuthTokensInternal(c *models.ReqContext, userID int64) Response { |
||||
userQuery := models.GetUserByIdQuery{Id: userID} |
||||
|
||||
if err := bus.Dispatch(&userQuery); err != nil { |
||||
if err == models.ErrUserNotFound { |
||||
return Error(404, "User not found", err) |
||||
} |
||||
return Error(500, "Failed to get user", err) |
||||
} |
||||
|
||||
tokens, err := server.AuthTokenService.GetUserTokens(userID) |
||||
if err != nil { |
||||
return Error(500, "Failed to get user auth tokens", err) |
||||
} |
||||
|
||||
result := []*dtos.UserToken{} |
||||
for _, token := range tokens { |
||||
isActive := false |
||||
if c.UserToken != nil && c.UserToken.Id == token.Id { |
||||
isActive = true |
||||
} |
||||
|
||||
result = append(result, &dtos.UserToken{ |
||||
Id: token.Id, |
||||
IsActive: isActive, |
||||
ClientIp: token.ClientIp, |
||||
UserAgent: token.UserAgent, |
||||
CreatedAt: time.Unix(token.CreatedAt, 0), |
||||
SeenAt: time.Unix(token.SeenAt, 0), |
||||
}) |
||||
} |
||||
|
||||
return JSON(200, result) |
||||
} |
||||
|
||||
func (server *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, userID int64, cmd models.RevokeAuthTokenCmd) Response { |
||||
userQuery := models.GetUserByIdQuery{Id: userID} |
||||
|
||||
if err := bus.Dispatch(&userQuery); err != nil { |
||||
if err == models.ErrUserNotFound { |
||||
return Error(404, "User not found", err) |
||||
} |
||||
return Error(500, "Failed to get user", err) |
||||
} |
||||
|
||||
token, err := server.AuthTokenService.GetUserToken(userID, cmd.AuthTokenId) |
||||
if err != nil { |
||||
if err == models.ErrUserTokenNotFound { |
||||
return Error(404, "User auth token not found", err) |
||||
} |
||||
return Error(500, "Failed to get user auth token", err) |
||||
} |
||||
|
||||
if c.UserToken != nil && c.UserToken.Id == token.Id { |
||||
return Error(400, "Cannot revoke active user auth token", nil) |
||||
} |
||||
|
||||
err = server.AuthTokenService.RevokeToken(token) |
||||
if err != nil { |
||||
if err == models.ErrUserTokenNotFound { |
||||
return Error(404, "User auth token not found", err) |
||||
} |
||||
return Error(500, "Failed to revoke user auth token", err) |
||||
} |
||||
|
||||
return JSON(200, util.DynMap{ |
||||
"message": "User auth token revoked", |
||||
}) |
||||
} |
||||
@ -0,0 +1,294 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/grafana/grafana/pkg/bus" |
||||
m "github.com/grafana/grafana/pkg/models" |
||||
"github.com/grafana/grafana/pkg/services/auth" |
||||
|
||||
. "github.com/smartystreets/goconvey/convey" |
||||
) |
||||
|
||||
func TestUserTokenApiEndpoint(t *testing.T) { |
||||
Convey("When current user attempts to revoke an auth token for a non-existing user", t, func() { |
||||
userId := int64(0) |
||||
bus.AddHandler("test", func(cmd *m.GetUserByIdQuery) error { |
||||
userId = cmd.Id |
||||
return m.ErrUserNotFound |
||||
}) |
||||
|
||||
cmd := m.RevokeAuthTokenCmd{AuthTokenId: 2} |
||||
|
||||
revokeUserAuthTokenScenario("Should return not found when calling POST on", "/api/user/revoke-auth-token", "/api/user/revoke-auth-token", cmd, 200, func(sc *scenarioContext) { |
||||
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() |
||||
So(sc.resp.Code, ShouldEqual, 404) |
||||
So(userId, ShouldEqual, 200) |
||||
}) |
||||
}) |
||||
|
||||
Convey("When current user gets auth tokens for a non-existing user", t, func() { |
||||
userId := int64(0) |
||||
bus.AddHandler("test", func(cmd *m.GetUserByIdQuery) error { |
||||
userId = cmd.Id |
||||
return m.ErrUserNotFound |
||||
}) |
||||
|
||||
getUserAuthTokensScenario("Should return not found when calling GET on", "/api/user/auth-tokens", "/api/user/auth-tokens", 200, func(sc *scenarioContext) { |
||||
sc.fakeReqWithParams("GET", sc.url, map[string]string{}).exec() |
||||
So(sc.resp.Code, ShouldEqual, 404) |
||||
So(userId, ShouldEqual, 200) |
||||
}) |
||||
}) |
||||
|
||||
Convey("When logout an existing user from all devices", t, func() { |
||||
bus.AddHandler("test", func(cmd *m.GetUserByIdQuery) error { |
||||
cmd.Result = &m.User{Id: 200} |
||||
return nil |
||||
}) |
||||
|
||||
logoutUserFromAllDevicesInternalScenario("Should be successful", 1, func(sc *scenarioContext) { |
||||
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() |
||||
So(sc.resp.Code, ShouldEqual, 200) |
||||
}) |
||||
}) |
||||
|
||||
Convey("When logout a non-existing user from all devices", t, func() { |
||||
bus.AddHandler("test", func(cmd *m.GetUserByIdQuery) error { |
||||
return m.ErrUserNotFound |
||||
}) |
||||
|
||||
logoutUserFromAllDevicesInternalScenario("Should return not found", TestUserID, func(sc *scenarioContext) { |
||||
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() |
||||
So(sc.resp.Code, ShouldEqual, 404) |
||||
}) |
||||
}) |
||||
|
||||
Convey("When revoke an auth token for a user", t, func() { |
||||
bus.AddHandler("test", func(cmd *m.GetUserByIdQuery) error { |
||||
cmd.Result = &m.User{Id: 200} |
||||
return nil |
||||
}) |
||||
|
||||
cmd := m.RevokeAuthTokenCmd{AuthTokenId: 2} |
||||
token := &m.UserToken{Id: 1} |
||||
|
||||
revokeUserAuthTokenInternalScenario("Should be successful", cmd, 200, token, func(sc *scenarioContext) { |
||||
sc.userAuthTokenService.GetUserTokenProvider = func(userId, userTokenId int64) (*m.UserToken, error) { |
||||
return &m.UserToken{Id: 2}, nil |
||||
} |
||||
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() |
||||
So(sc.resp.Code, ShouldEqual, 200) |
||||
}) |
||||
}) |
||||
|
||||
Convey("When revoke the active auth token used by himself", t, func() { |
||||
bus.AddHandler("test", func(cmd *m.GetUserByIdQuery) error { |
||||
cmd.Result = &m.User{Id: TestUserID} |
||||
return nil |
||||
}) |
||||
|
||||
cmd := m.RevokeAuthTokenCmd{AuthTokenId: 2} |
||||
token := &m.UserToken{Id: 2} |
||||
|
||||
revokeUserAuthTokenInternalScenario("Should not be successful", cmd, TestUserID, token, func(sc *scenarioContext) { |
||||
sc.userAuthTokenService.GetUserTokenProvider = func(userId, userTokenId int64) (*m.UserToken, error) { |
||||
return token, nil |
||||
} |
||||
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() |
||||
So(sc.resp.Code, ShouldEqual, 400) |
||||
}) |
||||
}) |
||||
|
||||
Convey("When gets auth tokens for a user", t, func() { |
||||
bus.AddHandler("test", func(cmd *m.GetUserByIdQuery) error { |
||||
cmd.Result = &m.User{Id: TestUserID} |
||||
return nil |
||||
}) |
||||
|
||||
currentToken := &m.UserToken{Id: 1} |
||||
|
||||
getUserAuthTokensInternalScenario("Should be successful", currentToken, func(sc *scenarioContext) { |
||||
tokens := []*m.UserToken{ |
||||
{ |
||||
Id: 1, |
||||
ClientIp: "127.0.0.1", |
||||
UserAgent: "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/72.0.3626.119 Safari/537.36", |
||||
CreatedAt: time.Now().Unix(), |
||||
SeenAt: time.Now().Unix(), |
||||
}, |
||||
{ |
||||
Id: 2, |
||||
ClientIp: "127.0.0.2", |
||||
UserAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 11_0 like Mac OS X) AppleWebKit/604.1.38 (KHTML, like Gecko) Version/11.0 Mobile/15A372 Safari/604.1", |
||||
CreatedAt: time.Now().Unix(), |
||||
SeenAt: time.Now().Unix(), |
||||
}, |
||||
} |
||||
sc.userAuthTokenService.GetUserTokensProvider = func(userId int64) ([]*m.UserToken, error) { |
||||
return tokens, nil |
||||
} |
||||
sc.fakeReqWithParams("GET", sc.url, map[string]string{}).exec() |
||||
|
||||
So(sc.resp.Code, ShouldEqual, 200) |
||||
result := sc.ToJSON() |
||||
So(result.MustArray(), ShouldHaveLength, 2) |
||||
|
||||
resultOne := result.GetIndex(0) |
||||
So(resultOne.Get("id").MustInt64(), ShouldEqual, tokens[0].Id) |
||||
So(resultOne.Get("isActive").MustBool(), ShouldBeTrue) |
||||
So(resultOne.Get("clientIp").MustString(), ShouldEqual, "127.0.0.1") |
||||
So(resultOne.Get("userAgent").MustString(), ShouldEqual, "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/72.0.3626.119 Safari/537.36") |
||||
So(resultOne.Get("createdAt").MustString(), ShouldEqual, time.Unix(tokens[0].CreatedAt, 0).Format(time.RFC3339)) |
||||
So(resultOne.Get("seenAt").MustString(), ShouldEqual, time.Unix(tokens[0].SeenAt, 0).Format(time.RFC3339)) |
||||
|
||||
resultTwo := result.GetIndex(1) |
||||
So(resultTwo.Get("id").MustInt64(), ShouldEqual, tokens[1].Id) |
||||
So(resultTwo.Get("isActive").MustBool(), ShouldBeFalse) |
||||
So(resultTwo.Get("clientIp").MustString(), ShouldEqual, "127.0.0.2") |
||||
So(resultTwo.Get("userAgent").MustString(), ShouldEqual, "Mozilla/5.0 (iPhone; CPU iPhone OS 11_0 like Mac OS X) AppleWebKit/604.1.38 (KHTML, like Gecko) Version/11.0 Mobile/15A372 Safari/604.1") |
||||
So(resultTwo.Get("createdAt").MustString(), ShouldEqual, time.Unix(tokens[1].CreatedAt, 0).Format(time.RFC3339)) |
||||
So(resultTwo.Get("seenAt").MustString(), ShouldEqual, time.Unix(tokens[1].SeenAt, 0).Format(time.RFC3339)) |
||||
}) |
||||
}) |
||||
} |
||||
|
||||
func revokeUserAuthTokenScenario(desc string, url string, routePattern string, cmd m.RevokeAuthTokenCmd, userId int64, fn scenarioFunc) { |
||||
Convey(desc+" "+url, func() { |
||||
defer bus.ClearBusHandlers() |
||||
|
||||
fakeAuthTokenService := auth.NewFakeUserAuthTokenService() |
||||
|
||||
hs := HTTPServer{ |
||||
Bus: bus.GetBus(), |
||||
AuthTokenService: fakeAuthTokenService, |
||||
} |
||||
|
||||
sc := setupScenarioContext(url) |
||||
sc.userAuthTokenService = fakeAuthTokenService |
||||
sc.defaultHandler = Wrap(func(c *m.ReqContext) Response { |
||||
sc.context = c |
||||
sc.context.UserId = userId |
||||
sc.context.OrgId = TestOrgID |
||||
sc.context.OrgRole = m.ROLE_ADMIN |
||||
|
||||
return hs.RevokeUserAuthToken(c, cmd) |
||||
}) |
||||
|
||||
sc.m.Post(routePattern, sc.defaultHandler) |
||||
|
||||
fn(sc) |
||||
}) |
||||
} |
||||
|
||||
func getUserAuthTokensScenario(desc string, url string, routePattern string, userId int64, fn scenarioFunc) { |
||||
Convey(desc+" "+url, func() { |
||||
defer bus.ClearBusHandlers() |
||||
|
||||
fakeAuthTokenService := auth.NewFakeUserAuthTokenService() |
||||
|
||||
hs := HTTPServer{ |
||||
Bus: bus.GetBus(), |
||||
AuthTokenService: fakeAuthTokenService, |
||||
} |
||||
|
||||
sc := setupScenarioContext(url) |
||||
sc.userAuthTokenService = fakeAuthTokenService |
||||
sc.defaultHandler = Wrap(func(c *m.ReqContext) Response { |
||||
sc.context = c |
||||
sc.context.UserId = userId |
||||
sc.context.OrgId = TestOrgID |
||||
sc.context.OrgRole = m.ROLE_ADMIN |
||||
|
||||
return hs.GetUserAuthTokens(c) |
||||
}) |
||||
|
||||
sc.m.Get(routePattern, sc.defaultHandler) |
||||
|
||||
fn(sc) |
||||
}) |
||||
} |
||||
|
||||
func logoutUserFromAllDevicesInternalScenario(desc string, userId int64, fn scenarioFunc) { |
||||
Convey(desc, func() { |
||||
defer bus.ClearBusHandlers() |
||||
|
||||
hs := HTTPServer{ |
||||
Bus: bus.GetBus(), |
||||
AuthTokenService: auth.NewFakeUserAuthTokenService(), |
||||
} |
||||
|
||||
sc := setupScenarioContext("/") |
||||
sc.defaultHandler = Wrap(func(c *m.ReqContext) Response { |
||||
sc.context = c |
||||
sc.context.UserId = TestUserID |
||||
sc.context.OrgId = TestOrgID |
||||
sc.context.OrgRole = m.ROLE_ADMIN |
||||
|
||||
return hs.logoutUserFromAllDevicesInternal(userId) |
||||
}) |
||||
|
||||
sc.m.Post("/", sc.defaultHandler) |
||||
|
||||
fn(sc) |
||||
}) |
||||
} |
||||
|
||||
func revokeUserAuthTokenInternalScenario(desc string, cmd m.RevokeAuthTokenCmd, userId int64, token *m.UserToken, fn scenarioFunc) { |
||||
Convey(desc, func() { |
||||
defer bus.ClearBusHandlers() |
||||
|
||||
fakeAuthTokenService := auth.NewFakeUserAuthTokenService() |
||||
|
||||
hs := HTTPServer{ |
||||
Bus: bus.GetBus(), |
||||
AuthTokenService: fakeAuthTokenService, |
||||
} |
||||
|
||||
sc := setupScenarioContext("/") |
||||
sc.userAuthTokenService = fakeAuthTokenService |
||||
sc.defaultHandler = Wrap(func(c *m.ReqContext) Response { |
||||
sc.context = c |
||||
sc.context.UserId = TestUserID |
||||
sc.context.OrgId = TestOrgID |
||||
sc.context.OrgRole = m.ROLE_ADMIN |
||||
sc.context.UserToken = token |
||||
|
||||
return hs.revokeUserAuthTokenInternal(c, userId, cmd) |
||||
}) |
||||
|
||||
sc.m.Post("/", sc.defaultHandler) |
||||
|
||||
fn(sc) |
||||
}) |
||||
} |
||||
|
||||
func getUserAuthTokensInternalScenario(desc string, token *m.UserToken, fn scenarioFunc) { |
||||
Convey(desc, func() { |
||||
defer bus.ClearBusHandlers() |
||||
|
||||
fakeAuthTokenService := auth.NewFakeUserAuthTokenService() |
||||
|
||||
hs := HTTPServer{ |
||||
Bus: bus.GetBus(), |
||||
AuthTokenService: fakeAuthTokenService, |
||||
} |
||||
|
||||
sc := setupScenarioContext("/") |
||||
sc.userAuthTokenService = fakeAuthTokenService |
||||
sc.defaultHandler = Wrap(func(c *m.ReqContext) Response { |
||||
sc.context = c |
||||
sc.context.UserId = TestUserID |
||||
sc.context.OrgId = TestOrgID |
||||
sc.context.OrgRole = m.ROLE_ADMIN |
||||
sc.context.UserToken = token |
||||
|
||||
return hs.getUserAuthTokensInternal(c, TestUserID) |
||||
}) |
||||
|
||||
sc.m.Get("/", sc.defaultHandler) |
||||
|
||||
fn(sc) |
||||
}) |
||||
} |
||||
@ -0,0 +1,126 @@ |
||||
package remotecache |
||||
|
||||
import ( |
||||
"context" |
||||
"time" |
||||
|
||||
"github.com/grafana/grafana/pkg/log" |
||||
"github.com/grafana/grafana/pkg/services/sqlstore" |
||||
) |
||||
|
||||
var getTime = time.Now |
||||
|
||||
const databaseCacheType = "database" |
||||
|
||||
type databaseCache struct { |
||||
SQLStore *sqlstore.SqlStore |
||||
log log.Logger |
||||
} |
||||
|
||||
func newDatabaseCache(sqlstore *sqlstore.SqlStore) *databaseCache { |
||||
dc := &databaseCache{ |
||||
SQLStore: sqlstore, |
||||
log: log.New("remotecache.database"), |
||||
} |
||||
|
||||
return dc |
||||
} |
||||
|
||||
func (dc *databaseCache) Run(ctx context.Context) error { |
||||
ticker := time.NewTicker(time.Minute * 10) |
||||
for { |
||||
select { |
||||
case <-ctx.Done(): |
||||
return ctx.Err() |
||||
case <-ticker.C: |
||||
dc.internalRunGC() |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (dc *databaseCache) internalRunGC() { |
||||
now := getTime().Unix() |
||||
sql := `DELETE FROM cache_data WHERE (? - created_at) >= expires AND expires <> 0` |
||||
|
||||
_, err := dc.SQLStore.NewSession().Exec(sql, now) |
||||
if err != nil { |
||||
dc.log.Error("failed to run garbage collect", "error", err) |
||||
} |
||||
} |
||||
|
||||
func (dc *databaseCache) Get(key string) (interface{}, error) { |
||||
cacheHit := CacheData{} |
||||
session := dc.SQLStore.NewSession() |
||||
defer session.Close() |
||||
|
||||
exist, err := session.Where("cache_key= ?", key).Get(&cacheHit) |
||||
|
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if !exist { |
||||
return nil, ErrCacheItemNotFound |
||||
} |
||||
|
||||
if cacheHit.Expires > 0 { |
||||
existedButExpired := getTime().Unix()-cacheHit.CreatedAt >= cacheHit.Expires |
||||
if existedButExpired { |
||||
_ = dc.Delete(key) //ignore this error since we will return `ErrCacheItemNotFound` anyway
|
||||
return nil, ErrCacheItemNotFound |
||||
} |
||||
} |
||||
|
||||
item := &cachedItem{} |
||||
if err = decodeGob(cacheHit.Data, item); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return item.Val, nil |
||||
} |
||||
|
||||
func (dc *databaseCache) Set(key string, value interface{}, expire time.Duration) error { |
||||
item := &cachedItem{Val: value} |
||||
data, err := encodeGob(item) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
session := dc.SQLStore.NewSession() |
||||
|
||||
var cacheHit CacheData |
||||
has, err := session.Where("cache_key = ?", key).Get(&cacheHit) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
var expiresInSeconds int64 |
||||
if expire != 0 { |
||||
expiresInSeconds = int64(expire) / int64(time.Second) |
||||
} |
||||
|
||||
// insert or update depending on if item already exist
|
||||
if has { |
||||
sql := `UPDATE cache_data SET data=?, created=?, expire=? WHERE cache_key='?'` |
||||
_, err = session.Exec(sql, data, getTime().Unix(), expiresInSeconds, key) |
||||
} else { |
||||
sql := `INSERT INTO cache_data (cache_key,data,created_at,expires) VALUES(?,?,?,?)` |
||||
_, err = session.Exec(sql, key, data, getTime().Unix(), expiresInSeconds) |
||||
} |
||||
|
||||
return err |
||||
} |
||||
|
||||
func (dc *databaseCache) Delete(key string) error { |
||||
sql := "DELETE FROM cache_data WHERE cache_key=?" |
||||
_, err := dc.SQLStore.NewSession().Exec(sql, key) |
||||
|
||||
return err |
||||
} |
||||
|
||||
type CacheData struct { |
||||
CacheKey string |
||||
Data []byte |
||||
Expires int64 |
||||
CreatedAt int64 |
||||
} |
||||
@ -0,0 +1,56 @@ |
||||
package remotecache |
||||
|
||||
import ( |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/bmizerany/assert" |
||||
|
||||
"github.com/grafana/grafana/pkg/log" |
||||
"github.com/grafana/grafana/pkg/services/sqlstore" |
||||
) |
||||
|
||||
func TestDatabaseStorageGarbageCollection(t *testing.T) { |
||||
sqlstore := sqlstore.InitTestDB(t) |
||||
|
||||
db := &databaseCache{ |
||||
SQLStore: sqlstore, |
||||
log: log.New("remotecache.database"), |
||||
} |
||||
|
||||
obj := &CacheableStruct{String: "foolbar"} |
||||
|
||||
//set time.now to 2 weeks ago
|
||||
var err error |
||||
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) } |
||||
err = db.Set("key1", obj, 1000*time.Second) |
||||
assert.Equal(t, err, nil) |
||||
|
||||
err = db.Set("key2", obj, 1000*time.Second) |
||||
assert.Equal(t, err, nil) |
||||
|
||||
err = db.Set("key3", obj, 1000*time.Second) |
||||
assert.Equal(t, err, nil) |
||||
|
||||
// insert object that should never expire
|
||||
db.Set("key4", obj, 0) |
||||
|
||||
getTime = time.Now |
||||
db.Set("key5", obj, 1000*time.Second) |
||||
|
||||
//run GC
|
||||
db.internalRunGC() |
||||
|
||||
//try to read values
|
||||
_, err = db.Get("key1") |
||||
assert.Equal(t, err, ErrCacheItemNotFound, "expected cache item not found. got: ", err) |
||||
_, err = db.Get("key2") |
||||
assert.Equal(t, err, ErrCacheItemNotFound) |
||||
_, err = db.Get("key3") |
||||
assert.Equal(t, err, ErrCacheItemNotFound) |
||||
|
||||
_, err = db.Get("key4") |
||||
assert.Equal(t, err, nil) |
||||
_, err = db.Get("key5") |
||||
assert.Equal(t, err, nil) |
||||
} |
||||
@ -0,0 +1,71 @@ |
||||
package remotecache |
||||
|
||||
import ( |
||||
"time" |
||||
|
||||
"github.com/bradfitz/gomemcache/memcache" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
) |
||||
|
||||
const memcachedCacheType = "memcached" |
||||
|
||||
type memcachedStorage struct { |
||||
c *memcache.Client |
||||
} |
||||
|
||||
func newMemcachedStorage(opts *setting.RemoteCacheOptions) *memcachedStorage { |
||||
return &memcachedStorage{ |
||||
c: memcache.New(opts.ConnStr), |
||||
} |
||||
} |
||||
|
||||
func newItem(sid string, data []byte, expire int32) *memcache.Item { |
||||
return &memcache.Item{ |
||||
Key: sid, |
||||
Value: data, |
||||
Expiration: expire, |
||||
} |
||||
} |
||||
|
||||
// Set sets value to given key in the cache.
|
||||
func (s *memcachedStorage) Set(key string, val interface{}, expires time.Duration) error { |
||||
item := &cachedItem{Val: val} |
||||
bytes, err := encodeGob(item) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
var expiresInSeconds int64 |
||||
if expires != 0 { |
||||
expiresInSeconds = int64(expires) / int64(time.Second) |
||||
} |
||||
|
||||
memcachedItem := newItem(key, bytes, int32(expiresInSeconds)) |
||||
return s.c.Set(memcachedItem) |
||||
} |
||||
|
||||
// Get gets value by given key in the cache.
|
||||
func (s *memcachedStorage) Get(key string) (interface{}, error) { |
||||
memcachedItem, err := s.c.Get(key) |
||||
if err != nil && err.Error() == "memcache: cache miss" { |
||||
return nil, ErrCacheItemNotFound |
||||
} |
||||
|
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
item := &cachedItem{} |
||||
|
||||
err = decodeGob(memcachedItem.Value, item) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return item.Val, nil |
||||
} |
||||
|
||||
// Delete delete a key from the cache
|
||||
func (s *memcachedStorage) Delete(key string) error { |
||||
return s.c.Delete(key) |
||||
} |
||||
@ -0,0 +1,15 @@ |
||||
// +build memcached
|
||||
|
||||
package remotecache |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"github.com/grafana/grafana/pkg/setting" |
||||
) |
||||
|
||||
func TestMemcachedCacheStorage(t *testing.T) { |
||||
opts := &setting.RemoteCacheOptions{Name: memcachedCacheType, ConnStr: "localhost:11211"} |
||||
client := createTestClient(t, opts, nil) |
||||
runTestsForClient(t, client) |
||||
} |
||||
@ -0,0 +1,62 @@ |
||||
package remotecache |
||||
|
||||
import ( |
||||
"time" |
||||
|
||||
"github.com/grafana/grafana/pkg/setting" |
||||
redis "gopkg.in/redis.v2" |
||||
) |
||||
|
||||
const redisCacheType = "redis" |
||||
|
||||
type redisStorage struct { |
||||
c *redis.Client |
||||
} |
||||
|
||||
func newRedisStorage(opts *setting.RemoteCacheOptions) *redisStorage { |
||||
opt := &redis.Options{ |
||||
Network: "tcp", |
||||
Addr: opts.ConnStr, |
||||
} |
||||
return &redisStorage{c: redis.NewClient(opt)} |
||||
} |
||||
|
||||
// Set sets value to given key in session.
|
||||
func (s *redisStorage) Set(key string, val interface{}, expires time.Duration) error { |
||||
item := &cachedItem{Val: val} |
||||
value, err := encodeGob(item) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
status := s.c.SetEx(key, expires, string(value)) |
||||
return status.Err() |
||||
} |
||||
|
||||
// Get gets value by given key in session.
|
||||
func (s *redisStorage) Get(key string) (interface{}, error) { |
||||
v := s.c.Get(key) |
||||
|
||||
item := &cachedItem{} |
||||
err := decodeGob([]byte(v.Val()), item) |
||||
|
||||
if err == nil { |
||||
return item.Val, nil |
||||
} |
||||
|
||||
if err.Error() == "EOF" { |
||||
return nil, ErrCacheItemNotFound |
||||
} |
||||
|
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return item.Val, nil |
||||
} |
||||
|
||||
// Delete delete a key from session.
|
||||
func (s *redisStorage) Delete(key string) error { |
||||
cmd := s.c.Del(key) |
||||
return cmd.Err() |
||||
} |
||||
@ -0,0 +1,16 @@ |
||||
// +build redis
|
||||
|
||||
package remotecache |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"github.com/grafana/grafana/pkg/setting" |
||||
) |
||||
|
||||
func TestRedisCacheStorage(t *testing.T) { |
||||
|
||||
opts := &setting.RemoteCacheOptions{Name: redisCacheType, ConnStr: "localhost:6379"} |
||||
client := createTestClient(t, opts, nil) |
||||
runTestsForClient(t, client) |
||||
} |
||||
@ -0,0 +1,133 @@ |
||||
package remotecache |
||||
|
||||
import ( |
||||
"bytes" |
||||
"context" |
||||
"encoding/gob" |
||||
"errors" |
||||
"time" |
||||
|
||||
"github.com/grafana/grafana/pkg/setting" |
||||
|
||||
"github.com/grafana/grafana/pkg/log" |
||||
"github.com/grafana/grafana/pkg/services/sqlstore" |
||||
|
||||
"github.com/grafana/grafana/pkg/registry" |
||||
) |
||||
|
||||
var ( |
||||
// ErrCacheItemNotFound is returned if cache does not exist
|
||||
ErrCacheItemNotFound = errors.New("cache item not found") |
||||
|
||||
// ErrInvalidCacheType is returned if the type is invalid
|
||||
ErrInvalidCacheType = errors.New("invalid remote cache name") |
||||
|
||||
defaultMaxCacheExpiration = time.Hour * 24 |
||||
) |
||||
|
||||
func init() { |
||||
registry.RegisterService(&RemoteCache{}) |
||||
} |
||||
|
||||
// CacheStorage allows the caller to set, get and delete items in the cache.
|
||||
// Cached items are stored as byte arrays and marshalled using "encoding/gob"
|
||||
// so any struct added to the cache needs to be registred with `remotecache.Register`
|
||||
// ex `remotecache.Register(CacheableStruct{})``
|
||||
type CacheStorage interface { |
||||
// Get reads object from Cache
|
||||
Get(key string) (interface{}, error) |
||||
|
||||
// Set sets an object into the cache. if `expire` is set to zero it will default to 24h
|
||||
Set(key string, value interface{}, expire time.Duration) error |
||||
|
||||
// Delete object from cache
|
||||
Delete(key string) error |
||||
} |
||||
|
||||
// RemoteCache allows Grafana to cache data outside its own process
|
||||
type RemoteCache struct { |
||||
log log.Logger |
||||
client CacheStorage |
||||
SQLStore *sqlstore.SqlStore `inject:""` |
||||
Cfg *setting.Cfg `inject:""` |
||||
} |
||||
|
||||
// Get reads object from Cache
|
||||
func (ds *RemoteCache) Get(key string) (interface{}, error) { |
||||
return ds.client.Get(key) |
||||
} |
||||
|
||||
// Set sets an object into the cache. if `expire` is set to zero it will default to 24h
|
||||
func (ds *RemoteCache) Set(key string, value interface{}, expire time.Duration) error { |
||||
if expire == 0 { |
||||
expire = defaultMaxCacheExpiration |
||||
} |
||||
|
||||
return ds.client.Set(key, value, expire) |
||||
} |
||||
|
||||
// Delete object from cache
|
||||
func (ds *RemoteCache) Delete(key string) error { |
||||
return ds.client.Delete(key) |
||||
} |
||||
|
||||
// Init initializes the service
|
||||
func (ds *RemoteCache) Init() error { |
||||
ds.log = log.New("cache.remote") |
||||
var err error |
||||
ds.client, err = createClient(ds.Cfg.RemoteCacheOptions, ds.SQLStore) |
||||
return err |
||||
} |
||||
|
||||
// Run start the backend processes for cache clients
|
||||
func (ds *RemoteCache) Run(ctx context.Context) error { |
||||
//create new interface if more clients need GC jobs
|
||||
backgroundjob, ok := ds.client.(registry.BackgroundService) |
||||
if ok { |
||||
return backgroundjob.Run(ctx) |
||||
} |
||||
|
||||
<-ctx.Done() |
||||
return ctx.Err() |
||||
} |
||||
|
||||
func createClient(opts *setting.RemoteCacheOptions, sqlstore *sqlstore.SqlStore) (CacheStorage, error) { |
||||
if opts.Name == redisCacheType { |
||||
return newRedisStorage(opts), nil |
||||
} |
||||
|
||||
if opts.Name == memcachedCacheType { |
||||
return newMemcachedStorage(opts), nil |
||||
} |
||||
|
||||
if opts.Name == databaseCacheType { |
||||
return newDatabaseCache(sqlstore), nil |
||||
} |
||||
|
||||
return nil, ErrInvalidCacheType |
||||
} |
||||
|
||||
// Register records a type, identified by a value for that type, under its
|
||||
// internal type name. That name will identify the concrete type of a value
|
||||
// sent or received as an interface variable. Only types that will be
|
||||
// transferred as implementations of interface values need to be registered.
|
||||
// Expecting to be used only during initialization, it panics if the mapping
|
||||
// between types and names is not a bijection.
|
||||
func Register(value interface{}) { |
||||
gob.Register(value) |
||||
} |
||||
|
||||
type cachedItem struct { |
||||
Val interface{} |
||||
} |
||||
|
||||
func encodeGob(item *cachedItem) ([]byte, error) { |
||||
buf := bytes.NewBuffer(nil) |
||||
err := gob.NewEncoder(buf).Encode(item) |
||||
return buf.Bytes(), err |
||||
} |
||||
|
||||
func decodeGob(data []byte, out *cachedItem) error { |
||||
buf := bytes.NewBuffer(data) |
||||
return gob.NewDecoder(buf).Decode(&out) |
||||
} |
||||
@ -0,0 +1,93 @@ |
||||
package remotecache |
||||
|
||||
import ( |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/bmizerany/assert" |
||||
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
) |
||||
|
||||
type CacheableStruct struct { |
||||
String string |
||||
Int64 int64 |
||||
} |
||||
|
||||
func init() { |
||||
Register(CacheableStruct{}) |
||||
} |
||||
|
||||
func createTestClient(t *testing.T, opts *setting.RemoteCacheOptions, sqlstore *sqlstore.SqlStore) CacheStorage { |
||||
t.Helper() |
||||
|
||||
dc := &RemoteCache{ |
||||
SQLStore: sqlstore, |
||||
Cfg: &setting.Cfg{ |
||||
RemoteCacheOptions: opts, |
||||
}, |
||||
} |
||||
|
||||
err := dc.Init() |
||||
if err != nil { |
||||
t.Fatalf("failed to init client for test. error: %v", err) |
||||
} |
||||
|
||||
return dc |
||||
} |
||||
|
||||
func TestCachedBasedOnConfig(t *testing.T) { |
||||
|
||||
cfg := setting.NewCfg() |
||||
cfg.Load(&setting.CommandLineArgs{ |
||||
HomePath: "../../../", |
||||
}) |
||||
|
||||
client := createTestClient(t, cfg.RemoteCacheOptions, sqlstore.InitTestDB(t)) |
||||
runTestsForClient(t, client) |
||||
} |
||||
|
||||
func TestInvalidCacheTypeReturnsError(t *testing.T) { |
||||
_, err := createClient(&setting.RemoteCacheOptions{Name: "invalid"}, nil) |
||||
assert.Equal(t, err, ErrInvalidCacheType) |
||||
} |
||||
|
||||
func runTestsForClient(t *testing.T, client CacheStorage) { |
||||
canPutGetAndDeleteCachedObjects(t, client) |
||||
canNotFetchExpiredItems(t, client) |
||||
} |
||||
|
||||
func canPutGetAndDeleteCachedObjects(t *testing.T, client CacheStorage) { |
||||
cacheableStruct := CacheableStruct{String: "hej", Int64: 2000} |
||||
|
||||
err := client.Set("key1", cacheableStruct, 0) |
||||
assert.Equal(t, err, nil, "expected nil. got: ", err) |
||||
|
||||
data, err := client.Get("key1") |
||||
s, ok := data.(CacheableStruct) |
||||
|
||||
assert.Equal(t, ok, true) |
||||
assert.Equal(t, s.String, "hej") |
||||
assert.Equal(t, s.Int64, int64(2000)) |
||||
|
||||
err = client.Delete("key1") |
||||
assert.Equal(t, err, nil) |
||||
|
||||
_, err = client.Get("key1") |
||||
assert.Equal(t, err, ErrCacheItemNotFound) |
||||
} |
||||
|
||||
func canNotFetchExpiredItems(t *testing.T, client CacheStorage) { |
||||
cacheableStruct := CacheableStruct{String: "hej", Int64: 2000} |
||||
|
||||
err := client.Set("key1", cacheableStruct, time.Second) |
||||
assert.Equal(t, err, nil) |
||||
|
||||
//not sure how this can be avoided when testing redis/memcached :/
|
||||
<-time.After(time.Second + time.Millisecond) |
||||
|
||||
// should not be able to read that value since its expired
|
||||
_, err = client.Get("key1") |
||||
assert.Equal(t, err, ErrCacheItemNotFound) |
||||
} |
||||
@ -0,0 +1,81 @@ |
||||
package auth |
||||
|
||||
import "github.com/grafana/grafana/pkg/models" |
||||
|
||||
type FakeUserAuthTokenService struct { |
||||
CreateTokenProvider func(userId int64, clientIP, userAgent string) (*models.UserToken, error) |
||||
TryRotateTokenProvider func(token *models.UserToken, clientIP, userAgent string) (bool, error) |
||||
LookupTokenProvider func(unhashedToken string) (*models.UserToken, error) |
||||
RevokeTokenProvider func(token *models.UserToken) error |
||||
RevokeAllUserTokensProvider func(userId int64) error |
||||
ActiveAuthTokenCount func() (int64, error) |
||||
GetUserTokenProvider func(userId, userTokenId int64) (*models.UserToken, error) |
||||
GetUserTokensProvider func(userId int64) ([]*models.UserToken, error) |
||||
} |
||||
|
||||
func NewFakeUserAuthTokenService() *FakeUserAuthTokenService { |
||||
return &FakeUserAuthTokenService{ |
||||
CreateTokenProvider: func(userId int64, clientIP, userAgent string) (*models.UserToken, error) { |
||||
return &models.UserToken{ |
||||
UserId: 0, |
||||
UnhashedToken: "", |
||||
}, nil |
||||
}, |
||||
TryRotateTokenProvider: func(token *models.UserToken, clientIP, userAgent string) (bool, error) { |
||||
return false, nil |
||||
}, |
||||
LookupTokenProvider: func(unhashedToken string) (*models.UserToken, error) { |
||||
return &models.UserToken{ |
||||
UserId: 0, |
||||
UnhashedToken: "", |
||||
}, nil |
||||
}, |
||||
RevokeTokenProvider: func(token *models.UserToken) error { |
||||
return nil |
||||
}, |
||||
RevokeAllUserTokensProvider: func(userId int64) error { |
||||
return nil |
||||
}, |
||||
ActiveAuthTokenCount: func() (int64, error) { |
||||
return 10, nil |
||||
}, |
||||
GetUserTokenProvider: func(userId, userTokenId int64) (*models.UserToken, error) { |
||||
return nil, nil |
||||
}, |
||||
GetUserTokensProvider: func(userId int64) ([]*models.UserToken, error) { |
||||
return nil, nil |
||||
}, |
||||
} |
||||
} |
||||
|
||||
func (s *FakeUserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) { |
||||
return s.CreateTokenProvider(userId, clientIP, userAgent) |
||||
} |
||||
|
||||
func (s *FakeUserAuthTokenService) LookupToken(unhashedToken string) (*models.UserToken, error) { |
||||
return s.LookupTokenProvider(unhashedToken) |
||||
} |
||||
|
||||
func (s *FakeUserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP, userAgent string) (bool, error) { |
||||
return s.TryRotateTokenProvider(token, clientIP, userAgent) |
||||
} |
||||
|
||||
func (s *FakeUserAuthTokenService) RevokeToken(token *models.UserToken) error { |
||||
return s.RevokeTokenProvider(token) |
||||
} |
||||
|
||||
func (s *FakeUserAuthTokenService) RevokeAllUserTokens(userId int64) error { |
||||
return s.RevokeAllUserTokensProvider(userId) |
||||
} |
||||
|
||||
func (s *FakeUserAuthTokenService) ActiveTokenCount() (int64, error) { |
||||
return s.ActiveAuthTokenCount() |
||||
} |
||||
|
||||
func (s *FakeUserAuthTokenService) GetUserToken(userId, userTokenId int64) (*models.UserToken, error) { |
||||
return s.GetUserTokenProvider(userId, userTokenId) |
||||
} |
||||
|
||||
func (s *FakeUserAuthTokenService) GetUserTokens(userId int64) ([]*models.UserToken, error) { |
||||
return s.GetUserTokensProvider(userId) |
||||
} |
||||
@ -0,0 +1,22 @@ |
||||
package migrations |
||||
|
||||
import "github.com/grafana/grafana/pkg/services/sqlstore/migrator" |
||||
|
||||
func addCacheMigration(mg *migrator.Migrator) { |
||||
var cacheDataV1 = migrator.Table{ |
||||
Name: "cache_data", |
||||
Columns: []*migrator.Column{ |
||||
{Name: "cache_key", Type: migrator.DB_NVarchar, IsPrimaryKey: true, Length: 168}, |
||||
{Name: "data", Type: migrator.DB_Blob}, |
||||
{Name: "expires", Type: migrator.DB_Integer, Length: 255, Nullable: false}, |
||||
{Name: "created_at", Type: migrator.DB_Integer, Length: 255, Nullable: false}, |
||||
}, |
||||
Indices: []*migrator.Index{ |
||||
{Cols: []string{"cache_key"}, Type: migrator.UniqueIndex}, |
||||
}, |
||||
} |
||||
|
||||
mg.AddMigration("create cache_data table", migrator.NewAddTableMigration(cacheDataV1)) |
||||
|
||||
mg.AddMigration("add unique index cache_data.cache_key", migrator.NewAddIndexMigration(cacheDataV1, cacheDataV1.Indices[0])) |
||||
} |
||||
@ -0,0 +1,16 @@ |
||||
#!/bin/bash |
||||
function exit_if_fail { |
||||
command=$@ |
||||
echo "Executing '$command'" |
||||
eval $command |
||||
rc=$? |
||||
if [ $rc -ne 0 ]; then |
||||
echo "'$command' returned $rc." |
||||
exit $rc |
||||
fi |
||||
} |
||||
|
||||
echo "running redis and memcache tests" |
||||
|
||||
time exit_if_fail go test -tags=redis ./pkg/infra/remotecache/... |
||||
time exit_if_fail go test -tags=memcached ./pkg/infra/remotecache/... |
||||
Loading…
Reference in new issue