mirror of https://github.com/grafana/grafana
Auth: Refresh OAuth access_token automatically using the refresh_token (#56076)
* Verify OAuth token expiration for oauth users in the ctx handler middleware * Use refresh token to get a new access token * Refactor oauth_token.go * Add tests for the middleware changes * Align other tests * Add tests, wip * Add more tests * Add InvalidateOAuthTokens method * Fix ExpiryDate update to default * Invalidate OAuth tokens during logout * Improve logout * Add more comments * Cleanup * Fix import order * Add error to HasOAuthEntry return values * add dev debug logs * Fix tests Co-authored-by: jguer <joao.guerreiro@grafana.com>pull/57204/head
parent
984ec00aac
commit
9c954d06ab
@ -0,0 +1,374 @@ |
||||
package oauthtoken |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"net/http" |
||||
"reflect" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/grafana/grafana/pkg/infra/usagestats" |
||||
"github.com/grafana/grafana/pkg/login/social" |
||||
"github.com/grafana/grafana/pkg/models" |
||||
"github.com/grafana/grafana/pkg/services/login" |
||||
"github.com/grafana/grafana/pkg/services/login/authinfoservice" |
||||
"github.com/grafana/grafana/pkg/services/user" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/mock" |
||||
"golang.org/x/oauth2" |
||||
"golang.org/x/sync/singleflight" |
||||
) |
||||
|
||||
func TestService_HasOAuthEntry(t *testing.T) { |
||||
testCases := []struct { |
||||
name string |
||||
user *user.SignedInUser |
||||
want *models.UserAuth |
||||
wantExist bool |
||||
wantErr bool |
||||
err error |
||||
getAuthInfoErr error |
||||
getAuthInfoUser models.UserAuth |
||||
}{ |
||||
{ |
||||
name: "returns false without an error in case user is nil", |
||||
user: nil, |
||||
want: nil, |
||||
wantExist: false, |
||||
wantErr: false, |
||||
}, |
||||
{ |
||||
name: "returns false and an error in case GetAuthInfo returns an error", |
||||
user: &user.SignedInUser{}, |
||||
want: nil, |
||||
wantExist: false, |
||||
wantErr: true, |
||||
getAuthInfoErr: errors.New("error"), |
||||
}, |
||||
{ |
||||
name: "returns false without an error in case auth entry is not found", |
||||
user: &user.SignedInUser{}, |
||||
want: nil, |
||||
wantExist: false, |
||||
wantErr: false, |
||||
getAuthInfoErr: user.ErrUserNotFound, |
||||
}, |
||||
{ |
||||
name: "returns false without an error in case the auth entry is not oauth", |
||||
user: &user.SignedInUser{}, |
||||
want: nil, |
||||
wantExist: false, |
||||
wantErr: false, |
||||
getAuthInfoUser: models.UserAuth{AuthModule: "auth_saml"}, |
||||
}, |
||||
{ |
||||
name: "returns true when the auth entry is found", |
||||
user: &user.SignedInUser{}, |
||||
want: &models.UserAuth{AuthModule: "oauth_generic_oauth"}, |
||||
wantExist: true, |
||||
wantErr: false, |
||||
getAuthInfoUser: models.UserAuth{AuthModule: "oauth_generic_oauth"}, |
||||
}, |
||||
} |
||||
for _, tc := range testCases { |
||||
t.Run(tc.name, func(t *testing.T) { |
||||
srv, authInfoStore, _ := setupOAuthTokenService(t) |
||||
authInfoStore.ExpectedOAuth = &tc.getAuthInfoUser |
||||
authInfoStore.ExpectedError = tc.getAuthInfoErr |
||||
|
||||
entry, exists, err := srv.HasOAuthEntry(context.Background(), tc.user) |
||||
|
||||
if tc.wantErr { |
||||
assert.Error(t, err) |
||||
} |
||||
|
||||
if tc.want != nil { |
||||
assert.True(t, reflect.DeepEqual(tc.want, entry)) |
||||
} |
||||
assert.Equal(t, tc.wantExist, exists) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestService_TryTokenRefresh_ValidToken(t *testing.T) { |
||||
srv, authInfoStore, socialConnector := setupOAuthTokenService(t) |
||||
ctx := context.Background() |
||||
token := &oauth2.Token{ |
||||
AccessToken: "testaccess", |
||||
RefreshToken: "testrefresh", |
||||
Expiry: time.Now(), |
||||
TokenType: "Bearer", |
||||
} |
||||
usr := &models.UserAuth{ |
||||
AuthModule: "oauth_generic_oauth", |
||||
OAuthAccessToken: token.AccessToken, |
||||
OAuthRefreshToken: token.RefreshToken, |
||||
OAuthExpiry: token.Expiry, |
||||
OAuthTokenType: token.TokenType, |
||||
} |
||||
|
||||
authInfoStore.ExpectedOAuth = usr |
||||
|
||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)) |
||||
|
||||
err := srv.TryTokenRefresh(ctx, usr) |
||||
assert.Nil(t, err) |
||||
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1) |
||||
|
||||
authInfoQuery := &models.GetAuthInfoQuery{} |
||||
err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery) |
||||
|
||||
assert.Nil(t, err) |
||||
|
||||
// User's token data had not been updated
|
||||
resultUsr := authInfoQuery.Result |
||||
assert.Equal(t, resultUsr.OAuthAccessToken, token.AccessToken) |
||||
assert.Equal(t, resultUsr.OAuthExpiry, token.Expiry) |
||||
assert.Equal(t, resultUsr.OAuthRefreshToken, token.RefreshToken) |
||||
assert.Equal(t, resultUsr.OAuthTokenType, token.TokenType) |
||||
} |
||||
|
||||
func TestService_TryTokenRefresh_NoRefreshToken(t *testing.T) { |
||||
srv, _, socialConnector := setupOAuthTokenService(t) |
||||
ctx := context.Background() |
||||
token := &oauth2.Token{ |
||||
AccessToken: "testaccess", |
||||
RefreshToken: "", |
||||
Expiry: time.Now().Add(-time.Hour), |
||||
TokenType: "Bearer", |
||||
} |
||||
usr := &models.UserAuth{ |
||||
AuthModule: "oauth_generic_oauth", |
||||
OAuthAccessToken: token.AccessToken, |
||||
OAuthRefreshToken: token.RefreshToken, |
||||
OAuthExpiry: token.Expiry, |
||||
OAuthTokenType: token.TokenType, |
||||
} |
||||
|
||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)) |
||||
|
||||
err := srv.TryTokenRefresh(ctx, usr) |
||||
|
||||
assert.NotNil(t, err) |
||||
assert.ErrorIs(t, err, ErrNoRefreshTokenFound) |
||||
|
||||
socialConnector.AssertNotCalled(t, "TokenSource") |
||||
} |
||||
|
||||
func TestService_TryTokenRefresh_ExpiredToken(t *testing.T) { |
||||
srv, authInfoStore, socialConnector := setupOAuthTokenService(t) |
||||
ctx := context.Background() |
||||
token := &oauth2.Token{ |
||||
AccessToken: "testaccess", |
||||
RefreshToken: "testrefresh", |
||||
Expiry: time.Now().Add(-time.Hour), |
||||
TokenType: "Bearer", |
||||
} |
||||
|
||||
newToken := &oauth2.Token{ |
||||
AccessToken: "testaccess_new", |
||||
RefreshToken: "testrefresh_new", |
||||
Expiry: time.Now().Add(time.Hour), |
||||
TokenType: "Bearer", |
||||
} |
||||
|
||||
usr := &models.UserAuth{ |
||||
AuthModule: "oauth_generic_oauth", |
||||
OAuthAccessToken: token.AccessToken, |
||||
OAuthRefreshToken: token.RefreshToken, |
||||
OAuthExpiry: token.Expiry, |
||||
OAuthTokenType: token.TokenType, |
||||
} |
||||
|
||||
authInfoStore.ExpectedOAuth = usr |
||||
|
||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.ReuseTokenSource(token, oauth2.StaticTokenSource(newToken)), nil) |
||||
|
||||
err := srv.TryTokenRefresh(ctx, usr) |
||||
|
||||
assert.Nil(t, err) |
||||
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1) |
||||
|
||||
authInfoQuery := &models.GetAuthInfoQuery{} |
||||
err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery) |
||||
|
||||
assert.Nil(t, err) |
||||
|
||||
// newToken should be returned after the .Token() call, therefore the User had to be updated
|
||||
assert.Equal(t, authInfoQuery.Result.OAuthAccessToken, newToken.AccessToken) |
||||
assert.Equal(t, authInfoQuery.Result.OAuthExpiry, newToken.Expiry) |
||||
assert.Equal(t, authInfoQuery.Result.OAuthRefreshToken, newToken.RefreshToken) |
||||
assert.Equal(t, authInfoQuery.Result.OAuthTokenType, newToken.TokenType) |
||||
} |
||||
|
||||
func TestService_TryTokenRefresh_DifferentAuthModuleForUser(t *testing.T) { |
||||
srv, _, socialConnector := setupOAuthTokenService(t) |
||||
ctx := context.Background() |
||||
token := &oauth2.Token{} |
||||
usr := &models.UserAuth{ |
||||
AuthModule: "auth.saml", |
||||
} |
||||
|
||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)) |
||||
|
||||
err := srv.TryTokenRefresh(ctx, usr) |
||||
|
||||
assert.NotNil(t, err) |
||||
assert.ErrorIs(t, err, ErrNotAnOAuthProvider) |
||||
|
||||
socialConnector.AssertNotCalled(t, "TokenSource") |
||||
} |
||||
|
||||
func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *MockSocialConnector) { |
||||
t.Helper() |
||||
|
||||
socialConnector := &MockSocialConnector{} |
||||
socialService := &FakeSocialService{ |
||||
connector: socialConnector, |
||||
} |
||||
|
||||
authInfoStore := &FakeAuthInfoStore{} |
||||
authInfoService := authinfoservice.ProvideAuthInfoService(nil, authInfoStore, &usagestats.UsageStatsMock{}) |
||||
return &Service{ |
||||
Cfg: setting.NewCfg(), |
||||
SocialService: socialService, |
||||
AuthInfoService: authInfoService, |
||||
singleFlightGroup: &singleflight.Group{}, |
||||
}, authInfoStore, socialConnector |
||||
} |
||||
|
||||
type FakeSocialService struct { |
||||
httpClient *http.Client |
||||
connector *MockSocialConnector |
||||
} |
||||
|
||||
func (fss *FakeSocialService) GetOAuthProviders() map[string]bool { |
||||
panic("not implemented") |
||||
} |
||||
|
||||
func (fss *FakeSocialService) GetOAuthHttpClient(string) (*http.Client, error) { |
||||
return fss.httpClient, nil |
||||
} |
||||
|
||||
func (fss *FakeSocialService) GetConnector(string) (social.SocialConnector, error) { |
||||
return fss.connector, nil |
||||
} |
||||
|
||||
func (fss *FakeSocialService) GetOAuthInfoProvider(string) *social.OAuthInfo { |
||||
panic("not implemented") |
||||
} |
||||
|
||||
func (fss *FakeSocialService) GetOAuthInfoProviders() map[string]*social.OAuthInfo { |
||||
panic("not implemented") |
||||
} |
||||
|
||||
type MockSocialConnector struct { |
||||
mock.Mock |
||||
} |
||||
|
||||
func (m *MockSocialConnector) Type() int { |
||||
args := m.Called() |
||||
return args.Int(0) |
||||
} |
||||
|
||||
func (m *MockSocialConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { |
||||
args := m.Called(client, token) |
||||
return args.Get(0).(*social.BasicUserInfo), args.Error(1) |
||||
} |
||||
|
||||
func (m *MockSocialConnector) IsEmailAllowed(email string) bool { |
||||
panic("not implemented") |
||||
} |
||||
|
||||
func (m *MockSocialConnector) IsSignupAllowed() bool { |
||||
panic("not implemented") |
||||
} |
||||
|
||||
func (m *MockSocialConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { |
||||
panic("not implemented") |
||||
} |
||||
|
||||
func (m *MockSocialConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) { |
||||
panic("not implemented") |
||||
} |
||||
|
||||
func (m *MockSocialConnector) Client(ctx context.Context, t *oauth2.Token) *http.Client { |
||||
panic("not implemented") |
||||
} |
||||
|
||||
func (m *MockSocialConnector) TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource { |
||||
args := m.Called(ctx, t) |
||||
return args.Get(0).(oauth2.TokenSource) |
||||
} |
||||
|
||||
type FakeAuthInfoStore struct { |
||||
ExpectedError error |
||||
ExpectedUser *user.User |
||||
ExpectedOAuth *models.UserAuth |
||||
ExpectedDuplicateUserEntries int |
||||
ExpectedHasDuplicateUserEntries int |
||||
ExpectedLoginStats login.LoginStats |
||||
} |
||||
|
||||
func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *models.GetExternalUserInfoByLoginQuery) error { |
||||
return f.ExpectedError |
||||
} |
||||
|
||||
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *models.GetAuthInfoQuery) error { |
||||
query.Result = f.ExpectedOAuth |
||||
return f.ExpectedError |
||||
} |
||||
|
||||
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error { |
||||
return f.ExpectedError |
||||
} |
||||
|
||||
func (f *FakeAuthInfoStore) UpdateAuthInfoDate(ctx context.Context, authInfo *models.UserAuth) error { |
||||
return f.ExpectedError |
||||
} |
||||
|
||||
func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error { |
||||
f.ExpectedOAuth.OAuthAccessToken = cmd.OAuthToken.AccessToken |
||||
f.ExpectedOAuth.OAuthExpiry = cmd.OAuthToken.Expiry |
||||
f.ExpectedOAuth.OAuthTokenType = cmd.OAuthToken.TokenType |
||||
f.ExpectedOAuth.OAuthRefreshToken = cmd.OAuthToken.RefreshToken |
||||
return f.ExpectedError |
||||
} |
||||
|
||||
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAuthInfoCommand) error { |
||||
return f.ExpectedError |
||||
} |
||||
|
||||
func (f *FakeAuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) { |
||||
return f.ExpectedUser, f.ExpectedError |
||||
} |
||||
|
||||
func (f *FakeAuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) { |
||||
return f.ExpectedUser, f.ExpectedError |
||||
} |
||||
|
||||
func (f *FakeAuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*user.User, error) { |
||||
return f.ExpectedUser, f.ExpectedError |
||||
} |
||||
|
||||
func (f *FakeAuthInfoStore) CollectLoginStats(ctx context.Context) (map[string]interface{}, error) { |
||||
var res = make(map[string]interface{}) |
||||
res["stats.users.duplicate_user_entries"] = f.ExpectedDuplicateUserEntries |
||||
res["stats.users.has_duplicate_user_entries"] = f.ExpectedHasDuplicateUserEntries |
||||
res["stats.users.duplicate_user_entries_by_login"] = 0 |
||||
res["stats.users.has_duplicate_user_entries_by_login"] = 0 |
||||
res["stats.users.duplicate_user_entries_by_email"] = 0 |
||||
res["stats.users.has_duplicate_user_entries_by_email"] = 0 |
||||
res["stats.users.mixed_cased_users"] = f.ExpectedLoginStats.MixedCasedUsers |
||||
return res, f.ExpectedError |
||||
} |
||||
|
||||
func (f *FakeAuthInfoStore) RunMetricsCollection(ctx context.Context) error { |
||||
return f.ExpectedError |
||||
} |
||||
|
||||
func (f *FakeAuthInfoStore) GetLoginStats(ctx context.Context) (login.LoginStats, error) { |
||||
return f.ExpectedLoginStats, f.ExpectedError |
||||
} |
||||
Loading…
Reference in new issue