iam: Refresh live connection when ID tokens expire (#107209)

* iam: refresh live connection when ID tokens expire

* add coverage for the handler functions

* reinstate inadvertently broken unit test
pull/107472/head^2
Victor Cinaglia 2 weeks ago committed by GitHub
parent 8d8b824f73
commit 4f66c4a2a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      go.work.sum
  2. 2
      pkg/apimachinery/go.mod
  3. 30
      pkg/apimachinery/identity/requester.go
  4. 98
      pkg/apimachinery/identity/requester_test.go
  5. 30
      pkg/services/live/live.go
  6. 190
      pkg/services/live/live_test.go

@ -1363,6 +1363,7 @@ github.com/grafana/grafana/pkg/build v0.0.0-20250220114259-be81314e2118/go.mod h
github.com/grafana/grafana/pkg/build v0.0.0-20250227105625-8f465f124924/go.mod h1:Vw0LdoMma64VgIMVpRY3i0D156jddgUGjTQBOcyeF3k=
github.com/grafana/grafana/pkg/build v0.0.0-20250227163402-d78c646f93bb/go.mod h1:Vw0LdoMma64VgIMVpRY3i0D156jddgUGjTQBOcyeF3k=
github.com/grafana/grafana/pkg/build v0.0.0-20250403075254-4918d8720c61/go.mod h1:LGVnSwdrS0ZnJ2WXEl5acgDoYPm74EUSFavca1NKHI8=
github.com/grafana/grafana/pkg/build v0.0.0-20250625151647-35f89a456cc6/go.mod h1:dIu5dZy00k2TBdpVBXkvSbxHNj5H7lW/sOTpJTtKIXg=
github.com/grafana/grafana/pkg/semconv v0.0.0-20250121113133-e747350fee2d/go.mod h1:tfLnBpPYgwrBMRz4EXqPCZJyCjEG4Ev37FSlXnocJ2c=
github.com/grafana/grafana/pkg/semconv v0.0.0-20250627191313-2f1a6ae1712b/go.mod h1:mu3yl0GxB0eQZV1q7Kka0pkF3Th9x7W04WrjR9wqBlc=
github.com/grafana/grafana/pkg/storage/unified/apistore v0.0.0-20250121113133-e747350fee2d/go.mod h1:CXpwZ3Mkw6xVlGKc0SqUxqXCP3Uv182q6qAQnLaLxRg=

@ -3,6 +3,7 @@ module github.com/grafana/grafana/pkg/apimachinery
go 1.24.4
require (
github.com/go-jose/go-jose/v3 v3.0.4 // @grafana/identity-access-team
github.com/grafana/authlib v0.0.0-20250618124654-54543efcfeed // @grafana/identity-access-team
github.com/grafana/authlib/types v0.0.0-20250325095148-d6da9c164a7d // @grafana/identity-access-team
github.com/stretchr/testify v1.10.0
@ -15,7 +16,6 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
github.com/go-jose/go-jose/v3 v3.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect

@ -3,7 +3,9 @@ package identity
import (
"fmt"
"strconv"
"time"
"github.com/go-jose/go-jose/v3/jwt"
"k8s.io/apiserver/pkg/authentication/user"
claims "github.com/grafana/authlib/types"
@ -125,3 +127,31 @@ func intIdentifier(typ claims.IdentityType, id string, expected ...claims.Identi
return 0, ErrNotIntIdentifier
}
// IsIDTokenExpired returns true if the ID token is expired.
// If no ID token exists, returns false.
func IsIDTokenExpired(requester Requester) bool {
idToken := requester.GetIDToken()
if idToken == "" {
return false
}
parsed, err := jwt.ParseSigned(idToken)
if err != nil {
return false
}
var claims struct {
Expiry *jwt.NumericDate `json:"exp"`
}
if err := parsed.UnsafeClaimsWithoutVerification(&claims); err != nil {
return false
}
if claims.Expiry != nil {
expiryTime := claims.Expiry.Time()
return time.Now().After(expiryTime)
}
return false
}

@ -0,0 +1,98 @@
package identity_test
import (
"testing"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/apimachinery/identity"
)
func TestIsIDTokenExpired(t *testing.T) {
tests := []struct {
name string
token func(t *testing.T) string
expected bool
}{
{
name: "should return false when ID token is not set",
token: func(t *testing.T) string {
return ""
},
expected: false,
},
{
name: "should return false when ID token is not expired",
token: func(t *testing.T) string {
expiration := time.Now().Add(time.Hour)
return createToken(t, &expiration)
},
expected: false,
},
{
name: "should return true when ID token is expired",
token: func(t *testing.T) string {
expiration := time.Now().Add(-time.Hour)
return createToken(t, &expiration)
},
expected: true,
},
{
name: "should return false when ID token has no expiry claim",
token: func(t *testing.T) string {
return createToken(t, nil)
},
expected: false,
},
{
name: "should return false when ID token is malformed",
token: func(t *testing.T) string {
return "invalid.jwt.token"
},
expected: false,
},
{
name: "should handle token that expires exactly now",
token: func(t *testing.T) string {
expiration := time.Now().Add(-time.Millisecond)
return createToken(t, &expiration)
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token := tt.token(t)
requester := &identity.StaticRequester{IDToken: token}
result := identity.IsIDTokenExpired(requester)
require.Equal(t, tt.expected, result)
})
}
}
func createToken(t *testing.T, exp *time.Time) string {
key := []byte("test-secret-key")
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, nil)
require.NoError(t, err)
claims := struct {
jwt.Claims
}{
Claims: jwt.Claims{
Subject: "test-user",
},
}
if exp != nil {
claims.Expiry = jwt.NewNumericDate(*exp)
}
token, err := jwt.Signed(signer).Claims(claims).CompactSerialize()
require.NoError(t, err)
return token
}

@ -641,6 +641,20 @@ func runConcurrentlyIfNeeded(ctx context.Context, semaphore chan struct{}, fn fu
return nil
}
func (g *GrafanaLive) checkIDTokenExpirationAndRefresh(user identity.Requester, client *centrifuge.Client) bool {
if !identity.IsIDTokenExpired(user) {
return false
}
logger.Debug("ID token expired, triggering refresh", "user", client.UserID(), "client", client.ID())
err := g.node.Refresh(client.UserID(), centrifuge.WithRefreshExpired(true))
if err != nil {
logger.Error("Failed to refresh expired ID token", "user", client.UserID(), "client", client.ID(), "error", err)
}
return true
}
func (g *GrafanaLive) HandleDatasourceDelete(orgID int64, dsUID string) {
if g.runStreamManager == nil {
return
@ -676,6 +690,12 @@ func (g *GrafanaLive) handleOnRPC(clientContextWithSpan context.Context, client
logger.Error("No user found in context", "user", client.UserID(), "client", client.ID(), "method", e.Method)
return centrifuge.RPCReply{}, centrifuge.ErrorInternal
}
// Check if ID token is expired and trigger refresh if needed
if expired := g.checkIDTokenExpirationAndRefresh(user, client); expired {
return centrifuge.RPCReply{}, centrifuge.ErrorExpired
}
var req dtos.MetricRequest
err := json.Unmarshal(e.Data, &req)
if err != nil {
@ -712,6 +732,11 @@ func (g *GrafanaLive) handleOnSubscribe(clientContextWithSpan context.Context, c
return centrifuge.SubscribeReply{}, centrifuge.ErrorInternal
}
// Check if ID token is expired and trigger refresh if needed
if expired := g.checkIDTokenExpirationAndRefresh(user, client); expired {
return centrifuge.SubscribeReply{}, centrifuge.ErrorExpired
}
// See a detailed comment for StripOrgID about orgID management in Live.
orgID, channel, err := orgchannel.StripOrgID(e.Channel)
if err != nil {
@ -813,6 +838,11 @@ func (g *GrafanaLive) handleOnPublish(clientCtxWithSpan context.Context, client
return centrifuge.PublishReply{}, centrifuge.ErrorInternal
}
// Check if ID token is expired and trigger refresh if needed
if expired := g.checkIDTokenExpirationAndRefresh(user, client); expired {
return centrifuge.PublishReply{}, centrifuge.ErrorExpired
}
// See a detailed comment for StripOrgID about orgID management in Live.
orgID, channel, err := orgchannel.StripOrgID(e.Channel)
if err != nil {

@ -7,15 +7,20 @@ import (
"testing"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/stretchr/testify/require"
"github.com/centrifugal/centrifuge"
"github.com/grafana/grafana/pkg/api/routing"
"github.com/grafana/grafana/pkg/apimachinery/identity"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/usagestats"
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
"github.com/grafana/grafana/pkg/services/annotations/annotationstest"
"github.com/grafana/grafana/pkg/services/dashboards"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/live/livecontext"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tests/testsuite"
)
@ -29,20 +34,9 @@ func TestIntegration_provideLiveService_RedisUnavailable(t *testing.T) {
cfg.LiveHAEngine = "testredisunavailable"
_, err := ProvideService(nil, cfg,
routing.NewRouteRegister(),
nil, nil, nil, nil,
db.InitTestDB(t),
nil,
&usagestats.UsageStatsMock{T: t},
nil,
featuremgmt.WithFeatures(),
acimpl.ProvideAccessControl(featuremgmt.WithFeatures()),
&dashboards.FakeDashboardService{},
annotationstest.NewFakeAnnotationsRepo(),
nil, nil)
_, err := setupLiveService(cfg, t)
// Proceeds without live HA if redis is unavaialble
// Proceeds without live HA if redis is unavailable
require.NoError(t, err)
}
@ -233,3 +227,173 @@ func Test_getHistogramMetric(t *testing.T) {
})
}
}
func Test_handleOnPublish_IDTokenExpiration(t *testing.T) {
g, err := setupLiveService(nil, t)
require.NoError(t, err)
client, _, err := centrifuge.NewClient(context.Background(), g.node, newDummyTransport("test"))
require.NoError(t, err)
t.Run("expired token", func(t *testing.T) {
expiration := time.Now().Add(-time.Hour)
token := createToken(t, &expiration)
ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token})
reply, err := g.handleOnPublish(ctx, client, centrifuge.PublishEvent{
Channel: "test",
Data: []byte("test"),
})
require.ErrorIs(t, err, centrifuge.ErrorExpired)
require.Empty(t, reply)
})
t.Run("unexpired token", func(t *testing.T) {
expiration := time.Now().Add(time.Hour)
token := createToken(t, &expiration)
ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token})
reply, err := g.handleOnPublish(ctx, client, centrifuge.PublishEvent{
Channel: "test",
Data: []byte("test"),
})
// Another error is returned if the token is not expired but the refresh fails.
// That happens because we're providing an invalid orgID as the channel.
require.NotErrorIs(t, err, centrifuge.ErrorExpired)
require.Empty(t, reply)
})
}
func Test_handleOnRPC_IDTokenExpiration(t *testing.T) {
g, err := setupLiveService(nil, t)
require.NoError(t, err)
client, _, err := centrifuge.NewClient(context.Background(), g.node, newDummyTransport("test"))
require.NoError(t, err)
t.Run("expired token", func(t *testing.T) {
expiration := time.Now().Add(-time.Hour)
token := createToken(t, &expiration)
ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token})
reply, err := g.handleOnRPC(ctx, client, centrifuge.RPCEvent{
Method: "grafana.query",
Data: []byte("test"),
})
require.ErrorIs(t, err, centrifuge.ErrorExpired)
require.Empty(t, reply)
})
t.Run("unexpired token", func(t *testing.T) {
expiration := time.Now().Add(time.Hour)
token := createToken(t, &expiration)
ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token})
reply, err := g.handleOnRPC(ctx, client, centrifuge.RPCEvent{
Method: "grafana.query",
Data: []byte("test"),
})
// Another error is returned if the token is not expired but the refresh fails.
// That happens because we're providing an invalid orgID as the channel.
require.NotErrorIs(t, err, centrifuge.ErrorExpired)
require.Empty(t, reply)
})
}
func Test_handleOnSubscribe_IDTokenExpiration(t *testing.T) {
g, err := setupLiveService(nil, t)
require.NoError(t, err)
client, _, err := centrifuge.NewClient(context.Background(), g.node, newDummyTransport("test"))
require.NoError(t, err)
t.Run("expired token", func(t *testing.T) {
expiration := time.Now().Add(-time.Hour)
token := createToken(t, &expiration)
ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token})
reply, err := g.handleOnSubscribe(ctx, client, centrifuge.SubscribeEvent{
Channel: "test",
})
require.ErrorIs(t, err, centrifuge.ErrorExpired)
require.Empty(t, reply)
})
t.Run("unexpired token", func(t *testing.T) {
expiration := time.Now().Add(time.Hour)
token := createToken(t, &expiration)
ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token})
reply, err := g.handleOnSubscribe(ctx, client, centrifuge.SubscribeEvent{
Channel: "test",
})
// Another error is returned if the token is not expired but the refresh fails.
// That happens because we're providing an invalid orgID as the channel.
require.NotErrorIs(t, err, centrifuge.ErrorExpired)
require.Empty(t, reply)
})
}
func setupLiveService(cfg *setting.Cfg, t *testing.T) (*GrafanaLive, error) {
if cfg == nil {
cfg = setting.NewCfg()
}
return ProvideService(nil,
cfg,
routing.NewRouteRegister(),
nil, nil, nil, nil,
db.InitTestDB(t),
nil,
&usagestats.UsageStatsMock{T: t},
nil,
featuremgmt.WithFeatures(),
acimpl.ProvideAccessControl(featuremgmt.WithFeatures()),
&dashboards.FakeDashboardService{},
annotationstest.NewFakeAnnotationsRepo(),
nil, nil)
}
type dummyTransport struct {
name string
}
func (t *dummyTransport) Name() string { return t.name }
func (t *dummyTransport) Protocol() centrifuge.ProtocolType { return centrifuge.ProtocolTypeJSON }
func (t *dummyTransport) ProtocolVersion() centrifuge.ProtocolVersion {
return centrifuge.ProtocolVersion2
}
func (t *dummyTransport) Emulation() bool { return false }
func (t *dummyTransport) Unidirectional() bool { return false }
func (t *dummyTransport) DisabledPushFlags() uint64 { return 0 }
func (t *dummyTransport) PingPongConfig() centrifuge.PingPongConfig {
return centrifuge.PingPongConfig{}
}
func (t *dummyTransport) Write(data []byte) error { return nil }
func (t *dummyTransport) WriteMany(d ...[]byte) error { return nil }
func (t *dummyTransport) Close(disconnect centrifuge.Disconnect) error {
return nil
}
func newDummyTransport(name string) *dummyTransport {
return &dummyTransport{name: name}
}
func createToken(t *testing.T, exp *time.Time) string {
key := []byte("test-secret-key")
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, nil)
require.NoError(t, err)
claims := struct {
jwt.Claims
}{
Claims: jwt.Claims{
Subject: "test-user",
},
}
if exp != nil {
claims.Expiry = jwt.NewNumericDate(*exp)
}
token, err := jwt.Signed(signer).Claims(claims).CompactSerialize()
require.NoError(t, err)
return token
}

Loading…
Cancel
Save