The open and composable observability and data visualization platform. Visualize metrics, logs, and traces from multiple sources like Prometheus, Loki, Elasticsearch, InfluxDB, Postgres and many more.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
grafana/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go

142 lines
4.9 KiB

package sync
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/socialtest"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/auth/authtest"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
)
func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
type testCase struct {
desc string
identity *authn.Identity
oauthInfo *social.OAuthInfo
expectedHasEntryToken *login.UserAuth
expectHasEntryCalled bool
expectedTryRefreshErr error
expectTryRefreshTokenCalled bool
expectRevokeTokenCalled bool
expectInvalidateOauthTokensCalled bool
expectedErr error
}
tests := []testCase{
{
desc: "should skip sync when identity is not a user",
identity: &authn.Identity{ID: "service-account:1"},
expectTryRefreshTokenCalled: false,
},
{
desc: "should skip sync when identity is a user but is not authenticated with session token",
identity: &authn.Identity{ID: "user:1"},
expectTryRefreshTokenCalled: false,
},
{
desc: "should invalidate access token and session token if token refresh fails",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectedTryRefreshErr: errors.New("some err"),
expectTryRefreshTokenCalled: true,
expectInvalidateOauthTokensCalled: true,
expectRevokeTokenCalled: true,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
expectedErr: authn.ErrExpiredAccessToken,
},
{
desc: "should refresh the token successfully",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: false,
expectTryRefreshTokenCalled: true,
expectInvalidateOauthTokensCalled: false,
expectRevokeTokenCalled: false,
},
{
desc: "should not invalidate the token if the token has already been refreshed by another request (singleflight)",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectTryRefreshTokenCalled: true,
expectInvalidateOauthTokensCalled: false,
expectRevokeTokenCalled: false,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
expectedTryRefreshErr: errors.New("some err"),
},
// TODO: address coverage of oauthtoken sync
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
var (
hasEntryCalled bool
tryRefreshCalled bool
invalidateTokensCalled bool
revokeTokenCalled bool
)
service := &oauthtokentest.MockOauthTokenService{
HasOAuthEntryFunc: func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
hasEntryCalled = true
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil
},
InvalidateOAuthTokensFunc: func(ctx context.Context, usr *login.UserAuth) error {
invalidateTokensCalled = true
return nil
},
TryTokenRefreshFunc: func(ctx context.Context, usr identity.Requester) error {
tryRefreshCalled = true
return tt.expectedTryRefreshErr
},
}
sessionService := &authtest.FakeUserAuthTokenService{
RevokeTokenProvider: func(ctx context.Context, token *auth.UserToken, soft bool) error {
revokeTokenCalled = true
return nil
},
}
if tt.oauthInfo == nil {
tt.oauthInfo = &social.OAuthInfo{
UseRefreshToken: true,
}
}
socialService := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthInfo,
}
sync := &OAuthTokenSync{
log: log.NewNopLogger(),
service: service,
sessionService: sessionService,
socialService: socialService,
singleflightGroup: new(singleflight.Group),
}
err := sync.SyncOauthTokenHook(context.Background(), tt.identity, nil)
assert.ErrorIs(t, err, tt.expectedErr)
assert.Equal(t, tt.expectHasEntryCalled, hasEntryCalled)
assert.Equal(t, tt.expectTryRefreshTokenCalled, tryRefreshCalled)
assert.Equal(t, tt.expectInvalidateOauthTokensCalled, invalidateTokensCalled)
assert.Equal(t, tt.expectRevokeTokenCalled, revokeTokenCalled)
})
}
}