@ -2,12 +2,13 @@ package sync
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
@ -18,9 +19,11 @@ import (
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
"github.com/grafana/grafana/pkg/services/user"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOa uthTokenSync_SyncOAuthTokenHook ( t * testing . T ) {
func TestOA uthTokenSync_SyncOAuthTokenHook ( t * testing . T ) {
type testCase struct {
desc string
identity * authn . Identity
@ -95,6 +98,13 @@ func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
expectedHasEntryToken : & login . UserAuth { OAuthExpiry : time . Now ( ) . Add ( - 10 * time . Minute ) } ,
oauthInfo : & social . OAuthInfo { UseRefreshToken : false } ,
} ,
{
desc : "should refresh access token when ID token has expired" ,
identity : & authn . Identity { ID : "user:1" , SessionToken : & auth . UserToken { } } ,
expectHasEntryCalled : true ,
expectTryRefreshTokenCalled : true ,
expectedHasEntryToken : & login . UserAuth { OAuthExpiry : time . Now ( ) . Add ( 10 * time . Minute ) , OAuthIdToken : fakeIDToken ( t , time . Now ( ) . Add ( - 10 * time . Minute ) ) } ,
} ,
}
for _ , tt := range tests {
@ -155,3 +165,93 @@ func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
} )
}
}
// fakeIDToken is used to create a fake invalid token to verify expiry logic
func fakeIDToken ( t * testing . T , expiryDate time . Time ) string {
type Header struct {
Kid string ` json:"kid" `
Alg string ` json:"alg" `
}
type Payload struct {
Iss string ` json:"iss" `
Sub string ` json:"sub" `
Exp int64 ` json:"exp" `
}
header , err := json . Marshal ( Header { Kid : "123" , Alg : "none" } )
require . NoError ( t , err )
u := expiryDate . UTC ( ) . Unix ( )
payload , err := json . Marshal ( Payload { Iss : "fake" , Sub : "a-sub" , Exp : u } )
require . NoError ( t , err )
fakeSignature := [ ] byte ( "6ICJm" )
return fmt . Sprintf ( "%s.%s.%s" , base64 . RawURLEncoding . EncodeToString ( header ) , base64 . RawURLEncoding . EncodeToString ( payload ) , base64 . RawURLEncoding . EncodeToString ( fakeSignature ) )
}
func TestOAuthTokenSync_getOAuthTokenCacheTTL ( t * testing . T ) {
defaultTime := time . Now ( )
tests := [ ] struct {
name string
accessTokenExpiry time . Time
idTokenExpiry time . Time
want time . Duration
} {
{
name : "should return maxOAuthTokenCacheTTL when no expiry is given" ,
accessTokenExpiry : time . Time { } ,
idTokenExpiry : time . Time { } ,
want : maxOAuthTokenCacheTTL ,
} ,
{
name : "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl" ,
accessTokenExpiry : time . Time { } ,
idTokenExpiry : defaultTime . Add ( 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
want : maxOAuthTokenCacheTTL ,
} ,
{
name : "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl" ,
accessTokenExpiry : time . Time { } ,
idTokenExpiry : defaultTime . Add ( - 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
want : time . Until ( defaultTime . Add ( - 5 * time . Minute + maxOAuthTokenCacheTTL ) ) ,
} ,
{
name : "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given" ,
accessTokenExpiry : defaultTime . Add ( 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
idTokenExpiry : time . Time { } ,
want : maxOAuthTokenCacheTTL ,
} ,
{
name : "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given" ,
accessTokenExpiry : defaultTime . Add ( - 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
idTokenExpiry : time . Time { } ,
want : time . Until ( defaultTime . Add ( - 5 * time . Minute + maxOAuthTokenCacheTTL ) ) ,
} ,
{
name : "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry" ,
accessTokenExpiry : defaultTime . Add ( - 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
idTokenExpiry : defaultTime . Add ( 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
want : time . Until ( defaultTime . Add ( - 5 * time . Minute + maxOAuthTokenCacheTTL ) ) ,
} ,
{
name : "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry" ,
accessTokenExpiry : defaultTime . Add ( 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
idTokenExpiry : defaultTime . Add ( - 3 * time . Minute + maxOAuthTokenCacheTTL ) ,
want : time . Until ( defaultTime . Add ( - 3 * time . Minute + maxOAuthTokenCacheTTL ) ) ,
} ,
{
name : "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl" ,
accessTokenExpiry : defaultTime . Add ( 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
idTokenExpiry : defaultTime . Add ( 5 * time . Minute + maxOAuthTokenCacheTTL ) ,
want : maxOAuthTokenCacheTTL ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
got := getOAuthTokenCacheTTL ( tt . accessTokenExpiry , tt . idTokenExpiry )
assert . Equal ( t , tt . want . Round ( time . Second ) , got . Round ( time . Second ) )
} )
}
}