diff --git a/pkg/api/apikey.go b/pkg/api/apikey.go index dcabb191f95..fc48575ee50 100644 --- a/pkg/api/apikey.go +++ b/pkg/api/apikey.go @@ -29,13 +29,14 @@ import ( func (hs *HTTPServer) GetAPIKeys(c *contextmodel.ReqContext) response.Response { query := apikey.GetApiKeysQuery{OrgID: c.OrgID, User: c.SignedInUser, IncludeExpired: c.QueryBool("includeExpired")} - if err := hs.apiKeyService.GetAPIKeys(c.Req.Context(), &query); err != nil { + keys, err := hs.apiKeyService.GetAPIKeys(c.Req.Context(), &query) + if err != nil { return response.Error(500, "Failed to list api keys", err) } ids := map[string]bool{} - result := make([]*dtos.ApiKeyDTO, len(query.Result)) - for i, t := range query.Result { + result := make([]*dtos.ApiKeyDTO, len(keys)) + for i, t := range keys { ids[strconv.FormatInt(t.ID, 10)] = true var expiration *time.Time = nil if t.Expires != nil { @@ -134,7 +135,8 @@ func (hs *HTTPServer) AddAPIKey(c *contextmodel.ReqContext) response.Response { } cmd.Key = newKeyInfo.HashedKey - if err := hs.apiKeyService.AddAPIKey(c.Req.Context(), &cmd); err != nil { + key, err := hs.apiKeyService.AddAPIKey(c.Req.Context(), &cmd) + if err != nil { if errors.Is(err, apikey.ErrInvalidExpiration) { return response.Error(400, err.Error(), nil) } @@ -145,8 +147,8 @@ func (hs *HTTPServer) AddAPIKey(c *contextmodel.ReqContext) response.Response { } result := &dtos.NewApiKeyResult{ - ID: cmd.Result.ID, - Name: cmd.Result.Name, + ID: key.ID, + Name: key.Name, Key: newKeyInfo.ClientSecret, } diff --git a/pkg/services/apikey/apikey.go b/pkg/services/apikey/apikey.go index cde6c9fd097..57b3b539257 100644 --- a/pkg/services/apikey/apikey.go +++ b/pkg/services/apikey/apikey.go @@ -5,12 +5,12 @@ import ( ) type Service interface { - GetAPIKeys(ctx context.Context, query *GetApiKeysQuery) error + GetAPIKeys(ctx context.Context, query *GetApiKeysQuery) (res []*APIKey, err error) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*APIKey, error) DeleteApiKey(ctx context.Context, cmd *DeleteCommand) error - AddAPIKey(ctx context.Context, cmd *AddCommand) error - GetApiKeyById(ctx context.Context, query *GetByIDQuery) error - GetApiKeyByName(ctx context.Context, query *GetByNameQuery) error + AddAPIKey(ctx context.Context, cmd *AddCommand) (res *APIKey, err error) + GetApiKeyById(ctx context.Context, query *GetByIDQuery) (res *APIKey, err error) + GetApiKeyByName(ctx context.Context, query *GetByNameQuery) (res *APIKey, err error) GetAPIKeyByHash(ctx context.Context, hash string) (*APIKey, error) UpdateAPIKeyLastUsedDate(ctx context.Context, tokenID int64) error // IsDisabled returns true if the API key is not available for use. diff --git a/pkg/services/apikey/apikeyimpl/apikey.go b/pkg/services/apikey/apikeyimpl/apikey.go index e5f6ea2f098..89addfc3beb 100644 --- a/pkg/services/apikey/apikeyimpl/apikey.go +++ b/pkg/services/apikey/apikeyimpl/apikey.go @@ -44,16 +44,16 @@ func (s *Service) Usage(ctx context.Context, scopeParams *quota.ScopeParameters) return s.store.Count(ctx, scopeParams) } -func (s *Service) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) error { +func (s *Service) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) ([]*apikey.APIKey, error) { return s.store.GetAPIKeys(ctx, query) } func (s *Service) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) { return s.store.GetAllAPIKeys(ctx, orgID) } -func (s *Service) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) error { +func (s *Service) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) (*apikey.APIKey, error) { return s.store.GetApiKeyById(ctx, query) } -func (s *Service) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) error { +func (s *Service) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) (*apikey.APIKey, error) { return s.store.GetApiKeyByName(ctx, query) } func (s *Service) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) { @@ -62,7 +62,7 @@ func (s *Service) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.API func (s *Service) DeleteApiKey(ctx context.Context, cmd *apikey.DeleteCommand) error { return s.store.DeleteApiKey(ctx, cmd) } -func (s *Service) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) error { +func (s *Service) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) (res *apikey.APIKey, err error) { return s.store.AddAPIKey(ctx, cmd) } func (s *Service) UpdateAPIKeyLastUsedDate(ctx context.Context, tokenID int64) error { diff --git a/pkg/services/apikey/apikeyimpl/sqlx_store.go b/pkg/services/apikey/apikeyimpl/sqlx_store.go index 7b43490ba48..abb14484b27 100644 --- a/pkg/services/apikey/apikeyimpl/sqlx_store.go +++ b/pkg/services/apikey/apikeyimpl/sqlx_store.go @@ -20,7 +20,7 @@ type sqlxStore struct { cfg *setting.Cfg } -func (ss *sqlxStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) error { +func (ss *sqlxStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) ([]*apikey.APIKey, error) { var where []string var args []interface{} @@ -37,7 +37,7 @@ func (ss *sqlxStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQue if !accesscontrol.IsDisabled(ss.cfg) { filter, err := accesscontrol.Filter(query.User, "id", "apikeys:id:", accesscontrol.ActionAPIKeyRead) if err != nil { - return err + return nil, err } where = append(where, filter.Where) args = append(args, filter.Args...) @@ -45,9 +45,9 @@ func (ss *sqlxStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQue ws := fmt.Sprint(strings.Join(where[:], " AND ")) qr := fmt.Sprintf(`SELECT * FROM api_key WHERE %s ORDER BY name ASC LIMIT 100`, ws) - query.Result = make([]*apikey.APIKey, 0) - err := ss.sess.Select(ctx, &query.Result, qr, args...) - return err + keys := make([]*apikey.APIKey, 0) + err := ss.sess.Select(ctx, &keys, qr, args...) + return keys, err } func (ss *sqlxStore) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) { @@ -87,20 +87,20 @@ func (ss *sqlxStore) DeleteApiKey(ctx context.Context, cmd *apikey.DeleteCommand return err } -func (ss *sqlxStore) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) error { +func (ss *sqlxStore) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) (*apikey.APIKey, error) { updated := timeNow() var expires *int64 = nil if cmd.SecondsToLive > 0 { v := updated.Add(time.Second * time.Duration(cmd.SecondsToLive)).Unix() expires = &v } else if cmd.SecondsToLive < 0 { - return apikey.ErrInvalidExpiration + return nil, apikey.ErrInvalidExpiration } - err := ss.GetApiKeyByName(ctx, &apikey.GetByNameQuery{OrgID: cmd.OrgID, KeyName: cmd.Name}) + _, err := ss.GetApiKeyByName(ctx, &apikey.GetByNameQuery{OrgID: cmd.OrgID, KeyName: cmd.Name}) // If key with the same orgId and name already exist return err if !errors.Is(err, apikey.ErrInvalid) { - return apikey.ErrDuplicate + return nil, apikey.ErrDuplicate } isRevoked := false t := apikey.APIKey{ @@ -117,28 +117,25 @@ func (ss *sqlxStore) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) erro t.ID, err = ss.sess.ExecWithReturningId(ctx, `INSERT INTO api_key (org_id, name, role, "key", created, updated, expires, service_account_id, is_revoked) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, t.OrgID, t.Name, t.Role, t.Key, t.Created, t.Updated, t.Expires, t.ServiceAccountId, t.IsRevoked) - cmd.Result = &t - return err + return &t, err } -func (ss *sqlxStore) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) error { +func (ss *sqlxStore) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) (*apikey.APIKey, error) { var key apikey.APIKey err := ss.sess.Get(ctx, &key, "SELECT * FROM api_key WHERE id=?", query.ApiKeyID) if err != nil && errors.Is(err, sql.ErrNoRows) { - return apikey.ErrInvalid + return nil, apikey.ErrInvalid } - query.Result = &key - return err + return &key, err } -func (ss *sqlxStore) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) error { +func (ss *sqlxStore) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) (*apikey.APIKey, error) { var key apikey.APIKey err := ss.sess.Get(ctx, &key, "SELECT * FROM api_key WHERE org_id=? AND name=?", query.OrgID, query.KeyName) if err != nil && errors.Is(err, sql.ErrNoRows) { - return apikey.ErrInvalid + return nil, apikey.ErrInvalid } - query.Result = &key - return err + return &key, err } func (ss *sqlxStore) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) { diff --git a/pkg/services/apikey/apikeyimpl/store.go b/pkg/services/apikey/apikeyimpl/store.go index fa3c33cae29..58bf1fc83e1 100644 --- a/pkg/services/apikey/apikeyimpl/store.go +++ b/pkg/services/apikey/apikeyimpl/store.go @@ -8,13 +8,13 @@ import ( ) type store interface { - GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) error + GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) (res []*apikey.APIKey, err error) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) CountAPIKeys(ctx context.Context, orgID int64) (int64, error) DeleteApiKey(ctx context.Context, cmd *apikey.DeleteCommand) error - AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) error - GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) error - GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) error + AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) (res *apikey.APIKey, err error) + GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) (res *apikey.APIKey, err error) + GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) (res *apikey.APIKey, err error) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) UpdateAPIKeyLastUsedDate(ctx context.Context, tokenID int64) error diff --git a/pkg/services/apikey/apikeyimpl/store_test.go b/pkg/services/apikey/apikeyimpl/store_test.go index 74971d78095..a5c62ac79a1 100644 --- a/pkg/services/apikey/apikeyimpl/store_test.go +++ b/pkg/services/apikey/apikeyimpl/store_test.go @@ -43,7 +43,7 @@ func seedApiKeys(t *testing.T, store store, num int) { t.Helper() for i := 0; i < num; i++ { - err := store.AddAPIKey(context.Background(), &apikey.AddCommand{ + _, err := store.AddAPIKey(context.Background(), &apikey.AddCommand{ Name: fmt.Sprintf("key:%d", i), Key: fmt.Sprintf("key:%d", i), OrgID: 1, @@ -64,15 +64,15 @@ func testIntegrationApiKeyDataAccess(t *testing.T, fn getStore) { t.Run("Given saved api key", func(t *testing.T) { cmd := apikey.AddCommand{OrgID: 1, Name: "hello", Key: "asd"} - err := ss.AddAPIKey(context.Background(), &cmd) + _, err := ss.AddAPIKey(context.Background(), &cmd) assert.Nil(t, err) t.Run("Should be able to get key by name", func(t *testing.T) { query := apikey.GetByNameQuery{KeyName: "hello", OrgID: 1} - err = ss.GetApiKeyByName(context.Background(), &query) + key, err := ss.GetApiKeyByName(context.Background(), &query) assert.Nil(t, err) - assert.NotNil(t, query.Result) + assert.NotNil(t, key) }) t.Run("Should be able to get key by hash", func(t *testing.T) { @@ -91,77 +91,77 @@ func testIntegrationApiKeyDataAccess(t *testing.T, fn getStore) { t.Run("Add non expiring key", func(t *testing.T) { cmd := apikey.AddCommand{OrgID: 1, Name: "non-expiring", Key: "asd1", SecondsToLive: 0} - err := ss.AddAPIKey(context.Background(), &cmd) + _, err := ss.AddAPIKey(context.Background(), &cmd) assert.Nil(t, err) query := apikey.GetByNameQuery{KeyName: "non-expiring", OrgID: 1} - err = ss.GetApiKeyByName(context.Background(), &query) + key, err := ss.GetApiKeyByName(context.Background(), &query) assert.Nil(t, err) - assert.Nil(t, query.Result.Expires) + assert.Nil(t, key.Expires) }) t.Run("Add an expiring key", func(t *testing.T) { // expires in one hour cmd := apikey.AddCommand{OrgID: 1, Name: "expiring-in-an-hour", Key: "asd2", SecondsToLive: 3600} - err := ss.AddAPIKey(context.Background(), &cmd) + _, err := ss.AddAPIKey(context.Background(), &cmd) assert.Nil(t, err) query := apikey.GetByNameQuery{KeyName: "expiring-in-an-hour", OrgID: 1} - err = ss.GetApiKeyByName(context.Background(), &query) + key, err := ss.GetApiKeyByName(context.Background(), &query) assert.Nil(t, err) - assert.True(t, *query.Result.Expires >= timeNow().Unix()) + assert.True(t, *key.Expires >= timeNow().Unix()) // timeNow() has been called twice since creation; once by AddAPIKey and once by GetApiKeyByName // therefore two seconds should be subtracted by next value returned by timeNow() // that equals the number by which timeSeed has been advanced then := timeNow().Add(-2 * time.Second) expected := then.Add(1 * time.Hour).UTC().Unix() - assert.Equal(t, *query.Result.Expires, expected) + assert.Equal(t, *key.Expires, expected) }) t.Run("Last Used At datetime update", func(t *testing.T) { // expires in one hour cmd := apikey.AddCommand{OrgID: 1, Name: "last-update-at", Key: "asd3", SecondsToLive: 3600} - err := ss.AddAPIKey(context.Background(), &cmd) + key, err := ss.AddAPIKey(context.Background(), &cmd) require.NoError(t, err) - assert.Nil(t, cmd.Result.LastUsedAt) + assert.Nil(t, key.LastUsedAt) - err = ss.UpdateAPIKeyLastUsedDate(context.Background(), cmd.Result.ID) + err = ss.UpdateAPIKeyLastUsedDate(context.Background(), key.ID) require.NoError(t, err) query := apikey.GetByNameQuery{KeyName: "last-update-at", OrgID: 1} - err = ss.GetApiKeyByName(context.Background(), &query) + key, err = ss.GetApiKeyByName(context.Background(), &query) assert.Nil(t, err) - assert.NotNil(t, query.Result.LastUsedAt) + assert.NotNil(t, key.LastUsedAt) }) t.Run("Add a key with negative lifespan", func(t *testing.T) { // expires in one day cmd := apikey.AddCommand{OrgID: 1, Name: "key-with-negative-lifespan", Key: "asd3", SecondsToLive: -3600} - err := ss.AddAPIKey(context.Background(), &cmd) + _, err := ss.AddAPIKey(context.Background(), &cmd) assert.EqualError(t, err, apikey.ErrInvalidExpiration.Error()) query := apikey.GetByNameQuery{KeyName: "key-with-negative-lifespan", OrgID: 1} - err = ss.GetApiKeyByName(context.Background(), &query) + _, err = ss.GetApiKeyByName(context.Background(), &query) assert.EqualError(t, err, "invalid API key") }) t.Run("Add keys", func(t *testing.T) { // never expires cmd := apikey.AddCommand{OrgID: 1, Name: "key1", Key: "key1", SecondsToLive: 0} - err := ss.AddAPIKey(context.Background(), &cmd) + _, err := ss.AddAPIKey(context.Background(), &cmd) assert.Nil(t, err) // expires in 1s cmd = apikey.AddCommand{OrgID: 1, Name: "key2", Key: "key2", SecondsToLive: 1} - err = ss.AddAPIKey(context.Background(), &cmd) + _, err = ss.AddAPIKey(context.Background(), &cmd) assert.Nil(t, err) // expires in one hour cmd = apikey.AddCommand{OrgID: 1, Name: "key3", Key: "key3", SecondsToLive: 3600} - err = ss.AddAPIKey(context.Background(), &cmd) + _, err = ss.AddAPIKey(context.Background(), &cmd) assert.Nil(t, err) // advance mocked getTime by 1s @@ -174,21 +174,21 @@ func testIntegrationApiKeyDataAccess(t *testing.T, fn getStore) { }, } query := apikey.GetApiKeysQuery{OrgID: 1, IncludeExpired: false, User: testUser} - err = ss.GetAPIKeys(context.Background(), &query) + keys, err := ss.GetAPIKeys(context.Background(), &query) assert.Nil(t, err) - for _, k := range query.Result { + for _, k := range keys { if k.Name == "key2" { t.Fatalf("key2 should not be there") } } query = apikey.GetApiKeysQuery{OrgID: 1, IncludeExpired: true, User: testUser} - err = ss.GetAPIKeys(context.Background(), &query) + keys, err = ss.GetAPIKeys(context.Background(), &query) assert.Nil(t, err) found := false - for _, k := range query.Result { + for _, k := range keys { if k.Name == "key2" { found = true } @@ -211,12 +211,12 @@ func testIntegrationApiKeyDataAccess(t *testing.T, fn getStore) { t.Run("Testing API Duplicate Key Errors", func(t *testing.T) { t.Run("Given saved api key", func(t *testing.T) { cmd := apikey.AddCommand{OrgID: 0, Name: "duplicate", Key: "asd"} - err := ss.AddAPIKey(context.Background(), &cmd) + _, err := ss.AddAPIKey(context.Background(), &cmd) assert.Nil(t, err) t.Run("Add API Key with existing Org ID and Name", func(t *testing.T) { cmd := apikey.AddCommand{OrgID: 0, Name: "duplicate", Key: "asd"} - err = ss.AddAPIKey(context.Background(), &cmd) + _, err = ss.AddAPIKey(context.Background(), &cmd) assert.EqualError(t, err, apikey.ErrDuplicate.Error()) }) }) @@ -258,9 +258,9 @@ func testIntegrationApiKeyDataAccess(t *testing.T, fn getStore) { seedApiKeys(t, store, 10) query := &apikey.GetApiKeysQuery{OrgID: 1, User: tt.user} - err := store.GetAPIKeys(context.Background(), query) + keys, err := store.GetAPIKeys(context.Background(), query) require.NoError(t, err) - assert.Len(t, query.Result, tt.expectedNumKeys) + assert.Len(t, keys, tt.expectedNumKeys) res, err := store.GetAllAPIKeys(context.Background(), 1) require.NoError(t, err) diff --git a/pkg/services/apikey/apikeyimpl/xorm_store.go b/pkg/services/apikey/apikeyimpl/xorm_store.go index 457e495ff6c..b94ea1c9fd7 100644 --- a/pkg/services/apikey/apikeyimpl/xorm_store.go +++ b/pkg/services/apikey/apikeyimpl/xorm_store.go @@ -23,8 +23,8 @@ type sqlStore struct { // timeNow makes it possible to test usage of time var timeNow = time.Now -func (ss *sqlStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) error { - return ss.db.WithDbSession(ctx, func(dbSession *db.Session) error { +func (ss *sqlStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) (res []*apikey.APIKey, err error) { + err = ss.db.WithDbSession(ctx, func(dbSession *db.Session) error { var sess *xorm.Session if query.IncludeExpired { @@ -47,9 +47,10 @@ func (ss *sqlStore) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuer sess.And(filter.Where, filter.Args...) } - query.Result = make([]*apikey.APIKey, 0) - return sess.Find(&query.Result) + res = make([]*apikey.APIKey, 0) + return sess.Find(&res) }) + return res, err } func (ss *sqlStore) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) { @@ -100,8 +101,8 @@ func (ss *sqlStore) DeleteApiKey(ctx context.Context, cmd *apikey.DeleteCommand) }) } -func (ss *sqlStore) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) error { - return ss.db.WithTransactionalDbSession(ctx, func(sess *db.Session) error { +func (ss *sqlStore) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) (res *apikey.APIKey, err error) { + err = ss.db.WithTransactionalDbSession(ctx, func(sess *db.Session) error { key := apikey.APIKey{OrgID: cmd.OrgID, Name: cmd.Name} exists, _ := sess.Get(&key) if exists { @@ -133,13 +134,14 @@ func (ss *sqlStore) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) error if _, err := sess.Insert(&t); err != nil { return fmt.Errorf("%s: %w", "failed to insert token", err) } - cmd.Result = &t + res = &t return nil }) + return res, err } -func (ss *sqlStore) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) error { - return ss.db.WithDbSession(ctx, func(sess *db.Session) error { +func (ss *sqlStore) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) (res *apikey.APIKey, err error) { + err = ss.db.WithDbSession(ctx, func(sess *db.Session) error { var key apikey.APIKey has, err := sess.ID(query.ApiKeyID).Get(&key) @@ -149,13 +151,14 @@ func (ss *sqlStore) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuer return apikey.ErrInvalid } - query.Result = &key + res = &key return nil }) + return res, err } -func (ss *sqlStore) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) error { - return ss.db.WithDbSession(ctx, func(sess *db.Session) error { +func (ss *sqlStore) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) (res *apikey.APIKey, err error) { + err = ss.db.WithDbSession(ctx, func(sess *db.Session) error { var key apikey.APIKey has, err := sess.Where("org_id=? AND name=?", query.OrgID, query.KeyName).Get(&key) @@ -165,9 +168,10 @@ func (ss *sqlStore) GetApiKeyByName(ctx context.Context, query *apikey.GetByName return apikey.ErrInvalid } - query.Result = &key + res = &key return nil }) + return res, err } func (ss *sqlStore) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) { diff --git a/pkg/services/apikey/apikeytest/fake.go b/pkg/services/apikey/apikeytest/fake.go index c03aa16c040..d48471bc0e5 100644 --- a/pkg/services/apikey/apikeytest/fake.go +++ b/pkg/services/apikey/apikeytest/fake.go @@ -13,20 +13,17 @@ type Service struct { ExpectedAPIKey *apikey.APIKey } -func (s *Service) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) error { - query.Result = s.ExpectedAPIKeys - return s.ExpectedError +func (s *Service) GetAPIKeys(ctx context.Context, query *apikey.GetApiKeysQuery) ([]*apikey.APIKey, error) { + return s.ExpectedAPIKeys, s.ExpectedError } func (s *Service) GetAllAPIKeys(ctx context.Context, orgID int64) ([]*apikey.APIKey, error) { return s.ExpectedAPIKeys, s.ExpectedError } -func (s *Service) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) error { - query.Result = s.ExpectedAPIKey - return s.ExpectedError +func (s *Service) GetApiKeyById(ctx context.Context, query *apikey.GetByIDQuery) (*apikey.APIKey, error) { + return s.ExpectedAPIKey, s.ExpectedError } -func (s *Service) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) error { - query.Result = s.ExpectedAPIKey - return s.ExpectedError +func (s *Service) GetApiKeyByName(ctx context.Context, query *apikey.GetByNameQuery) (*apikey.APIKey, error) { + return s.ExpectedAPIKey, s.ExpectedError } func (s *Service) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.APIKey, error) { return s.ExpectedAPIKey, s.ExpectedError @@ -34,9 +31,8 @@ func (s *Service) GetAPIKeyByHash(ctx context.Context, hash string) (*apikey.API func (s *Service) DeleteApiKey(ctx context.Context, cmd *apikey.DeleteCommand) error { return s.ExpectedError } -func (s *Service) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) error { - cmd.Result = s.ExpectedAPIKey - return s.ExpectedError +func (s *Service) AddAPIKey(ctx context.Context, cmd *apikey.AddCommand) (*apikey.APIKey, error) { + return s.ExpectedAPIKey, s.ExpectedError } func (s *Service) UpdateAPIKeyLastUsedDate(ctx context.Context, tokenID int64) error { return s.ExpectedError diff --git a/pkg/services/apikey/model.go b/pkg/services/apikey/model.go index 84f631e9a1d..1414b672287 100644 --- a/pkg/services/apikey/model.go +++ b/pkg/services/apikey/model.go @@ -40,8 +40,6 @@ type AddCommand struct { Key string `json:"-"` SecondsToLive int64 `json:"secondsToLive"` ServiceAccountID *int64 `json:"-"` - - Result *APIKey `json:"-"` } type DeleteCommand struct { @@ -53,17 +51,14 @@ type GetApiKeysQuery struct { OrgID int64 IncludeExpired bool User *user.SignedInUser - Result []*APIKey } type GetByNameQuery struct { KeyName string OrgID int64 - Result *APIKey } type GetByIDQuery struct { ApiKeyID int64 - Result *APIKey } const ( diff --git a/pkg/services/authn/clients/api_key.go b/pkg/services/authn/clients/api_key.go index 52da3782397..cc0b86fb5a3 100644 --- a/pkg/services/authn/clients/api_key.go +++ b/pkg/services/authn/clients/api_key.go @@ -119,12 +119,13 @@ func (s *APIKey) getFromTokenLegacy(ctx context.Context, token string) (*apikey. // fetch key keyQuery := apikey.GetByNameQuery{KeyName: decoded.Name, OrgID: decoded.OrgId} - if err := s.apiKeyService.GetApiKeyByName(ctx, &keyQuery); err != nil { + key, err := s.apiKeyService.GetApiKeyByName(ctx, &keyQuery) + if err != nil { return nil, err } // validate api key - isValid, err := apikeygen.IsValid(decoded, keyQuery.Result.Key) + isValid, err := apikeygen.IsValid(decoded, key.Key) if err != nil { return nil, err } @@ -132,7 +133,7 @@ func (s *APIKey) getFromTokenLegacy(ctx context.Context, token string) (*apikey. return nil, apikeygen.ErrInvalidApiKey } - return keyQuery.Result, nil + return key, nil } func (s *APIKey) Test(ctx context.Context, r *authn.Request) bool { diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index f8e97b8fb92..c2f21147c3c 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -282,12 +282,13 @@ func (h *ContextHandler) getAPIKey(ctx context.Context, keyString string) (*apik // fetch key keyQuery := apikey.GetByNameQuery{KeyName: decoded.Name, OrgID: decoded.OrgId} - if err := h.apiKeyService.GetApiKeyByName(ctx, &keyQuery); err != nil { + key, err := h.apiKeyService.GetApiKeyByName(ctx, &keyQuery) + if err != nil { return nil, err } // validate api key - isValid, err := apikeygen.IsValid(decoded, keyQuery.Result.Key) + isValid, err := apikeygen.IsValid(decoded, key.Key) if err != nil { return nil, err } @@ -295,7 +296,7 @@ func (h *ContextHandler) getAPIKey(ctx context.Context, keyString string) (*apik return nil, apikeygen.ErrInvalidApiKey } - return keyQuery.Result, nil + return key, nil } func (h *ContextHandler) initContextWithAPIKey(reqContext *contextmodel.ReqContext) bool { diff --git a/pkg/services/serviceaccounts/database/token_store.go b/pkg/services/serviceaccounts/database/token_store.go index 4674f60c7bc..791d8353172 100644 --- a/pkg/services/serviceaccounts/database/token_store.go +++ b/pkg/services/serviceaccounts/database/token_store.go @@ -58,7 +58,8 @@ func (s *ServiceAccountsStoreImpl) AddServiceAccountToken(ctx context.Context, s ServiceAccountID: &serviceAccountId, } - if err := s.apiKeyService.AddAPIKey(ctx, addKeyCmd); err != nil { + key, err := s.apiKeyService.AddAPIKey(ctx, addKeyCmd) + if err != nil { switch { case errors.Is(err, apikey.ErrDuplicate): return serviceaccounts.ErrDuplicateToken.Errorf("service account token with name %s already exists in the organization", cmd.Name) @@ -69,7 +70,7 @@ func (s *ServiceAccountsStoreImpl) AddServiceAccountToken(ctx context.Context, s return err } - apiKey = addKeyCmd.Result + apiKey = key return nil }) } diff --git a/pkg/services/serviceaccounts/tests/common.go b/pkg/services/serviceaccounts/tests/common.go index c12e49a0075..e4fa16025f1 100644 --- a/pkg/services/serviceaccounts/tests/common.go +++ b/pkg/services/serviceaccounts/tests/common.go @@ -87,22 +87,22 @@ func SetupApiKey(t *testing.T, sqlStore *sqlstore.SQLStore, testKey TestApiKey) quotaService := quotatest.New(false, nil) apiKeyService, err := apikeyimpl.ProvideService(sqlStore, sqlStore.Cfg, quotaService) require.NoError(t, err) - err = apiKeyService.AddAPIKey(context.Background(), addKeyCmd) + key, err := apiKeyService.AddAPIKey(context.Background(), addKeyCmd) require.NoError(t, err) if testKey.IsExpired { err := sqlStore.WithTransactionalDbSession(context.Background(), func(sess *db.Session) error { // Force setting expires to time before now to make key expired var expires int64 = 1 - key := apikey.APIKey{Expires: &expires} - rowsAffected, err := sess.ID(addKeyCmd.Result.ID).Update(&key) + expiringKey := apikey.APIKey{Expires: &expires} + rowsAffected, err := sess.ID(key.ID).Update(&expiringKey) require.Equal(t, int64(1), rowsAffected) return err }) require.NoError(t, err) } - return addKeyCmd.Result + return key } func SetupMockAccesscontrol(t *testing.T,