package ssosettingsimpl import ( "context" "encoding/base64" "errors" "fmt" "testing" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/accesscontrol/acimpl" secretsFakes "github.com/grafana/grafana/pkg/services/secrets/fakes" "github.com/grafana/grafana/pkg/services/ssosettings" "github.com/grafana/grafana/pkg/services/ssosettings/models" "github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests" "github.com/grafana/grafana/pkg/setting" ) func TestSSOSettingsService_GetForProvider(t *testing.T) { testCases := []struct { name string setup func(env testEnv) want *models.SSOSettings wantErr bool }{ { name: "should return successfully", setup: func(env testEnv) { env.store.ExpectedSSOSetting = &models.SSOSettings{ Provider: "github", Settings: map[string]any{"enabled": true}, Source: models.DB, } env.fallbackStrategy.ExpectedIsMatch = true env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{ "github": { "client_id": "client_id", "client_secret": "secret", }, } }, want: &models.SSOSettings{ Provider: "github", Settings: map[string]any{ "enabled": true, "client_id": "client_id", "client_secret": "secret", }, }, wantErr: false, }, { name: "should return error if store returns an error different than not found", setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") }, want: nil, wantErr: true, }, { name: "should fallback to the system settings if store returns not found", setup: func(env testEnv) { env.store.ExpectedError = ssosettings.ErrNotFound env.fallbackStrategy.ExpectedIsMatch = true env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{ "github": { "enabled": true, "client_id": "client_id", }, } }, want: &models.SSOSettings{ Provider: "github", Settings: map[string]any{ "enabled": true, "client_id": "client_id"}, Source: models.System, }, wantErr: false, }, { name: "should return error if the fallback strategy was not found", setup: func(env testEnv) { env.store.ExpectedError = ssosettings.ErrNotFound env.fallbackStrategy.ExpectedIsMatch = false }, want: nil, wantErr: true, }, { name: "should return error if fallback strategy returns error", setup: func(env testEnv) { env.store.ExpectedError = ssosettings.ErrNotFound env.fallbackStrategy.ExpectedIsMatch = true env.fallbackStrategy.ExpectedError = fmt.Errorf("error") }, want: nil, wantErr: true, }, { name: "should decrypt secrets if data is coming from store", setup: func(env testEnv) { env.store.ExpectedSSOSetting = &models.SSOSettings{ Provider: "github", Settings: map[string]any{ "enabled": true, "client_secret": base64.RawStdEncoding.EncodeToString([]byte("client_secret")), "other_secret": base64.RawStdEncoding.EncodeToString([]byte("other_secret")), }, Source: models.DB, } env.fallbackStrategy.ExpectedIsMatch = true env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{ "github": { "client_id": "client_id", }, } env.secrets.On("Decrypt", mock.Anything, []byte("client_secret"), mock.Anything).Return([]byte("decrypted-client-secret"), nil).Once() env.secrets.On("Decrypt", mock.Anything, []byte("other_secret"), mock.Anything).Return([]byte("decrypted-other-secret"), nil).Once() }, want: &models.SSOSettings{ Provider: "github", Settings: map[string]any{ "enabled": true, "client_id": "client_id", "client_secret": "decrypted-client-secret", "other_secret": "decrypted-other-secret", }, Source: models.DB, }, wantErr: false, }, { name: "should not decrypt secrets if data is coming from the fallback strategy", setup: func(env testEnv) { env.store.ExpectedError = ssosettings.ErrNotFound env.fallbackStrategy.ExpectedIsMatch = true env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{ "github": { "enabled": true, "client_id": "client_id", "client_secret": "client_secret", }, } }, want: &models.SSOSettings{ Provider: "github", Settings: map[string]any{ "enabled": true, "client_id": "client_id", "client_secret": "client_secret", }, Source: models.System, }, wantErr: false, }, { name: "should return an error if the data in the store is invalid", setup: func(env testEnv) { env.store.ExpectedSSOSetting = &models.SSOSettings{ Provider: "github", Settings: map[string]any{ "enabled": true, "client_secret": "not a valid base64 string", }, Source: models.DB, } env.fallbackStrategy.ExpectedIsMatch = true env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{ "github": { "client_id": "client_id", }, } }, wantErr: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { env := setupTestEnv(t) if tc.setup != nil { tc.setup(env) } actual, err := env.service.GetForProvider(context.Background(), "github") if tc.wantErr { require.Error(t, err) return } require.NoError(t, err) require.Equal(t, tc.want, actual) env.secrets.AssertExpectations(t) }) } } func TestSSOSettingsService_GetForProviderWithRedactedSecrets(t *testing.T) { testCases := []struct { name string setup func(env testEnv) want *models.SSOSettings wantErr bool }{ { name: "should return successfully and redact secrets", setup: func(env testEnv) { env.store.ExpectedSSOSetting = &models.SSOSettings{ Provider: "github", Settings: map[string]any{ "enabled": true, "secret": base64.RawStdEncoding.EncodeToString([]byte("secret")), "client_secret": base64.RawStdEncoding.EncodeToString([]byte("client_secret")), "client_id": "client_id", }, Source: models.DB, } env.secrets.On("Decrypt", mock.Anything, []byte("client_secret"), mock.Anything).Return([]byte("decrypted-client-secret"), nil).Once() env.secrets.On("Decrypt", mock.Anything, []byte("secret"), mock.Anything).Return([]byte("decrypted-secret"), nil).Once() }, want: &models.SSOSettings{ Provider: "github", Settings: map[string]any{ "enabled": true, "secret": "*********", "client_secret": "*********", "client_id": "client_id", }, }, wantErr: false, }, { name: "should return error if store returns an error different than not found", setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") }, want: nil, wantErr: true, }, { name: "should fallback to strategy if store returns not found", setup: func(env testEnv) { env.store.ExpectedError = ssosettings.ErrNotFound env.fallbackStrategy.ExpectedIsMatch = true env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{ "github": { "enabled": true, }, } }, want: &models.SSOSettings{ Provider: "github", Settings: map[string]any{"enabled": true}, Source: models.System, }, wantErr: false, }, { name: "should return error if the fallback strategy was not found", setup: func(env testEnv) { env.store.ExpectedError = ssosettings.ErrNotFound env.fallbackStrategy.ExpectedIsMatch = false }, want: nil, wantErr: true, }, { name: "should return error if fallback strategy returns error", setup: func(env testEnv) { env.store.ExpectedError = ssosettings.ErrNotFound env.fallbackStrategy.ExpectedIsMatch = true env.fallbackStrategy.ExpectedError = fmt.Errorf("error") }, want: nil, wantErr: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { env := setupTestEnv(t) if tc.setup != nil { tc.setup(env) } actual, err := env.service.GetForProviderWithRedactedSecrets(context.Background(), "github") if tc.wantErr { require.Error(t, err) return } require.NoError(t, err) require.Equal(t, tc.want, actual) }) } } func TestSSOSettingsService_List(t *testing.T) { testCases := []struct { name string setup func(env testEnv) want []*models.SSOSettings wantErr bool }{ { name: "should return successfully", setup: func(env testEnv) { env.store.ExpectedSSOSettings = []*models.SSOSettings{ { Provider: "github", Settings: map[string]any{ "enabled": true, "client_secret": base64.RawStdEncoding.EncodeToString([]byte("client_secret")), }, Source: models.DB, }, { Provider: "okta", Settings: map[string]any{ "enabled": false, "other_secret": base64.RawStdEncoding.EncodeToString([]byte("other_secret")), }, Source: models.DB, }, } env.secrets.On("Decrypt", mock.Anything, []byte("client_secret"), mock.Anything).Return([]byte("decrypted-client-secret"), nil).Once() env.secrets.On("Decrypt", mock.Anything, []byte("other_secret"), mock.Anything).Return([]byte("decrypted-other-secret"), nil).Once() env.fallbackStrategy.ExpectedIsMatch = true env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{ "github": { "enabled": false, "client_id": "client_id", "client_secret": "secret1", "token_url": "token_url", }, "okta": { "enabled": false, "client_id": "client_id", "client_secret": "coming-from-system", "other_secret": "secret2", "token_url": "token_url", }, "gitlab": { "enabled": false, }, "generic_oauth": { "enabled": false, }, "google": { "enabled": false, }, "azuread": { "enabled": false, }, "grafana_com": { "enabled": false, }, } }, want: []*models.SSOSettings{ { Provider: "github", Settings: map[string]any{ "enabled": true, "client_id": "client_id", "client_secret": "decrypted-client-secret", // client_secret is coming from the database, must be decrypted first "token_url": "token_url", }, Source: models.DB, }, { Provider: "okta", Settings: map[string]any{ "enabled": false, "client_id": "client_id", "client_secret": "coming-from-system", // client_secret is coming from the system, must not be decrypted "other_secret": "decrypted-other-secret", "token_url": "token_url", }, Source: models.DB, }, { Provider: "gitlab", Settings: map[string]any{"enabled": false}, Source: models.System, }, { Provider: "generic_oauth", Settings: map[string]any{"enabled": false}, Source: models.System, }, { Provider: "google", Settings: map[string]any{"enabled": false}, Source: models.System, }, { Provider: "azuread", Settings: map[string]any{"enabled": false}, Source: models.System, }, { Provider: "grafana_com", Settings: map[string]any{"enabled": false}, Source: models.System, }, }, wantErr: false, }, { name: "should return error if store returns an error", setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") }, want: nil, wantErr: true, }, { name: "should use the fallback strategy if store returns empty list", setup: func(env testEnv) { env.store.ExpectedSSOSettings = []*models.SSOSettings{} env.fallbackStrategy.ExpectedIsMatch = true env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{ "github": { "enabled": false, }, "okta": { "enabled": false, }, "gitlab": { "enabled": false, }, "generic_oauth": { "enabled": false, }, "google": { "enabled": false, }, "azuread": { "enabled": false, }, "grafana_com": { "enabled": false, }, } }, want: []*models.SSOSettings{ { Provider: "github", Settings: map[string]any{"enabled": false}, Source: models.System, }, { Provider: "okta", Settings: map[string]any{"enabled": false}, Source: models.System, }, { Provider: "gitlab", Settings: map[string]any{"enabled": false}, Source: models.System, }, { Provider: "generic_oauth", Settings: map[string]any{"enabled": false}, Source: models.System, }, { Provider: "google", Settings: map[string]any{"enabled": false}, Source: models.System, }, { Provider: "azuread", Settings: map[string]any{"enabled": false}, Source: models.System, }, { Provider: "grafana_com", Settings: map[string]any{"enabled": false}, Source: models.System, }, }, wantErr: false, }, { name: "should return error if any of the fallback strategies was not found", setup: func(env testEnv) { env.store.ExpectedSSOSettings = []*models.SSOSettings{} env.fallbackStrategy.ExpectedIsMatch = false }, want: nil, wantErr: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { env := setupTestEnv(t) if tc.setup != nil { tc.setup(env) } actual, err := env.service.List(context.Background()) if tc.wantErr { require.Error(t, err) return } require.NoError(t, err) require.ElementsMatch(t, tc.want, actual) }) } } func TestSSOSettingsService_Upsert(t *testing.T) { t.Run("successfully upsert SSO settings", func(t *testing.T) { env := setupTestEnv(t) provider := social.AzureADProviderName settings := models.SSOSettings{ Provider: provider, Settings: map[string]any{ "client_id": "client-id", "client_secret": "client-secret", "enabled": true, }, IsDeleted: false, } reloadable := ssosettingstests.NewMockReloadable(t) reloadable.On("Validate", mock.Anything, settings).Return(nil) reloadable.On("Reload", mock.Anything, mock.Anything).Return(nil).Maybe() env.reloadables[provider] = reloadable env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() err := env.service.Upsert(context.Background(), settings) require.NoError(t, err) settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret")) require.EqualValues(t, settings, env.store.ActualSSOSettings) }) t.Run("returns error if provider is not configurable", func(t *testing.T) { env := setupTestEnv(t) provider := social.GrafanaComProviderName settings := models.SSOSettings{ Provider: provider, Settings: map[string]any{ "client_id": "client-id", "client_secret": "client-secret", "enabled": true, }, IsDeleted: false, } reloadable := ssosettingstests.NewMockReloadable(t) env.reloadables[provider] = reloadable err := env.service.Upsert(context.Background(), settings) require.Error(t, err) }) t.Run("returns error if provider was not found in reloadables", func(t *testing.T) { env := setupTestEnv(t) provider := social.AzureADProviderName settings := models.SSOSettings{ Provider: provider, Settings: map[string]any{ "client_id": "client-id", "client_secret": "client-secret", "enabled": true, }, IsDeleted: false, } reloadable := ssosettingstests.NewMockReloadable(t) // the reloadable is available for other provider env.reloadables["github"] = reloadable err := env.service.Upsert(context.Background(), settings) require.Error(t, err) }) t.Run("returns error if validation fails", func(t *testing.T) { env := setupTestEnv(t) provider := social.AzureADProviderName settings := models.SSOSettings{ Provider: provider, Settings: map[string]any{ "client_id": "client-id", "client_secret": "client-secret", "enabled": true, }, IsDeleted: false, } reloadable := ssosettingstests.NewMockReloadable(t) reloadable.On("Validate", mock.Anything, settings).Return(errors.New("validation failed")) env.reloadables[provider] = reloadable err := env.service.Upsert(context.Background(), settings) require.Error(t, err) }) t.Run("returns error if a fallback strategy is not available for the provider", func(t *testing.T) { env := setupTestEnv(t) settings := models.SSOSettings{ Provider: social.AzureADProviderName, Settings: map[string]any{ "client_id": "client-id", "client_secret": "client-secret", "enabled": true, }, IsDeleted: false, } env.fallbackStrategy.ExpectedIsMatch = false err := env.service.Upsert(context.Background(), settings) require.Error(t, err) }) t.Run("returns error if secrets encryption failed", func(t *testing.T) { env := setupTestEnv(t) provider := social.OktaProviderName settings := models.SSOSettings{ Provider: provider, Settings: map[string]any{ "client_id": "client-id", "client_secret": "client-secret", "enabled": true, }, IsDeleted: false, } reloadable := ssosettingstests.NewMockReloadable(t) reloadable.On("Validate", mock.Anything, settings).Return(nil) env.reloadables[provider] = reloadable env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return(nil, errors.New("encryption failed")).Once() err := env.service.Upsert(context.Background(), settings) require.Error(t, err) }) t.Run("should not update the current secret if the secret has not been updated", func(t *testing.T) { env := setupTestEnv(t) provider := social.AzureADProviderName settings := models.SSOSettings{ Provider: provider, Settings: map[string]any{ "client_id": "client-id", "client_secret": setting.RedactedPassword, "enabled": true, }, IsDeleted: false, } env.store.ExpectedSSOSetting = &models.SSOSettings{ Provider: provider, Settings: map[string]any{ "client_secret": base64.RawStdEncoding.EncodeToString([]byte("current-client-secret")), }, } reloadable := ssosettingstests.NewMockReloadable(t) reloadable.On("Validate", mock.Anything, settings).Return(nil) reloadable.On("Reload", mock.Anything, mock.Anything).Return(nil).Maybe() env.reloadables[provider] = reloadable env.secrets.On("Decrypt", mock.Anything, []byte("current-client-secret"), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() env.secrets.On("Encrypt", mock.Anything, []byte("encrypted-client-secret"), mock.Anything).Return([]byte("current-client-secret"), nil).Once() err := env.service.Upsert(context.Background(), settings) require.NoError(t, err) settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("current-client-secret")) require.EqualValues(t, settings, env.store.ActualSSOSettings) }) t.Run("returns error if store failed to upsert settings", func(t *testing.T) { env := setupTestEnv(t) provider := social.AzureADProviderName settings := models.SSOSettings{ Provider: provider, Settings: map[string]any{ "client_id": "client-id", "client_secret": "client-secret", "enabled": true, }, IsDeleted: false, } reloadable := ssosettingstests.NewMockReloadable(t) reloadable.On("Validate", mock.Anything, settings).Return(nil) env.reloadables[provider] = reloadable env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() env.store.GetFn = func(ctx context.Context, provider string) (*models.SSOSettings, error) { return &models.SSOSettings{}, nil } env.store.UpsertFn = func(ctx context.Context, settings models.SSOSettings) error { return errors.New("failed to upsert settings") } err := env.service.Upsert(context.Background(), settings) require.Error(t, err) }) t.Run("successfully upsert SSO settings if reload fails", func(t *testing.T) { env := setupTestEnv(t) provider := social.AzureADProviderName settings := models.SSOSettings{ Provider: provider, Settings: map[string]any{ "client_id": "client-id", "client_secret": "client-secret", "enabled": true, }, IsDeleted: false, } reloadable := ssosettingstests.NewMockReloadable(t) reloadable.On("Validate", mock.Anything, settings).Return(nil) reloadable.On("Reload", mock.Anything, mock.Anything).Return(errors.New("failed reloading new settings")).Maybe() env.reloadables[provider] = reloadable env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once() err := env.service.Upsert(context.Background(), settings) require.NoError(t, err) settings.Settings["client_secret"] = base64.RawStdEncoding.EncodeToString([]byte("encrypted-client-secret")) require.EqualValues(t, settings, env.store.ActualSSOSettings) }) } func TestSSOSettingsService_Delete(t *testing.T) { t.Run("successfully delete SSO settings", func(t *testing.T) { env := setupTestEnv(t) provider := social.AzureADProviderName env.store.ExpectedError = nil err := env.service.Delete(context.Background(), provider) require.NoError(t, err) }) t.Run("SSO settings not found for the specified provider", func(t *testing.T) { env := setupTestEnv(t) provider := social.AzureADProviderName env.store.ExpectedError = ssosettings.ErrNotFound err := env.service.Delete(context.Background(), provider) require.Error(t, err) require.ErrorIs(t, err, ssosettings.ErrNotFound) }) t.Run("store fails to delete the SSO settings for the specified provider", func(t *testing.T) { env := setupTestEnv(t) provider := social.AzureADProviderName env.store.ExpectedError = errors.New("delete sso settings failed") err := env.service.Delete(context.Background(), provider) require.Error(t, err) require.NotErrorIs(t, err, ssosettings.ErrNotFound) }) } func TestSSOSettingsService_DoReload(t *testing.T) { t.Run("successfully reload settings", func(t *testing.T) { env := setupTestEnv(t) settingsList := []*models.SSOSettings{ { Provider: "github", Settings: map[string]any{ "enabled": true, "client_id": "github_client_id", }, }, { Provider: "google", Settings: map[string]any{ "enabled": true, "client_id": "google_client_id", }, }, { Provider: "azuread", Settings: map[string]any{ "enabled": true, "client_id": "azuread_client_id", }, }, } env.store.ExpectedSSOSettings = settingsList reloadable := ssosettingstests.NewMockReloadable(t) for _, settings := range settingsList { reloadable.On("Reload", mock.Anything, *settings).Return(nil).Once() env.reloadables[settings.Provider] = reloadable } env.service.doReload(context.Background()) }) t.Run("failed fetching the SSO settings", func(t *testing.T) { env := setupTestEnv(t) provider := "github" env.store.ExpectedError = errors.New("failed fetching the settings") reloadable := ssosettingstests.NewMockReloadable(t) env.reloadables[provider] = reloadable env.service.doReload(context.Background()) }) } func TestSSOSettingsService_decryptSecrets(t *testing.T) { testCases := []struct { name string setup func(env testEnv) settings map[string]any want map[string]any wantErr bool }{ { name: "should decrypt secrets successfully", setup: func(env testEnv) { env.secrets.On("Decrypt", mock.Anything, []byte("client_secret"), mock.Anything).Return([]byte("decrypted-client-secret"), nil).Once() env.secrets.On("Decrypt", mock.Anything, []byte("other_secret"), mock.Anything).Return([]byte("decrypted-other-secret"), nil).Once() }, settings: map[string]any{ "enabled": true, "client_secret": base64.RawStdEncoding.EncodeToString([]byte("client_secret")), "other_secret": base64.RawStdEncoding.EncodeToString([]byte("other_secret")), }, want: map[string]any{ "enabled": true, "client_secret": "decrypted-client-secret", "other_secret": "decrypted-other-secret", }, }, { name: "should return an error if data is not a string", settings: map[string]any{ "enabled": true, "client_secret": 2, "other_secret": 2, }, wantErr: true, }, { name: "should return an error if data is not a valid base64 string", settings: map[string]any{ "enabled": true, "client_secret": "client_secret", "other_secret": "other_secret", }, wantErr: true, }, { name: "should return an error decryption fails", setup: func(env testEnv) { env.secrets.On("Decrypt", mock.Anything, []byte("client_secret"), mock.Anything).Return(nil, errors.New("decryption failed")).Once() }, settings: map[string]any{ "enabled": true, "client_secret": base64.RawStdEncoding.EncodeToString([]byte("client_secret")), }, wantErr: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { env := setupTestEnv(t) if tc.setup != nil { tc.setup(env) } actual, err := env.service.decryptSecrets(context.Background(), tc.settings) if tc.wantErr { require.Error(t, err) return } require.NoError(t, err) require.Equal(t, tc.want, actual) env.secrets.AssertExpectations(t) }) } } func setupTestEnv(t *testing.T) testEnv { store := ssosettingstests.NewFakeStore() fallbackStrategy := ssosettingstests.NewFakeFallbackStrategy() secrets := secretsFakes.NewMockService(t) accessControl := acimpl.ProvideAccessControl(setting.NewCfg()) reloadables := make(map[string]ssosettings.Reloadable) fallbackStrategy.ExpectedIsMatch = true svc := &SSOSettingsService{ logger: log.NewNopLogger(), store: store, ac: accessControl, fbStrategies: []ssosettings.FallbackStrategy{fallbackStrategy}, reloadables: reloadables, secrets: secrets, } return testEnv{ service: svc, store: store, ac: accessControl, fallbackStrategy: fallbackStrategy, secrets: secrets, reloadables: reloadables, } } type testEnv struct { service *SSOSettingsService store *ssosettingstests.FakeStore ac accesscontrol.AccessControl fallbackStrategy *ssosettingstests.FakeFallbackStrategy secrets *secretsFakes.MockService reloadables map[string]ssosettings.Reloadable }