From d20afa2a390dc3b2ffae348c1760bb0bdb23907e Mon Sep 17 00:00:00 2001 From: Andrej Ocenas Date: Thu, 23 Jun 2022 14:48:16 +0200 Subject: [PATCH] Prometheus: Use contextual middleware for req headers and simplify client creation (#51061) * Use contextual middleware and simplify client creation * Fix tests * Add test for the header propagation * Fix tests and lint * Update pkg/tsdb/prometheus/prometheus.go Co-authored-by: ismail simsek Co-authored-by: ismail simsek --- .../provider_azure.go => azureauth/azure.go} | 22 +-- .../buffered/azureauth/azure_test.go | 117 +++++++++++ pkg/tsdb/prometheus/buffered/client.go | 80 ++++++++ pkg/tsdb/prometheus/buffered/framing_test.go | 3 +- .../prometheus/buffered/promclient/cache.go | 56 ------ .../buffered/promclient/cache_test.go | 131 ------------- .../buffered/promclient/provider.go | 106 ---------- .../promclient/provider_azure_test.go | 153 --------------- .../buffered/promclient/provider_test.go | 181 ------------------ .../buffered/prometeus_bench_test.go | 4 +- .../prometheus/buffered/time_series_query.go | 92 ++++----- .../buffered/time_series_query_test.go | 57 +++++- pkg/tsdb/prometheus/buffered/types.go | 4 - pkg/tsdb/prometheus/middleware/req_headers.go | 24 +++ pkg/tsdb/prometheus/prometheus.go | 13 +- pkg/tsdb/prometheus/prometheus_test.go | 4 + pkg/tsdb/prometheus/utils/utils.go | 42 ++++ 17 files changed, 393 insertions(+), 696 deletions(-) rename pkg/tsdb/prometheus/buffered/{promclient/provider_azure.go => azureauth/azure.go} (79%) create mode 100644 pkg/tsdb/prometheus/buffered/azureauth/azure_test.go create mode 100644 pkg/tsdb/prometheus/buffered/client.go delete mode 100644 pkg/tsdb/prometheus/buffered/promclient/cache.go delete mode 100644 pkg/tsdb/prometheus/buffered/promclient/cache_test.go delete mode 100644 pkg/tsdb/prometheus/buffered/promclient/provider.go delete mode 100644 pkg/tsdb/prometheus/buffered/promclient/provider_azure_test.go delete mode 100644 pkg/tsdb/prometheus/buffered/promclient/provider_test.go create mode 100644 pkg/tsdb/prometheus/middleware/req_headers.go create mode 100644 pkg/tsdb/prometheus/utils/utils.go diff --git a/pkg/tsdb/prometheus/buffered/promclient/provider_azure.go b/pkg/tsdb/prometheus/buffered/azureauth/azure.go similarity index 79% rename from pkg/tsdb/prometheus/buffered/promclient/provider_azure.go rename to pkg/tsdb/prometheus/buffered/azureauth/azure.go index c9f45ed704f..b80bfcf41cc 100644 --- a/pkg/tsdb/prometheus/buffered/promclient/provider_azure.go +++ b/pkg/tsdb/prometheus/buffered/azureauth/azure.go @@ -1,4 +1,4 @@ -package promclient +package azureauth import ( "fmt" @@ -8,9 +8,10 @@ import ( "github.com/grafana/grafana-azure-sdk-go/azcredentials" "github.com/grafana/grafana-azure-sdk-go/azhttpclient" "github.com/grafana/grafana-azure-sdk-go/azsettings" + "github.com/grafana/grafana-plugin-sdk-go/backend" sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/tsdb/prometheus/utils" - "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/util/maputil" ) @@ -22,13 +23,12 @@ var ( } ) -func (p *Provider) configureAzureAuthentication(opts *sdkhttpclient.Options) error { - // Azure authentication is experimental (#35857) - if !p.features.IsEnabled(featuremgmt.FlagPrometheusAzureAuth) { - return nil +func ConfigureAzureAuthentication(settings backend.DataSourceInstanceSettings, azureSettings *azsettings.AzureSettings, opts *sdkhttpclient.Options) error { + jsonData, err := utils.GetJsonData(settings) + if err != nil { + return fmt.Errorf("failed to get jsonData: %w", err) } - - credentials, err := azcredentials.FromDatasourceData(p.jsonData, p.settings.DecryptedSecureJSONData) + credentials, err := azcredentials.FromDatasourceData(jsonData, settings.DecryptedSecureJSONData) if err != nil { err = fmt.Errorf("invalid Azure credentials: %w", err) return err @@ -37,17 +37,17 @@ func (p *Provider) configureAzureAuthentication(opts *sdkhttpclient.Options) err if credentials != nil { var scopes []string - if scopes, err = getOverriddenScopes(p.jsonData); err != nil { + if scopes, err = getOverriddenScopes(jsonData); err != nil { return err } if scopes == nil { - if scopes, err = getPrometheusScopes(p.cfg.Azure, credentials); err != nil { + if scopes, err = getPrometheusScopes(azureSettings, credentials); err != nil { return err } } - azhttpclient.AddAzureAuthentication(opts, p.cfg.Azure, credentials, scopes) + azhttpclient.AddAzureAuthentication(opts, azureSettings, credentials, scopes) } return nil diff --git a/pkg/tsdb/prometheus/buffered/azureauth/azure_test.go b/pkg/tsdb/prometheus/buffered/azureauth/azure_test.go new file mode 100644 index 00000000000..8888c13cdec --- /dev/null +++ b/pkg/tsdb/prometheus/buffered/azureauth/azure_test.go @@ -0,0 +1,117 @@ +package azureauth + +import ( + "testing" + + "github.com/grafana/grafana-azure-sdk-go/azsettings" + "github.com/grafana/grafana-plugin-sdk-go/backend" + sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + + "github.com/grafana/grafana/pkg/setting" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfigureAzureAuthentication(t *testing.T) { + cfg := &setting.Cfg{ + Azure: &azsettings.AzureSettings{}, + } + + t.Run("should set Azure middleware when JsonData contains valid credentials", func(t *testing.T) { + settings := backend.DataSourceInstanceSettings{ + JSONData: []byte(`{ + "httpMethod": "POST", + "azureCredentials": { + "authType": "msi" + } + }`), + } + + var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} + + err := ConfigureAzureAuthentication(settings, cfg.Azure, opts) + require.NoError(t, err) + + require.NotNil(t, opts.Middlewares) + assert.Len(t, opts.Middlewares, 1) + }) + + t.Run("should not set Azure middleware when JsonData doesn't contain valid credentials", func(t *testing.T) { + settings := backend.DataSourceInstanceSettings{ + JSONData: []byte(`{ "httpMethod": "POST" }`), + } + + var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} + + err := ConfigureAzureAuthentication(settings, cfg.Azure, opts) + require.NoError(t, err) + + assert.NotContains(t, opts.CustomOptions, "_azureCredentials") + }) + + t.Run("should return error when JsonData contains invalid credentials", func(t *testing.T) { + settings := backend.DataSourceInstanceSettings{ + JSONData: []byte(`{ + "httpMethod": "POST", + "azureCredentials": "invalid" + }`), + } + + var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} + err := ConfigureAzureAuthentication(settings, cfg.Azure, opts) + assert.Error(t, err) + }) + + t.Run("should set Azure middleware when JsonData contains credentials and valid audience", func(t *testing.T) { + settings := backend.DataSourceInstanceSettings{ + JSONData: []byte(`{ + "httpMethod": "POST", + "azureCredentials": { + "authType": "msi" + }, + "azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5" + }`), + } + var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} + + err := ConfigureAzureAuthentication(settings, cfg.Azure, opts) + require.NoError(t, err) + + require.NotNil(t, opts.Middlewares) + assert.Len(t, opts.Middlewares, 1) + }) + + t.Run("should not set Azure middleware when JsonData doesn't contain credentials", func(t *testing.T) { + settings := backend.DataSourceInstanceSettings{ + JSONData: []byte(`{ + "httpMethod": "POST", + "azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5" + }`), + } + var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} + + err := ConfigureAzureAuthentication(settings, cfg.Azure, opts) + require.NoError(t, err) + + if opts.Middlewares != nil { + assert.Len(t, opts.Middlewares, 0) + } + }) + + t.Run("should return error when JsonData contains invalid audience", func(t *testing.T) { + settings := backend.DataSourceInstanceSettings{ + JSONData: []byte(`{ + "httpMethod": "POST", + "azureCredentials": { + "authType": "msi" + }, + "azureEndpointResourceId": "invalid" + }`), + } + + var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} + + err := ConfigureAzureAuthentication(settings, cfg.Azure, opts) + assert.Error(t, err) + }) +} diff --git a/pkg/tsdb/prometheus/buffered/client.go b/pkg/tsdb/prometheus/buffered/client.go new file mode 100644 index 00000000000..4b4e02c18eb --- /dev/null +++ b/pkg/tsdb/prometheus/buffered/client.go @@ -0,0 +1,80 @@ +package buffered + +import ( + "fmt" + "net/http" + "strings" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/tsdb/prometheus/buffered/azureauth" + "github.com/grafana/grafana/pkg/tsdb/prometheus/middleware" + "github.com/grafana/grafana/pkg/tsdb/prometheus/utils" + "github.com/grafana/grafana/pkg/util/maputil" + "github.com/prometheus/client_golang/api" + apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" +) + +// CreateTransportOptions creates options for the http client. Probably should be shared and should not live in the +// buffered package. +func CreateTransportOptions(settings backend.DataSourceInstanceSettings, cfg *setting.Cfg, features featuremgmt.FeatureToggles, logger log.Logger) (*sdkhttpclient.Options, error) { + opts, err := settings.HTTPClientOptions() + if err != nil { + return nil, err + } + + jsonData, err := utils.GetJsonData(settings) + if err != nil { + return nil, fmt.Errorf("error reading settings: %w", err) + } + httpMethod, _ := maputil.GetStringOptional(jsonData, "httpMethod") + + opts.Middlewares = middlewares(logger, httpMethod) + + // Set SigV4 service namespace + if opts.SigV4 != nil { + opts.SigV4.Service = "aps" + } + + // Azure authentication is experimental (#35857) + if features.IsEnabled(featuremgmt.FlagPrometheusAzureAuth) { + err = azureauth.ConfigureAzureAuthentication(settings, cfg.Azure, &opts) + if err != nil { + return nil, fmt.Errorf("error configuring Azure auth: %v", err) + } + } + + return &opts, nil +} + +func CreateClient(roundTripper http.RoundTripper, url string) (apiv1.API, error) { + cfg := api.Config{ + Address: url, + RoundTripper: roundTripper, + } + + client, err := api.NewClient(cfg) + if err != nil { + return nil, err + } + + return apiv1.NewAPI(client), nil +} + +func middlewares(logger log.Logger, httpMethod string) []sdkhttpclient.Middleware { + middlewares := []sdkhttpclient.Middleware{ + // TODO: probably isn't needed anymore and should by done by http infra code + middleware.CustomQueryParameters(logger), + sdkhttpclient.CustomHeadersMiddleware(), + } + + // Needed to control GET vs POST method of the requests + if strings.ToLower(httpMethod) == "get" { + middlewares = append(middlewares, middleware.ForceHttpGet(logger)) + } + + return middlewares +} diff --git a/pkg/tsdb/prometheus/buffered/framing_test.go b/pkg/tsdb/prometheus/buffered/framing_test.go index c1864f98baa..f3c83890d63 100644 --- a/pkg/tsdb/prometheus/buffered/framing_test.go +++ b/pkg/tsdb/prometheus/buffered/framing_test.go @@ -136,8 +136,9 @@ func runQuery(response []byte, query PrometheusQuery) (*backend.QueryDataRespons tracer: tracer, TimeInterval: "15s", log: &fakeLogger{}, + client: api, } - return s.runQueries(context.Background(), api, []*PrometheusQuery{&query}) + return s.runQueries(context.Background(), []*PrometheusQuery{&query}) } type fakeLogger struct { diff --git a/pkg/tsdb/prometheus/buffered/promclient/cache.go b/pkg/tsdb/prometheus/buffered/promclient/cache.go deleted file mode 100644 index 35cad0568dc..00000000000 --- a/pkg/tsdb/prometheus/buffered/promclient/cache.go +++ /dev/null @@ -1,56 +0,0 @@ -package promclient - -import ( - "sort" - "strings" - - lru "github.com/hashicorp/golang-lru" - apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" -) - -type ProviderCache struct { - provider promClientProvider - cache *lru.Cache -} - -type promClientProvider interface { - GetClient(map[string]string) (apiv1.API, error) -} - -func NewProviderCache(p promClientProvider) (*ProviderCache, error) { - cache, err := lru.New(500) - if err != nil { - return nil, err - } - - return &ProviderCache{ - provider: p, - cache: cache, - }, nil -} - -func (c *ProviderCache) GetClient(headers map[string]string) (apiv1.API, error) { - key := c.key(headers) - if client, ok := c.cache.Get(key); ok { - return client.(apiv1.API), nil - } - - client, err := c.provider.GetClient(headers) - if err != nil { - return nil, err - } - - c.cache.Add(key, client) - return client, nil -} - -func (c *ProviderCache) key(headers map[string]string) string { - vals := make([]string, len(headers)) - var i int - for _, v := range headers { - vals[i] = v - i++ - } - sort.Strings(vals) - return strings.Join(vals, "") -} diff --git a/pkg/tsdb/prometheus/buffered/promclient/cache_test.go b/pkg/tsdb/prometheus/buffered/promclient/cache_test.go deleted file mode 100644 index d1cbd7b8f5d..00000000000 --- a/pkg/tsdb/prometheus/buffered/promclient/cache_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package promclient_test - -import ( - "context" - "errors" - "sort" - "strings" - "testing" - - "github.com/grafana/grafana/pkg/tsdb/prometheus/buffered/promclient" - - apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" - - "github.com/stretchr/testify/require" -) - -func TestCache_GetClient(t *testing.T) { - t.Run("it caches the client for a set of auth headers", func(t *testing.T) { - tc := setupCacheContext() - - c, err := tc.providerCache.GetClient(headers) - require.Nil(t, err) - - c2, err := tc.providerCache.GetClient(headers) - require.Nil(t, err) - - require.Equal(t, c, c2) - require.Equal(t, 1, tc.clientProvider.numCalls) - }) - - t.Run("it returns different clients when the headers differ", func(t *testing.T) { - tc := setupCacheContext() - h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"} - h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"} - - c, err := tc.providerCache.GetClient(h1) - require.Nil(t, err) - - c2, err := tc.providerCache.GetClient(h2) - require.Nil(t, err) - - require.NotEqual(t, c, c2) - require.Equal(t, 2, tc.clientProvider.numCalls) - }) - - t.Run("it returns from the cache when headers are the same", func(t *testing.T) { - tc := setupCacheContext() - h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"} - h2 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"} - - c, err := tc.providerCache.GetClient(h1) - require.Nil(t, err) - - c2, err := tc.providerCache.GetClient(h2) - require.Nil(t, err) - - require.Equal(t, c, c2) - require.Equal(t, 1, tc.clientProvider.numCalls) - }) - - t.Run("it doesn't cache anything when an error occurs", func(t *testing.T) { - tc := setupCacheContext() - tc.clientProvider.errors <- errors.New("something bad") - - _, err := tc.providerCache.GetClient(headers) - require.EqualError(t, err, "something bad") - - c, err := tc.providerCache.GetClient(headers) - require.Nil(t, err) - - require.NotNil(t, c) - require.Equal(t, 2, tc.clientProvider.numCalls) - }) -} - -type cacheTestContext struct { - providerCache *promclient.ProviderCache - clientProvider *fakePromClientProvider -} - -func setupCacheContext() *cacheTestContext { - fp := newFakePromClientProvider() - p, err := promclient.NewProviderCache(fp) - if err != nil { - panic(err) - } - - return &cacheTestContext{ - providerCache: p, - clientProvider: fp, - } -} - -func newFakePromClientProvider() *fakePromClientProvider { - return &fakePromClientProvider{ - errors: make(chan error, 1), - } -} - -type fakePromClientProvider struct { - headers map[string]string - numCalls int - errors chan error -} - -func (p *fakePromClientProvider) GetClient(h map[string]string) (apiv1.API, error) { - p.headers = h - p.numCalls++ - - var err error - select { - case err = <-p.errors: - default: - } - - var config []string - for _, v := range h { - config = append(config, v) - } - sort.Strings(config) //because map - return &fakePromClient{config: strings.Join(config, "")}, err -} - -type fakePromClient struct { - apiv1.API - config string -} - -func (c *fakePromClient) Config(ctx context.Context) (apiv1.ConfigResult, error) { - return apiv1.ConfigResult{YAML: c.config}, nil -} diff --git a/pkg/tsdb/prometheus/buffered/promclient/provider.go b/pkg/tsdb/prometheus/buffered/promclient/provider.go deleted file mode 100644 index 3ccff28004c..00000000000 --- a/pkg/tsdb/prometheus/buffered/promclient/provider.go +++ /dev/null @@ -1,106 +0,0 @@ -package promclient - -import ( - "strings" - - "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/grafana/pkg/services/featuremgmt" - "github.com/grafana/grafana/pkg/setting" - "github.com/grafana/grafana/pkg/tsdb/prometheus/middleware" - "github.com/grafana/grafana/pkg/util/maputil" - - sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/grafana/grafana/pkg/infra/httpclient" - "github.com/grafana/grafana/pkg/infra/log" - "github.com/prometheus/client_golang/api" - apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" -) - -type Provider struct { - settings backend.DataSourceInstanceSettings - jsonData map[string]interface{} - httpMethod string - clientProvider httpclient.Provider - cfg *setting.Cfg - features featuremgmt.FeatureToggles - log log.Logger -} - -func NewProvider( - settings backend.DataSourceInstanceSettings, - jsonData map[string]interface{}, - clientProvider httpclient.Provider, - cfg *setting.Cfg, - features featuremgmt.FeatureToggles, - log log.Logger, -) *Provider { - httpMethod, _ := maputil.GetStringOptional(jsonData, "httpMethod") - return &Provider{ - settings: settings, - jsonData: jsonData, - httpMethod: httpMethod, - clientProvider: clientProvider, - cfg: cfg, - features: features, - log: log, - } -} - -func (p *Provider) GetClient(headers map[string]string) (apiv1.API, error) { - opts, err := p.settings.HTTPClientOptions() - if err != nil { - return nil, err - } - - opts.Middlewares = p.middlewares() - opts.Headers = reqHeaders(headers) - - // Set SigV4 service namespace - if opts.SigV4 != nil { - opts.SigV4.Service = "aps" - } - - // Azure authentication - err = p.configureAzureAuthentication(&opts) - if err != nil { - return nil, err - } - - roundTripper, err := p.clientProvider.GetTransport(opts) - if err != nil { - return nil, err - } - - cfg := api.Config{ - Address: p.settings.URL, - RoundTripper: roundTripper, - } - - client, err := api.NewClient(cfg) - if err != nil { - return nil, err - } - - return apiv1.NewAPI(client), nil -} - -func (p *Provider) middlewares() []sdkhttpclient.Middleware { - middlewares := []sdkhttpclient.Middleware{ - middleware.CustomQueryParameters(p.log), - sdkhttpclient.CustomHeadersMiddleware(), - } - if strings.ToLower(p.httpMethod) == "get" { - middlewares = append(middlewares, middleware.ForceHttpGet(p.log)) - } - - return middlewares -} - -func reqHeaders(headers map[string]string) map[string]string { - // copy to avoid changing the original map - h := make(map[string]string, len(headers)) - for k, v := range headers { - h[k] = v - } - return h -} diff --git a/pkg/tsdb/prometheus/buffered/promclient/provider_azure_test.go b/pkg/tsdb/prometheus/buffered/promclient/provider_azure_test.go deleted file mode 100644 index 74c60e94ecf..00000000000 --- a/pkg/tsdb/prometheus/buffered/promclient/provider_azure_test.go +++ /dev/null @@ -1,153 +0,0 @@ -package promclient - -import ( - "testing" - - "github.com/grafana/grafana-azure-sdk-go/azsettings" - "github.com/grafana/grafana-plugin-sdk-go/backend" - sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - - "github.com/grafana/grafana/pkg/services/featuremgmt" - "github.com/grafana/grafana/pkg/setting" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestConfigureAzureAuthentication(t *testing.T) { - cfg := &setting.Cfg{ - Azure: &azsettings.AzureSettings{}, - } - settings := backend.DataSourceInstanceSettings{} - - t.Run("given feature flag enabled", func(t *testing.T) { - features := featuremgmt.WithFeatures(featuremgmt.FlagPrometheusAzureAuth) - - t.Run("should set Azure middleware when JsonData contains valid credentials", func(t *testing.T) { - jsonData := map[string]interface{}{ - "httpMethod": "POST", - "azureCredentials": map[string]interface{}{ - "authType": "msi", - }, - } - - var p = NewProvider(settings, jsonData, nil, cfg, features, nil) - - var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} - - err := p.configureAzureAuthentication(opts) - require.NoError(t, err) - - require.NotNil(t, opts.Middlewares) - assert.Len(t, opts.Middlewares, 1) - }) - - t.Run("should not set Azure middleware when JsonData doesn't contain valid credentials", func(t *testing.T) { - jsonData := map[string]interface{}{ - "httpMethod": "POST", - } - - var p = NewProvider(settings, jsonData, nil, cfg, features, nil) - - var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} - - err := p.configureAzureAuthentication(opts) - require.NoError(t, err) - - assert.NotContains(t, opts.CustomOptions, "_azureCredentials") - }) - - t.Run("should return error when JsonData contains invalid credentials", func(t *testing.T) { - jsonData := map[string]interface{}{ - "httpMethod": "POST", - "azureCredentials": "invalid", - } - - var p = NewProvider(settings, jsonData, nil, cfg, features, nil) - - var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} - - err := p.configureAzureAuthentication(opts) - assert.Error(t, err) - }) - - t.Run("should set Azure middleware when JsonData contains credentials and valid audience", func(t *testing.T) { - jsonData := map[string]interface{}{ - "httpMethod": "POST", - "azureCredentials": map[string]interface{}{ - "authType": "msi", - }, - "azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5", - } - - var p = NewProvider(settings, jsonData, nil, cfg, features, nil) - - var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} - - err := p.configureAzureAuthentication(opts) - require.NoError(t, err) - - require.NotNil(t, opts.Middlewares) - assert.Len(t, opts.Middlewares, 1) - }) - - t.Run("should not set Azure middleware when JsonData doesn't contain credentials", func(t *testing.T) { - jsonData := map[string]interface{}{ - "httpMethod": "POST", - "azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5", - } - - var p = NewProvider(settings, jsonData, nil, cfg, features, nil) - - var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} - - err := p.configureAzureAuthentication(opts) - require.NoError(t, err) - - if opts.Middlewares != nil { - assert.Len(t, opts.Middlewares, 0) - } - }) - - t.Run("should return error when JsonData contains invalid audience", func(t *testing.T) { - jsonData := map[string]interface{}{ - "httpMethod": "POST", - "azureCredentials": map[string]interface{}{ - "authType": "msi", - }, - "azureEndpointResourceId": "invalid", - } - - var p = NewProvider(settings, jsonData, nil, cfg, features, nil) - - var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} - - err := p.configureAzureAuthentication(opts) - assert.Error(t, err) - }) - }) - - t.Run("given feature flag not enabled", func(t *testing.T) { - features := featuremgmt.WithFeatures() - - t.Run("should not set Azure Credentials even when JsonData contains credentials", func(t *testing.T) { - jsonData := map[string]interface{}{ - "httpMethod": "POST", - "azureCredentials": map[string]interface{}{ - "authType": "msi", - }, - "azureEndpointResourceId": "https://api.example.com/abd5c4ce-ca73-41e9-9cb2-bed39aa2adb5", - } - - var p = NewProvider(settings, jsonData, nil, cfg, features, nil) - - var opts = &sdkhttpclient.Options{CustomOptions: map[string]interface{}{}} - - err := p.configureAzureAuthentication(opts) - require.NoError(t, err) - - if opts.Middlewares != nil { - assert.Len(t, opts.Middlewares, 0) - } - }) - }) -} diff --git a/pkg/tsdb/prometheus/buffered/promclient/provider_test.go b/pkg/tsdb/prometheus/buffered/promclient/provider_test.go deleted file mode 100644 index cc340e3d330..00000000000 --- a/pkg/tsdb/prometheus/buffered/promclient/provider_test.go +++ /dev/null @@ -1,181 +0,0 @@ -package promclient_test - -import ( - "encoding/json" - "net/http" - "testing" - - "github.com/grafana/grafana/pkg/tsdb/prometheus/buffered/promclient" - - "github.com/grafana/grafana/pkg/services/featuremgmt" - "github.com/grafana/grafana/pkg/setting" - - "github.com/grafana/grafana-plugin-sdk-go/backend" - - sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" - "github.com/grafana/grafana/pkg/infra/httpclient" - - "github.com/stretchr/testify/require" -) - -var headers = map[string]string{"Authorization": "token", "X-ID-Token": "id-token"} - -func TestGetClient(t *testing.T) { - t.Run("it sets the SigV4 service if it exists", func(t *testing.T) { - tc := setup(`{"sigV4Auth":true}`) - - setting.SigV4AuthEnabled = true - defer func() { setting.SigV4AuthEnabled = false }() - - _, err := tc.promClientProvider.GetClient(headers) - require.Nil(t, err) - - require.Equal(t, "aps", tc.httpProvider.opts.SigV4.Service) - }) - - t.Run("it always uses the custom params and custom headers middlewares", func(t *testing.T) { - tc := setup() - - _, err := tc.promClientProvider.GetClient(headers) - require.Nil(t, err) - - require.Len(t, tc.httpProvider.middlewares(), 2) - require.Contains(t, tc.httpProvider.middlewares(), "prom-custom-query-parameters") - require.Contains(t, tc.httpProvider.middlewares(), "CustomHeaders") - }) - - t.Run("extra headers", func(t *testing.T) { - t.Run("it sets the headers when 'oauthPassThru' is true and auth headers are passed", func(t *testing.T) { - tc := setup(`{"oauthPassThru":true}`) - _, err := tc.promClientProvider.GetClient(headers) - require.Nil(t, err) - - require.Equal(t, headers, tc.httpProvider.opts.Headers) - }) - - t.Run("it sets all headers", func(t *testing.T) { - withNonAuth := map[string]string{"X-Not-Auth": "stuff"} - - tc := setup(`{"oauthPassThru":true}`) - _, err := tc.promClientProvider.GetClient(withNonAuth) - require.Nil(t, err) - - require.Equal(t, map[string]string{"X-Not-Auth": "stuff"}, tc.httpProvider.opts.Headers) - }) - - t.Run("it does not error when headers are nil", func(t *testing.T) { - tc := setup(`{"oauthPassThru":true}`) - - _, err := tc.promClientProvider.GetClient(nil) - require.Nil(t, err) - }) - }) - - t.Run("force get middleware", func(t *testing.T) { - t.Run("it add the force-get middleware when httpMethod is get", func(t *testing.T) { - tc := setup(`{"httpMethod":"get"}`) - - _, err := tc.promClientProvider.GetClient(headers) - require.Nil(t, err) - - require.Len(t, tc.httpProvider.middlewares(), 3) - require.Contains(t, tc.httpProvider.middlewares(), "force-http-get") - }) - - t.Run("it add the force-get middleware when httpMethod is get", func(t *testing.T) { - tc := setup(`{"httpMethod":"GET"}`) - - _, err := tc.promClientProvider.GetClient(headers) - require.Nil(t, err) - - require.Len(t, tc.httpProvider.middlewares(), 3) - require.Contains(t, tc.httpProvider.middlewares(), "force-http-get") - }) - - t.Run("it does not add the force-get middleware when httpMethod is POST", func(t *testing.T) { - tc := setup(`{"httpMethod":"POST"}`) - - _, err := tc.promClientProvider.GetClient(headers) - require.Nil(t, err) - - require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get") - }) - - t.Run("it does not add the force-get middleware when json data is nil", func(t *testing.T) { - tc := setup() - - _, err := tc.promClientProvider.GetClient(headers) - require.Nil(t, err) - - require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get") - }) - - t.Run("it does not add the force-get middleware when json data is empty", func(t *testing.T) { - tc := setup(`{}`) - - _, err := tc.promClientProvider.GetClient(headers) - require.Nil(t, err) - - require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get") - }) - - t.Run("it does not add the force-get middleware httpMethod is null", func(t *testing.T) { - tc := setup(`{"httpMethod":null}`) - - _, err := tc.promClientProvider.GetClient(headers) - require.Nil(t, err) - - require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get") - }) - }) -} - -func setup(jsonData ...string) *testContext { - var rawData []byte - if len(jsonData) > 0 { - rawData = []byte(jsonData[0]) - } - - var jd map[string]interface{} - _ = json.Unmarshal(rawData, &jd) - - cfg := &setting.Cfg{} - settings := backend.DataSourceInstanceSettings{URL: "test-url", JSONData: rawData} - features := featuremgmt.WithFeatures() - hp := &fakeHttpClientProvider{} - p := promclient.NewProvider(settings, jd, hp, cfg, features, nil) - - return &testContext{ - httpProvider: hp, - promClientProvider: p, - } -} - -type testContext struct { - httpProvider *fakeHttpClientProvider - promClientProvider *promclient.Provider -} - -type fakeHttpClientProvider struct { - httpclient.Provider - - opts sdkhttpclient.Options -} - -func (p *fakeHttpClientProvider) GetTransport(opts ...sdkhttpclient.Options) (http.RoundTripper, error) { - p.opts = opts[0] - return http.DefaultTransport, nil -} - -func (p *fakeHttpClientProvider) middlewares() []string { - var middlewareNames []string - for _, m := range p.opts.Middlewares { - mw, ok := m.(sdkhttpclient.MiddlewareName) - if !ok { - panic("unexpected middleware type") - } - - middlewareNames = append(middlewareNames, mw.MiddlewareName()) - } - return middlewareNames -} diff --git a/pkg/tsdb/prometheus/buffered/prometeus_bench_test.go b/pkg/tsdb/prometheus/buffered/prometeus_bench_test.go index 319950f40b0..b4d0655c22d 100644 --- a/pkg/tsdb/prometheus/buffered/prometeus_bench_test.go +++ b/pkg/tsdb/prometheus/buffered/prometeus_bench_test.go @@ -21,11 +21,11 @@ func BenchmarkJson(b *testing.B) { api, err := makeMockedApi(resp) require.NoError(b, err) - s := Buffered{tracer: tracing.InitializeTracerForTest(), log: &fakeLogger{}} + s := Buffered{tracer: tracing.InitializeTracerForTest(), log: &fakeLogger{}, client: api} b.ResetTimer() for n := 0; n < b.N; n++ { - _, err := s.runQueries(context.Background(), api, []*PrometheusQuery{&query}) + _, err := s.runQueries(context.Background(), []*PrometheusQuery{&query}) require.NoError(b, err) } } diff --git a/pkg/tsdb/prometheus/buffered/time_series_query.go b/pkg/tsdb/prometheus/buffered/time_series_query.go index 40db2d9100d..8d2634c35d1 100644 --- a/pkg/tsdb/prometheus/buffered/time_series_query.go +++ b/pkg/tsdb/prometheus/buffered/time_series_query.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "math" + "net/http" "regexp" "sort" "strconv" @@ -12,14 +13,13 @@ import ( "time" "github.com/grafana/grafana-plugin-sdk-go/backend" + sdkHTTPClient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" "github.com/grafana/grafana-plugin-sdk-go/data" - "github.com/grafana/grafana/pkg/infra/httpclient" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/tracing" - "github.com/grafana/grafana/pkg/services/featuremgmt" - "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/tsdb/intervalv2" - "github.com/grafana/grafana/pkg/tsdb/prometheus/buffered/promclient" + "github.com/grafana/grafana/pkg/tsdb/prometheus/middleware" + "github.com/grafana/grafana/pkg/tsdb/prometheus/utils" "github.com/grafana/grafana/pkg/util/maputil" apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" "github.com/prometheus/common/model" @@ -57,41 +57,59 @@ var ( type Buffered struct { intervalCalculator intervalv2.Calculator tracer tracing.Tracer - getClient clientGetter + client apiv1.API log log.Logger ID int64 URL string TimeInterval string } -func New(httpClientProvider httpclient.Provider, cfg *setting.Cfg, features featuremgmt.FeatureToggles, tracer tracing.Tracer, settings backend.DataSourceInstanceSettings, plog log.Logger) (*Buffered, error) { - var jsonData map[string]interface{} - if err := json.Unmarshal(settings.JSONData, &jsonData); err != nil { - return nil, fmt.Errorf("error reading settings: %w", err) +// New creates and object capable of executing and parsing a Prometheus queries. It's "buffered" because there is +// another implementation capable of streaming parse the response. +func New(roundTripper http.RoundTripper, tracer tracing.Tracer, settings backend.DataSourceInstanceSettings, plog log.Logger) (*Buffered, error) { + promClient, err := CreateClient(roundTripper, settings.URL) + if err != nil { + return nil, fmt.Errorf("error creating prom client: %v", err) } - timeInterval, err := maputil.GetStringOptional(jsonData, "timeInterval") + jsonData, err := utils.GetJsonData(settings) if err != nil { - return nil, err + return nil, fmt.Errorf("error getting jsonData: %w", err) } - p := promclient.NewProvider(settings, jsonData, httpClientProvider, cfg, features, plog) - pc, err := promclient.NewProviderCache(p) + timeInterval, err := maputil.GetStringOptional(jsonData, "timeInterval") if err != nil { return nil, err } + return &Buffered{ intervalCalculator: intervalv2.NewCalculator(), tracer: tracer, log: plog, - getClient: pc.GetClient, + client: promClient, TimeInterval: timeInterval, ID: settings.ID, URL: settings.URL, }, nil } -func (b *Buffered) runQueries(ctx context.Context, client apiv1.API, queries []*PrometheusQuery) (*backend.QueryDataResponse, error) { +func (b *Buffered) ExecuteTimeSeriesQuery(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + // Add headers from the request to context so they are added later on by a context middleware. This is because + // prom client does not allow us to do this directly. + ctxWithHeaders := sdkHTTPClient.WithContextualMiddleware(ctx, middleware.ReqHeadersMiddleware(req.Headers)) + + queries, err := b.parseTimeSeriesQuery(req) + if err != nil { + result := backend.QueryDataResponse{ + Responses: backend.Responses{}, + } + return &result, fmt.Errorf("error parsing time series query: %v", err) + } + + return b.runQueries(ctxWithHeaders, queries) +} + +func (b *Buffered) runQueries(ctx context.Context, queries []*PrometheusQuery) (*backend.QueryDataResponse, error) { result := backend.QueryDataResponse{ Responses: backend.Responses{}, } @@ -99,11 +117,12 @@ func (b *Buffered) runQueries(ctx context.Context, client apiv1.API, queries []* for _, query := range queries { b.log.Debug("Sending query", "start", query.Start, "end", query.End, "step", query.Step, "query", query.Expr) - ctx, span := b.tracer.Start(ctx, "datasource.prometheus") - span.SetAttributes("expr", query.Expr, attribute.Key("expr").String(query.Expr)) - span.SetAttributes("start_unixnano", query.Start, attribute.Key("start_unixnano").Int64(query.Start.UnixNano())) - span.SetAttributes("stop_unixnano", query.End, attribute.Key("stop_unixnano").Int64(query.End.UnixNano())) - defer span.End() + ctx, endSpan := utils.StartTrace(ctx, b.tracer, "datasource.prometheus", []utils.Attribute{ + {Key: "expr", Value: query.Expr, Kv: attribute.Key("expr").String(query.Expr)}, + {Key: "start_unixnano", Value: query.Start, Kv: attribute.Key("start_unixnano").Int64(query.Start.UnixNano())}, + {Key: "stop_unixnano", Value: query.End, Kv: attribute.Key("stop_unixnano").Int64(query.End.UnixNano())}, + }) + defer endSpan() response := make(map[TimeSeriesQueryType]interface{}) @@ -115,7 +134,7 @@ func (b *Buffered) runQueries(ctx context.Context, client apiv1.API, queries []* } if query.RangeQuery { - rangeResponse, _, err := client.QueryRange(ctx, query.Expr, timeRange) + rangeResponse, _, err := b.client.QueryRange(ctx, query.Expr, timeRange) if err != nil { b.log.Error("Range query failed", "query", query.Expr, "err", err) result.Responses[query.RefId] = backend.DataResponse{Error: err} @@ -125,7 +144,7 @@ func (b *Buffered) runQueries(ctx context.Context, client apiv1.API, queries []* } if query.InstantQuery { - instantResponse, _, err := client.Query(ctx, query.Expr, query.End) + instantResponse, _, err := b.client.Query(ctx, query.Expr, query.End) if err != nil { b.log.Error("Instant query failed", "query", query.Expr, "err", err) result.Responses[query.RefId] = backend.DataResponse{Error: err} @@ -137,7 +156,7 @@ func (b *Buffered) runQueries(ctx context.Context, client apiv1.API, queries []* // This is a special case // If exemplar query returns error, we want to only log it and continue with other results processing if query.ExemplarQuery { - exemplarResponse, err := client.QueryExemplars(ctx, query.Expr, timeRange.Start, timeRange.End) + exemplarResponse, err := b.client.QueryExemplars(ctx, query.Expr, timeRange.Start, timeRange.End) if err != nil { b.log.Error("Exemplar query failed", "query", query.Expr, "err", err) } else { @@ -163,23 +182,6 @@ func (b *Buffered) runQueries(ctx context.Context, client apiv1.API, queries []* return &result, nil } -func (b *Buffered) ExecuteTimeSeriesQuery(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { - client, err := b.getClient(req.Headers) - if err != nil { - return nil, err - } - - queries, err := b.parseTimeSeriesQuery(req) - if err != nil { - result := backend.QueryDataResponse{ - Responses: backend.Responses{}, - } - return &result, err - } - - return b.runQueries(ctx, client, queries) -} - func formatLegend(metric model.Metric, query *PrometheusQuery) string { var legend = metric.String() @@ -209,18 +211,18 @@ func formatLegend(metric model.Metric, query *PrometheusQuery) string { return legend } -func (b *Buffered) parseTimeSeriesQuery(queryContext *backend.QueryDataRequest) ([]*PrometheusQuery, error) { +func (b *Buffered) parseTimeSeriesQuery(req *backend.QueryDataRequest) ([]*PrometheusQuery, error) { qs := []*PrometheusQuery{} - for _, query := range queryContext.Queries { + for _, query := range req.Queries { model := &QueryModel{} err := json.Unmarshal(query.JSON, model) if err != nil { - return nil, err + return nil, fmt.Errorf("error unmarshaling query model: %v", err) } //Final interval value interval, err := calculatePrometheusInterval(model, b.TimeInterval, query, b.intervalCalculator) if err != nil { - return nil, err + return nil, fmt.Errorf("error calculating interval: %v", err) } // Interpolate variables in expr @@ -234,7 +236,7 @@ func (b *Buffered) parseTimeSeriesQuery(queryContext *backend.QueryDataRequest) // We never want to run exemplar query for alerting exemplarQuery := model.ExemplarQuery - if queryContext.Headers["FromAlert"] == "true" { + if req.Headers["FromAlert"] == "true" { exemplarQuery = false } diff --git a/pkg/tsdb/prometheus/buffered/time_series_query_test.go b/pkg/tsdb/prometheus/buffered/time_series_query_test.go index d0660c27624..0e868608a58 100644 --- a/pkg/tsdb/prometheus/buffered/time_series_query_test.go +++ b/pkg/tsdb/prometheus/buffered/time_series_query_test.go @@ -1,13 +1,17 @@ package buffered import ( + "context" "math" + "net/http" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/grafana/grafana-plugin-sdk-go/backend" + sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" "github.com/grafana/grafana-plugin-sdk-go/data" + "github.com/grafana/grafana/pkg/infra/log/logtest" "github.com/grafana/grafana/pkg/tsdb/intervalv2" apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" p "github.com/prometheus/common/model" @@ -16,7 +20,58 @@ import ( var now = time.Now() -func TestPrometheus_timeSeriesQuery_formatLeged(t *testing.T) { +type FakeRoundTripper struct { + Req *http.Request +} + +func (frt *FakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + frt.Req = req + return &http.Response{}, nil +} + +func FakeMiddleware(rt *FakeRoundTripper) sdkhttpclient.Middleware { + return sdkhttpclient.NamedMiddlewareFunc("fake", func(opts sdkhttpclient.Options, next http.RoundTripper) http.RoundTripper { + return rt + }) +} + +func TestPrometheus_ExecuteTimeSeriesQuery(t *testing.T) { + t.Run("adding req headers", func(t *testing.T) { + // This makes sure we add req headers from the front end request to the request to prometheus. We do that + // through contextual middleware so this setup is a bit complex and the test itself goes a bit too much into + // internals. + + // This ends the trip and saves the request on the instance so we can inspect it. + rt := &FakeRoundTripper{} + // DefaultMiddlewares also contain contextual middleware which is the one we need to use. + middlewares := sdkhttpclient.DefaultMiddlewares() + middlewares = append(middlewares, FakeMiddleware(rt)) + + // Setup http client in at least similar way to how grafana provides it to the service + provider := sdkhttpclient.NewProvider(sdkhttpclient.ProviderOptions{Middlewares: sdkhttpclient.DefaultMiddlewares()}) + roundTripper, err := provider.GetTransport(sdkhttpclient.Options{ + Middlewares: middlewares, + }) + require.NoError(t, err) + + buffered, err := New(roundTripper, nil, backend.DataSourceInstanceSettings{JSONData: []byte("{}")}, &logtest.Fake{}) + require.NoError(t, err) + + _, err = buffered.ExecuteTimeSeriesQuery(context.Background(), &backend.QueryDataRequest{ + PluginContext: backend.PluginContext{}, + // This header should end up in the outgoing request to prometheus + Headers: map[string]string{"foo": "bar"}, + Queries: []backend.DataQuery{{ + JSON: []byte(`{"expr": "metric{label=\"test\"}", "rangeQuery": true}`), + }}, + }) + require.NoError(t, err) + require.NotNil(t, rt.Req) + require.Equal(t, http.Header{"Content-Type": []string{"application/x-www-form-urlencoded"}, "foo": []string{"bar"}}, rt.Req.Header) + }) +} + +func TestPrometheus_timeSeriesQuery_formatLegend(t *testing.T) { t.Run("converting metric name", func(t *testing.T) { metric := map[p.LabelName]p.LabelValue{ p.LabelName("app"): p.LabelValue("backend"), diff --git a/pkg/tsdb/prometheus/buffered/types.go b/pkg/tsdb/prometheus/buffered/types.go index 8dec9d19dca..804d6dcdec1 100644 --- a/pkg/tsdb/prometheus/buffered/types.go +++ b/pkg/tsdb/prometheus/buffered/types.go @@ -2,12 +2,8 @@ package buffered import ( "time" - - apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" ) -type clientGetter func(map[string]string) (apiv1.API, error) - type PrometheusQuery struct { Expr string Step time.Duration diff --git a/pkg/tsdb/prometheus/middleware/req_headers.go b/pkg/tsdb/prometheus/middleware/req_headers.go new file mode 100644 index 00000000000..16bae1f8b5d --- /dev/null +++ b/pkg/tsdb/prometheus/middleware/req_headers.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "net/http" + + sdkHTTPClient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" +) + +// ReqHeadersMiddleware is used so that we can pass req headers through the prometheus go client as it does not allow +// access to the request directly. Should be used together with WithContextualMiddleware so that it is attached to +// the context of each request with its unique headers. +func ReqHeadersMiddleware(headers map[string]string) sdkHTTPClient.Middleware { + return sdkHTTPClient.NamedMiddlewareFunc("prometheus-req-headers-middleware", func(opts sdkHTTPClient.Options, next http.RoundTripper) http.RoundTripper { + return sdkHTTPClient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + for k, v := range headers { + // As custom headers middleware is before contextual we may overwrite custom headers here with those + // that came with the request which probably makes sense. + req.Header[k] = []string{v} + } + + return next.RoundTrip(req) + }) + }) +} diff --git a/pkg/tsdb/prometheus/prometheus.go b/pkg/tsdb/prometheus/prometheus.go index 310cb5a69a5..4febc222925 100644 --- a/pkg/tsdb/prometheus/prometheus.go +++ b/pkg/tsdb/prometheus/prometheus.go @@ -2,7 +2,6 @@ package prometheus import ( "context" - "encoding/json" "errors" "fmt" @@ -43,13 +42,17 @@ func ProvideService(httpClientProvider httpclient.Provider, cfg *setting.Cfg, fe func newInstanceSettings(httpClientProvider httpclient.Provider, cfg *setting.Cfg, features featuremgmt.FeatureToggles, tracer tracing.Tracer) datasource.InstanceFactoryFunc { return func(settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) { - var jsonData map[string]interface{} - err := json.Unmarshal(settings.JSONData, &jsonData) + // Creates a http roundTripper. Probably should be used for both buffered and streaming/querydata instances. + opts, err := buffered.CreateTransportOptions(settings, cfg, features, plog) if err != nil { - return nil, fmt.Errorf("error reading settings: %w", err) + return nil, fmt.Errorf("error creating transport options: %v", err) + } + roundTripper, err := httpClientProvider.GetTransport(*opts) + if err != nil { + return nil, fmt.Errorf("error creating http client: %v", err) } - b, err := buffered.New(httpClientProvider, cfg, features, tracer, settings, plog) + b, err := buffered.New(roundTripper, tracer, settings, plog) if err != nil { return nil, err } diff --git a/pkg/tsdb/prometheus/prometheus_test.go b/pkg/tsdb/prometheus/prometheus_test.go index d5ee22300a7..29a496c0139 100644 --- a/pkg/tsdb/prometheus/prometheus_test.go +++ b/pkg/tsdb/prometheus/prometheus_test.go @@ -45,6 +45,10 @@ func (provider *fakeHTTPClientProvider) New(opts ...sdkHttpClient.Options) (*htt return client, nil } +func (provider *fakeHTTPClientProvider) GetTransport(opts ...sdkHttpClient.Options) (http.RoundTripper, error) { + return &fakeRoundtripper{}, nil +} + func TestClient(t *testing.T) { t.Run("Service", func(t *testing.T) { t.Run("CallResource", func(t *testing.T) { diff --git a/pkg/tsdb/prometheus/utils/utils.go b/pkg/tsdb/prometheus/utils/utils.go new file mode 100644 index 00000000000..8930a214f91 --- /dev/null +++ b/pkg/tsdb/prometheus/utils/utils.go @@ -0,0 +1,42 @@ +package utils + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana/pkg/infra/tracing" + "go.opentelemetry.io/otel/attribute" +) + +// GetJsonData just gets the json in easier to work with type. It's used on multiple places which isn't super effective +// but only when creating a client which should not happen often anyway. +func GetJsonData(settings backend.DataSourceInstanceSettings) (map[string]interface{}, error) { + var jsonData map[string]interface{} + err := json.Unmarshal(settings.JSONData, &jsonData) + if err != nil { + return nil, fmt.Errorf("error unmarshalling JSONData: %w", err) + } + return jsonData, nil +} + +type Attribute struct { + Key string + Value interface{} + Kv attribute.KeyValue +} + +// StartTrace setups a trace but does not panic if tracer is nil which helps with testing +func StartTrace(ctx context.Context, tracer tracing.Tracer, name string, attributes []Attribute) (context.Context, func()) { + if tracer == nil { + return ctx, func() {} + } + ctx, span := tracer.Start(ctx, name) + for _, attr := range attributes { + span.SetAttributes(attr.Key, attr.Value, attr.Kv) + } + return ctx, func() { + span.End() + } +}