diff --git a/pkg/services/auth/id.go b/pkg/services/auth/id.go index 82aa7905a5b..92a18a9b8d7 100644 --- a/pkg/services/auth/id.go +++ b/pkg/services/auth/id.go @@ -20,6 +20,7 @@ type IDSigner interface { type IDClaims struct { jwt.Claims + AuthenticatedBy string `json:"authenticatedBy,omitempty"` } const settingsKey = "forwardGrafanaIdToken" diff --git a/pkg/services/auth/idimpl/service.go b/pkg/services/auth/idimpl/service.go index 566ad620838..3c42c5c103b 100644 --- a/pkg/services/auth/idimpl/service.go +++ b/pkg/services/auth/idimpl/service.go @@ -2,7 +2,9 @@ package idimpl import ( "context" + "errors" "fmt" + "strconv" "time" "github.com/go-jose/go-jose/v3/jwt" @@ -15,6 +17,8 @@ import ( "github.com/grafana/grafana/pkg/services/auth/identity" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) @@ -28,9 +32,14 @@ var _ auth.IDService = (*Service)(nil) func ProvideService( cfg *setting.Cfg, signer auth.IDSigner, cache remotecache.CacheStorage, - features featuremgmt.FeatureToggles, authnService authn.Service, reg prometheus.Registerer, + features featuremgmt.FeatureToggles, authnService authn.Service, + authInfoService login.AuthInfoService, reg prometheus.Registerer, ) *Service { - s := &Service{cfg: cfg, logger: log.New("id-service"), signer: signer, cache: cache, metrics: newMetrics(reg)} + s := &Service{ + cfg: cfg, logger: log.New("id-service"), + signer: signer, cache: cache, + authInfoService: authInfoService, metrics: newMetrics(reg), + } if features.IsEnabledGlobally(featuremgmt.FlagIdForwarding) { authnService.RegisterPostAuthHook(s.hook, 140) @@ -40,12 +49,13 @@ func ProvideService( } type Service struct { - cfg *setting.Cfg - logger log.Logger - signer auth.IDSigner - cache remotecache.CacheStorage - si singleflight.Group - metrics *metrics + cfg *setting.Cfg + logger log.Logger + signer auth.IDSigner + cache remotecache.CacheStorage + authInfoService login.AuthInfoService + si singleflight.Group + metrics *metrics } func (s *Service) SignIdentity(ctx context.Context, id identity.Requester) (string, error) { @@ -61,15 +71,15 @@ func (s *Service) SignIdentity(ctx context.Context, id identity.Requester) (stri cachedToken, err := s.cache.Get(ctx, cacheKey) if err == nil { s.metrics.tokenSigningFromCacheCounter.Inc() - s.logger.Debug("Cached token found", "namespace", namespace, "id", identifier) + s.logger.FromContext(ctx).Debug("Cached token found", "namespace", namespace, "id", identifier) return string(cachedToken), nil } s.metrics.tokenSigningCounter.Inc() - s.logger.Debug("Sign new id token", "namespace", namespace, "id", identifier) + s.logger.FromContext(ctx).Debug("Sign new id token", "namespace", namespace, "id", identifier) now := time.Now() - token, err := s.signer.SignIDToken(ctx, &auth.IDClaims{ + claims := &auth.IDClaims{ Claims: jwt.Claims{ Issuer: s.cfg.AppURL, Audience: getAudience(id.GetOrgID()), @@ -77,15 +87,22 @@ func (s *Service) SignIdentity(ctx context.Context, id identity.Requester) (stri Expiry: jwt.NewNumericDate(now.Add(tokenTTL)), IssuedAt: jwt.NewNumericDate(now), }, - }) + } + if identity.IsNamespace(namespace, identity.NamespaceUser) { + if err := s.setUserClaims(ctx, identifier, claims); err != nil { + return "", err + } + } + + token, err := s.signer.SignIDToken(ctx, claims) if err != nil { s.metrics.failedTokenSigningCounter.Inc() return "", err } if err := s.cache.Set(ctx, cacheKey, []byte(token), cacheTTL); err != nil { - s.logger.Error("Failed to add id token to cache", "error", err) + s.logger.FromContext(ctx).Error("Failed to add id token to cache", "error", err) } return token, nil @@ -98,12 +115,37 @@ func (s *Service) SignIdentity(ctx context.Context, id identity.Requester) (stri return result.(string), nil } +func (s *Service) setUserClaims(ctx context.Context, identifier string, claims *auth.IDClaims) error { + id, err := strconv.ParseInt(identifier, 10, 64) + if err != nil { + return err + } + + if id == 0 { + return nil + } + + info, err := s.authInfoService.GetAuthInfo(ctx, &login.GetAuthInfoQuery{UserId: id}) + if err != nil { + // we ignore errors when a user don't have external user auth + if !errors.Is(err, user.ErrUserNotFound) { + s.logger.FromContext(ctx).Error("Failed to fetch auth info", "userId", id, "error", err) + } + + return nil + } + + claims.AuthenticatedBy = info.AuthModule + + return nil +} + func (s *Service) hook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error { // FIXME(kalleep): we should probably lazy load this token, err := s.SignIdentity(ctx, identity) if err != nil { namespace, id := identity.GetNamespacedID() - s.logger.Error("Failed to sign id token", "err", err, "namespace", namespace, "id", id) + s.logger.FromContext(ctx).Error("Failed to sign id token", "err", err, "namespace", namespace, "id", id) // for now don't return error so we don't break authentication from this hook return nil } diff --git a/pkg/services/auth/idimpl/service_test.go b/pkg/services/auth/idimpl/service_test.go index 14a7fb36d5e..7d214231f6c 100644 --- a/pkg/services/auth/idimpl/service_test.go +++ b/pkg/services/auth/idimpl/service_test.go @@ -1,13 +1,22 @@ package idimpl import ( + "context" + "encoding/json" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana/pkg/infra/remotecache" + "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/idtest" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/authn/authntest" "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/login/authinfotest" + "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" ) @@ -22,7 +31,7 @@ func Test_ProvideService(t *testing.T) { }, } - _ = ProvideService(setting.NewCfg(), nil, nil, features, authnService, nil) + _ = ProvideService(setting.NewCfg(), nil, nil, features, authnService, nil, nil) assert.True(t, hookRegistered) }) @@ -36,7 +45,44 @@ func Test_ProvideService(t *testing.T) { }, } - _ = ProvideService(setting.NewCfg(), nil, nil, features, authnService, nil) + _ = ProvideService(setting.NewCfg(), nil, nil, features, authnService, nil, nil) assert.False(t, hookRegistered) }) } + +func TestService_SignIdentity(t *testing.T) { + signer := &idtest.MockSigner{ + SignIDTokenFn: func(_ context.Context, claims *auth.IDClaims) (string, error) { + data, err := json.Marshal(claims) + if err != nil { + return "", err + } + return string(data), nil + }, + } + + t.Run("should sing identity", func(t *testing.T) { + s := ProvideService( + setting.NewCfg(), signer, remotecache.NewFakeCacheStorage(), + featuremgmt.WithFeatures(featuremgmt.FlagIdForwarding), + &authntest.FakeService{}, &authinfotest.FakeService{ExpectedError: user.ErrUserNotFound}, nil, + ) + token, err := s.SignIdentity(context.Background(), &authn.Identity{ID: "user:1"}) + require.NoError(t, err) + require.NotEmpty(t, token) + }) + + t.Run("should sing identity with authenticated by if user is externally authenticated", func(t *testing.T) { + s := ProvideService( + setting.NewCfg(), signer, remotecache.NewFakeCacheStorage(), + featuremgmt.WithFeatures(featuremgmt.FlagIdForwarding), + &authntest.FakeService{}, &authinfotest.FakeService{ExpectedUserAuth: &login.UserAuth{AuthModule: login.AzureADAuthModule}}, nil, + ) + token, err := s.SignIdentity(context.Background(), &authn.Identity{ID: "user:1"}) + require.NoError(t, err) + + claims := &auth.IDClaims{} + require.NoError(t, json.Unmarshal([]byte(token), claims)) + assert.Equal(t, login.AzureADAuthModule, claims.AuthenticatedBy) + }) +} diff --git a/pkg/services/auth/idimpl/signer.go b/pkg/services/auth/idimpl/signer.go index 5bb004deb11..e69a8422b48 100644 --- a/pkg/services/auth/idimpl/signer.go +++ b/pkg/services/auth/idimpl/signer.go @@ -37,7 +37,7 @@ func (s *LocalSigner) SignIDToken(ctx context.Context, claims *auth.IDClaims) (s return "", err } - builder := jwt.Signed(signer).Claims(claims.Claims) + builder := jwt.Signed(signer).Claims(claims) token, err := builder.CompactSerialize() if err != nil { diff --git a/pkg/services/auth/idtest/mock.go b/pkg/services/auth/idtest/mock.go new file mode 100644 index 00000000000..eafc244ec5b --- /dev/null +++ b/pkg/services/auth/idtest/mock.go @@ -0,0 +1,18 @@ +package idtest + +import ( + "context" + + "github.com/grafana/grafana/pkg/services/auth" +) + +type MockSigner struct { + SignIDTokenFn func(ctx context.Context, claims *auth.IDClaims) (string, error) +} + +func (s *MockSigner) SignIDToken(ctx context.Context, claims *auth.IDClaims) (string, error) { + if s.SignIDTokenFn != nil { + return s.SignIDTokenFn(ctx, claims) + } + return "", nil +}