mirror of https://github.com/grafana/grafana
AzureMonitor: Use auth middleware for QueryData requests (#35343)
parent
36c997a625
commit
7109285ac9
@ -0,0 +1,127 @@ |
||||
package azuremonitor |
||||
|
||||
import ( |
||||
"context" |
||||
"testing" |
||||
|
||||
"github.com/google/go-cmp/cmp" |
||||
"github.com/google/go-cmp/cmp/cmpopts" |
||||
"github.com/grafana/grafana-plugin-sdk-go/backend" |
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestNewInstanceSettings(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
settings backend.DataSourceInstanceSettings |
||||
expectedModel datasourceInfo |
||||
Err require.ErrorAssertionFunc |
||||
}{ |
||||
{ |
||||
name: "creates an instance", |
||||
settings: backend.DataSourceInstanceSettings{ |
||||
JSONData: []byte(`{"cloudName":"azuremonitor"}`), |
||||
DecryptedSecureJSONData: map[string]string{"key": "value"}, |
||||
ID: 40, |
||||
}, |
||||
expectedModel: datasourceInfo{ |
||||
Settings: azureMonitorSettings{CloudName: "azuremonitor"}, |
||||
Routes: routes["azuremonitor"], |
||||
JSONData: map[string]interface{}{"cloudName": string("azuremonitor")}, |
||||
DatasourceID: 40, |
||||
DecryptedSecureJSONData: map[string]string{"key": "value"}, |
||||
}, |
||||
Err: require.NoError, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
factory := NewInstanceSettings() |
||||
instance, err := factory(tt.settings) |
||||
tt.Err(t, err) |
||||
if !cmp.Equal(instance, tt.expectedModel, cmpopts.IgnoreFields(datasourceInfo{}, "Services", "HTTPCliOpts")) { |
||||
t.Errorf("Unexpected instance: %v", cmp.Diff(instance, tt.expectedModel)) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
type fakeInstance struct{} |
||||
|
||||
func (f *fakeInstance) Get(pluginContext backend.PluginContext) (instancemgmt.Instance, error) { |
||||
return datasourceInfo{ |
||||
Services: map[string]datasourceService{}, |
||||
Routes: routes[azureMonitorPublic], |
||||
}, nil |
||||
} |
||||
|
||||
func (f *fakeInstance) Do(pluginContext backend.PluginContext, fn instancemgmt.InstanceCallbackFunc) error { |
||||
return nil |
||||
} |
||||
|
||||
type fakeExecutor struct { |
||||
t *testing.T |
||||
queryType string |
||||
expectedURL string |
||||
} |
||||
|
||||
func (f *fakeExecutor) executeTimeSeriesQuery(ctx context.Context, originalQueries []backend.DataQuery, dsInfo datasourceInfo) (*backend.QueryDataResponse, error) { |
||||
if s, ok := dsInfo.Services[f.queryType]; !ok { |
||||
f.t.Errorf("The HTTP client for %s is missing", f.queryType) |
||||
} else { |
||||
if s.URL != f.expectedURL { |
||||
f.t.Errorf("Unexpected URL %s wanted %s", s.URL, f.expectedURL) |
||||
} |
||||
} |
||||
return &backend.QueryDataResponse{}, nil |
||||
} |
||||
|
||||
func Test_newExecutor(t *testing.T) { |
||||
cfg := &setting.Cfg{} |
||||
|
||||
tests := []struct { |
||||
name string |
||||
queryType string |
||||
expectedURL string |
||||
Err require.ErrorAssertionFunc |
||||
}{ |
||||
{ |
||||
name: "creates an Azure Monitor executor", |
||||
queryType: azureMonitor, |
||||
expectedURL: routes[azureMonitorPublic][azureMonitor].URL, |
||||
Err: require.NoError, |
||||
}, |
||||
{ |
||||
name: "creates an Azure Log Analytics executor", |
||||
queryType: azureLogAnalytics, |
||||
expectedURL: routes[azureMonitorPublic][azureLogAnalytics].URL, |
||||
Err: require.NoError, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
mux := newExecutor(&fakeInstance{}, cfg, map[string]azDatasourceExecutor{ |
||||
tt.queryType: &fakeExecutor{ |
||||
t: t, |
||||
queryType: tt.queryType, |
||||
expectedURL: tt.expectedURL, |
||||
}, |
||||
}) |
||||
res, err := mux.QueryData(context.TODO(), &backend.QueryDataRequest{ |
||||
PluginContext: backend.PluginContext{}, |
||||
Queries: []backend.DataQuery{ |
||||
{QueryType: tt.queryType}, |
||||
}, |
||||
}) |
||||
tt.Err(t, err) |
||||
// Dummy response from the fake implementation
|
||||
if res == nil { |
||||
t.Errorf("Expecting a response") |
||||
} |
||||
}) |
||||
} |
||||
} |
@ -0,0 +1,54 @@ |
||||
package azuremonitor |
||||
|
||||
import ( |
||||
"context" |
||||
"testing" |
||||
|
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func Test_httpCliProvider(t *testing.T) { |
||||
ctx := context.TODO() |
||||
cfg := &setting.Cfg{} |
||||
model := datasourceInfo{ |
||||
Settings: azureMonitorSettings{}, |
||||
DecryptedSecureJSONData: map[string]string{"clientSecret": "content"}, |
||||
} |
||||
tests := []struct { |
||||
name string |
||||
route azRoute |
||||
expectedMiddlewares int |
||||
Err require.ErrorAssertionFunc |
||||
}{ |
||||
{ |
||||
name: "creates an HTTP client with a middleware", |
||||
route: azRoute{ |
||||
URL: "http://route", |
||||
Scopes: []string{"http://route/.default"}, |
||||
}, |
||||
expectedMiddlewares: 1, |
||||
Err: require.NoError, |
||||
}, |
||||
{ |
||||
name: "creates an HTTP client without a middleware", |
||||
route: azRoute{ |
||||
URL: "http://route", |
||||
Scopes: []string{}, |
||||
}, |
||||
// httpclient.NewProvider returns a client with 2 middlewares by default
|
||||
expectedMiddlewares: 2, |
||||
Err: require.NoError, |
||||
}, |
||||
} |
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
cli := httpClientProvider(ctx, tt.route, model, cfg) |
||||
// Cannot test that the cli middleware works properly since the azcore sdk
|
||||
// rejects the TLS certs (if provided)
|
||||
if len(cli.Opts.Middlewares) != tt.expectedMiddlewares { |
||||
t.Errorf("Unexpected middlewares: %v", cli.Opts.Middlewares) |
||||
} |
||||
}) |
||||
} |
||||
} |
@ -0,0 +1,53 @@ |
||||
package azuremonitor |
||||
|
||||
import ( |
||||
"context" |
||||
"net/http" |
||||
"testing" |
||||
|
||||
"github.com/google/go-cmp/cmp" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestInsightsAnalyticsCreateRequest(t *testing.T) { |
||||
ctx := context.Background() |
||||
dsInfo := datasourceInfo{ |
||||
Settings: azureMonitorSettings{AppInsightsAppId: "foo"}, |
||||
Services: map[string]datasourceService{ |
||||
insightsAnalytics: {URL: "http://ds"}, |
||||
}, |
||||
DecryptedSecureJSONData: map[string]string{ |
||||
"appInsightsApiKey": "key", |
||||
}, |
||||
} |
||||
|
||||
tests := []struct { |
||||
name string |
||||
expectedURL string |
||||
expectedHeaders http.Header |
||||
Err require.ErrorAssertionFunc |
||||
}{ |
||||
{ |
||||
name: "creates a request", |
||||
expectedURL: "http://ds/v1/apps/foo", |
||||
expectedHeaders: http.Header{ |
||||
"X-Api-Key": []string{"key"}, |
||||
}, |
||||
Err: require.NoError, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
ds := InsightsAnalyticsDatasource{} |
||||
req, err := ds.createRequest(ctx, dsInfo) |
||||
tt.Err(t, err) |
||||
if req.URL.String() != tt.expectedURL { |
||||
t.Errorf("Expecting %s, got %s", tt.expectedURL, req.URL.String()) |
||||
} |
||||
if !cmp.Equal(req.Header, tt.expectedHeaders) { |
||||
t.Errorf("Unexpected HTTP headers: %v", cmp.Diff(req.Header, tt.expectedHeaders)) |
||||
} |
||||
}) |
||||
} |
||||
} |
@ -1,45 +1,90 @@ |
||||
package azuremonitor |
||||
|
||||
import "fmt" |
||||
|
||||
func getManagementApiRoute(azureCloud string) (string, error) { |
||||
switch azureCloud { |
||||
case azureMonitorPublic: |
||||
return "azuremonitor", nil |
||||
case azureMonitorChina: |
||||
return "chinaazuremonitor", nil |
||||
case azureMonitorUSGovernment: |
||||
return "govazuremonitor", nil |
||||
case azureMonitorGermany: |
||||
return "germanyazuremonitor", nil |
||||
default: |
||||
err := fmt.Errorf("the cloud '%s' not supported", azureCloud) |
||||
return "", err |
||||
} |
||||
type azRoute struct { |
||||
URL string |
||||
Scopes []string |
||||
Headers map[string]string |
||||
} |
||||
|
||||
func getLogAnalyticsApiRoute(azureCloud string) (string, error) { |
||||
switch azureCloud { |
||||
case azureMonitorPublic: |
||||
return "loganalyticsazure", nil |
||||
case azureMonitorChina: |
||||
return "chinaloganalyticsazure", nil |
||||
case azureMonitorUSGovernment: |
||||
return "govloganalyticsazure", nil |
||||
default: |
||||
err := fmt.Errorf("the cloud '%s' not supported", azureCloud) |
||||
return "", err |
||||
} |
||||
var azManagement = azRoute{ |
||||
URL: "https://management.azure.com", |
||||
Scopes: []string{"https://management.azure.com/.default"}, |
||||
Headers: map[string]string{"x-ms-app": "Grafana"}, |
||||
} |
||||
|
||||
func getAppInsightsApiRoute(azureCloud string) (string, error) { |
||||
switch azureCloud { |
||||
case azureMonitorPublic: |
||||
return "appinsights", nil |
||||
case azureMonitorChina: |
||||
return "chinaappinsights", nil |
||||
default: |
||||
err := fmt.Errorf("the cloud '%s' not supported", azureCloud) |
||||
return "", err |
||||
} |
||||
var azUSGovManagement = azRoute{ |
||||
URL: "https://management.usgovcloudapi.net", |
||||
Scopes: []string{"https://management.usgovcloudapi.net/.default"}, |
||||
Headers: map[string]string{"x-ms-app": "Grafana"}, |
||||
} |
||||
|
||||
var azGermanyManagement = azRoute{ |
||||
URL: "https://management.microsoftazure.de", |
||||
Scopes: []string{"https://management.microsoftazure.de/.default"}, |
||||
Headers: map[string]string{"x-ms-app": "Grafana"}, |
||||
} |
||||
|
||||
var azChinaManagement = azRoute{ |
||||
URL: "https://management.chinacloudapi.cn", |
||||
Scopes: []string{"https://management.chinacloudapi.cn/.default"}, |
||||
Headers: map[string]string{"x-ms-app": "Grafana"}, |
||||
} |
||||
|
||||
var azAppInsights = azRoute{ |
||||
URL: "https://api.applicationinsights.io", |
||||
Scopes: []string{}, |
||||
Headers: map[string]string{"x-ms-app": "Grafana"}, |
||||
} |
||||
|
||||
var azChinaAppInsights = azRoute{ |
||||
URL: "https://api.applicationinsights.azure.cn", |
||||
Scopes: []string{}, |
||||
Headers: map[string]string{"x-ms-app": "Grafana"}, |
||||
} |
||||
|
||||
var azLogAnalytics = azRoute{ |
||||
URL: "https://api.loganalytics.io", |
||||
Scopes: []string{"https://api.loganalytics.io/.default"}, |
||||
Headers: map[string]string{"x-ms-app": "Grafana", "Cache-Control": "public, max-age=60"}, |
||||
} |
||||
|
||||
var azChinaLogAnalytics = azRoute{ |
||||
URL: "https://api.loganalytics.azure.cn", |
||||
Scopes: []string{"https://api.loganalytics.azure.cn/.default"}, |
||||
Headers: map[string]string{"x-ms-app": "Grafana", "Cache-Control": "public, max-age=60"}, |
||||
} |
||||
|
||||
var azUSGovLogAnalytics = azRoute{ |
||||
URL: "https://api.loganalytics.us", |
||||
Scopes: []string{"https://api.loganalytics.us/.default"}, |
||||
Headers: map[string]string{"x-ms-app": "Grafana", "Cache-Control": "public, max-age=60"}, |
||||
} |
||||
|
||||
var ( |
||||
// The different Azure routes are identified by its cloud (e.g. public or gov)
|
||||
// and the service to query (e.g. Azure Monitor or Azure Log Analytics)
|
||||
routes = map[string]map[string]azRoute{ |
||||
azureMonitorPublic: { |
||||
azureMonitor: azManagement, |
||||
azureLogAnalytics: azLogAnalytics, |
||||
azureResourceGraph: azManagement, |
||||
appInsights: azAppInsights, |
||||
insightsAnalytics: azAppInsights, |
||||
}, |
||||
azureMonitorUSGovernment: { |
||||
azureMonitor: azUSGovManagement, |
||||
azureLogAnalytics: azUSGovLogAnalytics, |
||||
azureResourceGraph: azUSGovManagement, |
||||
}, |
||||
azureMonitorGermany: { |
||||
azureMonitor: azGermanyManagement, |
||||
}, |
||||
azureMonitorChina: { |
||||
azureMonitor: azChinaManagement, |
||||
azureLogAnalytics: azChinaLogAnalytics, |
||||
azureResourceGraph: azChinaManagement, |
||||
appInsights: azChinaAppInsights, |
||||
insightsAnalytics: azChinaAppInsights, |
||||
}, |
||||
} |
||||
) |
||||
|
@ -0,0 +1,33 @@ |
||||
package tokenprovider |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net/http" |
||||
"time" |
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" |
||||
) |
||||
|
||||
var ( |
||||
// timeNow makes it possible to test usage of time
|
||||
timeNow = time.Now |
||||
) |
||||
|
||||
type TokenProvider interface { |
||||
GetAccessToken() (string, error) |
||||
} |
||||
|
||||
const authenticationMiddlewareName = "AzureAuthentication" |
||||
|
||||
func AuthMiddleware(tokenProvider TokenProvider) httpclient.Middleware { |
||||
return httpclient.NamedMiddlewareFunc(authenticationMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper { |
||||
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { |
||||
token, err := tokenProvider.GetAccessToken() |
||||
if err != nil { |
||||
return nil, fmt.Errorf("failed to retrieve azure access token: %w", err) |
||||
} |
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) |
||||
return next.RoundTrip(req) |
||||
}) |
||||
}) |
||||
} |
@ -0,0 +1,184 @@ |
||||
package tokenprovider |
||||
|
||||
import ( |
||||
"context" |
||||
"sort" |
||||
"strings" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
) |
||||
|
||||
type AccessToken struct { |
||||
Token string |
||||
ExpiresOn time.Time |
||||
} |
||||
|
||||
type TokenCredential interface { |
||||
GetCacheKey() string |
||||
Init() error |
||||
GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) |
||||
} |
||||
|
||||
type ConcurrentTokenCache interface { |
||||
GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) |
||||
} |
||||
|
||||
func NewConcurrentTokenCache() ConcurrentTokenCache { |
||||
return &tokenCacheImpl{} |
||||
} |
||||
|
||||
type tokenCacheImpl struct { |
||||
cache sync.Map // of *credentialCacheEntry
|
||||
} |
||||
type credentialCacheEntry struct { |
||||
credential TokenCredential |
||||
|
||||
credInit uint32 |
||||
credMutex sync.Mutex |
||||
cache sync.Map // of *scopesCacheEntry
|
||||
} |
||||
|
||||
type scopesCacheEntry struct { |
||||
credential TokenCredential |
||||
scopes []string |
||||
|
||||
cond *sync.Cond |
||||
refreshing bool |
||||
accessToken *AccessToken |
||||
} |
||||
|
||||
func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) { |
||||
return c.getEntryFor(credential).getAccessToken(ctx, scopes) |
||||
} |
||||
|
||||
func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCacheEntry { |
||||
var entry interface{} |
||||
var ok bool |
||||
|
||||
key := credential.GetCacheKey() |
||||
|
||||
if entry, ok = c.cache.Load(key); !ok { |
||||
entry, _ = c.cache.LoadOrStore(key, &credentialCacheEntry{ |
||||
credential: credential, |
||||
}) |
||||
} |
||||
|
||||
return entry.(*credentialCacheEntry) |
||||
} |
||||
|
||||
func (c *credentialCacheEntry) getAccessToken(ctx context.Context, scopes []string) (string, error) { |
||||
err := c.ensureInitialized() |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
return c.getEntryFor(scopes).getAccessToken(ctx) |
||||
} |
||||
|
||||
func (c *credentialCacheEntry) ensureInitialized() error { |
||||
if atomic.LoadUint32(&c.credInit) == 0 { |
||||
c.credMutex.Lock() |
||||
defer c.credMutex.Unlock() |
||||
|
||||
if c.credInit == 0 { |
||||
// Initialize credential
|
||||
err := c.credential.Init() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
atomic.StoreUint32(&c.credInit, 1) |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (c *credentialCacheEntry) getEntryFor(scopes []string) *scopesCacheEntry { |
||||
var entry interface{} |
||||
var ok bool |
||||
|
||||
key := getKeyForScopes(scopes) |
||||
|
||||
if entry, ok = c.cache.Load(key); !ok { |
||||
entry, _ = c.cache.LoadOrStore(key, &scopesCacheEntry{ |
||||
credential: c.credential, |
||||
scopes: scopes, |
||||
cond: sync.NewCond(&sync.Mutex{}), |
||||
}) |
||||
} |
||||
|
||||
return entry.(*scopesCacheEntry) |
||||
} |
||||
|
||||
func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) { |
||||
var accessToken *AccessToken |
||||
var err error |
||||
shouldRefresh := false |
||||
|
||||
c.cond.L.Lock() |
||||
for { |
||||
if c.accessToken != nil && c.accessToken.ExpiresOn.After(time.Now().Add(2*time.Minute)) { |
||||
// Use the cached token since it's available and not expired yet
|
||||
accessToken = c.accessToken |
||||
break |
||||
} |
||||
|
||||
if !c.refreshing { |
||||
// Start refreshing the token
|
||||
c.refreshing = true |
||||
shouldRefresh = true |
||||
break |
||||
} |
||||
|
||||
// Wait for the token to be refreshed
|
||||
c.cond.Wait() |
||||
} |
||||
c.cond.L.Unlock() |
||||
|
||||
if shouldRefresh { |
||||
accessToken, err = c.refreshAccessToken(ctx) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
} |
||||
|
||||
return accessToken.Token, nil |
||||
} |
||||
|
||||
func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken, error) { |
||||
var accessToken *AccessToken |
||||
|
||||
// Safeguarding from panic caused by credential implementation
|
||||
defer func() { |
||||
c.cond.L.Lock() |
||||
|
||||
c.refreshing = false |
||||
|
||||
if accessToken != nil { |
||||
c.accessToken = accessToken |
||||
} |
||||
|
||||
c.cond.Broadcast() |
||||
c.cond.L.Unlock() |
||||
}() |
||||
|
||||
token, err := c.credential.GetAccessToken(ctx, c.scopes) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
accessToken = token |
||||
return accessToken, nil |
||||
} |
||||
|
||||
func getKeyForScopes(scopes []string) string { |
||||
if len(scopes) > 1 { |
||||
arr := make([]string, len(scopes)) |
||||
copy(arr, scopes) |
||||
sort.Strings(arr) |
||||
scopes = arr |
||||
} |
||||
|
||||
return strings.Join(scopes, " ") |
||||
} |
@ -0,0 +1,457 @@ |
||||
package tokenprovider |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"fmt" |
||||
"sync" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
type fakeCredential struct { |
||||
key string |
||||
initCalledTimes int |
||||
calledTimes int |
||||
initFunc func() error |
||||
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error) |
||||
} |
||||
|
||||
func (c *fakeCredential) GetCacheKey() string { |
||||
return c.key |
||||
} |
||||
|
||||
func (c *fakeCredential) Reset() { |
||||
c.initCalledTimes = 0 |
||||
c.calledTimes = 0 |
||||
} |
||||
|
||||
func (c *fakeCredential) Init() error { |
||||
c.initCalledTimes = c.initCalledTimes + 1 |
||||
if c.initFunc != nil { |
||||
return c.initFunc() |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) { |
||||
c.calledTimes = c.calledTimes + 1 |
||||
if c.getAccessTokenFunc != nil { |
||||
return c.getAccessTokenFunc(ctx, scopes) |
||||
} |
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("%v-token-%v", c.key, c.calledTimes), ExpiresOn: timeNow().Add(time.Hour)} |
||||
return fakeAccessToken, nil |
||||
} |
||||
|
||||
func TestConcurrentTokenCache_GetAccessToken(t *testing.T) { |
||||
ctx := context.Background() |
||||
|
||||
scopes1 := []string{"Scope1"} |
||||
scopes2 := []string{"Scope2"} |
||||
|
||||
t.Run("should request access token from credential", func(t *testing.T) { |
||||
cache := NewConcurrentTokenCache() |
||||
credential := &fakeCredential{key: "credential-1"} |
||||
|
||||
token, err := cache.GetAccessToken(ctx, credential, scopes1) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, "credential-1-token-1", token) |
||||
|
||||
assert.Equal(t, 1, credential.calledTimes) |
||||
}) |
||||
|
||||
t.Run("should return cached token for same scopes", func(t *testing.T) { |
||||
var token1, token2 string |
||||
var err error |
||||
|
||||
cache := NewConcurrentTokenCache() |
||||
credential := &fakeCredential{key: "credential-1"} |
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential, scopes1) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, "credential-1-token-1", token1) |
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential, scopes2) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, "credential-1-token-2", token2) |
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential, scopes1) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, "credential-1-token-1", token1) |
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential, scopes2) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, "credential-1-token-2", token2) |
||||
|
||||
assert.Equal(t, 2, credential.calledTimes) |
||||
}) |
||||
|
||||
t.Run("should return cached token for same credentials", func(t *testing.T) { |
||||
var token1, token2 string |
||||
var err error |
||||
|
||||
cache := NewConcurrentTokenCache() |
||||
credential1 := &fakeCredential{key: "credential-1"} |
||||
credential2 := &fakeCredential{key: "credential-2"} |
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential1, scopes1) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, "credential-1-token-1", token1) |
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential2, scopes1) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, "credential-2-token-1", token2) |
||||
|
||||
token1, err = cache.GetAccessToken(ctx, credential1, scopes1) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, "credential-1-token-1", token1) |
||||
|
||||
token2, err = cache.GetAccessToken(ctx, credential2, scopes1) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, "credential-2-token-1", token2) |
||||
|
||||
assert.Equal(t, 1, credential1.calledTimes) |
||||
assert.Equal(t, 1, credential2.calledTimes) |
||||
}) |
||||
} |
||||
|
||||
func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) { |
||||
t.Run("when credential init returns error", func(t *testing.T) { |
||||
credential := &fakeCredential{ |
||||
initFunc: func() error { |
||||
return errors.New("unable to initialize") |
||||
}, |
||||
} |
||||
|
||||
t.Run("should return error", func(t *testing.T) { |
||||
cacheEntry := &credentialCacheEntry{ |
||||
credential: credential, |
||||
} |
||||
|
||||
err := cacheEntry.ensureInitialized() |
||||
|
||||
assert.Error(t, err) |
||||
}) |
||||
|
||||
t.Run("should call init again each time and return error", func(t *testing.T) { |
||||
credential.Reset() |
||||
|
||||
cacheEntry := &credentialCacheEntry{ |
||||
credential: credential, |
||||
} |
||||
|
||||
var err error |
||||
err = cacheEntry.ensureInitialized() |
||||
assert.Error(t, err) |
||||
|
||||
err = cacheEntry.ensureInitialized() |
||||
assert.Error(t, err) |
||||
|
||||
err = cacheEntry.ensureInitialized() |
||||
assert.Error(t, err) |
||||
|
||||
assert.Equal(t, 3, credential.initCalledTimes) |
||||
}) |
||||
}) |
||||
|
||||
t.Run("when credential init returns error only once", func(t *testing.T) { |
||||
var times = 0 |
||||
credential := &fakeCredential{ |
||||
initFunc: func() error { |
||||
times = times + 1 |
||||
if times == 1 { |
||||
return errors.New("unable to initialize") |
||||
} |
||||
return nil |
||||
}, |
||||
} |
||||
|
||||
t.Run("should call credential init again only while it returns error", func(t *testing.T) { |
||||
cacheEntry := &credentialCacheEntry{ |
||||
credential: credential, |
||||
} |
||||
|
||||
var err error |
||||
err = cacheEntry.ensureInitialized() |
||||
assert.Error(t, err) |
||||
|
||||
err = cacheEntry.ensureInitialized() |
||||
assert.NoError(t, err) |
||||
|
||||
err = cacheEntry.ensureInitialized() |
||||
assert.NoError(t, err) |
||||
|
||||
assert.Equal(t, 2, credential.initCalledTimes) |
||||
}) |
||||
}) |
||||
|
||||
t.Run("when credential init panics", func(t *testing.T) { |
||||
credential := &fakeCredential{ |
||||
initFunc: func() error { |
||||
panic(errors.New("unable to initialize")) |
||||
}, |
||||
} |
||||
|
||||
t.Run("should call credential init again each time", func(t *testing.T) { |
||||
credential.Reset() |
||||
|
||||
cacheEntry := &credentialCacheEntry{ |
||||
credential: credential, |
||||
} |
||||
|
||||
func() { |
||||
defer func() { |
||||
assert.NotNil(t, recover(), "credential expected to panic") |
||||
}() |
||||
_ = cacheEntry.ensureInitialized() |
||||
}() |
||||
|
||||
func() { |
||||
defer func() { |
||||
assert.NotNil(t, recover(), "credential expected to panic") |
||||
}() |
||||
_ = cacheEntry.ensureInitialized() |
||||
}() |
||||
|
||||
func() { |
||||
defer func() { |
||||
assert.NotNil(t, recover(), "credential expected to panic") |
||||
}() |
||||
_ = cacheEntry.ensureInitialized() |
||||
}() |
||||
|
||||
assert.Equal(t, 3, credential.initCalledTimes) |
||||
}) |
||||
}) |
||||
|
||||
t.Run("when credential init panics only once", func(t *testing.T) { |
||||
var times = 0 |
||||
credential := &fakeCredential{ |
||||
initFunc: func() error { |
||||
times = times + 1 |
||||
if times == 1 { |
||||
panic(errors.New("unable to initialize")) |
||||
} |
||||
return nil |
||||
}, |
||||
} |
||||
|
||||
t.Run("should call credential init again only while it panics", func(t *testing.T) { |
||||
cacheEntry := &credentialCacheEntry{ |
||||
credential: credential, |
||||
} |
||||
|
||||
var err error |
||||
|
||||
func() { |
||||
defer func() { |
||||
assert.NotNil(t, recover(), "credential expected to panic") |
||||
}() |
||||
_ = cacheEntry.ensureInitialized() |
||||
}() |
||||
|
||||
func() { |
||||
defer func() { |
||||
assert.Nil(t, recover(), "credential not expected to panic") |
||||
}() |
||||
err = cacheEntry.ensureInitialized() |
||||
assert.NoError(t, err) |
||||
}() |
||||
|
||||
func() { |
||||
defer func() { |
||||
assert.Nil(t, recover(), "credential not expected to panic") |
||||
}() |
||||
err = cacheEntry.ensureInitialized() |
||||
assert.NoError(t, err) |
||||
}() |
||||
|
||||
assert.Equal(t, 2, credential.initCalledTimes) |
||||
}) |
||||
}) |
||||
} |
||||
|
||||
func TestScopesCacheEntry_GetAccessToken(t *testing.T) { |
||||
ctx := context.Background() |
||||
|
||||
scopes := []string{"Scope1"} |
||||
|
||||
t.Run("when credential getAccessToken returns error", func(t *testing.T) { |
||||
credential := &fakeCredential{ |
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) { |
||||
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)} |
||||
return invalidToken, errors.New("unable to get access token") |
||||
}, |
||||
} |
||||
|
||||
t.Run("should return error", func(t *testing.T) { |
||||
cacheEntry := &scopesCacheEntry{ |
||||
credential: credential, |
||||
scopes: scopes, |
||||
cond: sync.NewCond(&sync.Mutex{}), |
||||
} |
||||
|
||||
accessToken, err := cacheEntry.getAccessToken(ctx) |
||||
|
||||
assert.Error(t, err) |
||||
assert.Equal(t, "", accessToken) |
||||
}) |
||||
|
||||
t.Run("should call credential again each time and return error", func(t *testing.T) { |
||||
credential.Reset() |
||||
|
||||
cacheEntry := &scopesCacheEntry{ |
||||
credential: credential, |
||||
scopes: scopes, |
||||
cond: sync.NewCond(&sync.Mutex{}), |
||||
} |
||||
|
||||
var err error |
||||
_, err = cacheEntry.getAccessToken(ctx) |
||||
assert.Error(t, err) |
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx) |
||||
assert.Error(t, err) |
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx) |
||||
assert.Error(t, err) |
||||
|
||||
assert.Equal(t, 3, credential.calledTimes) |
||||
}) |
||||
}) |
||||
|
||||
t.Run("when credential getAccessToken returns error only once", func(t *testing.T) { |
||||
var times = 0 |
||||
credential := &fakeCredential{ |
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) { |
||||
times = times + 1 |
||||
if times == 1 { |
||||
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)} |
||||
return invalidToken, errors.New("unable to get access token") |
||||
} |
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)} |
||||
return fakeAccessToken, nil |
||||
}, |
||||
} |
||||
|
||||
t.Run("should call credential again only while it returns error", func(t *testing.T) { |
||||
cacheEntry := &scopesCacheEntry{ |
||||
credential: credential, |
||||
scopes: scopes, |
||||
cond: sync.NewCond(&sync.Mutex{}), |
||||
} |
||||
|
||||
var accessToken string |
||||
var err error |
||||
|
||||
_, err = cacheEntry.getAccessToken(ctx) |
||||
assert.Error(t, err) |
||||
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx) |
||||
assert.NoError(t, err) |
||||
assert.Equal(t, "token-2", accessToken) |
||||
|
||||
accessToken, err = cacheEntry.getAccessToken(ctx) |
||||
assert.NoError(t, err) |
||||
assert.Equal(t, "token-2", accessToken) |
||||
|
||||
assert.Equal(t, 2, credential.calledTimes) |
||||
}) |
||||
}) |
||||
|
||||
t.Run("when credential getAccessToken panics", func(t *testing.T) { |
||||
credential := &fakeCredential{ |
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) { |
||||
panic(errors.New("unable to get access token")) |
||||
}, |
||||
} |
||||
|
||||
t.Run("should call credential again each time", func(t *testing.T) { |
||||
credential.Reset() |
||||
|
||||
cacheEntry := &scopesCacheEntry{ |
||||
credential: credential, |
||||
scopes: scopes, |
||||
cond: sync.NewCond(&sync.Mutex{}), |
||||
} |
||||
|
||||
func() { |
||||
defer func() { |
||||
assert.NotNil(t, recover(), "credential expected to panic") |
||||
}() |
||||
_, _ = cacheEntry.getAccessToken(ctx) |
||||
}() |
||||
|
||||
func() { |
||||
defer func() { |
||||
assert.NotNil(t, recover(), "credential expected to panic") |
||||
}() |
||||
_, _ = cacheEntry.getAccessToken(ctx) |
||||
}() |
||||
|
||||
func() { |
||||
defer func() { |
||||
assert.NotNil(t, recover(), "credential expected to panic") |
||||
}() |
||||
_, _ = cacheEntry.getAccessToken(ctx) |
||||
}() |
||||
|
||||
assert.Equal(t, 3, credential.calledTimes) |
||||
}) |
||||
}) |
||||
|
||||
t.Run("when credential getAccessToken panics only once", func(t *testing.T) { |
||||
var times = 0 |
||||
credential := &fakeCredential{ |
||||
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) { |
||||
times = times + 1 |
||||
if times == 1 { |
||||
panic(errors.New("unable to get access token")) |
||||
} |
||||
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)} |
||||
return fakeAccessToken, nil |
||||
}, |
||||
} |
||||
|
||||
t.Run("should call credential again only while it panics", func(t *testing.T) { |
||||
cacheEntry := &scopesCacheEntry{ |
||||
credential: credential, |
||||
scopes: scopes, |
||||
cond: sync.NewCond(&sync.Mutex{}), |
||||
} |
||||
|
||||
var accessToken string |
||||
var err error |
||||
|
||||
func() { |
||||
defer func() { |
||||
assert.NotNil(t, recover(), "credential expected to panic") |
||||
}() |
||||
_, _ = cacheEntry.getAccessToken(ctx) |
||||
}() |
||||
|
||||
func() { |
||||
defer func() { |
||||
assert.Nil(t, recover(), "credential not expected to panic") |
||||
}() |
||||
accessToken, err = cacheEntry.getAccessToken(ctx) |
||||
assert.NoError(t, err) |
||||
assert.Equal(t, "token-2", accessToken) |
||||
}() |
||||
|
||||
func() { |
||||
defer func() { |
||||
assert.Nil(t, recover(), "credential not expected to panic") |
||||
}() |
||||
accessToken, err = cacheEntry.getAccessToken(ctx) |
||||
assert.NoError(t, err) |
||||
assert.Equal(t, "token-2", accessToken) |
||||
}() |
||||
|
||||
assert.Equal(t, 2, credential.calledTimes) |
||||
}) |
||||
}) |
||||
} |
Loading…
Reference in new issue