The open and composable observability and data visualization platform. Visualize metrics, logs, and traces from multiple sources like Prometheus, Loki, Elasticsearch, InfluxDB, Postgres and many more.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
grafana/pkg/api/pluginproxy/token_cache_test.go

286 lines
7.8 KiB

package pluginproxy
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type fakeCredential struct {
key string
calledTimes int
getAccessTokenFunc func(ctx context.Context, scopes []string) (*AccessToken, error)
}
func (c *fakeCredential) GetCacheKey() string {
return c.key
}
func (c *fakeCredential) GetAccessToken(ctx context.Context, scopes []string) (*AccessToken, error) {
c.calledTimes = c.calledTimes + 1
if c.getAccessTokenFunc != nil {
return c.getAccessTokenFunc(ctx, scopes)
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("%v-token-%v", c.key, c.calledTimes), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
}
func TestConcurrentTokenCache_GetAccessToken(t *testing.T) {
ctx := context.Background()
scopes1 := []string{"Scope1"}
scopes2 := []string{"Scope2"}
t.Run("should request access token from credential", func(t *testing.T) {
cache := NewConcurrentTokenCache()
credential := &fakeCredential{key: "credential-1"}
token, err := cache.GetAccessToken(ctx, credential, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token)
assert.Equal(t, 1, credential.calledTimes)
})
t.Run("should return cached token for same scopes", func(t *testing.T) {
var token1, token2 string
var err error
cache := NewConcurrentTokenCache()
credential := &fakeCredential{key: "credential-1"}
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-2", token2)
token1, err = cache.GetAccessToken(ctx, credential, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential, scopes2)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-2", token2)
assert.Equal(t, 2, credential.calledTimes)
})
t.Run("should return cached token for same credentials", func(t *testing.T) {
var token1, token2 string
var err error
cache := NewConcurrentTokenCache()
credential1 := &fakeCredential{key: "credential-1"}
credential2 := &fakeCredential{key: "credential-2"}
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-2-token-1", token2)
token1, err = cache.GetAccessToken(ctx, credential1, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-1-token-1", token1)
token2, err = cache.GetAccessToken(ctx, credential2, scopes1)
require.NoError(t, err)
assert.Equal(t, "credential-2-token-1", token2)
assert.Equal(t, 1, credential1.calledTimes)
assert.Equal(t, 1, credential2.calledTimes)
})
}
func TestScopesCacheEntry_GetAccessToken(t *testing.T) {
ctx := context.Background()
scopes := []string{"Scope1"}
t.Run("when credential returns error", func(t *testing.T) {
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
return invalidToken, errors.New("unable to get access token")
},
}
t.Run("should return error", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
accessToken, err := cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
assert.Equal(t, "", accessToken)
})
t.Run("should call credential again each time and return error", func(t *testing.T) {
credential.calledTimes = 0
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var err error
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
assert.Equal(t, 3, credential.calledTimes)
})
})
t.Run("when credential returns error only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
times = times + 1
if times == 1 {
invalidToken := &AccessToken{Token: "invalid_token", ExpiresOn: timeNow().Add(time.Hour)}
return invalidToken, errors.New("unable to get access token")
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
},
}
t.Run("should call credential again only while it returns error", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var accessToken string
var err error
_, err = cacheEntry.getAccessToken(ctx)
assert.Error(t, err)
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
assert.Equal(t, 2, credential.calledTimes)
})
})
t.Run("when credential panics", func(t *testing.T) {
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
panic(errors.New("unable to get access token"))
},
}
t.Run("should call credential again each time", func(t *testing.T) {
credential.calledTimes = 0
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
assert.Equal(t, 3, credential.calledTimes)
})
})
t.Run("when credential panics only once", func(t *testing.T) {
var times = 0
credential := &fakeCredential{
getAccessTokenFunc: func(ctx context.Context, scopes []string) (*AccessToken, error) {
times = times + 1
if times == 1 {
panic(errors.New("unable to get access token"))
}
fakeAccessToken := &AccessToken{Token: fmt.Sprintf("token-%v", times), ExpiresOn: timeNow().Add(time.Hour)}
return fakeAccessToken, nil
},
}
t.Run("should call credential again only while it panics", func(t *testing.T) {
cacheEntry := &scopesCacheEntry{
credential: credential,
scopes: scopes,
cond: sync.NewCond(&sync.Mutex{}),
}
var accessToken string
var err error
func() {
defer func() {
assert.NotNil(t, recover(), "credential expected to panic")
}()
_, _ = cacheEntry.getAccessToken(ctx)
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
}()
func() {
defer func() {
assert.Nil(t, recover(), "credential not expected to panic")
}()
accessToken, err = cacheEntry.getAccessToken(ctx)
assert.NoError(t, err)
assert.Equal(t, "token-2", accessToken)
}()
assert.Equal(t, 2, credential.calledTimes)
})
})
}