Infra: Azure authentication in HttpClientProvider (#36932)

* Azure middleware in HttpClientProxy

* Azure authentication under feature flag

* Minor fixes

* Add prefixes to not clash with JsonData

* Return error if JsonData cannot be parsed

* Return original string if URL invalid

* Tests for datasource_cache
pull/36932/merge
Sergey Kostrukov 4 years ago committed by GitHub
parent 4664cba935
commit c1963024ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 113
      pkg/infra/httpclient/httpclientprovider/azure_middleware.go
  2. 4
      pkg/infra/httpclient/httpclientprovider/http_client_provider.go
  3. 34
      pkg/models/datasource_cache.go
  4. 104
      pkg/models/datasource_cache_test.go
  5. 83
      pkg/tsdb/azuremonitor/azcredentials/builder.go
  6. 20
      pkg/tsdb/azuremonitor/aztokenprovider/middleware.go

@ -0,0 +1,113 @@
package httpclientprovider
import (
"fmt"
"net/http"
"net/url"
"path"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/aztokenprovider"
)
const azureMiddlewareName = "AzureAuthentication.Provider"
func AzureMiddleware(cfg *setting.Cfg) httpclient.Middleware {
return httpclient.NamedMiddlewareFunc(azureMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
if enabled, err := isAzureAuthenticationEnabled(opts.CustomOptions); err != nil {
return errorResponse(err)
} else if !enabled {
return next
}
credentials, err := getAzureCredentials(opts.CustomOptions)
if err != nil {
return errorResponse(err)
} else if credentials == nil {
credentials = getDefaultAzureCredentials(cfg)
}
tokenProvider, err := aztokenprovider.NewAzureAccessTokenProvider(cfg, credentials)
if err != nil {
return errorResponse(err)
}
scopes, err := getAzureEndpointScopes(opts.CustomOptions)
if err != nil {
return errorResponse(err)
}
return aztokenprovider.ApplyAuth(tokenProvider, scopes, next)
})
}
func errorResponse(err error) http.RoundTripper {
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("invalid Azure configuration: %s", err)
})
}
func isAzureAuthenticationEnabled(customOptions map[string]interface{}) (bool, error) {
if untypedValue, ok := customOptions["_azureAuth"]; !ok {
return false, nil
} else if value, ok := untypedValue.(bool); !ok {
err := fmt.Errorf("the field 'azureAuth' should be a bool")
return false, err
} else {
return value, nil
}
}
func getAzureCredentials(customOptions map[string]interface{}) (azcredentials.AzureCredentials, error) {
if untypedValue, ok := customOptions["_azureCredentials"]; !ok {
return nil, nil
} else if value, ok := untypedValue.(azcredentials.AzureCredentials); !ok {
err := fmt.Errorf("the field 'azureCredentials' should be a valid credentials object")
return nil, err
} else {
return value, nil
}
}
func getDefaultAzureCredentials(cfg *setting.Cfg) azcredentials.AzureCredentials {
if cfg.Azure.ManagedIdentityEnabled {
return &azcredentials.AzureManagedIdentityCredentials{}
} else {
return &azcredentials.AzureClientSecretCredentials{
AzureCloud: cfg.Azure.Cloud,
}
}
}
func getAzureEndpointResourceId(customOptions map[string]interface{}) (*url.URL, error) {
var value string
if untypedValue, ok := customOptions["azureEndpointResourceId"]; !ok {
err := fmt.Errorf("the field 'azureEndpointResourceId' should be set")
return nil, err
} else if value, ok = untypedValue.(string); !ok {
err := fmt.Errorf("the field 'azureEndpointResourceId' should be a string")
return nil, err
}
resourceId, err := url.Parse(value)
if err != nil || resourceId.Scheme == "" || resourceId.Host == "" {
err := fmt.Errorf("invalid endpoint Resource ID URL '%s'", value)
return nil, err
}
return resourceId, nil
}
func getAzureEndpointScopes(customOptions map[string]interface{}) ([]string, error) {
resourceId, err := getAzureEndpointResourceId(customOptions)
if err != nil {
return nil, err
}
resourceId.Path = path.Join(resourceId.Path, ".default")
scopes := []string{resourceId.String()}
return scopes, nil
}

