AzureMonitor: Use auth middleware for QueryData requests (#35343)

pull/35578/head
Andres Martinez Gotor 4 years ago committed by GitHub
parent 36c997a625
commit 7109285ac9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      pkg/api/pluginproxy/ds_auth_provider.go
  2. 2
      pkg/api/pluginproxy/token_provider.go
  3. 2
      pkg/api/pluginproxy/token_provider_gce.go
  4. 2
      pkg/api/pluginproxy/token_provider_generic.go
  5. 2
      pkg/api/pluginproxy/token_provider_jwt.go
  6. 18
      pkg/api/pluginproxy/token_provider_test.go
  7. 63
      pkg/tsdb/azuremonitor/applicationinsights-datasource.go
  8. 134
      pkg/tsdb/azuremonitor/applicationinsights-datasource_test.go
  9. 62
      pkg/tsdb/azuremonitor/azure-log-analytics-datasource.go
  10. 114
      pkg/tsdb/azuremonitor/azure-log-analytics-datasource_test.go
  11. 60
      pkg/tsdb/azuremonitor/azure-resource-graph-datasource.go
  12. 42
      pkg/tsdb/azuremonitor/azure-resource-graph-datasource_test.go
  13. 63
      pkg/tsdb/azuremonitor/azuremonitor-datasource.go
  14. 41
      pkg/tsdb/azuremonitor/azuremonitor-datasource_test.go
  15. 72
      pkg/tsdb/azuremonitor/azuremonitor.go
  16. 127
      pkg/tsdb/azuremonitor/azuremonitor_test.go
  17. 92
      pkg/tsdb/azuremonitor/credentials.go
  18. 54
      pkg/tsdb/azuremonitor/credentials_test.go
  19. 65
      pkg/tsdb/azuremonitor/insights-analytics-datasource.go
  20. 53
      pkg/tsdb/azuremonitor/insights-analytics-datasource_test.go
  21. 121
      pkg/tsdb/azuremonitor/routes.go
  22. 33
      pkg/tsdb/azuremonitor/tokenprovider/authentication_middleware.go
  23. 184
      pkg/tsdb/azuremonitor/tokenprovider/token_cache.go
  24. 457
      pkg/tsdb/azuremonitor/tokenprovider/token_cache_test.go
  25. 25
      pkg/tsdb/azuremonitor/tokenprovider/token_provider.go
  26. 26
      pkg/tsdb/azuremonitor/tokenprovider/token_provider_test.go

@ -10,6 +10,7 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/tokenprovider"
"github.com/grafana/grafana/pkg/util"
)
@ -57,7 +58,7 @@ func ApplyRoute(ctx context.Context, req *http.Request, proxyPath string, route
if tokenProvider, err := getTokenProvider(ctx, cfg, ds, route, data); err != nil {
logger.Error("Failed to resolve auth token provider", "error", err)
} else if tokenProvider != nil {
if token, err := tokenProvider.getAccessToken(); err != nil {
if token, err := tokenProvider.GetAccessToken(); err != nil {
logger.Error("Failed to get access token", "error", err)
} else {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
@ -90,7 +91,7 @@ func getTokenProvider(ctx context.Context, cfg *setting.Cfg, ds *models.DataSour
if tokenAuth == nil {
return nil, fmt.Errorf("'tokenAuth' not configured for authentication type '%s'", authType)
}
provider := newAzureAccessTokenProvider(ctx, cfg, ds, pluginRoute, tokenAuth)
provider := tokenprovider.NewAzureAccessTokenProvider(ctx, cfg, tokenAuth)
return provider, nil
case "gce":

@ -3,7 +3,7 @@ package pluginproxy
import "time"
type accessTokenProvider interface {
getAccessToken() (string, error)
GetAccessToken() (string, error)
}
var (

@ -27,7 +27,7 @@ func newGceAccessTokenProvider(ctx context.Context, ds *models.DataSource, plugi
}
}
func (provider *gceAccessTokenProvider) getAccessToken() (string, error) {
func (provider *gceAccessTokenProvider) GetAccessToken() (string, error) {
tokenSrc, err := google.DefaultTokenSource(provider.ctx, provider.authParams.Scopes...)
if err != nil {
logger.Error("Failed to get default token from meta data server", "error", err)

@ -78,7 +78,7 @@ func newGenericAccessTokenProvider(ds *models.DataSource, pluginRoute *plugins.A
}
}
func (provider *genericAccessTokenProvider) getAccessToken() (string, error) {
func (provider *genericAccessTokenProvider) GetAccessToken() (string, error) {
tokenCache.Lock()
defer tokenCache.Unlock()
if cachedToken, found := tokenCache.cache[provider.getAccessTokenCacheKey()]; found {

@ -42,7 +42,7 @@ func newJwtAccessTokenProvider(ctx context.Context, ds *models.DataSource, plugi
}
}
func (provider *jwtAccessTokenProvider) getAccessToken() (string, error) {
func (provider *jwtAccessTokenProvider) GetAccessToken() (string, error) {
oauthJwtTokenCache.Lock()
defer oauthJwtTokenCache.Unlock()
if cachedToken, found := oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()]; found {

@ -70,7 +70,7 @@ func TestAccessToken_pluginWithJWTTokenAuthRoute(t *testing.T) {
return &oauth2.Token{AccessToken: "abc"}, nil
})
provider := newJwtAccessTokenProvider(context.Background(), ds, pluginRoute, authParams)
token, err := provider.getAccessToken()
token, err := provider.GetAccessToken()
require.NoError(t, err)
assert.Equal(t, "abc", token)
@ -89,7 +89,7 @@ func TestAccessToken_pluginWithJWTTokenAuthRoute(t *testing.T) {
})
provider := newJwtAccessTokenProvider(context.Background(), ds, pluginRoute, authParams)
_, err := provider.getAccessToken()
_, err := provider.GetAccessToken()
require.NoError(t, err)
})
@ -100,14 +100,14 @@ func TestAccessToken_pluginWithJWTTokenAuthRoute(t *testing.T) {
Expiry: time.Now().Add(1 * time.Minute)}, nil
})
provider := newJwtAccessTokenProvider(context.Background(), ds, pluginRoute, authParams)
token1, err := provider.getAccessToken()
token1, err := provider.GetAccessToken()
require.NoError(t, err)
assert.Equal(t, "abc", token1)
getTokenSource = func(conf *jwt.Config, ctx context.Context) (*oauth2.Token, error) {
return &oauth2.Token{AccessToken: "error: cache not used"}, nil
}
token2, err := provider.getAccessToken()
token2, err := provider.GetAccessToken()
require.NoError(t, err)
assert.Equal(t, "abc", token2)
})
@ -224,12 +224,12 @@ func TestAccessToken_pluginWithTokenAuthRoute(t *testing.T) {
token["expires_on"] = testCase.expiresOn
}
accessToken, err := provider.getAccessToken()
accessToken, err := provider.GetAccessToken()
require.NoError(t, err)
assert.Equal(t, token["access_token"], accessToken)
// getAccessToken should use internal cache
accessToken, err = provider.getAccessToken()
// GetAccessToken should use internal cache
accessToken, err = provider.GetAccessToken()
require.NoError(t, err)
assert.Equal(t, token["access_token"], accessToken)
assert.Equal(t, 1, authCalls)
@ -259,13 +259,13 @@ func TestAccessToken_pluginWithTokenAuthRoute(t *testing.T) {
"token_type": "3600",
"refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",
}
accessToken, err := provider.getAccessToken()
accessToken, err := provider.GetAccessToken()
require.NoError(t, err)
assert.Equal(t, token["access_token"], accessToken)
mockTimeNow(timeNow().Add(3601 * time.Second))
accessToken, err = provider.getAccessToken()
accessToken, err = provider.GetAccessToken()
require.NoError(t, err)
assert.Equal(t, token["access_token"], accessToken)
assert.Equal(t, 2, authCalls)

@ -3,7 +3,6 @@ package azuremonitor
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
@ -15,22 +14,13 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/grafana/grafana/pkg/api/pluginproxy"
"github.com/grafana/grafana/pkg/components/securejsondata"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util/errutil"
"github.com/opentracing/opentracing-go"
"golang.org/x/net/context/ctxhttp"
)
// ApplicationInsightsDatasource calls the application insights query API.
type ApplicationInsightsDatasource struct {
pluginManager plugins.Manager
cfg *setting.Cfg
}
type ApplicationInsightsDatasource struct{}
// ApplicationInsightsQuery is the model that holds the information
// needed to make a metrics query to Application Insights, and the information
@ -164,7 +154,7 @@ func (e *ApplicationInsightsDatasource) executeQuery(ctx context.Context, query
}
azlog.Debug("ApplicationInsights", "Request URL", req.URL.String())
res, err := ctxhttp.Do(ctx, dsInfo.HTTPClient, req)
res, err := ctxhttp.Do(ctx, dsInfo.Services[appInsights].HTTPClient, req)
if err != nil {
dataResponse.Error = err
return dataResponse, nil
@ -204,63 +194,20 @@ func (e *ApplicationInsightsDatasource) executeQuery(ctx context.Context, query
}
func (e *ApplicationInsightsDatasource) createRequest(ctx context.Context, dsInfo datasourceInfo) (*http.Request, error) {
// find plugin
plugin := e.pluginManager.GetDataSource(dsName)
if plugin == nil {
return nil, errors.New("unable to find datasource plugin Azure Application Insights")
}
appInsightsRoute, routeName, err := e.getPluginRoute(plugin, dsInfo)
if err != nil {
return nil, err
}
appInsightsAppID := dsInfo.Settings.AppInsightsAppId
u, err := url.Parse(dsInfo.URL)
if err != nil {
return nil, err
}
u.Path = path.Join(u.Path, fmt.Sprintf("/v1/apps/%s", appInsightsAppID))
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
req, err := http.NewRequest(http.MethodGet, dsInfo.Services[appInsights].URL, nil)
if err != nil {
azlog.Debug("Failed to create request", "error", err)
return nil, errutil.Wrap("Failed to create request", err)
}
req.Header.Set("X-API-Key", dsInfo.DecryptedSecureJSONData["appInsightsApiKey"])
// TODO: Use backend authentication instead
proxyPass := fmt.Sprintf("%s/v1/apps/%s", routeName, appInsightsAppID)
pluginproxy.ApplyRoute(ctx, req, proxyPass, appInsightsRoute, &models.DataSource{
JsonData: simplejson.NewFromAny(dsInfo.JSONData),
SecureJsonData: securejsondata.GetEncryptedJsonData(dsInfo.DecryptedSecureJSONData),
}, e.cfg)
req.URL.Path = fmt.Sprintf("/v1/apps/%s", appInsightsAppID)
return req, nil
}
func (e *ApplicationInsightsDatasource) getPluginRoute(plugin *plugins.DataSourcePlugin, dsInfo datasourceInfo) (*plugins.AppPluginRoute, string, error) {
cloud, err := getAzureCloud(e.cfg, dsInfo)
if err != nil {
return nil, "", err
}
routeName, err := getAppInsightsApiRoute(cloud)
if err != nil {
return nil, "", err
}
var pluginRoute *plugins.AppPluginRoute
for _, route := range plugin.Routes {
if route.Path == routeName {
pluginRoute = route
break
}
}
return pluginRoute, routeName, nil
}
// formatApplicationInsightsLegendKey builds the legend key or timeseries name
// Alias patterns like {{metric}} are replaced with the appropriate data values.
func formatApplicationInsightsLegendKey(alias string, metricName string, labels data.Labels) string {

@ -1,15 +1,14 @@
package azuremonitor
import (
"context"
"encoding/json"
"net/http"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require"
. "github.com/smartystreets/goconvey/convey"
@ -159,92 +158,6 @@ func TestApplicationInsightsDatasource(t *testing.T) {
})
}
func TestAppInsightsPluginRoutes(t *testing.T) {
cfg := &setting.Cfg{
Azure: setting.AzureSettings{
Cloud: setting.AzurePublic,
ManagedIdentityEnabled: true,
},
}
plugin := &plugins.DataSourcePlugin{
Routes: []*plugins.AppPluginRoute{
{
Path: "appinsights",
Method: "GET",
URL: "https://api.applicationinsights.io",
Headers: []plugins.AppPluginRouteHeader{
{Name: "X-API-Key", Content: "{{.SecureJsonData.appInsightsApiKey}}"},
{Name: "x-ms-app", Content: "Grafana"},
},
},
{
Path: "chinaappinsights",
Method: "GET",
URL: "https://api.applicationinsights.azure.cn",
Headers: []plugins.AppPluginRouteHeader{
{Name: "X-API-Key", Content: "{{.SecureJsonData.appInsightsApiKey}}"},
{Name: "x-ms-app", Content: "Grafana"},
},
},
},
}
tests := []struct {
name string
datasource *ApplicationInsightsDatasource
dsInfo datasourceInfo
expectedRouteName string
expectedRouteURL string
Err require.ErrorAssertionFunc
}{
{
name: "plugin proxy route for the Azure public cloud",
dsInfo: datasourceInfo{
Settings: azureMonitorSettings{
AzureAuthType: AzureAuthClientSecret,
CloudName: "azuremonitor",
},
},
datasource: &ApplicationInsightsDatasource{
cfg: cfg,
},
expectedRouteName: "appinsights",
expectedRouteURL: "https://api.applicationinsights.io",
Err: require.NoError,
},
{
name: "plugin proxy route for the Azure China cloud",
dsInfo: datasourceInfo{
Settings: azureMonitorSettings{
AzureAuthType: AzureAuthClientSecret,
CloudName: "chinaazuremonitor",
},
},
datasource: &ApplicationInsightsDatasource{
cfg: cfg,
},
expectedRouteName: "chinaappinsights",
expectedRouteURL: "https://api.applicationinsights.azure.cn",
Err: require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
route, routeName, err := tt.datasource.getPluginRoute(plugin, tt.dsInfo)
tt.Err(t, err)
if diff := cmp.Diff(tt.expectedRouteURL, route.URL, cmpopts.EquateNaNs()); diff != "" {
t.Errorf("Result mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedRouteName, routeName, cmpopts.EquateNaNs()); diff != "" {
t.Errorf("Result mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestInsightsDimensionsUnmarshalJSON(t *testing.T) {
a := []byte(`"foo"`)
b := []byte(`["foo"]`)
@ -291,3 +204,46 @@ func TestInsightsDimensionsUnmarshalJSON(t *testing.T) {
require.NoError(t, err)
require.Empty(t, gs)
}
func TestAppInsightsCreateRequest(t *testing.T) {
ctx := context.Background()
dsInfo := datasourceInfo{
Settings: azureMonitorSettings{AppInsightsAppId: "foo"},
Services: map[string]datasourceService{
appInsights: {URL: "http://ds"},
},
DecryptedSecureJSONData: map[string]string{
"appInsightsApiKey": "key",
},
}
tests := []struct {
name string
expectedURL string
expectedHeaders http.Header
Err require.ErrorAssertionFunc
}{
{
name: "creates a request",
expectedURL: "http://ds/v1/apps/foo",
expectedHeaders: http.Header{
"X-Api-Key": []string{"key"},
},
Err: require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ds := ApplicationInsightsDatasource{}
req, err := ds.createRequest(ctx, dsInfo)
tt.Err(t, err)
if req.URL.String() != tt.expectedURL {
t.Errorf("Expecting %s, got %s", tt.expectedURL, req.URL.String())
}
if !cmp.Equal(req.Header, tt.expectedHeaders) {
t.Errorf("Unexpected HTTP headers: %v", cmp.Diff(req.Header, tt.expectedHeaders))
}
})
}
}

@ -5,7 +5,6 @@ import (
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
@ -16,22 +15,14 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/grafana/grafana/pkg/api/pluginproxy"
"github.com/grafana/grafana/pkg/components/securejsondata"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util/errutil"
"github.com/opentracing/opentracing-go"
"golang.org/x/net/context/ctxhttp"
)
// AzureLogAnalyticsDatasource calls the Azure Log Analytics API's
type AzureLogAnalyticsDatasource struct {
pluginManager plugins.Manager
cfg *setting.Cfg
}
type AzureLogAnalyticsDatasource struct{}
// AzureLogAnalyticsQuery is the query request that is built from the saved values for
// from the UI
@ -170,7 +161,7 @@ func (e *AzureLogAnalyticsDatasource) executeQuery(ctx context.Context, query *A
}
azlog.Debug("AzureLogAnalytics", "Request ApiURL", req.URL.String())
res, err := ctxhttp.Do(ctx, dsInfo.HTTPClient, req)
res, err := ctxhttp.Do(ctx, dsInfo.Services[azureLogAnalytics].HTTPClient, req)
if err != nil {
return dataResponseErrorWithExecuted(err)
}
@ -220,62 +211,17 @@ func (e *AzureLogAnalyticsDatasource) executeQuery(ctx context.Context, query *A
}
func (e *AzureLogAnalyticsDatasource) createRequest(ctx context.Context, dsInfo datasourceInfo) (*http.Request, error) {
// find plugin
plugin := e.pluginManager.GetDataSource(dsName)
if plugin == nil {
return nil, errors.New("unable to find datasource plugin Azure Monitor")
}
logAnalyticsRoute, routeName, err := e.getPluginRoute(plugin, dsInfo)
if err != nil {
return nil, err
}
u, err := url.Parse(dsInfo.URL)
if err != nil {
return nil, err
}
u.Path = path.Join(u.Path, "render")
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
req, err := http.NewRequest(http.MethodGet, dsInfo.Services[azureLogAnalytics].URL, nil)
if err != nil {
azlog.Debug("Failed to create request", "error", err)
return nil, errutil.Wrap("failed to create request", err)
}
req.URL.Path = "/"
req.Header.Set("Content-Type", "application/json")
// TODO: Use backend authentication instead
pluginproxy.ApplyRoute(ctx, req, routeName, logAnalyticsRoute, &models.DataSource{
JsonData: simplejson.NewFromAny(dsInfo.JSONData),
SecureJsonData: securejsondata.GetEncryptedJsonData(dsInfo.DecryptedSecureJSONData),
}, e.cfg)
return req, nil
}
func (e *AzureLogAnalyticsDatasource) getPluginRoute(plugin *plugins.DataSourcePlugin, dsInfo datasourceInfo) (*plugins.AppPluginRoute, string, error) {
cloud, err := getAzureCloud(e.cfg, dsInfo)
if err != nil {
return nil, "", err
}
routeName, err := getLogAnalyticsApiRoute(cloud)
if err != nil {
return nil, "", err
}
var pluginRoute *plugins.AppPluginRoute
for _, route := range plugin.Routes {
if route.Path == routeName {
pluginRoute = route
break
}
}
return pluginRoute, routeName, nil
}
// GetPrimaryResultTable returns the first table in the response named "PrimaryResult", or an
// error if there is no table by that name.
func (ar *AzureLogAnalyticsResponse) GetPrimaryResultTable() (*AzureResponseTable, error) {

@ -1,16 +1,15 @@
package azuremonitor
import (
"context"
"fmt"
"net/http"
"net/url"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require"
)
@ -179,109 +178,38 @@ func TestBuildingAzureLogAnalyticsQueries(t *testing.T) {
}
}
func TestPluginRoutes(t *testing.T) {
cfg := &setting.Cfg{
Azure: setting.AzureSettings{
Cloud: setting.AzurePublic,
ManagedIdentityEnabled: true,
},
}
plugin := &plugins.DataSourcePlugin{
Routes: []*plugins.AppPluginRoute{
{
Path: "loganalyticsazure",
Method: "GET",
URL: "https://api.loganalytics.io/",
Headers: []plugins.AppPluginRouteHeader{
{Name: "x-ms-app", Content: "Grafana"},
},
},
{
Path: "chinaloganalyticsazure",
Method: "GET",
URL: "https://api.loganalytics.azure.cn/",
Headers: []plugins.AppPluginRouteHeader{
{Name: "x-ms-app", Content: "Grafana"},
},
},
{
Path: "govloganalyticsazure",
Method: "GET",
URL: "https://api.loganalytics.us/",
Headers: []plugins.AppPluginRouteHeader{
{Name: "x-ms-app", Content: "Grafana"},
},
},
func TestLogAnalyticsCreateRequest(t *testing.T) {
ctx := context.Background()
dsInfo := datasourceInfo{
Services: map[string]datasourceService{
azureLogAnalytics: {URL: "http://ds"},
},
}
tests := []struct {
name string
dsInfo datasourceInfo
datasource *AzureLogAnalyticsDatasource
expectedProxypass string
expectedRouteURL string
Err require.ErrorAssertionFunc
name string
expectedURL string
expectedHeaders http.Header
Err require.ErrorAssertionFunc
}{
{
name: "plugin proxy route for the Azure public cloud",
dsInfo: datasourceInfo{
Settings: azureMonitorSettings{
AzureAuthType: AzureAuthClientSecret,
CloudName: "azuremonitor",
},
},
datasource: &AzureLogAnalyticsDatasource{
cfg: cfg,
},
expectedProxypass: "loganalyticsazure",
expectedRouteURL: "https://api.loganalytics.io/",
Err: require.NoError,
},
{
name: "plugin proxy route for the Azure China cloud",
dsInfo: datasourceInfo{
Settings: azureMonitorSettings{
AzureAuthType: AzureAuthClientSecret,
CloudName: "chinaazuremonitor",
},
},
datasource: &AzureLogAnalyticsDatasource{
cfg: cfg,
},
expectedProxypass: "chinaloganalyticsazure",
expectedRouteURL: "https://api.loganalytics.azure.cn/",
Err: require.NoError,
},
{
name: "plugin proxy route for the Azure Gov cloud",
dsInfo: datasourceInfo{
Settings: azureMonitorSettings{
AzureAuthType: AzureAuthClientSecret,
CloudName: "govazuremonitor",
},
},
datasource: &AzureLogAnalyticsDatasource{
cfg: cfg,
},
expectedProxypass: "govloganalyticsazure",
expectedRouteURL: "https://api.loganalytics.us/",
Err: require.NoError,
name: "creates a request",
expectedURL: "http://ds/",
expectedHeaders: http.Header{"Content-Type": []string{"application/json"}},
Err: require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
route, proxypass, err := tt.datasource.getPluginRoute(plugin, tt.dsInfo)
ds := AzureLogAnalyticsDatasource{}
req, err := ds.createRequest(ctx, dsInfo)
tt.Err(t, err)
if diff := cmp.Diff(tt.expectedRouteURL, route.URL, cmpopts.EquateNaNs()); diff != "" {
t.Errorf("Result mismatch (-want +got):\n%s", diff)
if req.URL.String() != tt.expectedURL {
t.Errorf("Expecting %s, got %s", tt.expectedURL, req.URL.String())
}
if diff := cmp.Diff(tt.expectedProxypass, proxypass, cmpopts.EquateNaNs()); diff != "" {
t.Errorf("Result mismatch (-want +got):\n%s", diff)
if !cmp.Equal(req.Header, tt.expectedHeaders) {
t.Errorf("Unexpected HTTP headers: %v", cmp.Diff(req.Header, tt.expectedHeaders))
}
})
}

@ -6,7 +6,6 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
@ -15,11 +14,7 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/grafana/grafana/pkg/api/pluginproxy"
"github.com/grafana/grafana/pkg/components/securejsondata"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util/errutil"
"github.com/opentracing/opentracing-go"
@ -27,10 +22,7 @@ import (
)
// AzureResourceGraphDatasource calls the Azure Resource Graph API's
type AzureResourceGraphDatasource struct {
pluginManager plugins.Manager
cfg *setting.Cfg
}
type AzureResourceGraphDatasource struct{}
// AzureResourceGraphQuery is the query request that is built from the saved values for
// from the UI
@ -167,7 +159,7 @@ func (e *AzureResourceGraphDatasource) executeQuery(ctx context.Context, query *
}
azlog.Debug("AzureResourceGraph", "Request ApiURL", req.URL.String())
res, err := ctxhttp.Do(ctx, dsInfo.HTTPClient, req)
res, err := ctxhttp.Do(ctx, dsInfo.Services[azureResourceGraph].HTTPClient, req)
if err != nil {
return dataResponseErrorWithExecuted(err)
}
@ -191,62 +183,18 @@ func (e *AzureResourceGraphDatasource) executeQuery(ctx context.Context, query *
}
func (e *AzureResourceGraphDatasource) createRequest(ctx context.Context, dsInfo datasourceInfo, reqBody []byte) (*http.Request, error) {
// find plugin
plugin := e.pluginManager.GetDataSource(dsName)
if plugin == nil {
return nil, errors.New("unable to find datasource plugin Azure Monitor")
}
argRoute, routeName, err := e.getPluginRoute(plugin, dsInfo)
if err != nil {
return nil, err
}
u, err := url.Parse(dsInfo.URL)
if err != nil {
return nil, err
}
u.Path = path.Join(u.Path, "render")
req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewBuffer(reqBody))
req, err := http.NewRequest(http.MethodPost, dsInfo.Services[azureResourceGraph].URL, bytes.NewBuffer(reqBody))
if err != nil {
azlog.Debug("Failed to create request", "error", err)
return nil, errutil.Wrap("failed to create request", err)
}
req.URL.Path = "/"
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", fmt.Sprintf("Grafana/%s", setting.BuildVersion))
// TODO: Use backend authentication instead
pluginproxy.ApplyRoute(ctx, req, routeName, argRoute, &models.DataSource{
JsonData: simplejson.NewFromAny(dsInfo.JSONData),
SecureJsonData: securejsondata.GetEncryptedJsonData(dsInfo.DecryptedSecureJSONData),
}, e.cfg)
return req, nil
}
func (e *AzureResourceGraphDatasource) getPluginRoute(plugin *plugins.DataSourcePlugin, dsInfo datasourceInfo) (*plugins.AppPluginRoute, string, error) {
cloud, err := getAzureCloud(e.cfg, dsInfo)
if err != nil {
return nil, "", err
}
routeName, err := getManagementApiRoute(cloud)
if err != nil {
return nil, "", err
}
var pluginRoute *plugins.AppPluginRoute
for _, route := range plugin.Routes {
if route.Path == routeName {
pluginRoute = route
break
}
}
return pluginRoute, routeName, nil
}
func (e *AzureResourceGraphDatasource) unmarshalResponse(res *http.Response) (AzureResourceGraphResponse, error) {
body, err := ioutil.ReadAll(res.Body)
if err != nil {

@ -1,7 +1,9 @@
package azuremonitor
import (
"context"
"fmt"
"net/http"
"testing"
"time"
@ -71,3 +73,43 @@ func TestBuildingAzureResourceGraphQueries(t *testing.T) {
})
}
}
func TestAzureResourceGraphCreateRequest(t *testing.T) {
ctx := context.Background()
dsInfo := datasourceInfo{
Services: map[string]datasourceService{
azureResourceGraph: {URL: "http://ds"},
},
}
tests := []struct {
name string
expectedURL string
expectedHeaders http.Header
Err require.ErrorAssertionFunc
}{
{
name: "creates a request",
expectedURL: "http://ds/",
expectedHeaders: http.Header{
"Content-Type": []string{"application/json"},
"User-Agent": []string{"Grafana/"},
},
Err: require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ds := AzureResourceGraphDatasource{}
req, err := ds.createRequest(ctx, dsInfo, []byte{})
tt.Err(t, err)
if req.URL.String() != tt.expectedURL {
t.Errorf("Expecting %s, got %s", tt.expectedURL, req.URL.String())
}
if !cmp.Equal(req.Header, tt.expectedHeaders) {
t.Errorf("Unexpected HTTP headers: %v", cmp.Diff(req.Header, tt.expectedHeaders))
}
})
}
}

@ -3,7 +3,6 @@ package azuremonitor
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
@ -15,11 +14,6 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/grafana/grafana/pkg/api/pluginproxy"
"github.com/grafana/grafana/pkg/components/securejsondata"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util/errutil"
opentracing "github.com/opentracing/opentracing-go"
@ -27,10 +21,7 @@ import (
)
// AzureMonitorDatasource calls the Azure Monitor API - one of the four API's supported
type AzureMonitorDatasource struct {
pluginManager plugins.Manager
cfg *setting.Cfg
}
type AzureMonitorDatasource struct{}
var (
// 1m, 5m, 15m, 30m, 1h, 6h, 12h, 1d in milliseconds
@ -189,7 +180,7 @@ func (e *AzureMonitorDatasource) executeQuery(ctx context.Context, query *AzureM
azlog.Debug("AzureMonitor", "Request ApiURL", req.URL.String())
azlog.Debug("AzureMonitor", "Target", query.Target)
res, err := ctxhttp.Do(ctx, dsInfo.HTTPClient, req)
res, err := ctxhttp.Do(ctx, dsInfo.Services[azureMonitor].HTTPClient, req)
if err != nil {
dataResponse.Error = err
return dataResponse, AzureMonitorResponse{}, nil
@ -210,63 +201,17 @@ func (e *AzureMonitorDatasource) executeQuery(ctx context.Context, query *AzureM
}
func (e *AzureMonitorDatasource) createRequest(ctx context.Context, dsInfo datasourceInfo) (*http.Request, error) {
// find plugin
plugin := e.pluginManager.GetDataSource(dsName)
if plugin == nil {
return nil, errors.New("unable to find datasource plugin Azure Monitor")
}
azureMonitorRoute, routeName, err := e.getPluginRoute(plugin, dsInfo)
if err != nil {
return nil, err
}
u, err := url.Parse(dsInfo.URL)
if err != nil {
return nil, err
}
u.Path = path.Join(u.Path, "render")
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
req, err := http.NewRequest(http.MethodGet, dsInfo.Services[azureMonitor].URL, nil)
if err != nil {
azlog.Debug("Failed to create request", "error", err)
return nil, errutil.Wrap("Failed to create request", err)
}
req.URL.Path = "/subscriptions"
req.Header.Set("Content-Type", "application/json")
// TODO: Use backend authentication instead
proxyPass := fmt.Sprintf("%s/subscriptions", routeName)
pluginproxy.ApplyRoute(ctx, req, proxyPass, azureMonitorRoute, &models.DataSource{
JsonData: simplejson.NewFromAny(dsInfo.JSONData),
SecureJsonData: securejsondata.GetEncryptedJsonData(dsInfo.DecryptedSecureJSONData),
}, e.cfg)
return req, nil
}
func (e *AzureMonitorDatasource) getPluginRoute(plugin *plugins.DataSourcePlugin, dsInfo datasourceInfo) (*plugins.AppPluginRoute, string, error) {
cloud, err := getAzureCloud(e.cfg, dsInfo)
if err != nil {
return nil, "", err
}
routeName, err := getManagementApiRoute(cloud)
if err != nil {
return nil, "", err
}
var pluginRoute *plugins.AppPluginRoute
for _, route := range plugin.Routes {
if route.Path == routeName {
pluginRoute = route
break
}
}
return pluginRoute, routeName, nil
}
func (e *AzureMonitorDatasource) unmarshalResponse(res *http.Response) (AzureMonitorResponse, error) {
body, err := ioutil.ReadAll(res.Body)
if err != nil {

@ -1,9 +1,11 @@
package azuremonitor
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"path/filepath"
"testing"
@ -509,3 +511,42 @@ func loadTestFile(t *testing.T, name string) AzureMonitorResponse {
require.NoError(t, err)
return azData
}
func TestAzureMonitorCreateRequest(t *testing.T) {
ctx := context.Background()
dsInfo := datasourceInfo{
Services: map[string]datasourceService{
azureMonitor: {URL: "http://ds"},
},
}
tests := []struct {
name string
expectedURL string
expectedHeaders http.Header
Err require.ErrorAssertionFunc
}{
{
name: "creates a request",
expectedURL: "http://ds/subscriptions",
expectedHeaders: http.Header{
"Content-Type": []string{"application/json"},
},
Err: require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ds := AzureMonitorDatasource{}
req, err := ds.createRequest(ctx, dsInfo)
tt.Err(t, err)
if req.URL.String() != tt.expectedURL {
t.Errorf("Expecting %s, got %s", tt.expectedURL, req.URL.String())
}
if !cmp.Equal(req.Header, tt.expectedHeaders) {
t.Errorf("Unexpected HTTP headers: %v", cmp.Diff(req.Header, tt.expectedHeaders))
}
})
}
}

@ -9,8 +9,8 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/datasource"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
"github.com/grafana/grafana/pkg/infra/httpclient"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin"
@ -39,7 +39,6 @@ func init() {
type Service struct {
PluginManager plugins.Manager `inject:""`
HTTPClientProvider httpclient.Provider `inject:""`
Cfg *setting.Cfg `inject:""`
BackendPluginManager backendplugin.Manager `inject:""`
}
@ -59,30 +58,26 @@ type azureMonitorSettings struct {
}
type datasourceInfo struct {
Settings azureMonitorSettings
Settings azureMonitorSettings
Services map[string]datasourceService
Routes map[string]azRoute
HTTPCliOpts httpclient.Options
HTTPClient *http.Client
URL string
JSONData map[string]interface{}
DecryptedSecureJSONData map[string]string
DatasourceID int64
OrgID int64
}
func NewInstanceSettings(httpClientProvider httpclient.Provider) datasource.InstanceFactoryFunc {
return func(settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) {
opts, err := settings.HTTPClientOptions()
if err != nil {
return nil, err
}
client, err := httpClientProvider.New(opts)
if err != nil {
return nil, err
}
type datasourceService struct {
URL string
HTTPClient *http.Client
}
func NewInstanceSettings() datasource.InstanceFactoryFunc {
return func(settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) {
jsonData := map[string]interface{}{}
err = json.Unmarshal(settings.JSONData, &jsonData)
err := json.Unmarshal(settings.JSONData, &jsonData)
if err != nil {
return nil, fmt.Errorf("error reading settings: %w", err)
}
@ -92,15 +87,20 @@ func NewInstanceSettings(httpClientProvider httpclient.Provider) datasource.Inst
if err != nil {
return nil, fmt.Errorf("error reading settings: %w", err)
}
httpCliOpts, err := settings.HTTPClientOptions()
if err != nil {
return nil, fmt.Errorf("error getting http options: %w", err)
}
model := datasourceInfo{
Settings: azMonitorSettings,
HTTPClient: client,
URL: settings.URL,
JSONData: jsonData,
DecryptedSecureJSONData: settings.DecryptedSecureJSONData,
DatasourceID: settings.ID,
Services: map[string]datasourceService{},
Routes: routes[azMonitorSettings.CloudName],
HTTPCliOpts: httpCliOpts,
}
return model, nil
}
}
@ -109,15 +109,8 @@ type azDatasourceExecutor interface {
executeTimeSeriesQuery(ctx context.Context, originalQueries []backend.DataQuery, dsInfo datasourceInfo) (*backend.QueryDataResponse, error)
}
func newExecutor(im instancemgmt.InstanceManager, pm plugins.Manager, httpC httpclient.Provider, cfg *setting.Cfg) *datasource.QueryTypeMux {
func newExecutor(im instancemgmt.InstanceManager, cfg *setting.Cfg, executors map[string]azDatasourceExecutor) *datasource.QueryTypeMux {
mux := datasource.NewQueryTypeMux()
executors := map[string]azDatasourceExecutor{
"Azure Monitor": &AzureMonitorDatasource{pm, cfg},
"Application Insights": &ApplicationInsightsDatasource{pm, cfg},
"Azure Log Analytics": &AzureLogAnalyticsDatasource{pm, cfg},
"Insights Analytics": &InsightsAnalyticsDatasource{pm, cfg},
"Azure Resource Graph": &AzureResourceGraphDatasource{pm, cfg},
}
for dsType := range executors {
// Make a copy of the string to keep the reference after the iterator
dst := dsType
@ -129,6 +122,18 @@ func newExecutor(im instancemgmt.InstanceManager, pm plugins.Manager, httpC http
dsInfo := i.(datasourceInfo)
dsInfo.OrgID = req.PluginContext.OrgID
ds := executors[dst]
if _, ok := dsInfo.Services[dst]; !ok {
// Create an HTTP Client if it has not been created before
route := dsInfo.Routes[dst]
client, err := newHTTPClient(ctx, route, dsInfo, cfg)
if err != nil {
return nil, err
}
dsInfo.Services[dst] = datasourceService{
URL: dsInfo.Routes[dst].URL,
HTTPClient: client,
}
}
return ds.executeTimeSeriesQuery(ctx, req.Queries, dsInfo)
})
}
@ -136,9 +141,16 @@ func newExecutor(im instancemgmt.InstanceManager, pm plugins.Manager, httpC http
}
func (s *Service) Init() error {
im := datasource.NewInstanceManager(NewInstanceSettings(s.HTTPClientProvider))
im := datasource.NewInstanceManager(NewInstanceSettings())
executors := map[string]azDatasourceExecutor{
azureMonitor: &AzureMonitorDatasource{},
appInsights: &ApplicationInsightsDatasource{},
azureLogAnalytics: &AzureLogAnalyticsDatasource{},
insightsAnalytics: &InsightsAnalyticsDatasource{},
azureResourceGraph: &AzureResourceGraphDatasource{},
}
factory := coreplugin.New(backend.ServeOpts{
QueryDataHandler: newExecutor(im, s.PluginManager, s.HTTPClientProvider, s.Cfg),
QueryDataHandler: newExecutor(im, s.Cfg, executors),
})
if err := s.BackendPluginManager.Register(dsName, factory); err != nil {

@ -0,0 +1,127 @@
package azuremonitor
import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require"
)
func TestNewInstanceSettings(t *testing.T) {
tests := []struct {
name string
settings backend.DataSourceInstanceSettings
expectedModel datasourceInfo
Err require.ErrorAssertionFunc
}{
{
name: "creates an instance",
settings: backend.DataSourceInstanceSettings{
JSONData: []byte(`{"cloudName":"azuremonitor"}`),
DecryptedSecureJSONData: map[string]string{"key": "value"},
ID: 40,
},
expectedModel: datasourceInfo{
Settings: azureMonitorSettings{CloudName: "azuremonitor"},
Routes: routes["azuremonitor"],
JSONData: map[string]interface{}{"cloudName": string("azuremonitor")},
DatasourceID: 40,
DecryptedSecureJSONData: map[string]string{"key": "value"},
},
Err: require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
factory := NewInstanceSettings()
instance, err := factory(tt.settings)
tt.Err(t, err)
if !cmp.Equal(instance, tt.expectedModel, cmpopts.IgnoreFields(datasourceInfo{}, "Services", "HTTPCliOpts")) {
t.Errorf("Unexpected instance: %v", cmp.Diff(instance, tt.expectedModel))
}
})
}
}
type fakeInstance struct{}
func (f *fakeInstance) Get(pluginContext backend.PluginContext) (instancemgmt.Instance, error) {
return datasourceInfo{
Services: map[string]datasourceService{},
Routes: routes[azureMonitorPublic],
}, nil
}
func (f *fakeInstance) Do(pluginContext backend.PluginContext, fn instancemgmt.InstanceCallbackFunc) error {
return nil
}
type fakeExecutor struct {
t *testing.T
queryType string
expectedURL string
}
func (f *fakeExecutor) executeTimeSeriesQuery(ctx context.Context, originalQueries []backend.DataQuery, dsInfo datasourceInfo) (*backend.QueryDataResponse, error) {
if s, ok := dsInfo.Services[f.queryType]; !ok {
f.t.Errorf("The HTTP client for %s is missing", f.queryType)
} else {
if s.URL != f.expectedURL {
f.t.Errorf("Unexpected URL %s wanted %s", s.URL, f.expectedURL)
}
}
return &backend.QueryDataResponse{}, nil
}
func Test_newExecutor(t *testing.T) {
cfg := &setting.Cfg{}
tests := []struct {
name string
queryType string
expectedURL string
Err require.ErrorAssertionFunc
}{
{
name: "creates an Azure Monitor executor",
queryType: azureMonitor,
expectedURL: routes[azureMonitorPublic][azureMonitor].URL,
Err: require.NoError,
},
{
name: "creates an Azure Log Analytics executor",
queryType: azureLogAnalytics,
expectedURL: routes[azureMonitorPublic][azureLogAnalytics].URL,
Err: require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mux := newExecutor(&fakeInstance{}, cfg, map[string]azDatasourceExecutor{
tt.queryType: &fakeExecutor{
t: t,
queryType: tt.queryType,
expectedURL: tt.expectedURL,
},
})
res, err := mux.QueryData(context.TODO(), &backend.QueryDataRequest{
PluginContext: backend.PluginContext{},
Queries: []backend.DataQuery{
{QueryType: tt.queryType},
},
})
tt.Err(t, err)
// Dummy response from the fake implementation
if res == nil {
t.Errorf("Expecting a response")
}
})
}
}

@ -1,14 +1,13 @@
package azuremonitor
import (
"fmt"
"context"
"net/http"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
)
const (
AzureAuthManagedIdentity = "msi"
AzureAuthClientSecret = "clientsecret"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/tokenprovider"
)
// Azure cloud names specific to Azure Monitor
@ -19,59 +18,40 @@ const (
azureMonitorGermany = "germanyazuremonitor"
)
func getAuthType(cfg *setting.Cfg, dsInfo datasourceInfo) string {
if dsInfo.Settings.AzureAuthType != "" {
return dsInfo.Settings.AzureAuthType
} else {
tenantId := dsInfo.Settings.TenantId
clientId := dsInfo.Settings.ClientId
// If authentication type isn't explicitly specified and datasource has client credentials,
// then this is existing datasource which is configured for app registration (client secret)
if tenantId != "" && clientId != "" {
return AzureAuthClientSecret
}
// Azure cloud query types
const (
azureMonitor = "Azure Monitor"
appInsights = "Application Insights"
azureLogAnalytics = "Azure Log Analytics"
insightsAnalytics = "Insights Analytics"
azureResourceGraph = "Azure Resource Graph"
)
// For newly created datasource with no configuration, managed identity is the default authentication type
// if they are enabled in Grafana config
if cfg.Azure.ManagedIdentityEnabled {
return AzureAuthManagedIdentity
} else {
return AzureAuthClientSecret
func httpClientProvider(ctx context.Context, route azRoute, model datasourceInfo, cfg *setting.Cfg) *httpclient.Provider {
if len(route.Scopes) > 0 {
tokenAuth := &plugins.JwtTokenAuth{
Url: route.URL,
Scopes: route.Scopes,
Params: map[string]string{
"azure_auth_type": model.Settings.AzureAuthType,
"azure_cloud": cfg.Azure.Cloud,
"tenant_id": model.Settings.TenantId,
"client_id": model.Settings.ClientId,
"client_secret": model.DecryptedSecureJSONData["clientSecret"],
},
}
tokenProvider := tokenprovider.NewAzureAccessTokenProvider(ctx, cfg, tokenAuth)
return httpclient.NewProvider(httpclient.ProviderOptions{
Middlewares: []httpclient.Middleware{
tokenprovider.AuthMiddleware(tokenProvider),
},
})
} else {
return httpclient.NewProvider()
}
}
func getDefaultAzureCloud(cfg *setting.Cfg) (string, error) {
switch cfg.Azure.Cloud {
case setting.AzurePublic:
return azureMonitorPublic, nil
case setting.AzureChina:
return azureMonitorChina, nil
case setting.AzureUSGovernment:
return azureMonitorUSGovernment, nil
case setting.AzureGermany:
return azureMonitorGermany, nil
default:
err := fmt.Errorf("the cloud '%s' not supported", cfg.Azure.Cloud)
return "", err
}
}
func getAzureCloud(cfg *setting.Cfg, dsInfo datasourceInfo) (string, error) {
authType := getAuthType(cfg, dsInfo)
switch authType {
case AzureAuthManagedIdentity:
// In case of managed identity, the cloud is always same as where Grafana is hosted
return getDefaultAzureCloud(cfg)
case AzureAuthClientSecret:
if dsInfo.Settings.CloudName != "" {
return dsInfo.Settings.CloudName, nil
} else {
return getDefaultAzureCloud(cfg)
}
default:
err := fmt.Errorf("the authentication type '%s' not supported", authType)
return "", err
}
func newHTTPClient(ctx context.Context, route azRoute, model datasourceInfo, cfg *setting.Cfg) (*http.Client, error) {
model.HTTPCliOpts.Headers = route.Headers
return httpClientProvider(ctx, route, model, cfg).New(model.HTTPCliOpts)
}

@ -0,0 +1,54 @@
package azuremonitor
import (
"context"
"testing"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require"
)
func Test_httpCliProvider(t *testing.T) {
ctx := context.TODO()
cfg := &setting.Cfg{}
model := datasourceInfo{
Settings: azureMonitorSettings{},
DecryptedSecureJSONData: map[string]string{"clientSecret": "content"},
}
tests := []struct {
name string
route azRoute
expectedMiddlewares int
Err require.ErrorAssertionFunc
}{
{
name: "creates an HTTP client with a middleware",
route: azRoute{
URL: "http://route",
Scopes: []string{"http://route/.default"},
},
expectedMiddlewares: 1,
Err: require.NoError,
},
{
name: "creates an HTTP client without a middleware",
route: azRoute{
URL: "http://route",
Scopes: []string{},
},
// httpclient.NewProvider returns a client with 2 middlewares by default
expectedMiddlewares: 2,
Err: require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cli := httpClientProvider(ctx, tt.route, model, cfg)
// Cannot test that the cli middleware works properly since the azcore sdk
// rejects the TLS certs (if provided)
if len(cli.Opts.Middlewares) != tt.expectedMiddlewares {
t.Errorf("Unexpected middlewares: %v", cli.Opts.Middlewares)
}
})
}
}

@ -4,7 +4,6 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
@ -13,21 +12,12 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/grafana/grafana/pkg/api/pluginproxy"
"github.com/grafana/grafana/pkg/components/securejsondata"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util/errutil"
"github.com/opentracing/opentracing-go"
"golang.org/x/net/context/ctxhttp"
)
type InsightsAnalyticsDatasource struct {
pluginManager plugins.Manager
cfg *setting.Cfg
}
type InsightsAnalyticsDatasource struct{}
type InsightsAnalyticsQuery struct {
RefID string
@ -122,7 +112,7 @@ func (e *InsightsAnalyticsDatasource) executeQuery(ctx context.Context, query *I
}
azlog.Debug("ApplicationInsights", "Request URL", req.URL.String())
res, err := ctxhttp.Do(ctx, dsInfo.HTTPClient, req)
res, err := ctxhttp.Do(ctx, dsInfo.Services[appInsights].HTTPClient, req)
if err != nil {
return dataResponseError(err)
}
@ -179,59 +169,14 @@ func (e *InsightsAnalyticsDatasource) executeQuery(ctx context.Context, query *I
}
func (e *InsightsAnalyticsDatasource) createRequest(ctx context.Context, dsInfo datasourceInfo) (*http.Request, error) {
// find plugin
plugin := e.pluginManager.GetDataSource(dsName)
if plugin == nil {
return nil, errors.New("unable to find datasource plugin Azure Application Insights")
}
appInsightsRoute, routeName, err := e.getPluginRoute(plugin, dsInfo)
if err != nil {
return nil, err
}
appInsightsAppID := dsInfo.Settings.AppInsightsAppId
u, err := url.Parse(dsInfo.URL)
if err != nil {
return nil, fmt.Errorf("unable to parse url for Application Insights Analytics datasource: %w", err)
}
u.Path = path.Join(u.Path, fmt.Sprintf("/v1/apps/%s", appInsightsAppID))
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
req, err := http.NewRequest(http.MethodGet, dsInfo.Services[insightsAnalytics].URL, nil)
if err != nil {
azlog.Debug("Failed to create request", "error", err)
return nil, errutil.Wrap("Failed to create request", err)
}
// TODO: Use backend authentication instead
proxyPass := fmt.Sprintf("%s/v1/apps/%s", routeName, appInsightsAppID)
pluginproxy.ApplyRoute(ctx, req, proxyPass, appInsightsRoute, &models.DataSource{
JsonData: simplejson.NewFromAny(dsInfo.JSONData),
SecureJsonData: securejsondata.GetEncryptedJsonData(dsInfo.DecryptedSecureJSONData),
}, e.cfg)
req.Header.Set("X-API-Key", dsInfo.DecryptedSecureJSONData["appInsightsApiKey"])
req.URL.Path = fmt.Sprintf("/v1/apps/%s", appInsightsAppID)
return req, nil
}
func (e *InsightsAnalyticsDatasource) getPluginRoute(plugin *plugins.DataSourcePlugin, dsInfo datasourceInfo) (*plugins.AppPluginRoute, string, error) {
cloud, err := getAzureCloud(e.cfg, dsInfo)
if err != nil {
return nil, "", err
}
routeName, err := getAppInsightsApiRoute(cloud)
if err != nil {
return nil, "", err
}
var pluginRoute *plugins.AppPluginRoute
for _, route := range plugin.Routes {
if route.Path == routeName {
pluginRoute = route
break
}
}
return pluginRoute, routeName, nil
}

@ -0,0 +1,53 @@
package azuremonitor
import (
"context"
"net/http"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
)
func TestInsightsAnalyticsCreateRequest(t *testing.T) {
ctx := context.Background()
dsInfo := datasourceInfo{
Settings: azureMonitorSettings{AppInsightsAppId: "foo"},
Services: map[string]datasourceService{
insightsAnalytics: {URL: "http://ds"},
},
DecryptedSecureJSONData: map[string]string{
"appInsightsApiKey": "key",
},
}
tests := []struct {
name string
expectedURL string
expectedHeaders http.Header
Err require.ErrorAssertionFunc
}{
{
name: "creates a request",
expectedURL: "http://ds/v1/apps/foo",
expectedHeaders: http.Header{
"X-Api-Key": []string{"key"},
},
Err: require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ds := InsightsAnalyticsDatasource{}
req, err := ds.createRequest(ctx, dsInfo)
tt.Err(t, err)
if req.URL.String() != tt.expectedURL {
t.Errorf("Expecting %s, got %s", tt.expectedURL, req.URL.String())
}
if !cmp.Equal(req.Header, tt.expectedHeaders) {
t.Errorf("Unexpected HTTP headers: %v", cmp.Diff(req.Header, tt.expectedHeaders))
}
})
}
}

@ -1,45 +1,90 @@
package azuremonitor
import "fmt"
func getManagementApiRoute(azureCloud string) (string, error) {
switch azureCloud {
case azureMonitorPublic:
return "azuremonitor", nil
case azureMonitorChina:
return "chinaazuremonitor", nil
case azureMonitorUSGovernment:
return "govazuremonitor", nil
case azureMonitorGermany:
return "germanyazuremonitor", nil
default:
err := fmt.Errorf("the cloud '%s' not supported", azureCloud)
return "", err
}
type azRoute struct {
URL string
Scopes []string
Headers map[string]string
}
func getLogAnalyticsApiRoute(azureCloud string) (string, error) {
switch azureCloud {
case azureMonitorPublic:
return "loganalyticsazure", nil
case azureMonitorChina:
return "chinaloganalyticsazure", nil
case azureMonitorUSGovernment:
return "govloganalyticsazure", nil
default:
err := fmt.Errorf("the cloud '%s' not supported", azureCloud)
return "", err
}
var azManagement = azRoute{
URL: "https://management.azure.com",
Scopes: []string{"https://management.azure.com/.default"},
Headers: map[string]string{"x-ms-app": "Grafana"},
}
func getAppInsightsApiRoute(azureCloud string) (string, error) {
switch azureCloud {
case azureMonitorPublic:
return "appinsights", nil
case azureMonitorChina:
return "chinaappinsights", nil
default:
err := fmt.Errorf("the cloud '%s' not supported", azureCloud)
return "", err
}
var azUSGovManagement = azRoute{
URL: "https://management.usgovcloudapi.net",
Scopes: []string{"https://management.usgovcloudapi.net/.default"},
Headers: map[string]string{"x-ms-app": "Grafana"},
}
var azGermanyManagement = azRoute{
URL: "https://management.microsoftazure.de",
Scopes: []string{"https://management.microsoftazure.de/.default"},
Headers: map[string]string{"x-ms-app": "Grafana"},
}
var azChinaManagement = azRoute{
URL: "https://management.chinacloudapi.cn",
Scopes: []string{"https://management.chinacloudapi.cn/.default"},
Headers: map[string]string{"x-ms-app": "Grafana"},
}
var azAppInsights = azRoute{
URL: "https://api.applicationinsights.io",
Scopes: []string{},
Headers: map[string]string{"x-ms-app": "Grafana"},
}
var azChinaAppInsights = azRoute{
URL: "https://api.applicationinsights.azure.cn",
Scopes: []string{},
Headers: map[string]string{"x-ms-app": "Grafana"},
}
var azLogAnalytics = azRoute{
URL: "https://api.loganalytics.io",
Scopes: []string{"https://api.loganalytics.io/.default"},
Headers: map[string]string{"x-ms-app": "Grafana", "Cache-Control": "public, max-age=60"},
}
var azChinaLogAnalytics = azRoute{
URL: "https://api.loganalytics.azure.cn",
Scopes: []string{"https://api.loganalytics.azure.cn/.default"},
Headers: map[string]string{"x-ms-app": "Grafana", "Cache-Control": "public, max-age=60"},
}
var azUSGovLogAnalytics = azRoute{
URL: "https://api.loganalytics.us",
Scopes: []string{"https://api.loganalytics.us/.default"},
Headers: map[string]string{"x-ms-app": "Grafana", "Cache-Control": "public, max-age=60"},
}
var (
// The different Azure routes are identified by its cloud (e.g. public or gov)
// and the service to query (e.g. Azure Monitor or Azure Log Analytics)
routes = map[string]map[string]azRoute{
azureMonitorPublic: {
azureMonitor: azManagement,
azureLogAnalytics: azLogAnalytics,
azureResourceGraph: azManagement,
appInsights: azAppInsights,
insightsAnalytics: azAppInsights,
},
azureMonitorUSGovernment: {
azureMonitor: azUSGovManagement,
azureLogAnalytics: azUSGovLogAnalytics,
azureResourceGraph: azUSGovManagement,
},
azureMonitorGermany: {
azureMonitor: azGermanyManagement,
},
azureMonitorChina: {
azureMonitor: azChinaManagement,
azureLogAnalytics: azChinaLogAnalytics,
azureResourceGraph: azChinaManagement,
appInsights: azChinaAppInsights,
insightsAnalytics: azChinaAppInsights,
},
}
)

@ -0,0 +1,33 @@
package tokenprovider
import (
"fmt"
"net/http"
"time"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
)
var (
// timeNow makes it possible to test usage of time
timeNow = time.Now
)
type TokenProvider interface {
GetAccessToken() (string, error)
}
const authenticationMiddlewareName = "AzureAuthentication"
func AuthMiddleware(tokenProvider TokenProvider) httpclient.Middleware {
return httpclient.NamedMiddlewareFunc(authenticationMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
token, err := tokenProvider.GetAccessToken()
if err != nil {
return nil, fmt.Errorf("failed to retrieve azure access token: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
return next.RoundTrip(req)
})
})
}

@ -0,0 +1,184 @@
package tokenprovider
import (
"context"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
)
type AccessToken struct {
Token string
ExpiresOn time.Time
}
type TokenCredential interface {
GetCacheKey() string
Init() error
GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error)
}
type ConcurrentTokenCache interface {
GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error)
}
func NewConcurrentTokenCache() ConcurrentTokenCache {
return &tokenCacheImpl{}
}
type tokenCacheImpl struct {
cache sync.Map // of *credentialCacheEntry
}
type credentialCacheEntry struct {
credential TokenCredential
credInit uint32
credMutex sync.Mutex
cache sync.Map // of *scopesCacheEntry
}
type scopesCacheEntry struct {
credential TokenCredential
scopes []string
cond *sync.Cond
refreshing bool
accessToken *AccessToken
}
func (c *tokenCacheImpl) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) {
return c.getEntryFor(credential).getAccessToken(ctx, scopes)
}
func (c *tokenCacheImpl) getEntryFor(credential TokenCredential) *credentialCacheEntry {
var entry interface{}
var ok bool
key := credential.GetCacheKey()
if entry, ok = c.cache.Load(key); !ok {
entry, _ = c.cache.LoadOrStore(key, &credentialCacheEntry{
credential: credential,
})
}
return entry.(*credentialCacheEntry)
}
func (c *credentialCacheEntry) getAccessToken(ctx context.Context, scopes []string) (string, error) {
err := c.ensureInitialized()
if err != nil {
return "", err
}
return c.getEntryFor(scopes).getAccessToken(ctx)
}
func (c *credentialCacheEntry) ensureInitialized() error {
if atomic.LoadUint32(&c.credInit) == 0 {
c.credMutex.Lock()
defer c.credMutex.Unlock()
if c.credInit == 0 {
// Initialize credential
err := c.credential.Init()
if err != nil {
return err
}
atomic.StoreUint32(&c.credInit, 1)
}
}
return nil
}
func (c *credentialCacheEntry) getEntryFor(scopes []string) *scopesCacheEntry {
var entry interface{}
var ok bool
key := getKeyForScopes(scopes)
if entry, ok = c.cache.Load(key); !ok {
entry, _ = c.cache.LoadOrStore(key, &scopesCacheEntry{
credential: c.credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
})
}
return entry.(*scopesCacheEntry)
}
func (c *scopesCacheEntry) getAccessToken(ctx context.Context) (string, error) {
var accessToken *AccessToken
var err error
shouldRefresh := false
c.cond.L.Lock()
for {
if c.accessToken != nil && c.accessToken.ExpiresOn.After(time.Now().Add(2*time.Minute)) {
// Use the cached token since it's available and not expired yet
accessToken = c.accessToken
break
}
if !c.refreshing {
// Start refreshing the token
c.refreshing = true
shouldRefresh = true
break
}
// Wait for the token to be refreshed
c.cond.Wait()
}
c.cond.L.Unlock()
if shouldRefresh {
accessToken, err = c.refreshAccessToken(ctx)
if err != nil {
return "", err
}
}
return accessToken.Token, nil
}
func (c *scopesCacheEntry) refreshAccessToken(ctx context.Context) (*AccessToken, error) {
var accessToken *AccessToken
// Safeguarding from panic caused by credential implementation
defer func() {
c.cond.L.Lock()
c.refreshing = false
if accessToken != nil {
c.accessToken = accessToken
}
c.cond.Broadcast()
c.cond.L.Unlock()
}()
token, err := c.credential.GetAccessToken(ctx, c.scopes)
if err != nil {
return nil, err
}
accessToken = token
return accessToken, nil
}
func getKeyForScopes(scopes []string) string {
if len(scopes) > 1 {
arr := make([]string, len(scopes))
copy(arr, scopes)
sort.Strings(arr)
scopes = arr
}
return strings.Join(scopes, " ")
}

@ -0,0 +1,457 @@
package tokenprovider
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type fakeCredential struct {
key string
initCalledTimes int
calledTimes int
initFunc func() error
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
}
func (c *fakeCredential) GetCacheKey() string {
return c.key
}
func (c *fakeCredential) Reset() {
c.initCalledTimes = 0
c.calledTimes = 0
}
func (c *fakeCredential) Init() error {
c.initCalledTimes = c.initCalledTimes + 1
if c.initFunc != nil {
return c.initFunc()
}
return nil
}
func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
c.calledTimes = c.calledTimes + 1
if c.getAccessTokenFunc != nil {
return c.getAccessTokenFunc(ctx, scopes)
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("%v-token-%v", c.key, c.calledTimes), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
}
func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
ctx := context.Background()
scopes1 := []string{"Scope1"}
scopes2 := []string{"Scope2"}
t.Run("should request access token from credential", func(t *testing.T) {
cache := NewConcurrentTokenCache()
credential := &fakeCredential{key: "credential-1"}
token, err := cache.GetAccessToken(ctx, credential, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token)
assert.Equal(t, 1, credential.calledTimes)
})
t.Run("should return cached token for same scopes", func(t *testing.T) {
var token1, token2 string
var err error
cache := NewConcurrentTokenCache()
credential := &fakeCredential{key: "credential-1"}
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-2", token2)
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-2", token2)
assert.Equal(t, 2, credential.calledTimes)
})
t.Run("should return cached token for same credentials", func(t *testing.T) {
var token1, token2 string
var err error
cache := NewConcurrentTokenCache()
credential1 := &fakeCredential{key: "credential-1"}
credential2 := &fakeCredential{key: "credential-2"}
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-2-token-1", token2)
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-2-token-1", token2)
assert.Equal(t, 1, credential1.calledTimes)
assert.Equal(t, 1, credential2.calledTimes)
})
}
func TestCredentialCacheEntry_EnsureInitialized(t *testing.T) {
t.Run("when credential init returns error", func(t *testing.T) {
credential := &fakeCredential{
initFunc: func() error {
return errors.New("unable to initialize")
},
}
t.Run("should return error", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
}
err := cacheEntry.ensureInitialized()
assert.Error(t, err)
})
t.Run("should call init again each time and return error", func(t *testing.T) {
credential.Reset()
cacheEntry := &credentialCacheEntry{
credential: credential,
}
var err error
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
assert.Equal(t, 3, credential.initCalledTimes)
})
})
t.Run("when credential init returns error only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
initFunc: func() error {
times = times + 1
if times == 1 {
return errors.New("unable to initialize")
}
return nil
},
}
t.Run("should call credential init again only while it returns error", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
}
var err error
err = cacheEntry.ensureInitialized()
assert.Error(t, err)
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
assert.Equal(t, 2, credential.initCalledTimes)
})
})
t.Run("when credential init panics", func(t *testing.T) {
credential := &fakeCredential{
initFunc: func() error {
panic(errors.New("unable to initialize"))
},
}
t.Run("should call credential init again each time", func(t *testing.T) {
credential.Reset()
cacheEntry := &credentialCacheEntry{
credential: credential,
}
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
assert.Equal(t, 3, credential.initCalledTimes)
})
})
t.Run("when credential init panics only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
initFunc: func() error {
times = times + 1
if times == 1 {
panic(errors.New("unable to initialize"))
}
return nil
},
}
t.Run("should call credential init again only while it panics", func(t *testing.T) {
cacheEntry := &credentialCacheEntry{
credential: credential,
}
var err error
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_ = cacheEntry.ensureInitialized()
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
err = cacheEntry.ensureInitialized()
assert.NoError(t, err)
}()
assert.Equal(t, 2, credential.initCalledTimes)
})
})
}
func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
ctx := context.Background()
scopes := []string{"Scope1"}
t.Run("when credential getAccessToken returns error", func(t *testing.T) {
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
return invalidToken, errors.New("unable to get access token")
},
}
t.Run("should return error", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
accessToken, err := cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
assert.Equal(t, "", accessToken)
})
t.Run("should call credential again each time and return error", func(t *testing.T) {
credential.Reset()
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var err error
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
assert.Equal(t, 3, credential.calledTimes)
})
})
t.Run("when credential getAccessToken returns error only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
times = times + 1
if times == 1 {
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
return invalidToken, errors.New("unable to get access token")
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
},
}
t.Run("should call credential again only while it returns error", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var accessToken string
var err error
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
assert.Equal(t, 2, credential.calledTimes)
})
})
t.Run("when credential getAccessToken panics", func(t *testing.T) {
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
panic(errors.New("unable to get access token"))
},
}
t.Run("should call credential again each time", func(t *testing.T) {
credential.Reset()
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
assert.Equal(t, 3, credential.calledTimes)
})
})
t.Run("when credential getAccessToken panics only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
times = times + 1
if times == 1 {
panic(errors.New("unable to get access token"))
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
},
}
t.Run("should call credential again only while it panics", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var accessToken string
var err error
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
}()
assert.Equal(t, 2, credential.calledTimes)
})
})
}

