diff --git a/go.work.sum b/go.work.sum index cf90c70f3b1..d684e218820 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1363,6 +1363,7 @@ github.com/grafana/grafana/pkg/build v0.0.0-20250220114259-be81314e2118/go.mod h github.com/grafana/grafana/pkg/build v0.0.0-20250227105625-8f465f124924/go.mod h1:Vw0LdoMma64VgIMVpRY3i0D156jddgUGjTQBOcyeF3k= github.com/grafana/grafana/pkg/build v0.0.0-20250227163402-d78c646f93bb/go.mod h1:Vw0LdoMma64VgIMVpRY3i0D156jddgUGjTQBOcyeF3k= github.com/grafana/grafana/pkg/build v0.0.0-20250403075254-4918d8720c61/go.mod h1:LGVnSwdrS0ZnJ2WXEl5acgDoYPm74EUSFavca1NKHI8= +github.com/grafana/grafana/pkg/build v0.0.0-20250625151647-35f89a456cc6/go.mod h1:dIu5dZy00k2TBdpVBXkvSbxHNj5H7lW/sOTpJTtKIXg= github.com/grafana/grafana/pkg/semconv v0.0.0-20250121113133-e747350fee2d/go.mod h1:tfLnBpPYgwrBMRz4EXqPCZJyCjEG4Ev37FSlXnocJ2c= github.com/grafana/grafana/pkg/semconv v0.0.0-20250627191313-2f1a6ae1712b/go.mod h1:mu3yl0GxB0eQZV1q7Kka0pkF3Th9x7W04WrjR9wqBlc= github.com/grafana/grafana/pkg/storage/unified/apistore v0.0.0-20250121113133-e747350fee2d/go.mod h1:CXpwZ3Mkw6xVlGKc0SqUxqXCP3Uv182q6qAQnLaLxRg= diff --git a/pkg/apimachinery/go.mod b/pkg/apimachinery/go.mod index b9a01c87fe0..2155cda748e 100644 --- a/pkg/apimachinery/go.mod +++ b/pkg/apimachinery/go.mod @@ -3,6 +3,7 @@ module github.com/grafana/grafana/pkg/apimachinery go 1.24.4 require ( + github.com/go-jose/go-jose/v3 v3.0.4 // @grafana/identity-access-team github.com/grafana/authlib v0.0.0-20250618124654-54543efcfeed // @grafana/identity-access-team github.com/grafana/authlib/types v0.0.0-20250325095148-d6da9c164a7d // @grafana/identity-access-team github.com/stretchr/testify v1.10.0 @@ -15,7 +16,6 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect - github.com/go-jose/go-jose/v3 v3.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect diff --git a/pkg/apimachinery/identity/requester.go b/pkg/apimachinery/identity/requester.go index 3bdbd30bf17..bd2ca2559ce 100644 --- a/pkg/apimachinery/identity/requester.go +++ b/pkg/apimachinery/identity/requester.go @@ -3,7 +3,9 @@ package identity import ( "fmt" "strconv" + "time" + "github.com/go-jose/go-jose/v3/jwt" "k8s.io/apiserver/pkg/authentication/user" claims "github.com/grafana/authlib/types" @@ -125,3 +127,31 @@ func intIdentifier(typ claims.IdentityType, id string, expected ...claims.Identi return 0, ErrNotIntIdentifier } + +// IsIDTokenExpired returns true if the ID token is expired. +// If no ID token exists, returns false. +func IsIDTokenExpired(requester Requester) bool { + idToken := requester.GetIDToken() + if idToken == "" { + return false + } + + parsed, err := jwt.ParseSigned(idToken) + if err != nil { + return false + } + + var claims struct { + Expiry *jwt.NumericDate `json:"exp"` + } + if err := parsed.UnsafeClaimsWithoutVerification(&claims); err != nil { + return false + } + + if claims.Expiry != nil { + expiryTime := claims.Expiry.Time() + return time.Now().After(expiryTime) + } + + return false +} diff --git a/pkg/apimachinery/identity/requester_test.go b/pkg/apimachinery/identity/requester_test.go new file mode 100644 index 00000000000..9dbb027ad3d --- /dev/null +++ b/pkg/apimachinery/identity/requester_test.go @@ -0,0 +1,98 @@ +package identity_test + +import ( + "testing" + "time" + + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" + "github.com/stretchr/testify/require" + + "github.com/grafana/grafana/pkg/apimachinery/identity" +) + +func TestIsIDTokenExpired(t *testing.T) { + tests := []struct { + name string + token func(t *testing.T) string + expected bool + }{ + { + name: "should return false when ID token is not set", + token: func(t *testing.T) string { + return "" + }, + expected: false, + }, + { + name: "should return false when ID token is not expired", + token: func(t *testing.T) string { + expiration := time.Now().Add(time.Hour) + return createToken(t, &expiration) + }, + expected: false, + }, + { + name: "should return true when ID token is expired", + token: func(t *testing.T) string { + expiration := time.Now().Add(-time.Hour) + return createToken(t, &expiration) + }, + expected: true, + }, + { + name: "should return false when ID token has no expiry claim", + token: func(t *testing.T) string { + return createToken(t, nil) + }, + expected: false, + }, + { + name: "should return false when ID token is malformed", + token: func(t *testing.T) string { + return "invalid.jwt.token" + }, + expected: false, + }, + { + name: "should handle token that expires exactly now", + token: func(t *testing.T) string { + expiration := time.Now().Add(-time.Millisecond) + return createToken(t, &expiration) + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token := tt.token(t) + requester := &identity.StaticRequester{IDToken: token} + + result := identity.IsIDTokenExpired(requester) + require.Equal(t, tt.expected, result) + }) + } +} + +func createToken(t *testing.T, exp *time.Time) string { + key := []byte("test-secret-key") + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, nil) + require.NoError(t, err) + + claims := struct { + jwt.Claims + }{ + Claims: jwt.Claims{ + Subject: "test-user", + }, + } + + if exp != nil { + claims.Expiry = jwt.NewNumericDate(*exp) + } + + token, err := jwt.Signed(signer).Claims(claims).CompactSerialize() + require.NoError(t, err) + return token +} diff --git a/pkg/services/live/live.go b/pkg/services/live/live.go index 116726a9e86..0394e06fdc0 100644 --- a/pkg/services/live/live.go +++ b/pkg/services/live/live.go @@ -641,6 +641,20 @@ func runConcurrentlyIfNeeded(ctx context.Context, semaphore chan struct{}, fn fu return nil } +func (g *GrafanaLive) checkIDTokenExpirationAndRefresh(user identity.Requester, client *centrifuge.Client) bool { + if !identity.IsIDTokenExpired(user) { + return false + } + + logger.Debug("ID token expired, triggering refresh", "user", client.UserID(), "client", client.ID()) + err := g.node.Refresh(client.UserID(), centrifuge.WithRefreshExpired(true)) + if err != nil { + logger.Error("Failed to refresh expired ID token", "user", client.UserID(), "client", client.ID(), "error", err) + } + + return true +} + func (g *GrafanaLive) HandleDatasourceDelete(orgID int64, dsUID string) { if g.runStreamManager == nil { return @@ -676,6 +690,12 @@ func (g *GrafanaLive) handleOnRPC(clientContextWithSpan context.Context, client logger.Error("No user found in context", "user", client.UserID(), "client", client.ID(), "method", e.Method) return centrifuge.RPCReply{}, centrifuge.ErrorInternal } + + // Check if ID token is expired and trigger refresh if needed + if expired := g.checkIDTokenExpirationAndRefresh(user, client); expired { + return centrifuge.RPCReply{}, centrifuge.ErrorExpired + } + var req dtos.MetricRequest err := json.Unmarshal(e.Data, &req) if err != nil { @@ -712,6 +732,11 @@ func (g *GrafanaLive) handleOnSubscribe(clientContextWithSpan context.Context, c return centrifuge.SubscribeReply{}, centrifuge.ErrorInternal } + // Check if ID token is expired and trigger refresh if needed + if expired := g.checkIDTokenExpirationAndRefresh(user, client); expired { + return centrifuge.SubscribeReply{}, centrifuge.ErrorExpired + } + // See a detailed comment for StripOrgID about orgID management in Live. orgID, channel, err := orgchannel.StripOrgID(e.Channel) if err != nil { @@ -813,6 +838,11 @@ func (g *GrafanaLive) handleOnPublish(clientCtxWithSpan context.Context, client return centrifuge.PublishReply{}, centrifuge.ErrorInternal } + // Check if ID token is expired and trigger refresh if needed + if expired := g.checkIDTokenExpirationAndRefresh(user, client); expired { + return centrifuge.PublishReply{}, centrifuge.ErrorExpired + } + // See a detailed comment for StripOrgID about orgID management in Live. orgID, channel, err := orgchannel.StripOrgID(e.Channel) if err != nil { diff --git a/pkg/services/live/live_test.go b/pkg/services/live/live_test.go index 7ae91b968f5..eb685d6a4d2 100644 --- a/pkg/services/live/live_test.go +++ b/pkg/services/live/live_test.go @@ -7,15 +7,20 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" "github.com/stretchr/testify/require" + "github.com/centrifugal/centrifuge" "github.com/grafana/grafana/pkg/api/routing" + "github.com/grafana/grafana/pkg/apimachinery/identity" "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/usagestats" "github.com/grafana/grafana/pkg/services/accesscontrol/acimpl" "github.com/grafana/grafana/pkg/services/annotations/annotationstest" "github.com/grafana/grafana/pkg/services/dashboards" "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/services/live/livecontext" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/tests/testsuite" ) @@ -29,20 +34,9 @@ func TestIntegration_provideLiveService_RedisUnavailable(t *testing.T) { cfg.LiveHAEngine = "testredisunavailable" - _, err := ProvideService(nil, cfg, - routing.NewRouteRegister(), - nil, nil, nil, nil, - db.InitTestDB(t), - nil, - &usagestats.UsageStatsMock{T: t}, - nil, - featuremgmt.WithFeatures(), - acimpl.ProvideAccessControl(featuremgmt.WithFeatures()), - &dashboards.FakeDashboardService{}, - annotationstest.NewFakeAnnotationsRepo(), - nil, nil) + _, err := setupLiveService(cfg, t) - // Proceeds without live HA if redis is unavaialble + // Proceeds without live HA if redis is unavailable require.NoError(t, err) } @@ -233,3 +227,173 @@ func Test_getHistogramMetric(t *testing.T) { }) } } + +func Test_handleOnPublish_IDTokenExpiration(t *testing.T) { + g, err := setupLiveService(nil, t) + require.NoError(t, err) + + client, _, err := centrifuge.NewClient(context.Background(), g.node, newDummyTransport("test")) + require.NoError(t, err) + + t.Run("expired token", func(t *testing.T) { + expiration := time.Now().Add(-time.Hour) + token := createToken(t, &expiration) + ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token}) + reply, err := g.handleOnPublish(ctx, client, centrifuge.PublishEvent{ + Channel: "test", + Data: []byte("test"), + }) + require.ErrorIs(t, err, centrifuge.ErrorExpired) + require.Empty(t, reply) + }) + + t.Run("unexpired token", func(t *testing.T) { + expiration := time.Now().Add(time.Hour) + token := createToken(t, &expiration) + ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token}) + reply, err := g.handleOnPublish(ctx, client, centrifuge.PublishEvent{ + Channel: "test", + Data: []byte("test"), + }) + + // Another error is returned if the token is not expired but the refresh fails. + // That happens because we're providing an invalid orgID as the channel. + require.NotErrorIs(t, err, centrifuge.ErrorExpired) + require.Empty(t, reply) + }) +} + +func Test_handleOnRPC_IDTokenExpiration(t *testing.T) { + g, err := setupLiveService(nil, t) + require.NoError(t, err) + + client, _, err := centrifuge.NewClient(context.Background(), g.node, newDummyTransport("test")) + require.NoError(t, err) + + t.Run("expired token", func(t *testing.T) { + expiration := time.Now().Add(-time.Hour) + token := createToken(t, &expiration) + ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token}) + reply, err := g.handleOnRPC(ctx, client, centrifuge.RPCEvent{ + Method: "grafana.query", + Data: []byte("test"), + }) + require.ErrorIs(t, err, centrifuge.ErrorExpired) + require.Empty(t, reply) + }) + + t.Run("unexpired token", func(t *testing.T) { + expiration := time.Now().Add(time.Hour) + token := createToken(t, &expiration) + ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token}) + reply, err := g.handleOnRPC(ctx, client, centrifuge.RPCEvent{ + Method: "grafana.query", + Data: []byte("test"), + }) + + // Another error is returned if the token is not expired but the refresh fails. + // That happens because we're providing an invalid orgID as the channel. + require.NotErrorIs(t, err, centrifuge.ErrorExpired) + require.Empty(t, reply) + }) +} + +func Test_handleOnSubscribe_IDTokenExpiration(t *testing.T) { + g, err := setupLiveService(nil, t) + require.NoError(t, err) + + client, _, err := centrifuge.NewClient(context.Background(), g.node, newDummyTransport("test")) + require.NoError(t, err) + + t.Run("expired token", func(t *testing.T) { + expiration := time.Now().Add(-time.Hour) + token := createToken(t, &expiration) + ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token}) + reply, err := g.handleOnSubscribe(ctx, client, centrifuge.SubscribeEvent{ + Channel: "test", + }) + require.ErrorIs(t, err, centrifuge.ErrorExpired) + require.Empty(t, reply) + }) + + t.Run("unexpired token", func(t *testing.T) { + expiration := time.Now().Add(time.Hour) + token := createToken(t, &expiration) + ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token}) + reply, err := g.handleOnSubscribe(ctx, client, centrifuge.SubscribeEvent{ + Channel: "test", + }) + + // Another error is returned if the token is not expired but the refresh fails. + // That happens because we're providing an invalid orgID as the channel. + require.NotErrorIs(t, err, centrifuge.ErrorExpired) + require.Empty(t, reply) + }) +} + +func setupLiveService(cfg *setting.Cfg, t *testing.T) (*GrafanaLive, error) { + if cfg == nil { + cfg = setting.NewCfg() + } + + return ProvideService(nil, + cfg, + routing.NewRouteRegister(), + nil, nil, nil, nil, + db.InitTestDB(t), + nil, + &usagestats.UsageStatsMock{T: t}, + nil, + featuremgmt.WithFeatures(), + acimpl.ProvideAccessControl(featuremgmt.WithFeatures()), + &dashboards.FakeDashboardService{}, + annotationstest.NewFakeAnnotationsRepo(), + nil, nil) +} + +type dummyTransport struct { + name string +} + +func (t *dummyTransport) Name() string { return t.name } +func (t *dummyTransport) Protocol() centrifuge.ProtocolType { return centrifuge.ProtocolTypeJSON } +func (t *dummyTransport) ProtocolVersion() centrifuge.ProtocolVersion { + return centrifuge.ProtocolVersion2 +} +func (t *dummyTransport) Emulation() bool { return false } +func (t *dummyTransport) Unidirectional() bool { return false } +func (t *dummyTransport) DisabledPushFlags() uint64 { return 0 } +func (t *dummyTransport) PingPongConfig() centrifuge.PingPongConfig { + return centrifuge.PingPongConfig{} +} +func (t *dummyTransport) Write(data []byte) error { return nil } +func (t *dummyTransport) WriteMany(d ...[]byte) error { return nil } +func (t *dummyTransport) Close(disconnect centrifuge.Disconnect) error { + return nil +} + +func newDummyTransport(name string) *dummyTransport { + return &dummyTransport{name: name} +} + +func createToken(t *testing.T, exp *time.Time) string { + key := []byte("test-secret-key") + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, nil) + require.NoError(t, err) + + claims := struct { + jwt.Claims + }{ + Claims: jwt.Claims{ + Subject: "test-user", + }, + } + + if exp != nil { + claims.Expiry = jwt.NewNumericDate(*exp) + } + + token, err := jwt.Signed(signer).Claims(claims).CompactSerialize() + require.NoError(t, err) + return token +}