@ -34,6 +34,10 @@ func New(cfg *setting.Cfg) httpclient.Provider {
setDefaultTimeoutOptions(cfg)
if cfg.FeatureToggles["httpclientprovider_azure_auth"] {
middlewares = append(middlewares, AzureMiddleware(cfg))
}
return newProviderFunc(sdkhttpclient.ProviderOptions{
Middlewares: middlewares,
ConfigureTransport: func(opts sdkhttpclient.Options, transport *http.Transport) {

@ -11,6 +11,7 @@ import (
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/httpclient"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
)
func (ds *DataSource) getTimeout() time.Duration {
@ -66,10 +67,14 @@ func (ds *DataSource) GetHTTPTransport(provider httpclient.Provider, customMiddl
return t.roundTripper, nil
}
opts := ds.HTTPClientOptions()
opts, err := ds.HTTPClientOptions()
if err != nil {
return nil, err
}
opts.Middlewares = customMiddlewares
rt, err := provider.GetTransport(opts)
rt, err := provider.GetTransport(*opts)
if err != nil {
return nil, err
}
@ -82,7 +87,7 @@ func (ds *DataSource) GetHTTPTransport(provider httpclient.Provider, customMiddl
return rt, nil
}
func (ds *DataSource) HTTPClientOptions() sdkhttpclient.Options {
func (ds *DataSource) HTTPClientOptions() (*sdkhttpclient.Options, error) {
tlsOptions := ds.TLSOptions()
timeouts := &sdkhttpclient.TimeoutOptions{
Timeout: ds.getTimeout(),
@ -95,7 +100,7 @@ func (ds *DataSource) HTTPClientOptions() sdkhttpclient.Options {
MaxIdleConnsPerHost: sdkhttpclient.DefaultTimeoutOptions.MaxIdleConnsPerHost,
IdleConnTimeout: sdkhttpclient.DefaultTimeoutOptions.IdleConnTimeout,
}
opts := sdkhttpclient.Options{
opts := &sdkhttpclient.Options{
Timeouts: timeouts,
Headers: getCustomHeaders(ds.JsonData, ds.DecryptedValues()),
Labels: map[string]string{
@ -121,6 +126,19 @@ func (ds *DataSource) HTTPClientOptions() sdkhttpclient.Options {
}
}
if ds.JsonData != nil && ds.JsonData.Get("azureAuth").MustBool() {
credentials, err := azcredentials.FromDatasourceData(ds.JsonData.MustMap(), ds.DecryptedValues())
if err != nil {
err = fmt.Errorf("invalid Azure credentials: %s", err)
return nil, err
}
opts.CustomOptions["_azureAuth"] = true
if credentials != nil {
opts.CustomOptions["_azureCredentials"] = credentials
}
}
if ds.JsonData != nil && ds.JsonData.Get("sigV4Auth").MustBool(false) {
opts.SigV4 = &sdkhttpclient.SigV4Config{
Service: awsServiceNamespace(ds.Type),
@ -140,7 +158,7 @@ func (ds *DataSource) HTTPClientOptions() sdkhttpclient.Options {
}
}
return opts
return opts, nil
}
func (ds *DataSource) TLSOptions() sdkhttpclient.TLSOptions {
@ -180,7 +198,11 @@ func (ds *DataSource) TLSOptions() sdkhttpclient.TLSOptions {
}
func (ds *DataSource) GetTLSConfig(httpClientProvider httpclient.Provider) (*tls.Config, error) {
return httpClientProvider.GetTLSConfig(ds.HTTPClientOptions())
opts, err := ds.HTTPClientOptions()
if err != nil {
return nil, err
}
return httpClientProvider.GetTLSConfig(*opts)
}
// getCustomHeaders returns a map with all the to be set headers

@ -13,6 +13,7 @@ import (
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/httpclient"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor/azcredentials"
"github.com/grafana/grafana/pkg/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -394,6 +395,109 @@ func TestDataSource_DecryptedValue(t *testing.T) {
})
}
func TestDataSource_HTTPClientOptions(t *testing.T) {
emptyJsonData := simplejson.New()
emptySecureJsonData := map[string][]byte{}
ds := DataSource{
Id: 1,
Url: "https://api.example.com",
Type: "prometheus",
}
t.Run("Azure authentication", func(t *testing.T) {
t.Run("should be disabled if not enabled in JsonData", func(t *testing.T) {
t.Cleanup(func() { ds.JsonData = emptyJsonData; ds.SecureJsonData = emptySecureJsonData })
opts, err := ds.HTTPClientOptions()
require.NoError(t, err)
assert.NotEqual(t, true, opts.CustomOptions["_azureAuth"])
assert.NotContains(t, opts.CustomOptions, "_azureCredentials")
})
t.Run("should be enabled if enabled in JsonData without credentials configured", func(t *testing.T) {
t.Cleanup(func() { ds.JsonData = emptyJsonData; ds.SecureJsonData = emptySecureJsonData })
ds.JsonData = simplejson.NewFromAny(map[string]interface{}{
"azureAuth": true,
})
opts, err := ds.HTTPClientOptions()
require.NoError(t, err)
assert.Equal(t, true, opts.CustomOptions["_azureAuth"])
assert.NotContains(t, opts.CustomOptions, "_azureCredentials")
})
t.Run("should be enabled if enabled in JsonData with credentials configured", func(t *testing.T) {
t.Cleanup(func() { ds.JsonData = emptyJsonData; ds.SecureJsonData = emptySecureJsonData })
ds.JsonData = simplejson.NewFromAny(map[string]interface{}{
"azureAuth": true,
"azureCredentials": map[string]interface{}{
"authType": "msi",
},
})
opts, err := ds.HTTPClientOptions()
require.NoError(t, err)
assert.Equal(t, true, opts.CustomOptions["_azureAuth"])
require.Contains(t, opts.CustomOptions, "_azureCredentials")
credentials := opts.CustomOptions["_azureCredentials"]
assert.IsType(t, &azcredentials.AzureManagedIdentityCredentials{}, credentials)
})
t.Run("should be disabled if disabled in JsonData even with credentials configured", func(t *testing.T) {
t.Cleanup(func() { ds.JsonData = emptyJsonData; ds.SecureJsonData = emptySecureJsonData })
ds.JsonData = simplejson.NewFromAny(map[string]interface{}{
"azureAuth": false,
"azureCredentials": map[string]interface{}{
"authType": "msi",
},
})
opts, err := ds.HTTPClientOptions()
require.NoError(t, err)
assert.NotEqual(t, true, opts.CustomOptions["_azureAuth"])
assert.NotContains(t, opts.CustomOptions, "_azureCredentials")
})
t.Run("should fail if credentials are invalid", func(t *testing.T) {
t.Cleanup(func() { ds.JsonData = emptyJsonData; ds.SecureJsonData = emptySecureJsonData })
ds.JsonData = simplejson.NewFromAny(map[string]interface{}{
"azureAuth": true,
"azureCredentials": "invalid",
})
_, err := ds.HTTPClientOptions()
assert.Error(t, err)
})
t.Run("should pass resourceId from JsonData", func(t *testing.T) {
t.Cleanup(func() { ds.JsonData = emptyJsonData; ds.SecureJsonData = emptySecureJsonData })
ds.JsonData = simplejson.NewFromAny(map[string]interface{}{
"azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5",
})
opts, err := ds.HTTPClientOptions()
require.NoError(t, err)
require.Contains(t, opts.CustomOptions, "azureEndpointResourceId")
azureEndpointResourceId := opts.CustomOptions["azureEndpointResourceId"]
assert.Equal(t, "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5", azureEndpointResourceId)
})
})
}
func clearDSProxyCache(t *testing.T) {
t.Helper()

@ -0,0 +1,83 @@
package azcredentials
import (
"fmt"
)
func FromDatasourceData(data map[string]interface{}, secureData map[string]string) (AzureCredentials, error) {
if credentialsObj, err := getMapOptional(data, "azureCredentials"); err != nil {
return nil, err
} else if credentialsObj == nil {
return nil, nil
} else {
return getFromCredentialsObject(credentialsObj, secureData)
}
}
func getFromCredentialsObject(credentialsObj map[string]interface{}, secureData map[string]string) (AzureCredentials, error) {
authType, err := getStringValue(credentialsObj, "authType")
if err != nil {
return nil, err
}
switch authType {
case AzureAuthManagedIdentity:
credentials := &AzureManagedIdentityCredentials{}
return credentials, nil
case AzureAuthClientSecret:
cloud, err := getStringValue(credentialsObj, "azureCloud")
if err != nil {
return nil, err
}
tenantId, err := getStringValue(credentialsObj, "tenantId")
if err != nil {
return nil, err
}
clientId, err := getStringValue(credentialsObj, "clientId")
if err != nil {
return nil, err
}
clientSecret := secureData["azureClientSecret"]
credentials := &AzureClientSecretCredentials{
AzureCloud: cloud,
TenantId: tenantId,
ClientId: clientId,
ClientSecret: clientSecret,
}
return credentials, nil
default:
err := fmt.Errorf("the authentication type '%s' not supported", authType)
return nil, err
}
}
func getMapOptional(obj map[string]interface{}, key string) (map[string]interface{}, error) {
if untypedValue, ok := obj[key]; ok {
if value, ok := untypedValue.(map[string]interface{}); ok {
return value, nil
} else {
err := fmt.Errorf("the field '%s' should be an object", key)
return nil, err
}
} else {
// Value optional, not error
return nil, nil
}
}
func getStringValue(obj map[string]interface{}, key string) (string, error) {
if untypedValue, ok := obj[key]; ok {
if value, ok := untypedValue.(string); ok {
return value, nil
} else {
err := fmt.Errorf("the field '%s' should be a string", key)
return "", err
}
} else {
err := fmt.Errorf("the field '%s' should be set", key)
return "", err
}
}

@ -11,13 +11,17 @@ const authenticationMiddlewareName = "AzureAuthentication"
func AuthMiddleware(tokenProvider AzureTokenProvider, scopes []string) httpclient.Middleware {
return httpclient.NamedMiddlewareFunc(authenticationMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper {
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
token, err := tokenProvider.GetAccessToken(req.Context(), scopes)
if err != nil {
return nil, fmt.Errorf("failed to retrieve Azure access token: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
return next.RoundTrip(req)
})
return ApplyAuth(tokenProvider, scopes, next)
})
}
func ApplyAuth(tokenProvider AzureTokenProvider, scopes []string, next http.RoundTripper) http.RoundTripper {
return httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
token, err := tokenProvider.GetAccessToken(req.Context(), scopes)
if err != nil {
return nil, fmt.Errorf("failed to retrieve Azure access token: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
return next.RoundTrip(req)
})
}

Loading…
Cancel
Save