Chore: Remove result field from API keys commands and queries (#65055)

* Chore: remove result field from api keys

* fix shadowing

* actually shadowing was all right
pull/65112/head
Serge Zaitsev 2 years ago committed by GitHub
parent 590b07539f
commit 743d66396a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 14
      pkg/api/apikey.go
  2. 8
      pkg/services/apikey/apikey.go
  3. 8
      pkg/services/apikey/apikeyimpl/apikey.go
  4. 35
      pkg/services/apikey/apikeyimpl/sqlx_store.go
  5. 8
      pkg/services/apikey/apikeyimpl/store.go
  6. 58
      pkg/services/apikey/apikeyimpl/store_test.go
  7. 30
      pkg/services/apikey/apikeyimpl/xorm_store.go
  8. 20
      pkg/services/apikey/apikeytest/fake.go
  9. 5
      pkg/services/apikey/model.go
  10. 7
      pkg/services/authn/clients/api_key.go
  11. 7
      pkg/services/contexthandler/contexthandler.go
  12. 5
      pkg/services/serviceaccounts/database/token_store.go
  13. 8
      pkg/services/serviceaccounts/tests/common.go

@ -29,13 +29,14 @@ import (
func (hs *HTTPServer) GetAPIKeys(c *contextmodel.ReqContext) response.Response { func (hs *HTTPServer) GetAPIKeys(c *contextmodel.ReqContext) response.Response {
query := apikey.GetApiKeysQuery{OrgID: c.OrgID, User: c.SignedInUser, IncludeExpired: c.QueryBool("includeExpired")} query := apikey.GetApiKeysQuery{OrgID: c.OrgID, User: c.SignedInUser, IncludeExpired: c.QueryBool("includeExpired")}
if err := hs.apiKeyService.GetAPIKeys(c.Req.Context(), &query); err != nil { keys, err := hs.apiKeyService.GetAPIKeys(c.Req.Context(), &query)
if err != nil {
return response.Error(500, "Failed to list api keys", err) return response.Error(500, "Failed to list api keys", err)
} }
ids := map[string]bool{} ids := map[string]bool{}
result := make([]*dtos.ApiKeyDTO, len(query.Result)) result := make([]*dtos.ApiKeyDTO, len(keys))
for i, t := range query.Result { for i, t := range keys {
ids[strconv.FormatInt(t.ID, 10)] = true ids[strconv.FormatInt(t.ID, 10)] = true
var expiration *time.Time = nil var expiration *time.Time = nil
if t.Expires != nil { if t.Expires != nil {
@ -134,7 +135,8 @@ func (hs *HTTPServer) AddAPIKey(c *contextmodel.ReqContext) response.Response {
} }
cmd.Key = newKeyInfo.HashedKey cmd.Key = newKeyInfo.HashedKey
if err := hs.apiKeyService.AddAPIKey(c.Req.Context(), &cmd); err != nil { key, err := hs.apiKeyService.AddAPIKey(c.Req.Context(), &cmd)
if err != nil {
if errors.Is(err, apikey.ErrInvalidExpiration) { if errors.Is(err, apikey.ErrInvalidExpiration) {
return response.Error(400, err.Error(), nil) return response.Error(400, err.Error(), nil)
} }
@ -145,8 +147,8 @@ func (hs *HTTPServer) AddAPIKey(c *contextmodel.ReqContext) response.Response {
} }
result := &dtos.NewApiKeyResult{ result := &dtos.NewApiKeyResult{
ID: cmd.Result.ID, ID: key.ID,
Name: cmd.Result.Name, Name: key.Name,
Key: newKeyInfo.ClientSecret, Key: newKeyInfo.ClientSecret,
} }

@ -5,12 +5,12 @@ import (
) )
type Service interface { type Service interface {
GetAPIKeys(ctx context.Context, query *GetApiKeysQuery) error GetAPIKeys(ctx context.Context, query *GetApiKeysQuery) (res []*APIKey, err error)
GetAllAPIKeys(ctx context.Context, orgID int64) ([]*APIKey, error) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*APIKey, error)
DeleteApiKey(ctx context.Context, cmd *DeleteCommand) error DeleteApiKey(ctx context.Context, cmd *DeleteCommand) error
AddAPIKey(ctx context.Context, cmd *AddCommand) error AddAPIKey(ctx context.Context, cmd *AddCommand) (res *APIKey, err error)
GetApiKeyById(ctx context.Context, query *GetByIDQuery) error GetApiKeyById(ctx context.Context, query *GetByIDQuery) (res *APIKey, err error)
GetApiKeyByName(ctx context.Context, query *GetByNameQuery) error GetApiKeyByName(ctx context.Context, query *GetByNameQuery) (res *APIKey, err error)
GetAPIKeyByHash(ctx context.Context, hash string) (*APIKey, error) GetAPIKeyByHash(ctx context.Context, hash string) (*APIKey, error)
UpdateAPIKeyLastUsedDate(ctx context.Context, tokenID int64) error UpdateAPIKeyLastUsedDate(ctx context.Context, tokenID int64) error
// IsDisabled returns true if the API key is not available for use. // IsDisabled returns true if the API key is not available for use.

@ -44,16 +44,16 @@ func (s *Service) Usage(ctx context.Context, scopeParams *quota.ScopeParameters)
return s.store.Count(ctx, scopeParams) return s.store.Count(ctx, scopeParams)
} }
func (s *Service) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) error { func (s *Service) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) ([]*apikey.APIKey, error) {
return s.store.GetAPIKeys(ctx, query) return s.store.GetAPIKeys(ctx, query)
} }
func (s *Service) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) { func (s *Service) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) {
return s.store.GetAllAPIKeys(ctx, orgID) return s.store.GetAllAPIKeys(ctx, orgID)
} }
func (s *Service) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) error { func (s *Service) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) (*apikey.APIKey, error) {
return s.store.GetApiKeyById(ctx, query) return s.store.GetApiKeyById(ctx, query)
} }
func (s *Service) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) error { func (s *Service) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) (*apikey.APIKey, error) {
return s.store.GetApiKeyByName(ctx, query) return s.store.GetApiKeyByName(ctx, query)
} }
func (s *Service) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) { func (s *Service) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) {
@ -62,7 +62,7 @@ func (s *Service) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.API
func (s *Service) DeleteApiKey(ctx context.Context, cmd *apikey.DeleteCommand) error { func (s *Service) DeleteApiKey(ctx context.Context, cmd *apikey.DeleteCommand) error {
return s.store.DeleteApiKey(ctx, cmd) return s.store.DeleteApiKey(ctx, cmd)
} }
func (s *Service) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) error { func (s *Service) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) (res *apikey.APIKey, err error) {
return s.store.AddAPIKey(ctx, cmd) return s.store.AddAPIKey(ctx, cmd)
} }
func (s *Service) UpdateAPIKeyLastUsedDate(ctx context.Context, tokenID int64) error { func (s *Service) UpdateAPIKeyLastUsedDate(ctx context.Context, tokenID int64) error {

@ -20,7 +20,7 @@ type sqlxStore struct {
cfg *setting.Cfg cfg *setting.Cfg
} }
func (ss *sqlxStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) error { func (ss *sqlxStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) ([]*apikey.APIKey, error) {
var where []string var where []string
var args []interface{} var args []interface{}
@ -37,7 +37,7 @@ func (ss *sqlxStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQue
if !accesscontrol.IsDisabled(ss.cfg) { if !accesscontrol.IsDisabled(ss.cfg) {
filter, err := accesscontrol.Filter(query.User, "id", "apikeys:id:", accesscontrol.ActionAPIKeyRead) filter, err := accesscontrol.Filter(query.User, "id", "apikeys:id:", accesscontrol.ActionAPIKeyRead)
if err != nil { if err != nil {
return err return nil, err
} }
where = append(where, filter.Where) where = append(where, filter.Where)
args = append(args, filter.Args...) args = append(args, filter.Args...)
@ -45,9 +45,9 @@ func (ss *sqlxStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQue
ws := fmt.Sprint(strings.Join(where[:], " AND ")) ws := fmt.Sprint(strings.Join(where[:], " AND "))
qr := fmt.Sprintf(`SELECT * FROM api_key WHERE %s ORDER BY name ASC LIMIT 100`, ws) qr := fmt.Sprintf(`SELECT * FROM api_key WHERE %s ORDER BY name ASC LIMIT 100`, ws)
query.Result = make([]*apikey.APIKey, 0) keys := make([]*apikey.APIKey, 0)
err := ss.sess.Select(ctx, &query.Result, qr, args...) err := ss.sess.Select(ctx, &keys, qr, args...)
return err return keys, err
} }
func (ss *sqlxStore) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) { func (ss *sqlxStore) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) {
@ -87,20 +87,20 @@ func (ss *sqlxStore) DeleteApiKey(ctx context.Context, cmd *apikey.DeleteCommand
return err return err
} }
func (ss *sqlxStore) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) error { func (ss *sqlxStore) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) (*apikey.APIKey, error) {
updated := timeNow() updated := timeNow()
var expires *int64 = nil var expires *int64 = nil
if cmd.SecondsToLive > 0 { if cmd.SecondsToLive > 0 {
v := updated.Add(time.Second * time.Duration(cmd.SecondsToLive)).Unix() v := updated.Add(time.Second * time.Duration(cmd.SecondsToLive)).Unix()
expires = &v expires = &v
} else if cmd.SecondsToLive < 0 { } else if cmd.SecondsToLive < 0 {
return apikey.ErrInvalidExpiration return nil, apikey.ErrInvalidExpiration
} }
err := ss.GetApiKeyByName(ctx, &apikey.GetByNameQuery{OrgID: cmd.OrgID, KeyName: cmd.Name}) _, err := ss.GetApiKeyByName(ctx, &apikey.GetByNameQuery{OrgID: cmd.OrgID, KeyName: cmd.Name})
// If key with the same orgId and name already exist return err // If key with the same orgId and name already exist return err
if !errors.Is(err, apikey.ErrInvalid) { if !errors.Is(err, apikey.ErrInvalid) {
return apikey.ErrDuplicate return nil, apikey.ErrDuplicate
} }
isRevoked := false isRevoked := false
t := apikey.APIKey{ t := apikey.APIKey{
@ -117,28 +117,25 @@ func (ss *sqlxStore) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) erro
t.ID, err = ss.sess.ExecWithReturningId(ctx, t.ID, err = ss.sess.ExecWithReturningId(ctx,
`INSERT INTO api_key (org_id, name, role, "key", created, updated, expires, service_account_id, is_revoked) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, t.OrgID, t.Name, t.Role, t.Key, t.Created, t.Updated, t.Expires, t.ServiceAccountId, t.IsRevoked) `INSERT INTO api_key (org_id, name, role, "key", created, updated, expires, service_account_id, is_revoked) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, t.OrgID, t.Name, t.Role, t.Key, t.Created, t.Updated, t.Expires, t.ServiceAccountId, t.IsRevoked)
cmd.Result = &t return &t, err
return err
} }
func (ss *sqlxStore) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) error { func (ss *sqlxStore) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) (*apikey.APIKey, error) {
var key apikey.APIKey var key apikey.APIKey
err := ss.sess.Get(ctx, &key, "SELECT * FROM api_key WHERE id=?", query.ApiKeyID) err := ss.sess.Get(ctx, &key, "SELECT * FROM api_key WHERE id=?", query.ApiKeyID)
if err != nil && errors.Is(err, sql.ErrNoRows) { if err != nil && errors.Is(err, sql.ErrNoRows) {
return apikey.ErrInvalid return nil, apikey.ErrInvalid
} }
query.Result = &key return &key, err
return err
} }
func (ss *sqlxStore) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) error { func (ss *sqlxStore) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) (*apikey.APIKey, error) {
var key apikey.APIKey var key apikey.APIKey
err := ss.sess.Get(ctx, &key, "SELECT * FROM api_key WHERE org_id=? AND name=?", query.OrgID, query.KeyName) err := ss.sess.Get(ctx, &key, "SELECT * FROM api_key WHERE org_id=? AND name=?", query.OrgID, query.KeyName)
if err != nil && errors.Is(err, sql.ErrNoRows) { if err != nil && errors.Is(err, sql.ErrNoRows) {
return apikey.ErrInvalid return nil, apikey.ErrInvalid
} }
query.Result = &key return &key, err
return err
} }
func (ss *sqlxStore) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) { func (ss *sqlxStore) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) {

@ -8,13 +8,13 @@ import (
) )
type store interface { type store interface {
GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) error GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) (res []*apikey.APIKey, err error)
GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error)
CountAPIKeys(ctx context.Context, orgID int64) (int64, error) CountAPIKeys(ctx context.Context, orgID int64) (int64, error)
DeleteApiKey(ctx context.Context, cmd *apikey.DeleteCommand) error DeleteApiKey(ctx context.Context, cmd *apikey.DeleteCommand) error
AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) error AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) (res *apikey.APIKey, err error)
GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) error GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) (res *apikey.APIKey, err error)
GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) error GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) (res *apikey.APIKey, err error)
GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error)
UpdateAPIKeyLastUsedDate(ctx context.Context, tokenID int64) error UpdateAPIKeyLastUsedDate(ctx context.Context, tokenID int64) error

