The open and composable observability and data visualization platform. Visualize metrics, logs, and traces from multiple sources like Prometheus, Loki, Elasticsearch, InfluxDB, Postgres and many more.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
grafana/pkg/services/authn/clients/oauth_test.go

602 lines
18 KiB

package clients
import (
"context"
"net/http"
"net/url"
"strconv"
"strings"
"testing"
"time"
"golang.org/x/oauth2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/socialtest"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/setting"
)
func TestOAuth_Authenticate(t *testing.T) {
type testCase struct {
desc string
req *authn.Request
oauthCfg *social.OAuthInfo
allowInsecureTakeover bool
addStateCookie bool
stateCookieValue string
addPKCECookie bool
pkceCookieValue string
features []any
isEmailAllowed bool
userInfo *social.BasicUserInfo
expectedErr error
expectedIdentity *authn.Identity
}
tests := []testCase{
{
desc: "should return error when missing state cookie",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{Enabled: true},
expectedErr: errOAuthMissingState,
},
{
desc: "should return error when state cookie is present but don't have a value",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{Enabled: true},
addStateCookie: true,
stateCookieValue: "",
expectedErr: errOAuthMissingState,
},
{
desc: "should return error when the client is not enabled",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{Enabled: false},
expectedErr: errOAuthClientDisabled,
},
{
desc: "should return error when state from ipd does not match stored state",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-other-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
expectedErr: errOAuthInvalidState,
},
{
desc: "should return error when pkce is configured but the cookie is not present",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
expectedErr: errOAuthMissingPKCE,
},
{
desc: "should return error when email is empty",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
pkceCookieValue: "some-pkce-value",
userInfo: &social.BasicUserInfo{},
expectedErr: errOAuthMissingRequiredEmail,
},
{
desc: "should return error when email is not allowed",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
pkceCookieValue: "some-pkce-value",
userInfo: &social.BasicUserInfo{Email: "some@email.com"},
isEmailAllowed: false,
expectedErr: errOAuthEmailNotAllowed,
},
{
desc: "should return error when no auth id is set and feature toggle is enabled",
req: &authn.Request{
HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
features: []any{featuremgmt.FlagOauthRequireSubClaim},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
pkceCookieValue: "some-pkce-value",
userInfo: &social.BasicUserInfo{Email: "some@email.com"},
isEmailAllowed: false,
expectedErr: errOAuthUserInfo,
},
{
desc: "should return identity for valid request",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
pkceCookieValue: "some-pkce-value",
isEmailAllowed: true,
userInfo: &social.BasicUserInfo{
Id: "123",
Name: "name",
Email: "some@email.com",
Role: "Admin",
Groups: []string{"grp1", "grp2"},
},
expectedIdentity: &authn.Identity{
Email: "some@email.com",
AuthenticatedBy: login.AzureADAuthModule,
AuthID: "123",
Name: "name",
Groups: []string{"grp1", "grp2"},
OAuthToken: &oauth2.Token{},
OrgRoles: map[int64]org.RoleType{1: org.RoleAdmin},
ClientParams: authn.ClientParams{
SyncUser: true,
SyncTeams: true,
AllowSignUp: true,
FetchSyncedUser: true,
SyncOrgRoles: true,
LookUpParams: login.UserLookupParams{},
},
},
},
{
desc: "should return identity for valid request - and lookup user by email",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
allowInsecureTakeover: true,
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
pkceCookieValue: "some-pkce-value",
isEmailAllowed: true,
userInfo: &social.BasicUserInfo{
Id: "123",
Name: "name",
Email: "some@email.com",
Role: "Admin",
Groups: []string{"grp1", "grp2"},
},
expectedIdentity: &authn.Identity{
Email: "some@email.com",
AuthenticatedBy: login.AzureADAuthModule,
AuthID: "123",
Name: "name",
Groups: []string{"grp1", "grp2"},
OAuthToken: &oauth2.Token{},
OrgRoles: map[int64]org.RoleType{1: org.RoleAdmin},
ClientParams: authn.ClientParams{
SyncUser: true,
SyncTeams: true,
AllowSignUp: true,
FetchSyncedUser: true,
SyncOrgRoles: true,
LookUpParams: login.UserLookupParams{Email: strPtr("some@email.com")},
},
},
},
{
desc: "should return identity when feature toggle is enabled and auth id is set",
req: &authn.Request{
HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{Enabled: true},
addStateCookie: true,
stateCookieValue: "some-state",
isEmailAllowed: true,
features: []any{featuremgmt.FlagOauthRequireSubClaim},
userInfo: &social.BasicUserInfo{
Id: "123",
Name: "name",
Email: "some@email.com",
Role: "Admin",
},
expectedIdentity: &authn.Identity{
Email: "some@email.com",
AuthenticatedBy: login.AzureADAuthModule,
AuthID: "123",
Name: "name",
OAuthToken: &oauth2.Token{},
OrgRoles: map[int64]org.RoleType{1: org.RoleAdmin},
ClientParams: authn.ClientParams{
SyncUser: true,
SyncTeams: true,
AllowSignUp: true,
FetchSyncedUser: true,
SyncOrgRoles: true,
LookUpParams: login.UserLookupParams{},
},
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
cfg := setting.NewCfg()
auth, err := cfg.Raw.NewSection("auth")
assert.NoError(t, err)
_, err = auth.NewKey("oauth_allow_insecure_email_lookup", strconv.FormatBool(tt.allowInsecureTakeover))
assert.NoError(t, err)
settingsProvider := &setting.OSSImpl{Cfg: cfg}
if tt.addStateCookie {
v := tt.stateCookieValue
if v != "" {
v = hashOAuthState(v, cfg.SecretKey, tt.oauthCfg.ClientSecret)
}
tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthStateCookieName, Value: v})
}
if tt.addPKCECookie {
tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthPKCECookieName, Value: tt.pkceCookieValue})
}
fakeSocialSvc := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthCfg,
ExpectedConnector: fakeConnector{
ExpectedUserInfo: tt.userInfo,
ExpectedToken: &oauth2.Token{},
ExpectedIsSignupAllowed: true,
ExpectedIsEmailAllowed: tt.isEmailAllowed,
},
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, nil, fakeSocialSvc, settingsProvider, featuremgmt.WithFeatures(tt.features...))
identity, err := c.Authenticate(context.Background(), tt.req)
assert.ErrorIs(t, err, tt.expectedErr)
if tt.expectedIdentity != nil {
assert.Equal(t, tt.expectedIdentity.Login, identity.Login)
assert.Equal(t, tt.expectedIdentity.Name, identity.Name)
assert.Equal(t, tt.expectedIdentity.Email, identity.Email)
assert.Equal(t, tt.expectedIdentity.AuthID, identity.AuthID)
assert.Equal(t, tt.expectedIdentity.AuthenticatedBy, identity.AuthenticatedBy)
assert.Equal(t, tt.expectedIdentity.Groups, identity.Groups)
assert.Equal(t, tt.expectedIdentity.ClientParams.SyncUser, identity.ClientParams.SyncUser)
assert.Equal(t, tt.expectedIdentity.ClientParams.AllowSignUp, identity.ClientParams.AllowSignUp)
assert.Equal(t, tt.expectedIdentity.ClientParams.SyncTeams, identity.ClientParams.SyncTeams)
assert.Equal(t, tt.expectedIdentity.ClientParams.EnableUser, identity.ClientParams.EnableUser)
assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.Email, identity.ClientParams.LookUpParams.Email)
assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.Login, identity.ClientParams.LookUpParams.Login)
} else {
assert.Nil(t, tt.expectedIdentity)
}
})
}
}
func TestOAuth_RedirectURL(t *testing.T) {
type testCase struct {
desc string
oauthCfg *social.OAuthInfo
expectedErr error
numCallOptions int
authCodeUrlCalled bool
}
tests := []testCase{
{
desc: "should generate redirect url and state",
oauthCfg: &social.OAuthInfo{Enabled: true},
authCodeUrlCalled: true,
},
{
desc: "should generate redirect url with hosted domain option if configured",
oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com", Enabled: true},
numCallOptions: 1,
authCodeUrlCalled: true,
},
{
desc: "should generate redirect url with pkce if configured",
oauthCfg: &social.OAuthInfo{UsePKCE: true, Enabled: true},
numCallOptions: 1,
authCodeUrlCalled: true,
},
{
desc: "should return error if the client is not enabled",
oauthCfg: &social.OAuthInfo{Enabled: false},
authCodeUrlCalled: false,
expectedErr: errOAuthClientDisabled,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
var (
authCodeUrlCalled = false
)
fakeSocialSvc := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthCfg,
ExpectedConnector: mockConnector{
AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string {
authCodeUrlCalled = true
require.Len(t, opts, tt.numCallOptions)
return ""
},
},
}
cfg := setting.NewCfg()
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, nil, fakeSocialSvc, &setting.OSSImpl{Cfg: cfg}, featuremgmt.WithFeatures())
redirect, err := c.RedirectURL(context.Background(), nil)
assert.ErrorIs(t, err, tt.expectedErr)
assert.Equal(t, tt.authCodeUrlCalled, authCodeUrlCalled)
if tt.expectedErr != nil {
return
}
assert.NotEmpty(t, redirect.Extra[authn.KeyOAuthState])
if tt.oauthCfg.UsePKCE {
assert.NotEmpty(t, redirect.Extra[authn.KeyOAuthPKCE])
}
})
}
}
func TestOAuth_Logout(t *testing.T) {
type testCase struct {
desc string
cfg *setting.Cfg
oauthCfg *social.OAuthInfo
expectedOK bool
expectedURL string
expectedIDTokenHint string
expectedPostLogoutURI string
}
tests := []testCase{
{
desc: "should not return redirect url if not configured for client or globably",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{},
},
{
desc: "should not return redirect url when client is not enabled",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{Enabled: false},
},
{
desc: "should return redirect url for globably configured redirect url",
cfg: &setting.Cfg{
SignoutRedirectUrl: "http://idp.com/logout",
},
oauthCfg: &social.OAuthInfo{Enabled: true},
expectedURL: "http://idp.com/logout",
expectedOK: true,
},
{
desc: "should return redirect url for client configured redirect url",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{
Enabled: true,
SignoutRedirectUrl: "http://idp.com/logout",
},
expectedURL: "http://idp.com/logout",
expectedOK: true,
},
{
desc: "client specific url should take precedence",
cfg: &setting.Cfg{
SignoutRedirectUrl: "http://idp.com/logout",
},
oauthCfg: &social.OAuthInfo{
Enabled: true,
SignoutRedirectUrl: "http://idp-2.com/logout",
},
expectedURL: "http://idp-2.com/logout",
expectedOK: true,
},
{
desc: "should add id token hint if oicd logout is configured and token is valid",
cfg: &setting.Cfg{},
oauthCfg: &social.OAuthInfo{
Enabled: true,
SignoutRedirectUrl: "http://idp.com/logout?post_logout_redirect_uri=http%3A%3A%2F%2Ftest.com%2Flogin",
},
expectedURL: "http://idp.com/logout",
expectedIDTokenHint: "id_token_hint=some.id.token",
expectedPostLogoutURI: "http%3A%3A%2F%2Ftest.com%2Flogin",
expectedOK: true,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
var (
getTokenCalled bool
invalidateTokenCalled bool
)
mockService := &oauthtokentest.MockOauthTokenService{
GetCurrentOauthTokenFunc: func(_ context.Context, _ identity.Requester) *oauth2.Token {
getTokenCalled = true
token := &oauth2.Token{
AccessToken: "some.access.token",
Expiry: time.Now().Add(10 * time.Minute),
}
return token.WithExtra(map[string]any{
"id_token": "some.id.token",
})
},
InvalidateOAuthTokensFunc: func(_ context.Context, _ *login.UserAuth) error {
invalidateTokenCalled = true
return nil
},
}
fakeSocialSvc := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthCfg,
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, mockService, fakeSocialSvc, &setting.OSSImpl{Cfg: tt.cfg}, featuremgmt.WithFeatures())
redirect, ok := c.Logout(context.Background(), &authn.Identity{})
assert.Equal(t, tt.expectedOK, ok)
if tt.expectedOK {
assert.True(t, strings.HasPrefix(redirect.URL, tt.expectedURL))
assert.Contains(t, redirect.URL, tt.expectedIDTokenHint)
assert.Contains(t, redirect.URL, tt.expectedPostLogoutURI)
}
assert.True(t, getTokenCalled)
assert.True(t, invalidateTokenCalled)
})
}
}
func TestGenPKCECodeVerifier(t *testing.T) {
verifier, err := genPKCECodeVerifier()
assert.NoError(t, err)
assert.Len(t, verifier, 128)
}
func TestIsEnabled(t *testing.T) {
type testCase struct {
desc string
oauthCfg *social.OAuthInfo
expected bool
}
tests := []testCase{
{
desc: "should return false when client is not enabled",
oauthCfg: &social.OAuthInfo{Enabled: false},
expected: false,
},
{
desc: "should return false when client doesnt exists",
oauthCfg: nil,
expected: false,
},
{
desc: "should return true when client is enabled",
oauthCfg: &social.OAuthInfo{Enabled: true},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
fakeSocialSvc := &socialtest.FakeSocialService{
ExpectedAuthInfoProvider: tt.oauthCfg,
}
cfg := setting.NewCfg()
c := ProvideOAuth(
social.GitHubProviderName,
cfg,
nil,
fakeSocialSvc,
&setting.OSSImpl{Cfg: cfg},
featuremgmt.WithFeatures())
assert.Equal(t, tt.expected, c.IsEnabled())
})
}
}
type mockConnector struct {
AuthCodeURLFunc func(state string, opts ...oauth2.AuthCodeOption) string
social.SocialConnector
}
func (m mockConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
if m.AuthCodeURLFunc != nil {
return m.AuthCodeURLFunc(state, opts...)
}
return ""
}
var _ social.SocialConnector = new(fakeConnector)
type fakeConnector struct {
ExpectedUserInfo *social.BasicUserInfo
ExpectedUserInfoErr error
ExpectedIsEmailAllowed bool
ExpectedIsSignupAllowed bool
ExpectedToken *oauth2.Token
ExpectedTokenErr error
social.SocialConnector
}
func (f fakeConnector) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
return f.ExpectedUserInfo, f.ExpectedUserInfoErr
}
func (f fakeConnector) IsEmailAllowed(email string) bool {
return f.ExpectedIsEmailAllowed
}
func (f fakeConnector) IsSignupAllowed() bool {
return f.ExpectedIsSignupAllowed
}
func (f fakeConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return f.ExpectedToken, f.ExpectedTokenErr
}
func (f fakeConnector) Client(ctx context.Context, t *oauth2.Token) *http.Client {
return nil
}
func mustParseURL(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return u
}