@ -1,4 +1,4 @@
package pluginproxy
package tokenprovider
import (
"context"
@ -8,7 +8,6 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
)
@ -18,27 +17,21 @@ var (
)
type azureAccessTokenProvider struct {
datasourceId int64
datasourceVersion int
ctx context.Context
cfg *setting.Cfg
route *plugins.AppPluginRoute
authParams *plugins.JwtTokenAuth
ctx context.Context
cfg *setting.Cfg
authParams *plugins.JwtTokenAuth
}
func newAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg, ds *models.DataSource, pluginRoute *plugins.AppPluginRoute,
func NewAzureAccessTokenProvider(ctx context.Context, cfg *setting.Cfg,
authParams *plugins.JwtTokenAuth) *azureAccessTokenProvider {
return &azureAccessTokenProvider{
datasourceId: ds.Id,
datasourceVersion: ds.Version,
ctx: ctx,
cfg: cfg,
route: pluginRoute,
authParams: authParams,
ctx: ctx,
cfg: cfg,
authParams: authParams,
}
}
func (provider *azureAccessTokenProvider) getAccessToken() (string, error) {
func (provider *azureAccessTokenProvider) GetAccessToken() (string, error) {
var credential TokenCredential
if provider.isManagedIdentityCredential() {

@ -1,10 +1,9 @@
package pluginproxy
package tokenprovider
import (
"context"
"testing"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/assert"
@ -15,7 +14,7 @@ var getAccessTokenFunc func(credential TokenCredential, scopes []string)
type tokenCacheFake struct{}
func (c *tokenCacheFake) GetAccessToken(_ context.Context, credential TokenCredential, scopes []string) (string, error) {
func (c *tokenCacheFake) GetAccessToken(ctx context.Context, credential TokenCredential, scopes []string) (string, error) {
getAccessTokenFunc(credential, scopes)
return "4cb83b87-0ffb-4abd-82f6-48a8c08afc53", nil
}
@ -25,9 +24,6 @@ func TestAzureTokenProvider_isManagedIdentityCredential(t *testing.T) {
cfg := &setting.Cfg{}
ds := &models.DataSource{Id: 1, Version: 2}
route := &plugins.AppPluginRoute{}
authParams := &plugins.JwtTokenAuth{
Scopes: []string{
"https://management.azure.com/.default",
@ -41,7 +37,7 @@ func TestAzureTokenProvider_isManagedIdentityCredential(t *testing.T) {
},
}
provider := newAzureAccessTokenProvider(ctx, cfg, ds, route, authParams)
provider := NewAzureAccessTokenProvider(ctx, cfg, authParams)
t.Run("when managed identities enabled", func(t *testing.T) {
cfg.Azure.ManagedIdentityEnabled = true
@ -114,9 +110,6 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
cfg := &setting.Cfg{}
ds := &models.DataSource{Id: 1, Version: 2}
route := &plugins.AppPluginRoute{}
authParams := &plugins.JwtTokenAuth{
Scopes: []string{
"https://management.azure.com/.default",
@ -130,7 +123,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
},
}
provider := newAzureAccessTokenProvider(ctx, cfg, ds, route, authParams)
provider := NewAzureAccessTokenProvider(ctx, cfg, authParams)
original := azureTokenCache
azureTokenCache = &tokenCacheFake{}
@ -148,7 +141,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
assert.IsType(t, &managedIdentityCredential{}, credential)
}
_, err := provider.getAccessToken()
_, err := provider.GetAccessToken()
require.NoError(t, err)
})
@ -161,7 +154,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
assert.IsType(t, &clientSecretCredential{}, credential)
}
_, err := provider.getAccessToken()
_, err := provider.GetAccessToken()
require.NoError(t, err)
})
})
@ -178,7 +171,7 @@ func TestAzureTokenProvider_getAccessToken(t *testing.T) {
assert.Fail(t, "token cache not expected to be called")
}
_, err := provider.getAccessToken()
_, err := provider.GetAccessToken()
require.Error(t, err)
})
})
@ -189,9 +182,6 @@ func TestAzureTokenProvider_getClientSecretCredential(t *testing.T) {
cfg := &setting.Cfg{}
ds := &models.DataSource{Id: 1, Version: 2}
route := &plugins.AppPluginRoute{}
authParams := &plugins.JwtTokenAuth{
Scopes: []string{
"https://management.azure.com/.default",
@ -205,7 +195,7 @@ func TestAzureTokenProvider_getClientSecretCredential(t *testing.T) {
},
}
provider := newAzureAccessTokenProvider(ctx, cfg, ds, route, authParams)
provider := NewAzureAccessTokenProvider(ctx, cfg, authParams)
t.Run("should return clientSecretCredential with values", func(t *testing.T) {
result := provider.getClientSecretCredential()
Loading…
Cancel
Save