mirror of https://github.com/grafana/grafana
Prometheus: Use contextual middleware for req headers and simplify client creation (#51061)
* Use contextual middleware and simplify client creation * Fix tests * Add test for the header propagation * Fix tests and lint * Update pkg/tsdb/prometheus/prometheus.go Co-authored-by: ismail simsek <ismailsimsek09@gmail.com> Co-authored-by: ismail simsek <ismailsimsek09@gmail.com>pull/51324/head^2
parent
a8eb29f1d7
commit
d20afa2a39
@ -0,0 +1,117 @@ |
||||
package azureauth |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings" |
||||
"github.com/grafana/grafana-plugin-sdk-go/backend" |
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" |
||||
|
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestConfigureAzureAuthentication(t *testing.T) { |
||||
cfg := &setting.Cfg{ |
||||
Azure: &azsettings.AzureSettings{}, |
||||
} |
||||
|
||||
t.Run("should set Azure middleware when JsonData contains valid credentials", func(t *testing.T) { |
||||
settings := backend.DataSourceInstanceSettings{ |
||||
JSONData: []byte(`{ |
||||
"httpMethod": "POST", |
||||
"azureCredentials": { |
||||
"authType": "msi" |
||||
} |
||||
}`), |
||||
} |
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} |
||||
|
||||
err := ConfigureAzureAuthentication(settings, cfg.Azure, opts) |
||||
require.NoError(t, err) |
||||
|
||||
require.NotNil(t, opts.Middlewares) |
||||
assert.Len(t, opts.Middlewares, 1) |
||||
}) |
||||
|
||||
t.Run("should not set Azure middleware when JsonData doesn't contain valid credentials", func(t *testing.T) { |
||||
settings := backend.DataSourceInstanceSettings{ |
||||
JSONData: []byte(`{ "httpMethod": "POST" }`), |
||||
} |
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} |
||||
|
||||
err := ConfigureAzureAuthentication(settings, cfg.Azure, opts) |
||||
require.NoError(t, err) |
||||
|
||||
assert.NotContains(t, opts.CustomOptions, "_azureCredentials") |
||||
}) |
||||
|
||||
t.Run("should return error when JsonData contains invalid credentials", func(t *testing.T) { |
||||
settings := backend.DataSourceInstanceSettings{ |
||||
JSONData: []byte(`{ |
||||
"httpMethod": "POST", |
||||
"azureCredentials": "invalid" |
||||
}`), |
||||
} |
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} |
||||
err := ConfigureAzureAuthentication(settings, cfg.Azure, opts) |
||||
assert.Error(t, err) |
||||
}) |
||||
|
||||
t.Run("should set Azure middleware when JsonData contains credentials and valid audience", func(t *testing.T) { |
||||
settings := backend.DataSourceInstanceSettings{ |
||||
JSONData: []byte(`{ |
||||
"httpMethod": "POST", |
||||
"azureCredentials": { |
||||
"authType": "msi" |
||||
}, |
||||
"azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5" |
||||
}`), |
||||
} |
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} |
||||
|
||||
err := ConfigureAzureAuthentication(settings, cfg.Azure, opts) |
||||
require.NoError(t, err) |
||||
|
||||
require.NotNil(t, opts.Middlewares) |
||||
assert.Len(t, opts.Middlewares, 1) |
||||
}) |
||||
|
||||
t.Run("should not set Azure middleware when JsonData doesn't contain credentials", func(t *testing.T) { |
||||
settings := backend.DataSourceInstanceSettings{ |
||||
JSONData: []byte(`{ |
||||
"httpMethod": "POST", |
||||
"azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5" |
||||
}`), |
||||
} |
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} |
||||
|
||||
err := ConfigureAzureAuthentication(settings, cfg.Azure, opts) |
||||
require.NoError(t, err) |
||||
|
||||
if opts.Middlewares != nil { |
||||
assert.Len(t, opts.Middlewares, 0) |
||||
} |
||||
}) |
||||
|
||||
t.Run("should return error when JsonData contains invalid audience", func(t *testing.T) { |
||||
settings := backend.DataSourceInstanceSettings{ |
||||
JSONData: []byte(`{ |
||||
"httpMethod": "POST", |
||||
"azureCredentials": { |
||||
"authType": "msi" |
||||
}, |
||||
"azureEndpointResourceId": "invalid" |
||||
}`), |
||||
} |
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} |
||||
|
||||
err := ConfigureAzureAuthentication(settings, cfg.Azure, opts) |
||||
assert.Error(t, err) |
||||
}) |
||||
} |
@ -0,0 +1,80 @@ |
||||
package buffered |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net/http" |
||||
"strings" |
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend" |
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" |
||||
"github.com/grafana/grafana/pkg/infra/log" |
||||
"github.com/grafana/grafana/pkg/services/featuremgmt" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/buffered/azureauth" |
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/middleware" |
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/utils" |
||||
"github.com/grafana/grafana/pkg/util/maputil" |
||||
"github.com/prometheus/client_golang/api" |
||||
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" |
||||
) |
||||
|
||||
// CreateTransportOptions creates options for the http client. Probably should be shared and should not live in the
|
||||
// buffered package.
|
||||
func CreateTransportOptions(settings backend.DataSourceInstanceSettings, cfg *setting.Cfg, features featuremgmt.FeatureToggles, logger log.Logger) (*sdkhttpclient.Options, error) { |
||||
opts, err := settings.HTTPClientOptions() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
jsonData, err := utils.GetJsonData(settings) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("error reading settings: %w", err) |
||||
} |
||||
httpMethod, _ := maputil.GetStringOptional(jsonData, "httpMethod") |
||||
|
||||
opts.Middlewares = middlewares(logger, httpMethod) |
||||
|
||||
// Set SigV4 service namespace
|
||||
if opts.SigV4 != nil { |
||||
opts.SigV4.Service = "aps" |
||||
} |
||||
|
||||
// Azure authentication is experimental (#35857)
|
||||
if features.IsEnabled(featuremgmt.FlagPrometheusAzureAuth) { |
||||
err = azureauth.ConfigureAzureAuthentication(settings, cfg.Azure, &opts) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("error configuring Azure auth: %v", err) |
||||
} |
||||
} |
||||
|
||||
return &opts, nil |
||||
} |
||||
|
||||
func CreateClient(roundTripper http.RoundTripper, url string) (apiv1.API, error) { |
||||
cfg := api.Config{ |
||||
Address: url, |
||||
RoundTripper: roundTripper, |
||||
} |
||||
|
||||
client, err := api.NewClient(cfg) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return apiv1.NewAPI(client), nil |
||||
} |
||||
|
||||
func middlewares(logger log.Logger, httpMethod string) []sdkhttpclient.Middleware { |
||||
middlewares := []sdkhttpclient.Middleware{ |
||||
// TODO: probably isn't needed anymore and should by done by http infra code
|
||||
middleware.CustomQueryParameters(logger), |
||||
sdkhttpclient.CustomHeadersMiddleware(), |
||||
} |
||||
|
||||
// Needed to control GET vs POST method of the requests
|
||||
if strings.ToLower(httpMethod) == "get" { |
||||
middlewares = append(middlewares, middleware.ForceHttpGet(logger)) |
||||
} |
||||
|
||||
return middlewares |
||||
} |
@ -1,56 +0,0 @@ |
||||
package promclient |
||||
|
||||
import ( |
||||
"sort" |
||||
"strings" |
||||
|
||||
lru "github.com/hashicorp/golang-lru" |
||||
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" |
||||
) |
||||
|
||||
type ProviderCache struct { |
||||
provider promClientProvider |
||||
cache *lru.Cache |
||||
} |
||||
|
||||
type promClientProvider interface { |
||||
GetClient(map[string]string) (apiv1.API, error) |
||||
} |
||||
|
||||
func NewProviderCache(p promClientProvider) (*ProviderCache, error) { |
||||
cache, err := lru.New(500) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &ProviderCache{ |
||||
provider: p, |
||||
cache: cache, |
||||
}, nil |
||||
} |
||||
|
||||
func (c *ProviderCache) GetClient(headers map[string]string) (apiv1.API, error) { |
||||
key := c.key(headers) |
||||
if client, ok := c.cache.Get(key); ok { |
||||
return client.(apiv1.API), nil |
||||
} |
||||
|
||||
client, err := c.provider.GetClient(headers) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
c.cache.Add(key, client) |
||||
return client, nil |
||||
} |
||||
|
||||
func (c *ProviderCache) key(headers map[string]string) string { |
||||
vals := make([]string, len(headers)) |
||||
var i int |
||||
for _, v := range headers { |
||||
vals[i] = v |
||||
i++ |
||||
} |
||||
sort.Strings(vals) |
||||
return strings.Join(vals, "") |
||||
} |
@ -1,131 +0,0 @@ |
||||
package promclient_test |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"sort" |
||||
"strings" |
||||
"testing" |
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/buffered/promclient" |
||||
|
||||
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" |
||||
|
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestCache_GetClient(t *testing.T) { |
||||
t.Run("it caches the client for a set of auth headers", func(t *testing.T) { |
||||
tc := setupCacheContext() |
||||
|
||||
c, err := tc.providerCache.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
c2, err := tc.providerCache.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.Equal(t, c, c2) |
||||
require.Equal(t, 1, tc.clientProvider.numCalls) |
||||
}) |
||||
|
||||
t.Run("it returns different clients when the headers differ", func(t *testing.T) { |
||||
tc := setupCacheContext() |
||||
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"} |
||||
h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"} |
||||
|
||||
c, err := tc.providerCache.GetClient(h1) |
||||
require.Nil(t, err) |
||||
|
||||
c2, err := tc.providerCache.GetClient(h2) |
||||
require.Nil(t, err) |
||||
|
||||
require.NotEqual(t, c, c2) |
||||
require.Equal(t, 2, tc.clientProvider.numCalls) |
||||
}) |
||||
|
||||
t.Run("it returns from the cache when headers are the same", func(t *testing.T) { |
||||
tc := setupCacheContext() |
||||
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"} |
||||
h2 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"} |
||||
|
||||
c, err := tc.providerCache.GetClient(h1) |
||||
require.Nil(t, err) |
||||
|
||||
c2, err := tc.providerCache.GetClient(h2) |
||||
require.Nil(t, err) |
||||
|
||||
require.Equal(t, c, c2) |
||||
require.Equal(t, 1, tc.clientProvider.numCalls) |
||||
}) |
||||
|
||||
t.Run("it doesn't cache anything when an error occurs", func(t *testing.T) { |
||||
tc := setupCacheContext() |
||||
tc.clientProvider.errors <- errors.New("something bad") |
||||
|
||||
_, err := tc.providerCache.GetClient(headers) |
||||
require.EqualError(t, err, "something bad") |
||||
|
||||
c, err := tc.providerCache.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.NotNil(t, c) |
||||
require.Equal(t, 2, tc.clientProvider.numCalls) |
||||
}) |
||||
} |
||||
|
||||
type cacheTestContext struct { |
||||
providerCache *promclient.ProviderCache |
||||
clientProvider *fakePromClientProvider |
||||
} |
||||
|
||||
func setupCacheContext() *cacheTestContext { |
||||
fp := newFakePromClientProvider() |
||||
p, err := promclient.NewProviderCache(fp) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
|
||||
return &cacheTestContext{ |
||||
providerCache: p, |
||||
clientProvider: fp, |
||||
} |
||||
} |
||||
|
||||
func newFakePromClientProvider() *fakePromClientProvider { |
||||
return &fakePromClientProvider{ |
||||
errors: make(chan error, 1), |
||||
} |
||||
} |
||||
|
||||
type fakePromClientProvider struct { |
||||
headers map[string]string |
||||
numCalls int |
||||
errors chan error |
||||
} |
||||
|
||||
func (p *fakePromClientProvider) GetClient(h map[string]string) (apiv1.API, error) { |
||||
p.headers = h |
||||
p.numCalls++ |
||||
|
||||
var err error |
||||
select { |
||||
case err = <-p.errors: |
||||
default: |
||||
} |
||||
|
||||
var config []string |
||||
for _, v := range h { |
||||
config = append(config, v) |
||||
} |
||||
sort.Strings(config) //because map
|
||||
return &fakePromClient{config: strings.Join(config, "")}, err |
||||
} |
||||
|
||||
type fakePromClient struct { |
||||
apiv1.API |
||||
config string |
||||
} |
||||
|
||||
func (c *fakePromClient) Config(ctx context.Context) (apiv1.ConfigResult, error) { |
||||
return apiv1.ConfigResult{YAML: c.config}, nil |
||||
} |
@ -1,106 +0,0 @@ |
||||
package promclient |
||||
|
||||
import ( |
||||
"strings" |
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend" |
||||
"github.com/grafana/grafana/pkg/services/featuremgmt" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/middleware" |
||||
"github.com/grafana/grafana/pkg/util/maputil" |
||||
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" |
||||
"github.com/grafana/grafana/pkg/infra/httpclient" |
||||
"github.com/grafana/grafana/pkg/infra/log" |
||||
"github.com/prometheus/client_golang/api" |
||||
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" |
||||
) |
||||
|
||||
type Provider struct { |
||||
settings backend.DataSourceInstanceSettings |
||||
jsonData map[string]interface{} |
||||
httpMethod string |
||||
clientProvider httpclient.Provider |
||||
cfg *setting.Cfg |
||||
features featuremgmt.FeatureToggles |
||||
log log.Logger |
||||
} |
||||
|
||||
func NewProvider( |
||||
settings backend.DataSourceInstanceSettings, |
||||
jsonData map[string]interface{}, |
||||
clientProvider httpclient.Provider, |
||||
cfg *setting.Cfg, |
||||
features featuremgmt.FeatureToggles, |
||||
log log.Logger, |
||||
) *Provider { |
||||
httpMethod, _ := maputil.GetStringOptional(jsonData, "httpMethod") |
||||
return &Provider{ |
||||
settings: settings, |
||||
jsonData: jsonData, |
||||
httpMethod: httpMethod, |
||||
clientProvider: clientProvider, |
||||
cfg: cfg, |
||||
features: features, |
||||
log: log, |
||||
} |
||||
} |
||||
|
||||
func (p *Provider) GetClient(headers map[string]string) (apiv1.API, error) { |
||||
opts, err := p.settings.HTTPClientOptions() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
opts.Middlewares = p.middlewares() |
||||
opts.Headers = reqHeaders(headers) |
||||
|
||||
// Set SigV4 service namespace
|
||||
if opts.SigV4 != nil { |
||||
opts.SigV4.Service = "aps" |
||||
} |
||||
|
||||
// Azure authentication
|
||||
err = p.configureAzureAuthentication(&opts) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
roundTripper, err := p.clientProvider.GetTransport(opts) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
cfg := api.Config{ |
||||
Address: p.settings.URL, |
||||
RoundTripper: roundTripper, |
||||
} |
||||
|
||||
client, err := api.NewClient(cfg) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return apiv1.NewAPI(client), nil |
||||
} |
||||
|
||||
func (p *Provider) middlewares() []sdkhttpclient.Middleware { |
||||
middlewares := []sdkhttpclient.Middleware{ |
||||
middleware.CustomQueryParameters(p.log), |
||||
sdkhttpclient.CustomHeadersMiddleware(), |
||||
} |
||||
if strings.ToLower(p.httpMethod) == "get" { |
||||
middlewares = append(middlewares, middleware.ForceHttpGet(p.log)) |
||||
} |
||||
|
||||
return middlewares |
||||
} |
||||
|
||||
func reqHeaders(headers map[string]string) map[string]string { |
||||
// copy to avoid changing the original map
|
||||
h := make(map[string]string, len(headers)) |
||||
for k, v := range headers { |
||||
h[k] = v |
||||
} |
||||
return h |
||||
} |
@ -1,153 +0,0 @@ |
||||
package promclient |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"github.com/grafana/grafana-azure-sdk-go/azsettings" |
||||
"github.com/grafana/grafana-plugin-sdk-go/backend" |
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" |
||||
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestConfigureAzureAuthentication(t *testing.T) { |
||||
cfg := &setting.Cfg{ |
||||
Azure: &azsettings.AzureSettings{}, |
||||
} |
||||
settings := backend.DataSourceInstanceSettings{} |
||||
|
||||
t.Run("given feature flag enabled", func(t *testing.T) { |
||||
features := featuremgmt.WithFeatures(featuremgmt.FlagPrometheusAzureAuth) |
||||
|
||||
t.Run("should set Azure middleware when JsonData contains valid credentials", func(t *testing.T) { |
||||
jsonData := map[string]interface{}{ |
||||
"httpMethod": "POST", |
||||
"azureCredentials": map[string]interface{}{ |
||||
"authType": "msi", |
||||
}, |
||||
} |
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil) |
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} |
||||
|
||||
err := p.configureAzureAuthentication(opts) |
||||
require.NoError(t, err) |
||||
|
||||
require.NotNil(t, opts.Middlewares) |
||||
assert.Len(t, opts.Middlewares, 1) |
||||
}) |
||||
|
||||
t.Run("should not set Azure middleware when JsonData doesn't contain valid credentials", func(t *testing.T) { |
||||
jsonData := map[string]interface{}{ |
||||
"httpMethod": "POST", |
||||
} |
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil) |
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} |
||||
|
||||
err := p.configureAzureAuthentication(opts) |
||||
require.NoError(t, err) |
||||
|
||||
assert.NotContains(t, opts.CustomOptions, "_azureCredentials") |
||||
}) |
||||
|
||||
t.Run("should return error when JsonData contains invalid credentials", func(t *testing.T) { |
||||
jsonData := map[string]interface{}{ |
||||
"httpMethod": "POST", |
||||
"azureCredentials": "invalid", |
||||
} |
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil) |
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} |
||||
|
||||
err := p.configureAzureAuthentication(opts) |
||||
assert.Error(t, err) |
||||
}) |
||||
|
||||
t.Run("should set Azure middleware when JsonData contains credentials and valid audience", func(t *testing.T) { |
||||
jsonData := map[string]interface{}{ |
||||
"httpMethod": "POST", |
||||
"azureCredentials": map[string]interface{}{ |
||||
"authType": "msi", |
||||
}, |
||||
"azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5", |
||||
} |
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil) |
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} |
||||
|
||||
err := p.configureAzureAuthentication(opts) |
||||
require.NoError(t, err) |
||||
|
||||
require.NotNil(t, opts.Middlewares) |
||||
assert.Len(t, opts.Middlewares, 1) |
||||
}) |
||||
|
||||
t.Run("should not set Azure middleware when JsonData doesn't contain credentials", func(t *testing.T) { |
||||
jsonData := map[string]interface{}{ |
||||
"httpMethod": "POST", |
||||
"azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5", |
||||
} |
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil) |
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} |
||||
|
||||
err := p.configureAzureAuthentication(opts) |
||||
require.NoError(t, err) |
||||
|
||||
if opts.Middlewares != nil { |
||||
assert.Len(t, opts.Middlewares, 0) |
||||
} |
||||
}) |
||||
|
||||
t.Run("should return error when JsonData contains invalid audience", func(t *testing.T) { |
||||
jsonData := map[string]interface{}{ |
||||
"httpMethod": "POST", |
||||
"azureCredentials": map[string]interface{}{ |
||||
"authType": "msi", |
||||
}, |
||||
"azureEndpointResourceId": "invalid", |
||||
} |
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil) |
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} |
||||
|
||||
err := p.configureAzureAuthentication(opts) |
||||
assert.Error(t, err) |
||||
}) |
||||
}) |
||||
|
||||
t.Run("given feature flag not enabled", func(t *testing.T) { |
||||
features := featuremgmt.WithFeatures() |
||||
|
||||
t.Run("should not set Azure Credentials even when JsonData contains credentials", func(t *testing.T) { |
||||
jsonData := map[string]interface{}{ |
||||
"httpMethod": "POST", |
||||
"azureCredentials": map[string]interface{}{ |
||||
"authType": "msi", |
||||
}, |
||||
"azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5", |
||||
} |
||||
|
||||
var p = NewProvider(settings, jsonData, nil, cfg, features, nil) |
||||
|
||||
var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} |
||||
|
||||
err := p.configureAzureAuthentication(opts) |
||||
require.NoError(t, err) |
||||
|
||||
if opts.Middlewares != nil { |
||||
assert.Len(t, opts.Middlewares, 0) |
||||
} |
||||
}) |
||||
}) |
||||
} |
@ -1,181 +0,0 @@ |
||||
package promclient_test |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"net/http" |
||||
"testing" |
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/buffered/promclient" |
||||
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend" |
||||
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" |
||||
"github.com/grafana/grafana/pkg/infra/httpclient" |
||||
|
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
var headers = map[string]string{"Authorization": "token", "X-ID-Token": "id-token"} |
||||
|
||||
func TestGetClient(t *testing.T) { |
||||
t.Run("it sets the SigV4 service if it exists", func(t *testing.T) { |
||||
tc := setup(`{"sigV4Auth":true}`) |
||||
|
||||
setting.SigV4AuthEnabled = true |
||||
defer func() { setting.SigV4AuthEnabled = false }() |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.Equal(t, "aps", tc.httpProvider.opts.SigV4.Service) |
||||
}) |
||||
|
||||
t.Run("it always uses the custom params and custom headers middlewares", func(t *testing.T) { |
||||
tc := setup() |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.Len(t, tc.httpProvider.middlewares(), 2) |
||||
require.Contains(t, tc.httpProvider.middlewares(), "prom-custom-query-parameters") |
||||
require.Contains(t, tc.httpProvider.middlewares(), "CustomHeaders") |
||||
}) |
||||
|
||||
t.Run("extra headers", func(t *testing.T) { |
||||
t.Run("it sets the headers when 'oauthPassThru' is true and auth headers are passed", func(t *testing.T) { |
||||
tc := setup(`{"oauthPassThru":true}`) |
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.Equal(t, headers, tc.httpProvider.opts.Headers) |
||||
}) |
||||
|
||||
t.Run("it sets all headers", func(t *testing.T) { |
||||
withNonAuth := map[string]string{"X-Not-Auth": "stuff"} |
||||
|
||||
tc := setup(`{"oauthPassThru":true}`) |
||||
_, err := tc.promClientProvider.GetClient(withNonAuth) |
||||
require.Nil(t, err) |
||||
|
||||
require.Equal(t, map[string]string{"X-Not-Auth": "stuff"}, tc.httpProvider.opts.Headers) |
||||
}) |
||||
|
||||
t.Run("it does not error when headers are nil", func(t *testing.T) { |
||||
tc := setup(`{"oauthPassThru":true}`) |
||||
|
||||
_, err := tc.promClientProvider.GetClient(nil) |
||||
require.Nil(t, err) |
||||
}) |
||||
}) |
||||
|
||||
t.Run("force get middleware", func(t *testing.T) { |
||||
t.Run("it add the force-get middleware when httpMethod is get", func(t *testing.T) { |
||||
tc := setup(`{"httpMethod":"get"}`) |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.Len(t, tc.httpProvider.middlewares(), 3) |
||||
require.Contains(t, tc.httpProvider.middlewares(), "force-http-get") |
||||
}) |
||||
|
||||
t.Run("it add the force-get middleware when httpMethod is get", func(t *testing.T) { |
||||
tc := setup(`{"httpMethod":"GET"}`) |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.Len(t, tc.httpProvider.middlewares(), 3) |
||||
require.Contains(t, tc.httpProvider.middlewares(), "force-http-get") |
||||
}) |
||||
|
||||
t.Run("it does not add the force-get middleware when httpMethod is POST", func(t *testing.T) { |
||||
tc := setup(`{"httpMethod":"POST"}`) |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get") |
||||
}) |
||||
|
||||
t.Run("it does not add the force-get middleware when json data is nil", func(t *testing.T) { |
||||
tc := setup() |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get") |
||||
}) |
||||
|
||||
t.Run("it does not add the force-get middleware when json data is empty", func(t *testing.T) { |
||||
tc := setup(`{}`) |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get") |
||||
}) |
||||
|
||||
t.Run("it does not add the force-get middleware httpMethod is null", func(t *testing.T) { |
||||
tc := setup(`{"httpMethod":null}`) |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get") |
||||
}) |
||||
}) |
||||
} |
||||
|
||||
func setup(jsonData ...string) *testContext { |
||||
var rawData []byte |
||||
if len(jsonData) > 0 { |
||||
rawData = []byte(jsonData[0]) |
||||
} |
||||
|
||||
var jd map[string]interface{} |
||||
_ = json.Unmarshal(rawData, &jd) |
||||
|
||||
cfg := &setting.Cfg{} |
||||
settings := backend.DataSourceInstanceSettings{URL: "test-url", JSONData: rawData} |
||||
features := featuremgmt.WithFeatures() |
||||
hp := &fakeHttpClientProvider{} |
||||
p := promclient.NewProvider(settings, jd, hp, cfg, features, nil) |
||||
|
||||
return &testContext{ |
||||
httpProvider: hp, |
||||
promClientProvider: p, |
||||
} |
||||
} |
||||
|
||||
type testContext struct { |
||||
httpProvider *fakeHttpClientProvider |
||||
promClientProvider *promclient.Provider |
||||
} |
||||
|
||||
type fakeHttpClientProvider struct { |
||||
httpclient.Provider |
||||
|
||||
opts sdkhttpclient.Options |
||||
} |
||||
|
||||
func (p *fakeHttpClientProvider) GetTransport(opts ...sdkhttpclient.Options) (http.RoundTripper, error) { |
||||
p.opts = opts[0] |
||||
return http.DefaultTransport, nil |
||||
} |
||||
|
||||
func (p *fakeHttpClientProvider) middlewares() []string { |
||||
var middlewareNames []string |
||||
for _, m := range p.opts.Middlewares { |
||||
mw, ok := m.(sdkhttpclient.MiddlewareName) |
||||
if !ok { |
||||
panic("unexpected middleware type") |
||||
} |
||||
|
||||
middlewareNames = append(middlewareNames, mw.MiddlewareName()) |
||||
} |
||||
return middlewareNames |
||||
} |
@ -0,0 +1,24 @@ |
||||
package middleware |
||||
|
||||
import ( |
||||
"net/http" |
||||
|
||||
sdkHTTPClient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" |
||||
) |
||||
|
||||
// ReqHeadersMiddleware is used so that we can pass req headers through the prometheus go client as it does not allow
|
||||
// access to the request directly. Should be used together with WithContextualMiddleware so that it is attached to
|
||||
// the context of each request with its unique headers.
|
||||
func ReqHeadersMiddleware(headers map[string]string) sdkHTTPClient.Middleware { |
||||
return sdkHTTPClient.NamedMiddlewareFunc("prometheus-req-headers-middleware", func(opts sdkHTTPClient.Options, next http.RoundTripper) http.RoundTripper { |
||||
return sdkHTTPClient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { |
||||
for k, v := range headers { |
||||
// As custom headers middleware is before contextual we may overwrite custom headers here with those
|
||||
// that came with the request which probably makes sense.
|
||||
req.Header[k] = []string{v} |
||||
} |
||||
|
||||
return next.RoundTrip(req) |
||||
}) |
||||
}) |
||||
} |
@ -0,0 +1,42 @@ |
||||
package utils |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/json" |
||||
"fmt" |
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend" |
||||
"github.com/grafana/grafana/pkg/infra/tracing" |
||||
"go.opentelemetry.io/otel/attribute" |
||||
) |
||||
|
||||
// GetJsonData just gets the json in easier to work with type. It's used on multiple places which isn't super effective
|
||||
// but only when creating a client which should not happen often anyway.
|
||||
func GetJsonData(settings backend.DataSourceInstanceSettings) (map[string]interface{}, error) { |
||||
var jsonData map[string]interface{} |
||||
err := json.Unmarshal(settings.JSONData, &jsonData) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("error unmarshalling JSONData: %w", err) |
||||
} |
||||
return jsonData, nil |
||||
} |
||||
|
||||
type Attribute struct { |
||||
Key string |
||||
Value interface{} |
||||
Kv attribute.KeyValue |
||||
} |
||||
|
||||
// StartTrace setups a trace but does not panic if tracer is nil which helps with testing
|
||||
func StartTrace(ctx context.Context, tracer tracing.Tracer, name string, attributes []Attribute) (context.Context, func()) { |
||||
if tracer == nil { |
||||
return ctx, func() {} |
||||
} |
||||
ctx, span := tracer.Start(ctx, name) |
||||
for _, attr := range attributes { |
||||
span.SetAttributes(attr.Key, attr.Value, attr.Kv) |
||||
} |
||||
return ctx, func() { |
||||
span.End() |
||||
} |
||||
} |
Loading…
Reference in new issue