@ -43,7 +43,7 @@ func seedApiKeys(t *testing.T, store store, num int) {
t.Helper() t.Helper()
for i := 0; i < num; i++ { for i := 0; i < num; i++ {
err := store.AddAPIKey(context.Background(), &apikey.AddCommand{ _, err := store.AddAPIKey(context.Background(), &apikey.AddCommand{
Name: fmt.Sprintf("key:%d", i), Name: fmt.Sprintf("key:%d", i),
Key: fmt.Sprintf("key:%d", i), Key: fmt.Sprintf("key:%d", i),
OrgID: 1, OrgID: 1,
@ -64,15 +64,15 @@ func testIntegrationApiKeyDataAccess(t *testing.T, fn getStore) {
t.Run("Given saved api key", func(t *testing.T) { t.Run("Given saved api key", func(t *testing.T) {
cmd := apikey.AddCommand{OrgID: 1, Name: "hello", Key: "asd"} cmd := apikey.AddCommand{OrgID: 1, Name: "hello", Key: "asd"}
err := ss.AddAPIKey(context.Background(), &cmd) _, err := ss.AddAPIKey(context.Background(), &cmd)
assert.Nil(t, err) assert.Nil(t, err)
t.Run("Should be able to get key by name", func(t *testing.T) { t.Run("Should be able to get key by name", func(t *testing.T) {
query := apikey.GetByNameQuery{KeyName: "hello", OrgID: 1} query := apikey.GetByNameQuery{KeyName: "hello", OrgID: 1}
err = ss.GetApiKeyByName(context.Background(), &query) key, err := ss.GetApiKeyByName(context.Background(), &query)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, query.Result) assert.NotNil(t, key)
}) })
t.Run("Should be able to get key by hash", func(t *testing.T) { t.Run("Should be able to get key by hash", func(t *testing.T) {
@ -91,77 +91,77 @@ func testIntegrationApiKeyDataAccess(t *testing.T, fn getStore) {
t.Run("Add non expiring key", func(t *testing.T) { t.Run("Add non expiring key", func(t *testing.T) {
cmd := apikey.AddCommand{OrgID: 1, Name: "non-expiring", Key: "asd1", SecondsToLive: 0} cmd := apikey.AddCommand{OrgID: 1, Name: "non-expiring", Key: "asd1", SecondsToLive: 0}
err := ss.AddAPIKey(context.Background(), &cmd) _, err := ss.AddAPIKey(context.Background(), &cmd)
assert.Nil(t, err) assert.Nil(t, err)
query := apikey.GetByNameQuery{KeyName: "non-expiring", OrgID: 1} query := apikey.GetByNameQuery{KeyName: "non-expiring", OrgID: 1}
err = ss.GetApiKeyByName(context.Background(), &query) key, err := ss.GetApiKeyByName(context.Background(), &query)
assert.Nil(t, err) assert.Nil(t, err)
assert.Nil(t, query.Result.Expires) assert.Nil(t, key.Expires)
}) })
t.Run("Add an expiring key", func(t *testing.T) { t.Run("Add an expiring key", func(t *testing.T) {
// expires in one hour // expires in one hour
cmd := apikey.AddCommand{OrgID: 1, Name: "expiring-in-an-hour", Key: "asd2", SecondsToLive: 3600} cmd := apikey.AddCommand{OrgID: 1, Name: "expiring-in-an-hour", Key: "asd2", SecondsToLive: 3600}
err := ss.AddAPIKey(context.Background(), &cmd) _, err := ss.AddAPIKey(context.Background(), &cmd)
assert.Nil(t, err) assert.Nil(t, err)
query := apikey.GetByNameQuery{KeyName: "expiring-in-an-hour", OrgID: 1} query := apikey.GetByNameQuery{KeyName: "expiring-in-an-hour", OrgID: 1}
err = ss.GetApiKeyByName(context.Background(), &query) key, err := ss.GetApiKeyByName(context.Background(), &query)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, *query.Result.Expires >= timeNow().Unix()) assert.True(t, *key.Expires >= timeNow().Unix())
// timeNow() has been called twice since creation; once by AddAPIKey and once by GetApiKeyByName // timeNow() has been called twice since creation; once by AddAPIKey and once by GetApiKeyByName
// therefore two seconds should be subtracted by next value returned by timeNow() // therefore two seconds should be subtracted by next value returned by timeNow()
// that equals the number by which timeSeed has been advanced // that equals the number by which timeSeed has been advanced
then := timeNow().Add(-2 * time.Second) then := timeNow().Add(-2 * time.Second)
expected := then.Add(1 * time.Hour).UTC().Unix() expected := then.Add(1 * time.Hour).UTC().Unix()
assert.Equal(t, *query.Result.Expires, expected) assert.Equal(t, *key.Expires, expected)
}) })
t.Run("Last Used At datetime update", func(t *testing.T) { t.Run("Last Used At datetime update", func(t *testing.T) {
// expires in one hour // expires in one hour
cmd := apikey.AddCommand{OrgID: 1, Name: "last-update-at", Key: "asd3", SecondsToLive: 3600} cmd := apikey.AddCommand{OrgID: 1, Name: "last-update-at", Key: "asd3", SecondsToLive: 3600}
err := ss.AddAPIKey(context.Background(), &cmd) key, err := ss.AddAPIKey(context.Background(), &cmd)
require.NoError(t, err) require.NoError(t, err)
assert.Nil(t, cmd.Result.LastUsedAt) assert.Nil(t, key.LastUsedAt)
err = ss.UpdateAPIKeyLastUsedDate(context.Background(), cmd.Result.ID) err = ss.UpdateAPIKeyLastUsedDate(context.Background(), key.ID)
require.NoError(t, err) require.NoError(t, err)
query := apikey.GetByNameQuery{KeyName: "last-update-at", OrgID: 1} query := apikey.GetByNameQuery{KeyName: "last-update-at", OrgID: 1}
err = ss.GetApiKeyByName(context.Background(), &query) key, err = ss.GetApiKeyByName(context.Background(), &query)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, query.Result.LastUsedAt) assert.NotNil(t, key.LastUsedAt)
}) })
t.Run("Add a key with negative lifespan", func(t *testing.T) { t.Run("Add a key with negative lifespan", func(t *testing.T) {
// expires in one day // expires in one day
cmd := apikey.AddCommand{OrgID: 1, Name: "key-with-negative-lifespan", Key: "asd3", SecondsToLive: -3600} cmd := apikey.AddCommand{OrgID: 1, Name: "key-with-negative-lifespan", Key: "asd3", SecondsToLive: -3600}
err := ss.AddAPIKey(context.Background(), &cmd) _, err := ss.AddAPIKey(context.Background(), &cmd)
assert.EqualError(t, err, apikey.ErrInvalidExpiration.Error()) assert.EqualError(t, err, apikey.ErrInvalidExpiration.Error())
query := apikey.GetByNameQuery{KeyName: "key-with-negative-lifespan", OrgID: 1} query := apikey.GetByNameQuery{KeyName: "key-with-negative-lifespan", OrgID: 1}
err = ss.GetApiKeyByName(context.Background(), &query) _, err = ss.GetApiKeyByName(context.Background(), &query)
assert.EqualError(t, err, "invalid API key") assert.EqualError(t, err, "invalid API key")
}) })
t.Run("Add keys", func(t *testing.T) { t.Run("Add keys", func(t *testing.T) {
// never expires // never expires
cmd := apikey.AddCommand{OrgID: 1, Name: "key1", Key: "key1", SecondsToLive: 0} cmd := apikey.AddCommand{OrgID: 1, Name: "key1", Key: "key1", SecondsToLive: 0}
err := ss.AddAPIKey(context.Background(), &cmd) _, err := ss.AddAPIKey(context.Background(), &cmd)
assert.Nil(t, err) assert.Nil(t, err)
// expires in 1s // expires in 1s
cmd = apikey.AddCommand{OrgID: 1, Name: "key2", Key: "key2", SecondsToLive: 1} cmd = apikey.AddCommand{OrgID: 1, Name: "key2", Key: "key2", SecondsToLive: 1}
err = ss.AddAPIKey(context.Background(), &cmd) _, err = ss.AddAPIKey(context.Background(), &cmd)
assert.Nil(t, err) assert.Nil(t, err)
// expires in one hour // expires in one hour
cmd = apikey.AddCommand{OrgID: 1, Name: "key3", Key: "key3", SecondsToLive: 3600} cmd = apikey.AddCommand{OrgID: 1, Name: "key3", Key: "key3", SecondsToLive: 3600}
err = ss.AddAPIKey(context.Background(), &cmd) _, err = ss.AddAPIKey(context.Background(), &cmd)
assert.Nil(t, err) assert.Nil(t, err)
// advance mocked getTime by 1s // advance mocked getTime by 1s
@ -174,21 +174,21 @@ func testIntegrationApiKeyDataAccess(t *testing.T, fn getStore) {
}, },
} }
query := apikey.GetApiKeysQuery{OrgID: 1, IncludeExpired: false, User: testUser} query := apikey.GetApiKeysQuery{OrgID: 1, IncludeExpired: false, User: testUser}
err = ss.GetAPIKeys(context.Background(), &query) keys, err := ss.GetAPIKeys(context.Background(), &query)
assert.Nil(t, err) assert.Nil(t, err)
for _, k := range query.Result { for _, k := range keys {
if k.Name == "key2" { if k.Name == "key2" {
t.Fatalf("key2 should not be there") t.Fatalf("key2 should not be there")
} }
} }
query = apikey.GetApiKeysQuery{OrgID: 1, IncludeExpired: true, User: testUser} query = apikey.GetApiKeysQuery{OrgID: 1, IncludeExpired: true, User: testUser}
err = ss.GetAPIKeys(context.Background(), &query) keys, err = ss.GetAPIKeys(context.Background(), &query)
assert.Nil(t, err) assert.Nil(t, err)
found := false found := false
for _, k := range query.Result { for _, k := range keys {
if k.Name == "key2" { if k.Name == "key2" {
found = true found = true
} }
@ -211,12 +211,12 @@ func testIntegrationApiKeyDataAccess(t *testing.T, fn getStore) {
t.Run("Testing API Duplicate Key Errors", func(t *testing.T) { t.Run("Testing API Duplicate Key Errors", func(t *testing.T) {
t.Run("Given saved api key", func(t *testing.T) { t.Run("Given saved api key", func(t *testing.T) {
cmd := apikey.AddCommand{OrgID: 0, Name: "duplicate", Key: "asd"} cmd := apikey.AddCommand{OrgID: 0, Name: "duplicate", Key: "asd"}
err := ss.AddAPIKey(context.Background(), &cmd) _, err := ss.AddAPIKey(context.Background(), &cmd)
assert.Nil(t, err) assert.Nil(t, err)
t.Run("Add API Key with existing Org ID and Name", func(t *testing.T) { t.Run("Add API Key with existing Org ID and Name", func(t *testing.T) {
cmd := apikey.AddCommand{OrgID: 0, Name: "duplicate", Key: "asd"} cmd := apikey.AddCommand{OrgID: 0, Name: "duplicate", Key: "asd"}
err = ss.AddAPIKey(context.Background(), &cmd) _, err = ss.AddAPIKey(context.Background(), &cmd)
assert.EqualError(t, err, apikey.ErrDuplicate.Error()) assert.EqualError(t, err, apikey.ErrDuplicate.Error())
}) })
}) })
@ -258,9 +258,9 @@ func testIntegrationApiKeyDataAccess(t *testing.T, fn getStore) {
seedApiKeys(t, store, 10) seedApiKeys(t, store, 10)
query := &apikey.GetApiKeysQuery{OrgID: 1, User: tt.user} query := &apikey.GetApiKeysQuery{OrgID: 1, User: tt.user}
err := store.GetAPIKeys(context.Background(), query) keys, err := store.GetAPIKeys(context.Background(), query)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, query.Result, tt.expectedNumKeys) assert.Len(t, keys, tt.expectedNumKeys)
res, err := store.GetAllAPIKeys(context.Background(), 1) res, err := store.GetAllAPIKeys(context.Background(), 1)
require.NoError(t, err) require.NoError(t, err)

@ -23,8 +23,8 @@ type sqlStore struct {
// timeNow makes it possible to test usage of time // timeNow makes it possible to test usage of time
var timeNow = time.Now var timeNow = time.Now
func (ss *sqlStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) error { func (ss *sqlStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) (res []*apikey.APIKey, err error) {
return ss.db.WithDbSession(ctx, func(dbSession *db.Session) error { err = ss.db.WithDbSession(ctx, func(dbSession *db.Session) error {
var sess *xorm.Session var sess *xorm.Session
if query.IncludeExpired { if query.IncludeExpired {
@ -47,9 +47,10 @@ func (ss *sqlStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuer
sess.And(filter.Where, filter.Args...) sess.And(filter.Where, filter.Args...)
} }
query.Result = make([]*apikey.APIKey, 0) res = make([]*apikey.APIKey, 0)
return sess.Find(&query.Result) return sess.Find(&res)
}) })
return res, err
} }
func (ss *sqlStore) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) { func (ss *sqlStore) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) {
@ -100,8 +101,8 @@ func (ss *sqlStore) DeleteApiKey(ctx context.Context, cmd *apikey.DeleteCommand)
}) })
} }
func (ss *sqlStore) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) error { func (ss *sqlStore) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) (res *apikey.APIKey, err error) {
return ss.db.WithTransactionalDbSession(ctx, func(sess *db.Session) error { err = ss.db.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
key := apikey.APIKey{OrgID: cmd.OrgID, Name: cmd.Name} key := apikey.APIKey{OrgID: cmd.OrgID, Name: cmd.Name}
exists, _ := sess.Get(&key) exists, _ := sess.Get(&key)
if exists { if exists {
@ -133,13 +134,14 @@ func (ss *sqlStore) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) error
if _, err := sess.Insert(&t); err != nil { if _, err := sess.Insert(&t); err != nil {
return fmt.Errorf("%s: %w", "failed to insert token", err) return fmt.Errorf("%s: %w", "failed to insert token", err)
} }
cmd.Result = &t res = &t
return nil return nil
}) })
return res, err
} }
func (ss *sqlStore) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) error { func (ss *sqlStore) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) (res *apikey.APIKey, err error) {
return ss.db.WithDbSession(ctx, func(sess *db.Session) error { err = ss.db.WithDbSession(ctx, func(sess *db.Session) error {
var key apikey.APIKey var key apikey.APIKey
has, err := sess.ID(query.ApiKeyID).Get(&key) has, err := sess.ID(query.ApiKeyID).Get(&key)
@ -149,13 +151,14 @@ func (ss *sqlStore) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuer
return apikey.ErrInvalid return apikey.ErrInvalid
} }
query.Result = &key res = &key
return nil return nil
}) })
return res, err
} }
func (ss *sqlStore) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) error { func (ss *sqlStore) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) (res *apikey.APIKey, err error) {
return ss.db.WithDbSession(ctx, func(sess *db.Session) error { err = ss.db.WithDbSession(ctx, func(sess *db.Session) error {
var key apikey.APIKey var key apikey.APIKey
has, err := sess.Where("org_id=? AND name=?", query.OrgID, query.KeyName).Get(&key) has, err := sess.Where("org_id=? AND name=?", query.OrgID, query.KeyName).Get(&key)
@ -165,9 +168,10 @@ func (ss *sqlStore) GetApiKeyByName(ctx context.Context, query *apikey.GetByName
return apikey.ErrInvalid return apikey.ErrInvalid
} }
query.Result = &key res = &key
return nil return nil
}) })
return res, err
} }
func (ss *sqlStore) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) { func (ss *sqlStore) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) {

@ -13,20 +13,17 @@ type Service struct {
ExpectedAPIKey *apikey.APIKey ExpectedAPIKey *apikey.APIKey
} }
func (s *Service) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) error { func (s *Service) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) ([]*apikey.APIKey, error) {
query.Result = s.ExpectedAPIKeys return s.ExpectedAPIKeys, s.ExpectedError
return s.ExpectedError
} }
func (s *Service) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) { func (s *Service) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) {
return s.ExpectedAPIKeys, s.ExpectedError return s.ExpectedAPIKeys, s.ExpectedError
} }
func (s *Service) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) error { func (s *Service) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) (*apikey.APIKey, error) {
query.Result = s.ExpectedAPIKey return s.ExpectedAPIKey, s.ExpectedError
return s.ExpectedError
} }
func (s *Service) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) error { func (s *Service) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) (*apikey.APIKey, error) {
query.Result = s.ExpectedAPIKey return s.ExpectedAPIKey, s.ExpectedError
return s.ExpectedError
} }
func (s *Service) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) { func (s *Service) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) {
return s.ExpectedAPIKey, s.ExpectedError return s.ExpectedAPIKey, s.ExpectedError
@ -34,9 +31,8 @@ func (s *Service) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.API
func (s *Service) DeleteApiKey(ctx context.Context, cmd *apikey.DeleteCommand) error { func (s *Service) DeleteApiKey(ctx context.Context, cmd *apikey.DeleteCommand) error {
return s.ExpectedError return s.ExpectedError
} }
func (s *Service) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) error { func (s *Service) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) (*apikey.APIKey, error) {
cmd.Result = s.ExpectedAPIKey return s.ExpectedAPIKey, s.ExpectedError
return s.ExpectedError
} }
func (s *Service) UpdateAPIKeyLastUsedDate(ctx context.Context, tokenID int64) error { func (s *Service) UpdateAPIKeyLastUsedDate(ctx context.Context, tokenID int64) error {
return s.ExpectedError return s.ExpectedError

@ -40,8 +40,6 @@ type AddCommand struct {
Key string `json:"-"` Key string `json:"-"`
SecondsToLive int64 `json:"secondsToLive"` SecondsToLive int64 `json:"secondsToLive"`
ServiceAccountID *int64 `json:"-"` ServiceAccountID *int64 `json:"-"`
Result *APIKey `json:"-"`
} }
type DeleteCommand struct { type DeleteCommand struct {
@ -53,17 +51,14 @@ type GetApiKeysQuery struct {
OrgID int64 OrgID int64
IncludeExpired bool IncludeExpired bool
User *user.SignedInUser User *user.SignedInUser
Result []*APIKey
} }
type GetByNameQuery struct { type GetByNameQuery struct {
KeyName string KeyName string
OrgID int64 OrgID int64
Result *APIKey
} }
type GetByIDQuery struct { type GetByIDQuery struct {
ApiKeyID int64 ApiKeyID int64
Result *APIKey
} }
const ( const (

@ -119,12 +119,13 @@ func (s *APIKey) getFromTokenLegacy(ctx context.Context, token string) (*apikey.
// fetch key // fetch key
keyQuery := apikey.GetByNameQuery{KeyName: decoded.Name, OrgID: decoded.OrgId} keyQuery := apikey.GetByNameQuery{KeyName: decoded.Name, OrgID: decoded.OrgId}
if err := s.apiKeyService.GetApiKeyByName(ctx, &keyQuery); err != nil { key, err := s.apiKeyService.GetApiKeyByName(ctx, &keyQuery)
if err != nil {
return nil, err return nil, err
} }
// validate api key // validate api key
isValid, err := apikeygen.IsValid(decoded, keyQuery.Result.Key) isValid, err := apikeygen.IsValid(decoded, key.Key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -132,7 +133,7 @@ func (s *APIKey) getFromTokenLegacy(ctx context.Context, token string) (*apikey.
return nil, apikeygen.ErrInvalidApiKey return nil, apikeygen.ErrInvalidApiKey
} }
return keyQuery.Result, nil return key, nil
} }
func (s *APIKey) Test(ctx context.Context, r *authn.Request) bool { func (s *APIKey) Test(ctx context.Context, r *authn.Request) bool {

@ -282,12 +282,13 @@ func (h *ContextHandler) getAPIKey(ctx context.Context, keyString string) (*apik
// fetch key // fetch key
keyQuery := apikey.GetByNameQuery{KeyName: decoded.Name, OrgID: decoded.OrgId} keyQuery := apikey.GetByNameQuery{KeyName: decoded.Name, OrgID: decoded.OrgId}
if err := h.apiKeyService.GetApiKeyByName(ctx, &keyQuery); err != nil { key, err := h.apiKeyService.GetApiKeyByName(ctx, &keyQuery)
if err != nil {
return nil, err return nil, err
} }
// validate api key // validate api key
isValid, err := apikeygen.IsValid(decoded, keyQuery.Result.Key) isValid, err := apikeygen.IsValid(decoded, key.Key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -295,7 +296,7 @@ func (h *ContextHandler) getAPIKey(ctx context.Context, keyString string) (*apik
return nil, apikeygen.ErrInvalidApiKey return nil, apikeygen.ErrInvalidApiKey
} }
return keyQuery.Result, nil return key, nil
} }
func (h *ContextHandler) initContextWithAPIKey(reqContext *contextmodel.ReqContext) bool { func (h *ContextHandler) initContextWithAPIKey(reqContext *contextmodel.ReqContext) bool {

@ -58,7 +58,8 @@ func (s *ServiceAccountsStoreImpl) AddServiceAccountToken(ctx context.Context, s
ServiceAccountID: &serviceAccountId, ServiceAccountID: &serviceAccountId,
} }
if err := s.apiKeyService.AddAPIKey(ctx, addKeyCmd); err != nil { key, err := s.apiKeyService.AddAPIKey(ctx, addKeyCmd)
if err != nil {
switch { switch {
case errors.Is(err, apikey.ErrDuplicate): case errors.Is(err, apikey.ErrDuplicate):
return serviceaccounts.ErrDuplicateToken.Errorf("service account token with name %s already exists in the organization", cmd.Name) return serviceaccounts.ErrDuplicateToken.Errorf("service account token with name %s already exists in the organization", cmd.Name)
@ -69,7 +70,7 @@ func (s *ServiceAccountsStoreImpl) AddServiceAccountToken(ctx context.Context, s
return err return err
} }
apiKey = addKeyCmd.Result apiKey = key
return nil return nil
}) })
} }

@ -87,22 +87,22 @@ func SetupApiKey(t *testing.T, sqlStore *sqlstore.SQLStore, testKey TestApiKey)
quotaService := quotatest.New(false, nil) quotaService := quotatest.New(false, nil)
apiKeyService, err := apikeyimpl.ProvideService(sqlStore, sqlStore.Cfg, quotaService) apiKeyService, err := apikeyimpl.ProvideService(sqlStore, sqlStore.Cfg, quotaService)
require.NoError(t, err) require.NoError(t, err)
err = apiKeyService.AddAPIKey(context.Background(), addKeyCmd) key, err := apiKeyService.AddAPIKey(context.Background(), addKeyCmd)
require.NoError(t, err) require.NoError(t, err)
if testKey.IsExpired { if testKey.IsExpired {
err := sqlStore.WithTransactionalDbSession(context.Background(), func(sess *db.Session) error { err := sqlStore.WithTransactionalDbSession(context.Background(), func(sess *db.Session) error {
// Force setting expires to time before now to make key expired // Force setting expires to time before now to make key expired
var expires int64 = 1 var expires int64 = 1
key := apikey.APIKey{Expires: &expires} expiringKey := apikey.APIKey{Expires: &expires}
rowsAffected, err := sess.ID(addKeyCmd.Result.ID).Update(&key) rowsAffected, err := sess.ID(key.ID).Update(&expiringKey)
require.Equal(t, int64(1), rowsAffected) require.Equal(t, int64(1), rowsAffected)
return err return err
}) })
require.NoError(t, err) require.NoError(t, err)
} }
return addKeyCmd.Result return key
} }
func SetupMockAccesscontrol(t *testing.T, func SetupMockAccesscontrol(t *testing.T,

Loading…
Cancel
Save