diff --git a/pkg/services/authn/clients/api_key.go b/pkg/services/authn/clients/api_key.go index 558340cdaa0..b4cf6bc42a5 100644 --- a/pkg/services/authn/clients/api_key.go +++ b/pkg/services/authn/clients/api_key.go @@ -154,8 +154,9 @@ func (s *APIKey) Priority() uint { } func (s *APIKey) Hook(ctx context.Context, identity *authn.Identity, r *authn.Request) error { - namespace, id := identity.NamespacedID() - if namespace != authn.NamespaceAPIKey { + id, exists := s.getAPIKeyID(ctx, identity, r) + + if !exists { return nil } @@ -173,6 +174,27 @@ func (s *APIKey) Hook(ctx context.Context, identity *authn.Identity, r *authn.Re return nil } +func (s *APIKey) getAPIKeyID(ctx context.Context, identity *authn.Identity, r *authn.Request) (apiKeyID int64, exists bool) { + namespace, id := identity.NamespacedID() + + if namespace == authn.NamespaceAPIKey { + return id, true + } + + if namespace == authn.NamespaceServiceAccount { + // When the identity is service account, the ID in from the namespace is the service account ID. + // We need to fetch the API key in this scenario, as we could use it to uniquely identify a service account token. + apiKey, err := s.getAPIKey(ctx, getTokenFromRequest(r)) + if err != nil { + s.log.Warn("Failed to fetch the API Key from request") + return -1, false + } + + return apiKey.ID, true + } + return -1, false +} + func looksLikeApiKey(token string) bool { return token != "" } diff --git a/pkg/services/authn/clients/api_key_test.go b/pkg/services/authn/clients/api_key_test.go index 4848eb38358..f4dce265cf0 100644 --- a/pkg/services/authn/clients/api_key_test.go +++ b/pkg/services/authn/clients/api_key_test.go @@ -201,6 +201,112 @@ func TestAPIKey_Test(t *testing.T) { } } +func TestAPIKey_GetAPIKeyIDFromIdentity(t *testing.T) { + type TestCase struct { + desc string + expectedKey *apikey.APIKey + expectedIdentity *authn.Identity + expectedError error + expectedKeyID int64 + expectedExists bool + } + + tests := []TestCase{ + { + desc: "should return API Key ID for valid token that is connected to service account", + expectedKey: &apikey.APIKey{ + ID: 1, + OrgID: 1, + Key: hash, + ServiceAccountId: intPtr(1), + }, + expectedIdentity: &authn.Identity{ + ID: "service-account:1", + OrgID: 1, + Name: "test", + AuthenticatedBy: login.APIKeyAuthModule, + }, + expectedKeyID: 1, + expectedExists: true, + }, + { + desc: "should return API Key ID for valid token for API key", + expectedKey: &apikey.APIKey{ + ID: 2, + OrgID: 1, + Key: hash, + }, + expectedIdentity: &authn.Identity{ + ID: "api-key:2", + OrgID: 1, + Name: "test", + AuthenticatedBy: login.APIKeyAuthModule, + }, + expectedKeyID: 2, + expectedExists: true, + }, + { + desc: "should not return any ID when the request is not made by API key or service account", + expectedKey: &apikey.APIKey{ + ID: 2, + OrgID: 1, + Key: hash, + }, + expectedIdentity: &authn.Identity{ + ID: "user:2", + OrgID: 1, + Name: "test", + AuthenticatedBy: login.APIKeyAuthModule, + }, + expectedKeyID: -1, + expectedExists: false, + }, + { + desc: "should not return any ID when the can't fetch API Key", + expectedKey: &apikey.APIKey{ + ID: 1, + OrgID: 1, + Key: hash, + }, + expectedIdentity: &authn.Identity{ + ID: "service-account:2", + OrgID: 1, + Name: "test", + AuthenticatedBy: login.APIKeyAuthModule, + }, + expectedError: fmt.Errorf("invalid token"), + expectedKeyID: -1, + expectedExists: false, + }, + } + + req := &authn.Request{HTTPRequest: &http.Request{ + Header: map[string][]string{ + "Authorization": {"Bearer " + secret}, + }, + }} + + signedInUser := &user.SignedInUser{ + UserID: 1, + OrgID: 1, + Name: "test", + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + c := ProvideAPIKey(&apikeytest.Service{ + ExpectedError: tt.expectedError, + ExpectedAPIKey: tt.expectedKey, + }, &usertest.FakeUserService{ + ExpectedSignedInUser: signedInUser, + }) + id, exists := c.getAPIKeyID(context.Background(), tt.expectedIdentity, req) + assert.Equal(t, tt.expectedExists, exists) + assert.Equal(t, tt.expectedKeyID, id) + }) + } +} + func intPtr(n int64) *int64 { return &n }