mirror of https://github.com/grafana/grafana
* create the prom client
* implement lru cache of prometheus clients based on auth headers
* linter
(cherry picked from commit 20b3b2a448)
pull/43841/head
parent
38f86645e7
commit
d3d3e99a3f
@ -0,0 +1,55 @@ |
||||
package promclient |
||||
|
||||
import ( |
||||
lru "github.com/hashicorp/golang-lru" |
||||
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" |
||||
) |
||||
|
||||
const ( |
||||
noPassThrough = "no-pass-through" |
||||
) |
||||
|
||||
type ProviderCache struct { |
||||
provider promClientProvider |
||||
cache *lru.Cache |
||||
jsonData JsonData |
||||
} |
||||
|
||||
type promClientProvider interface { |
||||
GetClient(map[string]string) (apiv1.API, error) |
||||
} |
||||
|
||||
func NewProviderCache(p promClientProvider, jd JsonData) (*ProviderCache, error) { |
||||
cache, err := lru.New(500) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &ProviderCache{ |
||||
provider: p, |
||||
cache: cache, |
||||
jsonData: jd, |
||||
}, nil |
||||
} |
||||
|
||||
func (c *ProviderCache) GetClient(headers map[string]string) (apiv1.API, error) { |
||||
key := c.key(headers) |
||||
if client, ok := c.cache.Get(key); ok { |
||||
return client.(apiv1.API), nil |
||||
} |
||||
|
||||
client, err := c.provider.GetClient(headers) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
c.cache.Add(key, client) |
||||
return client, nil |
||||
} |
||||
|
||||
func (c *ProviderCache) key(headers map[string]string) string { |
||||
if c.jsonData.OauthPassThru { |
||||
return headers[authHeader] + headers[idTokenHeader] |
||||
} |
||||
return noPassThrough |
||||
} |
||||
@ -0,0 +1,144 @@ |
||||
package promclient_test |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"sort" |
||||
"strings" |
||||
"testing" |
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/promclient" |
||||
|
||||
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" |
||||
|
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestCache_GetClient(t *testing.T) { |
||||
t.Run("it caches the client for a set of auth headers", func(t *testing.T) { |
||||
tc := setupCacheContext(true) |
||||
|
||||
c, err := tc.providerCache.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
c2, err := tc.providerCache.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.Equal(t, c, c2) |
||||
require.Equal(t, 1, tc.clientProvider.numCalls) |
||||
}) |
||||
|
||||
t.Run("it returns different clients when the auth headers differ", func(t *testing.T) { |
||||
tc := setupCacheContext(true) |
||||
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"} |
||||
h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"} |
||||
|
||||
c, err := tc.providerCache.GetClient(h1) |
||||
require.Nil(t, err) |
||||
|
||||
c2, err := tc.providerCache.GetClient(h2) |
||||
require.Nil(t, err) |
||||
|
||||
require.NotEqual(t, c, c2) |
||||
require.Equal(t, 2, tc.clientProvider.numCalls) |
||||
}) |
||||
|
||||
t.Run("it always returns from the cache when 'oauthPassThru' not set", func(t *testing.T) { |
||||
tc := setupCacheContext(false) |
||||
h1 := map[string]string{"Authorization": "token", "X-ID-Token": "id-token"} |
||||
h2 := map[string]string{"Authorization": "token2", "X-ID-Token": "id-token"} |
||||
|
||||
c, err := tc.providerCache.GetClient(h1) |
||||
require.Nil(t, err) |
||||
|
||||
c2, err := tc.providerCache.GetClient(h2) |
||||
require.Nil(t, err) |
||||
|
||||
require.Equal(t, c, c2) |
||||
require.Equal(t, 1, tc.clientProvider.numCalls) |
||||
}) |
||||
|
||||
t.Run("it only accounts for auth headers", func(t *testing.T) { |
||||
tc := setupCacheContext(true) |
||||
|
||||
c, err := tc.providerCache.GetClient(map[string]string{"X-Not-Auth": "stuff"}) |
||||
require.Nil(t, err) |
||||
|
||||
c2, err := tc.providerCache.GetClient(map[string]string{"X-Not-Auth": "other-stuff"}) |
||||
require.Nil(t, err) |
||||
|
||||
require.Equal(t, c, c2) |
||||
require.Equal(t, 1, tc.clientProvider.numCalls) |
||||
}) |
||||
|
||||
t.Run("it doesn't cache anything when an error occurs", func(t *testing.T) { |
||||
tc := setupCacheContext(true) |
||||
tc.clientProvider.errors <- errors.New("something bad") |
||||
|
||||
_, err := tc.providerCache.GetClient(headers) |
||||
require.EqualError(t, err, "something bad") |
||||
|
||||
c, err := tc.providerCache.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.NotNil(t, c) |
||||
require.Equal(t, 2, tc.clientProvider.numCalls) |
||||
}) |
||||
} |
||||
|
||||
type cacheTestContext struct { |
||||
providerCache *promclient.ProviderCache |
||||
clientProvider *fakePromClientProvider |
||||
} |
||||
|
||||
func setupCacheContext(oauthPassTrough bool) *cacheTestContext { |
||||
fp := newFakePromClientProvider() |
||||
p, err := promclient.NewProviderCache(fp, promclient.JsonData{OauthPassThru: oauthPassTrough}) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
|
||||
return &cacheTestContext{ |
||||
providerCache: p, |
||||
clientProvider: fp, |
||||
} |
||||
} |
||||
|
||||
func newFakePromClientProvider() *fakePromClientProvider { |
||||
return &fakePromClientProvider{ |
||||
errors: make(chan error, 1), |
||||
} |
||||
} |
||||
|
||||
type fakePromClientProvider struct { |
||||
headers map[string]string |
||||
numCalls int |
||||
errors chan error |
||||
} |
||||
|
||||
func (p *fakePromClientProvider) GetClient(h map[string]string) (apiv1.API, error) { |
||||
p.headers = h |
||||
p.numCalls++ |
||||
|
||||
var err error |
||||
select { |
||||
case err = <-p.errors: |
||||
default: |
||||
} |
||||
|
||||
var config []string |
||||
for _, v := range h { |
||||
config = append(config, v) |
||||
} |
||||
sort.Strings(config) //because map
|
||||
return &fakePromClient{config: strings.Join(config, "")}, err |
||||
} |
||||
|
||||
type fakePromClient struct { |
||||
apiv1.API |
||||
config string |
||||
} |
||||
|
||||
func (c *fakePromClient) Config(ctx context.Context) (apiv1.ConfigResult, error) { |
||||
return apiv1.ConfigResult{YAML: c.config}, nil |
||||
} |
||||
@ -0,0 +1,106 @@ |
||||
package promclient |
||||
|
||||
import ( |
||||
"strings" |
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend" |
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/middleware" |
||||
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" |
||||
"github.com/grafana/grafana/pkg/infra/httpclient" |
||||
"github.com/grafana/grafana/pkg/infra/log" |
||||
"github.com/prometheus/client_golang/api" |
||||
apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" |
||||
) |
||||
|
||||
const ( |
||||
authHeader = "Authorization" |
||||
idTokenHeader = "X-ID-Token" |
||||
) |
||||
|
||||
type Provider struct { |
||||
settings backend.DataSourceInstanceSettings |
||||
jsonData JsonData |
||||
clientProvider httpclient.Provider |
||||
log log.Logger |
||||
} |
||||
|
||||
func NewProvider( |
||||
settings backend.DataSourceInstanceSettings, |
||||
jsonData JsonData, |
||||
clientProvider httpclient.Provider, |
||||
log log.Logger, |
||||
) *Provider { |
||||
return &Provider{ |
||||
settings: settings, |
||||
jsonData: jsonData, |
||||
clientProvider: clientProvider, |
||||
log: log, |
||||
} |
||||
} |
||||
|
||||
type JsonData struct { |
||||
Method string `json:"httpMethod"` |
||||
OauthPassThru bool `json:"oauthPassThru"` |
||||
TimeInterval string `json:"timeInterval"` |
||||
} |
||||
|
||||
func (p *Provider) GetClient(headers map[string]string) (apiv1.API, error) { |
||||
opts, err := p.settings.HTTPClientOptions() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
opts.Middlewares = p.middlewares() |
||||
if p.jsonData.OauthPassThru { |
||||
opts.Headers = authHeaders(headers) |
||||
} |
||||
|
||||
// Set SigV4 service namespace
|
||||
if opts.SigV4 != nil { |
||||
opts.SigV4.Service = "aps" |
||||
} |
||||
|
||||
roundTripper, err := p.clientProvider.GetTransport(opts) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
cfg := api.Config{ |
||||
Address: p.settings.URL, |
||||
RoundTripper: roundTripper, |
||||
} |
||||
|
||||
client, err := api.NewClient(cfg) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return apiv1.NewAPI(client), nil |
||||
} |
||||
|
||||
func (p *Provider) middlewares() []sdkhttpclient.Middleware { |
||||
middlewares := []sdkhttpclient.Middleware{ |
||||
middleware.CustomQueryParameters(p.log), |
||||
sdkhttpclient.CustomHeadersMiddleware(), |
||||
} |
||||
if strings.ToLower(p.jsonData.Method) == "get" { |
||||
middlewares = append(middlewares, middleware.ForceHttpGet(p.log)) |
||||
} |
||||
|
||||
return middlewares |
||||
} |
||||
|
||||
func authHeaders(headers map[string]string) map[string]string { |
||||
authHeaders := make(map[string]string) |
||||
if v, ok := headers[authHeader]; ok { |
||||
authHeaders[authHeader] = v |
||||
} |
||||
|
||||
if v, ok := headers[idTokenHeader]; ok { |
||||
authHeaders[idTokenHeader] = v |
||||
} |
||||
|
||||
return authHeaders |
||||
} |
||||
@ -0,0 +1,186 @@ |
||||
package promclient_test |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"net/http" |
||||
"testing" |
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/prometheus/promclient" |
||||
|
||||
"github.com/grafana/grafana/pkg/setting" |
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend" |
||||
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" |
||||
"github.com/grafana/grafana/pkg/infra/httpclient" |
||||
|
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
var headers = map[string]string{"Authorization": "token", "X-ID-Token": "id-token"} |
||||
|
||||
func TestGetClient(t *testing.T) { |
||||
t.Run("it sets the SigV4 service if it exists", func(t *testing.T) { |
||||
tc := setup(`{"sigV4Auth":true}`) |
||||
|
||||
setting.SigV4AuthEnabled = true |
||||
defer func() { setting.SigV4AuthEnabled = false }() |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.Equal(t, "aps", tc.httpProvider.opts.SigV4.Service) |
||||
}) |
||||
|
||||
t.Run("it always uses the custom params and custom headers middlewares", func(t *testing.T) { |
||||
tc := setup() |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.Len(t, tc.httpProvider.middlewares(), 2) |
||||
require.Contains(t, tc.httpProvider.middlewares(), "prom-custom-query-parameters") |
||||
require.Contains(t, tc.httpProvider.middlewares(), "CustomHeaders") |
||||
}) |
||||
|
||||
t.Run("oauth pass through", func(t *testing.T) { |
||||
t.Run("it sets the headers when 'oauthPassThru' is true and auth headers are passed", func(t *testing.T) { |
||||
tc := setup(`{"oauthPassThru":true}`) |
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.Equal(t, headers, tc.httpProvider.opts.Headers) |
||||
}) |
||||
|
||||
t.Run("it only sets auth headers", func(t *testing.T) { |
||||
withNonAuth := map[string]string{"X-Not-Auth": "stuff"} |
||||
|
||||
tc := setup(`{"oauthPassThru":true}`) |
||||
_, err := tc.promClientProvider.GetClient(withNonAuth) |
||||
require.Nil(t, err) |
||||
|
||||
require.Equal(t, map[string]string{}, tc.httpProvider.opts.Headers) |
||||
}) |
||||
|
||||
t.Run("it does not error when headers are nil", func(t *testing.T) { |
||||
tc := setup(`{"oauthPassThru":true}`) |
||||
|
||||
_, err := tc.promClientProvider.GetClient(nil) |
||||
require.Nil(t, err) |
||||
}) |
||||
|
||||
t.Run("it does not set the headers when 'oauthPassThru' is false", func(t *testing.T) { |
||||
tc := setup() |
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.Len(t, tc.httpProvider.opts.Headers, 0) |
||||
}) |
||||
}) |
||||
|
||||
t.Run("force get middleware", func(t *testing.T) { |
||||
t.Run("it add the force-get middleware when httpMethod is get", func(t *testing.T) { |
||||
tc := setup(`{"httpMethod":"get"}`) |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.Len(t, tc.httpProvider.middlewares(), 3) |
||||
require.Contains(t, tc.httpProvider.middlewares(), "force-http-get") |
||||
}) |
||||
|
||||
t.Run("it add the force-get middleware when httpMethod is get", func(t *testing.T) { |
||||
tc := setup(`{"httpMethod":"GET"}`) |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.Len(t, tc.httpProvider.middlewares(), 3) |
||||
require.Contains(t, tc.httpProvider.middlewares(), "force-http-get") |
||||
}) |
||||
|
||||
t.Run("it does not add the force-get middleware when httpMethod is POST", func(t *testing.T) { |
||||
tc := setup(`{"httpMethod":"POST"}`) |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get") |
||||
}) |
||||
|
||||
t.Run("it does not add the force-get middleware when json data is nil", func(t *testing.T) { |
||||
tc := setup() |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get") |
||||
}) |
||||
|
||||
t.Run("it does not add the force-get middleware when json data is empty", func(t *testing.T) { |
||||
tc := setup(`{}`) |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get") |
||||
}) |
||||
|
||||
t.Run("it does not add the force-get middleware httpMethod is null", func(t *testing.T) { |
||||
tc := setup(`{"httpMethod":null}`) |
||||
|
||||
_, err := tc.promClientProvider.GetClient(headers) |
||||
require.Nil(t, err) |
||||
|
||||
require.NotContains(t, tc.httpProvider.middlewares(), "force-http-get") |
||||
}) |
||||
}) |
||||
} |
||||
|
||||
func setup(jsonData ...string) *testContext { |
||||
var rawData []byte |
||||
if len(jsonData) > 0 { |
||||
rawData = []byte(jsonData[0]) |
||||
} |
||||
|
||||
var jd promclient.JsonData |
||||
_ = json.Unmarshal(rawData, &jd) |
||||
|
||||
settings := backend.DataSourceInstanceSettings{URL: "test-url", JSONData: rawData} |
||||
hp := &fakeHttpClientProvider{} |
||||
p := promclient.NewProvider(settings, jd, hp, nil) |
||||
|
||||
return &testContext{ |
||||
httpProvider: hp, |
||||
promClientProvider: p, |
||||
} |
||||
} |
||||
|
||||
type testContext struct { |
||||
httpProvider *fakeHttpClientProvider |
||||
promClientProvider *promclient.Provider |
||||
} |
||||
|
||||
type fakeHttpClientProvider struct { |
||||
httpclient.Provider |
||||
|
||||
opts sdkhttpclient.Options |
||||
} |
||||
|
||||
func (p *fakeHttpClientProvider) GetTransport(opts ...sdkhttpclient.Options) (http.RoundTripper, error) { |
||||
p.opts = opts[0] |
||||
return http.DefaultTransport, nil |
||||
} |
||||
|
||||
func (p *fakeHttpClientProvider) middlewares() []string { |
||||
var middlewareNames []string |
||||
for _, m := range p.opts.Middlewares { |
||||
mw, ok := m.(sdkhttpclient.MiddlewareName) |
||||
if !ok { |
||||
panic("unexpected middleware type") |
||||
} |
||||
|
||||
middlewareNames = append(middlewareNames, mw.MiddlewareName()) |
||||
} |
||||
return middlewareNames |
||||
} |
||||
Loading…
Reference in new issue