Authn: Handle logout logic in auth broker (#79635)

* AuthN: Add new client extension interface that allows for custom logout logic

* AuthN: Add tests for oauth client logout

* Call authn.Logout

Co-authored-by: Gabriel MABILLE <gamab@users.noreply.github.com>
pull/79678/head
Karl Persson 1 year ago committed by GitHub
parent eb490193b9
commit 8cb351e54a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 117
      pkg/api/login.go
  2. 11
      pkg/services/authn/authn.go
  3. 75
      pkg/services/authn/authnimpl/service.go
  4. 93
      pkg/services/authn/authnimpl/service_test.go
  5. 6
      pkg/services/authn/authntest/fake.go
  6. 16
      pkg/services/authn/authntest/mock.go
  7. 93
      pkg/services/authn/clients/oauth.go
  8. 109
      pkg/services/authn/clients/oauth_test.go

@ -29,8 +29,6 @@ import (
const (
viewIndex = "index"
loginErrorCookieName = "login_error"
// #nosec G101 - this is not a hardcoded secret
postLogoutRedirectParam = "post_logout_redirect_uri"
)
var setIndexViewData = (*HTTPServer).setIndexViewData
@ -243,70 +241,31 @@ func (hs *HTTPServer) loginUserWithUser(user *user.User, c *contextmodel.ReqCont
}
func (hs *HTTPServer) Logout(c *contextmodel.ReqContext) {
userID, errID := identity.UserIdentifier(c.SignedInUser.GetNamespacedID())
if errID != nil {
hs.log.Error("failed to retrieve user ID", "error", errID)
}
oauthProviderSignoutRedirectUrl := ""
getAuthQuery := loginservice.GetAuthInfoQuery{UserId: userID}
authInfo, err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery)
if err == nil {
// If SAML is enabled and this is a SAML user use saml logout
if hs.samlSingleLogoutEnabled() {
if authInfo.AuthModule == loginservice.SAMLAuthModule {
c.Redirect(hs.Cfg.AppSubURL + "/logout/saml")
return
}
}
oauthProvider := hs.SocialService.GetOAuthInfoProvider(strings.TrimPrefix(authInfo.AuthModule, "oauth_"))
if oauthProvider != nil {
oauthProviderSignoutRedirectUrl = oauthProvider.SignoutRedirectUrl
}
}
hs.log.Debug("Logout Redirect url", "auth.SignoutRedirectUrl:", hs.Cfg.SignoutRedirectUrl)
hs.log.Debug("Logout Redirect url", "oauth provider redirect url:", oauthProviderSignoutRedirectUrl)
signOutRedirectUrl := getSignOutRedirectUrl(hs.Cfg.SignoutRedirectUrl, oauthProviderSignoutRedirectUrl)
hs.log.Debug("Logout Redirect url", "signOurRedirectUrl:", signOutRedirectUrl)
idTokenHint := ""
oidcLogout := isPostLogoutRedirectConfigured(signOutRedirectUrl)
// Invalidate the OAuth tokens in case the User logged in with OAuth or the last external AuthEntry is an OAuth one
if entry, exists, _ := hs.oauthTokenService.HasOAuthEntry(c.Req.Context(), c.SignedInUser); exists {
token := hs.oauthTokenService.GetCurrentOAuthToken(c.Req.Context(), c.SignedInUser)
if oidcLogout {
if token.Valid() {
idTokenHint = token.Extra("id_token").(string)
} else {
hs.log.Warn("Token is not valid")
}
// FIXME: restructure saml client to implement authn.LogoutClient
if hs.samlSingleLogoutEnabled() {
id, err := identity.UserIdentifier(c.SignedInUser.GetNamespacedID())
if err != nil {
hs.log.Error("failed to retrieve user ID", "error", err)
}
if err := hs.oauthTokenService.InvalidateOAuthTokens(c.Req.Context(), entry); err != nil {
hs.log.Warn("failed to invalidate oauth tokens for user", "userId", userID, "error", err)
authInfo, _ := hs.authInfoService.GetAuthInfo(c.Req.Context(), &loginservice.GetAuthInfoQuery{UserId: id})
if authInfo != nil && authInfo.AuthModule == loginservice.SAMLAuthModule {
c.Redirect(hs.Cfg.AppSubURL + "/logout/saml")
return
}
}
err = hs.AuthTokenService.RevokeToken(c.Req.Context(), c.UserToken, false)
if err != nil && !errors.Is(err, auth.ErrUserTokenNotFound) {
hs.log.Error("failed to revoke auth token", "error", err)
}
redirect, err := hs.authnService.Logout(c.Req.Context(), c.SignedInUser, c.UserToken)
authn.DeleteSessionCookie(c.Resp, hs.Cfg)
rdUrl := signOutRedirectUrl
if rdUrl != "" {
if oidcLogout {
rdUrl = getPostRedirectUrl(signOutRedirectUrl, idTokenHint)
}
c.Redirect(rdUrl)
} else {
hs.log.Info("Successful Logout", "User", c.SignedInUser.GetEmail())
if err != nil {
hs.log.Error("Failed perform proper logout", "error", err)
c.Redirect(hs.Cfg.AppSubURL + "/login")
}
_, id := c.SignedInUser.GetNamespacedID()
hs.log.Info("Successful Logout", "userID", id)
c.Redirect(redirect.URL)
}
func (hs *HTTPServer) tryGetEncryptedCookie(ctx *contextmodel.ReqContext, cookieName string) (string, bool) {
@ -420,47 +379,3 @@ func getFirstPublicErrorMessage(err *errutil.Error) string {
return errPublic.Message
}
func isPostLogoutRedirectConfigured(redirectUrl string) bool {
if redirectUrl == "" {
return false
}
u, err := url.Parse(redirectUrl)
if err != nil {
return false
}
q := u.Query()
_, ok := q[postLogoutRedirectParam]
return ok
}
func getPostRedirectUrl(rdUrl string, tokenHint string) string {
if tokenHint == "" {
return rdUrl
}
if rdUrl == "" {
return rdUrl
}
u, err := url.Parse(rdUrl)
if err != nil {
return rdUrl
}
q := u.Query()
q.Set("id_token_hint", tokenHint)
u.RawQuery = q.Encode()
return u.String()
}
func getSignOutRedirectUrl(gRdUrl string, oauthProviderUrl string) string {
if oauthProviderUrl != "" {
return oauthProviderUrl
} else if gRdUrl != "" {
return gRdUrl
}
return ""
}

@ -11,6 +11,7 @@ import (
"github.com/grafana/grafana/pkg/api/response"
"github.com/grafana/grafana/pkg/middleware/cookies"
"github.com/grafana/grafana/pkg/models/usertoken"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web"
@ -74,6 +75,8 @@ type Service interface {
RegisterPostLoginHook(hook PostLoginHookFn, priority uint)
// RedirectURL will generate url that we can use to initiate auth flow for supported clients.
RedirectURL(ctx context.Context, client string, r *Request) (*Redirect, error)
// Logout revokes session token and does additional clean up if client used to authenticate supports it
Logout(ctx context.Context, user identity.Requester, sessionToken *usertoken.UserToken) (*Redirect, error)
// RegisterClient will register a new authn.Client that can be used for authentication
RegisterClient(c Client)
}
@ -115,6 +118,14 @@ type RedirectClient interface {
RedirectURL(ctx context.Context, r *Request) (*Redirect, error)
}
// LogoutCLient is an optional interface that auth client can implement.
// Clients that implements this interface can implement additional logic
// that should happen during logout and supports client specific redirect URL.
type LogoutClient interface {
Client
Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*Redirect, bool)
}
type PasswordClient interface {
AuthenticatePassword(ctx context.Context, r *Request, username, password string) (*Identity, error)
}

@ -5,6 +5,7 @@ import (
"errors"
"net/http"
"strconv"
"strings"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/attribute"
@ -19,6 +20,7 @@ import (
"github.com/grafana/grafana/pkg/services/accesscontrol"
"github.com/grafana/grafana/pkg/services/apikey"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/authn/authnimpl/sync"
"github.com/grafana/grafana/pkg/services/authn/clients"
@ -73,15 +75,16 @@ func ProvideService(
signingKeysService signingkeys.Service, oauthServer oauthserver.OAuth2Server,
) *Service {
s := &Service{
log: log.New("authn.service"),
cfg: cfg,
clients: make(map[string]authn.Client),
clientQueue: newQueue[authn.ContextAwareClient](),
tracer: tracer,
metrics: newMetrics(registerer),
sessionService: sessionService,
postAuthHooks: newQueue[authn.PostAuthHookFn](),
postLoginHooks: newQueue[authn.PostLoginHookFn](),
log: log.New("authn.service"),
cfg: cfg,
clients: make(map[string]authn.Client),
clientQueue: newQueue[authn.ContextAwareClient](),
tracer: tracer,
metrics: newMetrics(registerer),
authInfoService: authInfoService,
sessionService: sessionService,
postAuthHooks: newQueue[authn.PostAuthHookFn](),
postLoginHooks: newQueue[authn.PostLoginHookFn](),
}
usageStats.RegisterMetricsFunc(s.getUsageStats)
@ -146,7 +149,7 @@ func ProvideService(
if errConnector != nil || errHTTPClient != nil {
s.log.Error("Failed to configure oauth client", "client", clientName, "err", errors.Join(errConnector, errHTTPClient))
} else {
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient))
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient, oauthTokenService))
}
}
}
@ -175,7 +178,8 @@ type Service struct {
tracer tracing.Tracer
metrics *metrics
sessionService auth.UserTokenService
authInfoService login.AuthInfoService
sessionService auth.UserTokenService
// postAuthHooks are called after a successful authentication. They can modify the identity.
postAuthHooks *queue[authn.PostAuthHookFn]
@ -335,6 +339,55 @@ func (s *Service) RedirectURL(ctx context.Context, client string, r *authn.Reque
return redirectClient.RedirectURL(ctx, r)
}
func (s *Service) Logout(ctx context.Context, user identity.Requester, sessionToken *auth.UserToken) (*authn.Redirect, error) {
ctx, span := s.tracer.Start(ctx, "authn.Logout")
defer span.End()
redirect := &authn.Redirect{URL: s.cfg.AppSubURL + "/login"}
namespace, id := user.GetNamespacedID()
if namespace != authn.NamespaceUser {
return redirect, nil
}
userID, err := identity.IntIdentifier(namespace, id)
if err != nil {
s.log.FromContext(ctx).Debug("Invalid user id", "id", userID, "err", err)
return redirect, nil
}
info, _ := s.authInfoService.GetAuthInfo(ctx, &login.GetAuthInfoQuery{UserId: userID})
if info != nil {
client := authn.ClientWithPrefix(strings.TrimPrefix(info.AuthModule, "oauth_"))
c, ok := s.clients[client]
if !ok {
s.log.FromContext(ctx).Debug("No client configured for auth module", "client", client)
goto Default
}
logoutClient, ok := c.(authn.LogoutClient)
if !ok {
s.log.FromContext(ctx).Debug("Client do not support specialized logout logic", "client", client)
goto Default
}
clientRedirect, ok := logoutClient.Logout(ctx, user, info)
if !ok {
goto Default
}
redirect = clientRedirect
}
Default:
if err = s.sessionService.RevokeToken(ctx, sessionToken, false); err != nil {
return nil, err
}
return redirect, nil
}
func (s *Service) RegisterClient(c authn.Client) {
s.clients[c.Name()] = c
if cac, ok := c.(authn.ContextAwareClient); ok {

@ -13,10 +13,14 @@ import (
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/models/usertoken"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/auth/authtest"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/authn/authntest"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfotest"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
)
@ -299,6 +303,95 @@ func TestService_RedirectURL(t *testing.T) {
}
}
func TestService_Logout(t *testing.T) {
type TestCase struct {
desc string
identity *authn.Identity
sessionToken *usertoken.UserToken
info *login.UserAuth
client authn.Client
expectedErr error
expectedTokenRevoked bool
expectedRedirect *authn.Redirect
}
tests := []TestCase{
{
desc: "should redirect to default redirect url when identity is not a user",
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceServiceAccount, 1)},
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
},
{
desc: "should redirect to default redirect url when no external provider was used to authenticate",
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
expectedTokenRevoked: true,
},
{
desc: "should redirect to default redirect url when client is not found",
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
info: &login.UserAuth{AuthModule: "notFound"},
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
expectedTokenRevoked: true,
},
{
desc: "should redirect to default redirect url when client do not implement logout extension",
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
info: &login.UserAuth{AuthModule: "azuread"},
expectedRedirect: &authn.Redirect{URL: "http://localhost:3000/login"},
client: &authntest.FakeClient{ExpectedName: "auth.client.azuread"},
expectedTokenRevoked: true,
},
{
desc: "should redirect to client specific url",
identity: &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1)},
info: &login.UserAuth{AuthModule: "azuread"},
expectedRedirect: &authn.Redirect{URL: "http://idp.com/logout"},
client: &authntest.MockClient{
NameFunc: func() string { return "auth.client.azuread" },
LogoutFunc: func(ctx context.Context, _ identity.Requester, _ *login.UserAuth) (*authn.Redirect, bool) {
return &authn.Redirect{URL: "http://idp.com/logout"}, true
},
},
expectedTokenRevoked: true,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
var tokenRevoked bool
s := setupTests(t, func(svc *Service) {
if tt.client != nil {
svc.RegisterClient(tt.client)
}
svc.cfg.AppSubURL = "http://localhost:3000"
svc.authInfoService = &authinfotest.FakeService{
ExpectedUserAuth: tt.info,
}
svc.sessionService = &authtest.FakeUserAuthTokenService{
RevokeTokenProvider: func(_ context.Context, sessionToken *auth.UserToken, soft bool) error {
tokenRevoked = true
assert.EqualValues(t, tt.sessionToken, sessionToken)
assert.False(t, soft)
return nil
},
}
})
redirect, err := s.Logout(context.Background(), tt.identity, tt.sessionToken)
assert.ErrorIs(t, err, tt.expectedErr)
assert.EqualValues(t, tt.expectedRedirect, redirect)
assert.Equal(t, tt.expectedTokenRevoked, tokenRevoked)
})
}
}
func mustParseURL(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {

@ -3,6 +3,8 @@ package authntest
import (
"context"
"github.com/grafana/grafana/pkg/models/usertoken"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
)
@ -66,6 +68,10 @@ func (f *FakeService) RedirectURL(ctx context.Context, client string, r *authn.R
return f.ExpectedRedirect, f.ExpectedErr
}
func (*FakeService) Logout(_ context.Context, _ identity.Requester, _ *usertoken.UserToken) (*authn.Redirect, error) {
panic("unimplemented")
}
func (f *FakeService) RegisterClient(c authn.Client) {}
func (f *FakeService) SyncIdentity(ctx context.Context, identity *authn.Identity) error {

@ -3,7 +3,10 @@ package authntest
import (
"context"
"github.com/grafana/grafana/pkg/models/usertoken"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
)
var _ authn.Service = new(MockService)
@ -40,6 +43,10 @@ func (m *MockService) RegisterPostLoginHook(hook authn.PostLoginHookFn, priority
panic("unimplemented")
}
func (*MockService) Logout(_ context.Context, _ identity.Requester, _ *usertoken.UserToken) (*authn.Redirect, error) {
panic("unimplemented")
}
func (m *MockService) SyncIdentity(ctx context.Context, identity *authn.Identity) error {
if m.SyncIdentityFunc != nil {
return m.SyncIdentityFunc(ctx, identity)
@ -48,6 +55,7 @@ func (m *MockService) SyncIdentity(ctx context.Context, identity *authn.Identity
}
var _ authn.HookClient = new(MockClient)
var _ authn.LogoutClient = new(MockClient)
var _ authn.ContextAwareClient = new(MockClient)
type MockClient struct {
@ -56,6 +64,7 @@ type MockClient struct {
TestFunc func(ctx context.Context, r *authn.Request) bool
PriorityFunc func() uint
HookFunc func(ctx context.Context, identity *authn.Identity, r *authn.Request) error
LogoutFunc func(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool)
}
func (m MockClient) Name() string {
@ -93,6 +102,13 @@ func (m MockClient) Hook(ctx context.Context, identity *authn.Identity, r *authn
return nil
}
func (m *MockClient) Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool) {
if m.LogoutFunc != nil {
return m.LogoutFunc(ctx, user, info)
}
return nil, false
}
var _ authn.ProxyClient = new(MockProxyClient)
type MockProxyClient struct {

@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"golang.org/x/oauth2"
@ -16,8 +17,10 @@ import (
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/connectors"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util/errutil"
@ -30,9 +33,10 @@ const (
codeChallengeMethodParamName = "code_challenge_method"
codeChallengeMethod = "S256"
oauthStateQueryName = "state"
oauthStateCookieName = "oauth_state"
oauthPKCECookieName = "oauth_code_verifier"
oauthStateQueryName = "state"
oauthStateCookieName = "oauth_state"
oauthPKCECookieName = "oauth_code_verifier"
oauthPostLogoutRedirectParam = "post_logout_redirect_uri"
)
var (
@ -54,26 +58,28 @@ func fromSocialErr(err *connectors.SocialError) error {
return errutil.Unauthorized("auth.oauth.userinfo.failed", errutil.WithPublicMessage(err.Error())).Errorf("%w", err)
}
var _ authn.LogoutClient = new(OAuth)
var _ authn.RedirectClient = new(OAuth)
func ProvideOAuth(
name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo,
connector social.SocialConnector, httpClient *http.Client,
connector social.SocialConnector, httpClient *http.Client, oauthService oauthtoken.OAuthTokenService,
) *OAuth {
return &OAuth{
name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")),
log.New(name), cfg, oauthCfg, connector, httpClient,
log.New(name), cfg, oauthCfg, connector, httpClient, oauthService,
}
}
type OAuth struct {
name string
moduleName string
log log.Logger
cfg *setting.Cfg
oauthCfg *social.OAuthInfo
connector social.SocialConnector
httpClient *http.Client
name string
moduleName string
log log.Logger
cfg *setting.Cfg
oauthCfg *social.OAuthInfo
connector social.SocialConnector
httpClient *http.Client
oauthService oauthtoken.OAuthTokenService
}
func (c *OAuth) Name() string {
@ -204,6 +210,29 @@ func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redir
}, nil
}
func (c *OAuth) Logout(ctx context.Context, user identity.Requester, info *login.UserAuth) (*authn.Redirect, bool) {
token := c.oauthService.GetCurrentOAuthToken(ctx, user)
if err := c.oauthService.InvalidateOAuthTokens(ctx, info); err != nil {
namespace, id := user.GetNamespacedID()
c.log.FromContext(ctx).Error("Failed to invalidate tokens", "namespace", namespace, "id", id, "error", err)
}
redirctURL := getOAuthSignoutRedirectURL(c.cfg, c.oauthCfg)
if redirctURL == "" {
c.log.FromContext(ctx).Debug("No signout redirect url configured")
return nil, false
}
if isOICDLogout(redirctURL) && token != nil && token.Valid() {
if idToken, ok := token.Extra("id_token").(string); ok {
redirctURL = withIDTokenHint(redirctURL, idToken)
}
}
return &authn.Redirect{URL: redirctURL}, true
}
// genPKCECode returns a random URL-friendly string and it's base64 URL encoded SHA256 digest.
func genPKCECode() (string, string, error) {
// IETF RFC 7636 specifies that the code verifier should be 43-128
@ -243,3 +272,43 @@ func hashOAuthState(state, secret, seed string) string {
hashBytes := sha256.Sum256([]byte(state + secret + seed))
return hex.EncodeToString(hashBytes[:])
}
func getOAuthSignoutRedirectURL(cfg *setting.Cfg, oauthCfg *social.OAuthInfo) string {
if oauthCfg.SignoutRedirectUrl != "" {
return oauthCfg.SignoutRedirectUrl
}
return cfg.SignoutRedirectUrl
}
func withIDTokenHint(redirectURL string, idToken string) string {
if idToken == "" {
return redirectURL
}
u, err := url.Parse(redirectURL)
if err != nil {
return redirectURL
}
q := u.Query()
q.Set("id_token_hint", idToken)
u.RawQuery = q.Encode()
return u.String()
}
func isOICDLogout(redirectUrl string) bool {
if redirectUrl == "" {
return false
}
u, err := url.Parse(redirectUrl)
if err != nil {
return false
}
q := u.Query()
_, ok := q[oauthPostLogoutRedirectParam]
return ok
}

@ -4,7 +4,9 @@ import (
"context"
"net/http"
"net/url"
"strings"
"testing"
"time"
"golang.org/x/oauth2"
@ -12,8 +14,10 @@ import (
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/setting"
)
@ -212,7 +216,7 @@ func TestOAuth_Authenticate(t *testing.T) {
ExpectedToken: &oauth2.Token{},
ExpectedIsSignupAllowed: true,
ExpectedIsEmailAllowed: tt.isEmailAllowed,
}, nil)
}, nil, nil)
identity, err := c.Authenticate(context.Background(), tt.req)
assert.ErrorIs(t, err, tt.expectedErr)
@ -281,7 +285,7 @@ func TestOAuth_RedirectURL(t *testing.T) {
require.Len(t, opts, tt.numCallOptions)
return ""
},
}, nil)
}, nil, nil)
redirect, err := c.RedirectURL(context.Background(), nil)
assert.ErrorIs(t, err, tt.expectedErr)
@ -299,6 +303,107 @@ func TestOAuth_RedirectURL(t *testing.T) {
}
}
func TestOAuth_Logout(t *testing.T) {
type testCase struct {
desc string
cfg *setting.Cfg
oauthCfg *social.OAuthInfo
expectedOK bool
expectedURL string
expectedIDTokenHint string
expectedPostLogoutURI string
}
tests := []testCase{
{
desc: "should not return redirect url if not configured for client or globably",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{},
},
{
desc: "should return redirect url for globably configured redirect url",
cfg: &setting.Cfg{
SignoutRedirectUrl: "http://idp.com/logout",
},
oauthCfg: &social.OAuthInfo{},
expectedURL: "http://idp.com/logout",
expectedOK: true,
},
{
desc: "should return redirect url for client configured redirect url",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{
SignoutRedirectUrl: "http://idp.com/logout",
},
expectedURL: "http://idp.com/logout",
expectedOK: true,
},
{
desc: "client specific url should take precedence",
cfg: &setting.Cfg{
SignoutRedirectUrl: "http://idp.com/logout",
},
oauthCfg: &social.OAuthInfo{
SignoutRedirectUrl: "http://idp-2.com/logout",
},
expectedURL: "http://idp-2.com/logout",
expectedOK: true,
},
{
desc: "should add id token hint if oicd logout is configured and token is valid",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{
SignoutRedirectUrl: "http://idp.com/logout?post_logout_redirect_uri=http%3A%3A%2F%2Ftest.com%2Flogin",
},
expectedURL: "http://idp.com/logout",
expectedIDTokenHint: "id_token_hint=some.id.token",
expectedPostLogoutURI: "http%3A%3A%2F%2Ftest.com%2Flogin",
expectedOK: true,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
var (
getTokenCalled bool
invalidateTokenCalled bool
)
mockService := &oauthtokentest.MockOauthTokenService{
GetCurrentOauthTokenFunc: func(_ context.Context, _ identity.Requester) *oauth2.Token {
getTokenCalled = true
token := &oauth2.Token{
AccessToken: "some.access.token",
Expiry: time.Now().Add(10 * time.Minute),
}
return token.WithExtra(map[string]any{
"id_token": "some.id.token",
})
},
InvalidateOAuthTokensFunc: func(_ context.Context, _ *login.UserAuth) error {
invalidateTokenCalled = true
return nil
},
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, tt.oauthCfg, mockConnector{}, nil, mockService)
redirect, ok := c.Logout(context.Background(), &authn.Identity{}, &login.UserAuth{})
assert.Equal(t, tt.expectedOK, ok)
if tt.expectedOK {
assert.True(t, strings.HasPrefix(redirect.URL, tt.expectedURL))
assert.Contains(t, redirect.URL, tt.expectedIDTokenHint)
assert.Contains(t, redirect.URL, tt.expectedPostLogoutURI)
}
assert.True(t, getTokenCalled)
assert.True(t, invalidateTokenCalled)
})
}
}
type mockConnector struct {
AuthCodeURLFunc func(state string, opts ...oauth2.AuthCodeOption) string
social.SocialConnector

Loading…
Cancel
Save