mirror of https://github.com/grafana/grafana
Postgres: Switch the datasource plugin from lib/pq to pgx (#81353)
* postgres: switch from lib/pq to pgx * postgres: improved tls handlingpull/83231/head
parent
e8df62941b
commit
8c18d06386
@ -1,85 +0,0 @@ |
||||
package postgres |
||||
|
||||
import ( |
||||
"fmt" |
||||
"sync" |
||||
) |
||||
|
||||
// locker is a named reader/writer mutual exclusion lock.
|
||||
// The lock for each particular key can be held by an arbitrary number of readers or a single writer.
|
||||
type locker struct { |
||||
locks map[any]*sync.RWMutex |
||||
locksRW *sync.RWMutex |
||||
} |
||||
|
||||
func newLocker() *locker { |
||||
return &locker{ |
||||
locks: make(map[any]*sync.RWMutex), |
||||
locksRW: new(sync.RWMutex), |
||||
} |
||||
} |
||||
|
||||
// Lock locks named rw mutex with specified key for writing.
|
||||
// If the lock with the same key is already locked for reading or writing,
|
||||
// Lock blocks until the lock is available.
|
||||
func (lkr *locker) Lock(key any) { |
||||
lk, ok := lkr.getLock(key) |
||||
if !ok { |
||||
lk = lkr.newLock(key) |
||||
} |
||||
lk.Lock() |
||||
} |
||||
|
||||
// Unlock unlocks named rw mutex with specified key for writing. It is a run-time error if rw is
|
||||
// not locked for writing on entry to Unlock.
|
||||
func (lkr *locker) Unlock(key any) { |
||||
lk, ok := lkr.getLock(key) |
||||
if !ok { |
||||
panic(fmt.Errorf("lock for key '%s' not initialized", key)) |
||||
} |
||||
lk.Unlock() |
||||
} |
||||
|
||||
// RLock locks named rw mutex with specified key for reading.
|
||||
//
|
||||
// It should not be used for recursive read locking for the same key; a blocked Lock
|
||||
// call excludes new readers from acquiring the lock. See the
|
||||
// documentation on the golang RWMutex type.
|
||||
func (lkr *locker) RLock(key any) { |
||||
lk, ok := lkr.getLock(key) |
||||
if !ok { |
||||
lk = lkr.newLock(key) |
||||
} |
||||
lk.RLock() |
||||
} |
||||
|
||||
// RUnlock undoes a single RLock call for specified key;
|
||||
// it does not affect other simultaneous readers of locker for specified key.
|
||||
// It is a run-time error if locker for specified key is not locked for reading
|
||||
func (lkr *locker) RUnlock(key any) { |
||||
lk, ok := lkr.getLock(key) |
||||
if !ok { |
||||
panic(fmt.Errorf("lock for key '%s' not initialized", key)) |
||||
} |
||||
lk.RUnlock() |
||||
} |
||||
|
||||
func (lkr *locker) newLock(key any) *sync.RWMutex { |
||||
lkr.locksRW.Lock() |
||||
defer lkr.locksRW.Unlock() |
||||
|
||||
if lk, ok := lkr.locks[key]; ok { |
||||
return lk |
||||
} |
||||
lk := new(sync.RWMutex) |
||||
lkr.locks[key] = lk |
||||
return lk |
||||
} |
||||
|
||||
func (lkr *locker) getLock(key any) (*sync.RWMutex, bool) { |
||||
lkr.locksRW.RLock() |
||||
defer lkr.locksRW.RUnlock() |
||||
|
||||
lock, ok := lkr.locks[key] |
||||
return lock, ok |
||||
} |
@ -1,63 +0,0 @@ |
||||
package postgres |
||||
|
||||
import ( |
||||
"sync" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestIntegrationLocker(t *testing.T) { |
||||
if testing.Short() { |
||||
t.Skip("Tests with Sleep") |
||||
} |
||||
const notUpdated = "not_updated" |
||||
const atThread1 = "at_thread_1" |
||||
const atThread2 = "at_thread_2" |
||||
t.Run("Should lock for same keys", func(t *testing.T) { |
||||
updated := notUpdated |
||||
locker := newLocker() |
||||
locker.Lock(1) |
||||
var wg sync.WaitGroup |
||||
wg.Add(1) |
||||
defer func() { |
||||
locker.Unlock(1) |
||||
wg.Wait() |
||||
}() |
||||
|
||||
go func() { |
||||
locker.RLock(1) |
||||
defer func() { |
||||
locker.RUnlock(1) |
||||
wg.Done() |
||||
}() |
||||
require.Equal(t, atThread1, updated, "Value should be updated in different thread") |
||||
updated = atThread2 |
||||
}() |
||||
time.Sleep(time.Millisecond * 10) |
||||
require.Equal(t, notUpdated, updated, "Value should not be updated in different thread") |
||||
updated = atThread1 |
||||
}) |
||||
|
||||
t.Run("Should not lock for different keys", func(t *testing.T) { |
||||
updated := notUpdated |
||||
locker := newLocker() |
||||
locker.Lock(1) |
||||
defer locker.Unlock(1) |
||||
var wg sync.WaitGroup |
||||
wg.Add(1) |
||||
go func() { |
||||
locker.RLock(2) |
||||
defer func() { |
||||
locker.RUnlock(2) |
||||
wg.Done() |
||||
}() |
||||
require.Equal(t, notUpdated, updated, "Value should not be updated in different thread") |
||||
updated = atThread2 |
||||
}() |
||||
wg.Wait() |
||||
require.Equal(t, atThread2, updated, "Value should be updated in different thread") |
||||
updated = atThread1 |
||||
}) |
||||
} |
@ -0,0 +1,147 @@ |
||||
package tls |
||||
|
||||
import ( |
||||
"crypto/tls" |
||||
"crypto/x509" |
||||
"errors" |
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/sqleng" |
||||
) |
||||
|
||||
// we support 4 postgres tls modes:
|
||||
// disable - no tls
|
||||
// require - use tls
|
||||
// verify-ca - use tls, verify root cert but not the hostname
|
||||
// verify-full - use tls, verify root cert
|
||||
// (for all the options except `disable`, you can optionally use client certificates)
|
||||
|
||||
var errNoRootCert = errors.New("tls: missing root certificate") |
||||
|
||||
func getTLSConfigRequire(certs *Certs) (*tls.Config, error) { |
||||
// we may have a client-cert, we do not have a root-cert
|
||||
|
||||
// see https://www.postgresql.org/docs/12/libpq-ssl.html ,
|
||||
// mode=require + provided root-cert should behave as mode=verify-ca
|
||||
if certs.rootCerts != nil { |
||||
return getTLSConfigVerifyCA(certs) |
||||
} |
||||
|
||||
return &tls.Config{ |
||||
InsecureSkipVerify: true, // we do not verify the root cert
|
||||
Certificates: certs.clientCerts, |
||||
}, nil |
||||
} |
||||
|
||||
// to implement the verify-ca mode, we need to do this:
|
||||
// - for the root certificate
|
||||
// - verify that the certificate we receive from the server is trusted,
|
||||
// meaning it relates to our root certificate
|
||||
// - we DO NOT verify that the hostname of the database matches
|
||||
// the hostname in the certificate
|
||||
//
|
||||
// the problem is, `go“ does not offer such an option.
|
||||
// by default, it will verify both things.
|
||||
//
|
||||
// so what we do is:
|
||||
// - we turn off the default-verification with `InsecureSkipVerify`
|
||||
// - we implement our own verification using `VerifyConnection`
|
||||
//
|
||||
// extra info about this:
|
||||
// - there is a rejected feature-request about this at https://github.com/golang/go/issues/21971
|
||||
// - the recommended workaround is based on VerifyPeerCertificate
|
||||
// - there is even example code at https://github.com/golang/go/commit/29cfb4d3c3a97b6f426d1b899234da905be699aa
|
||||
// - but later the example code was changed to use VerifyConnection instead:
|
||||
// https://github.com/golang/go/commit/7eb5941b95a588a23f18fa4c22fe42ff0119c311
|
||||
//
|
||||
// a verifyConnection example is at https://pkg.go.dev/crypto/tls#example-Config-VerifyConnection .
|
||||
//
|
||||
// this is how the `pgx` library handles verify-ca:
|
||||
//
|
||||
// https://github.com/jackc/pgx/blob/5c63f646f820ca9696fc3515c1caf2a557d562e5/pgconn/config.go#L657-L690
|
||||
// (unfortunately pgx only handles this for certificate-provided-as-path, so we cannot rely on it)
|
||||
func getTLSConfigVerifyCA(certs *Certs) (*tls.Config, error) { |
||||
// we must have a root certificate
|
||||
if certs.rootCerts == nil { |
||||
return nil, errNoRootCert |
||||
} |
||||
|
||||
conf := tls.Config{ |
||||
Certificates: certs.clientCerts, |
||||
InsecureSkipVerify: true, // we turn off the default-verification, we'll do VerifyConnection instead
|
||||
VerifyConnection: func(state tls.ConnectionState) error { |
||||
// we add all the certificates to the pool, we skip the first cert.
|
||||
intermediates := x509.NewCertPool() |
||||
for _, c := range state.PeerCertificates[1:] { |
||||
intermediates.AddCert(c) |
||||
} |
||||
|
||||
opts := x509.VerifyOptions{ |
||||
Roots: certs.rootCerts, |
||||
Intermediates: intermediates, |
||||
} |
||||
|
||||
// we call `Verify()` on the first cert (that we skipped previously)
|
||||
_, err := state.PeerCertificates[0].Verify(opts) |
||||
return err |
||||
}, |
||||
RootCAs: certs.rootCerts, |
||||
} |
||||
|
||||
return &conf, nil |
||||
} |
||||
|
||||
func getTLSConfigVerifyFull(certs *Certs, serverName string) (*tls.Config, error) { |
||||
// we must have a root certificate
|
||||
if certs.rootCerts == nil { |
||||
return nil, errNoRootCert |
||||
} |
||||
|
||||
conf := tls.Config{ |
||||
Certificates: certs.clientCerts, |
||||
ServerName: serverName, |
||||
RootCAs: certs.rootCerts, |
||||
} |
||||
|
||||
return &conf, nil |
||||
} |
||||
|
||||
func IsTLSEnabled(dsInfo sqleng.DataSourceInfo) bool { |
||||
mode := dsInfo.JsonData.Mode |
||||
return mode != "disable" |
||||
} |
||||
|
||||
// returns `nil` if tls is disabled
|
||||
func GetTLSConfig(dsInfo sqleng.DataSourceInfo, readFile ReadFileFunc, serverName string) (*tls.Config, error) { |
||||
mode := dsInfo.JsonData.Mode |
||||
// we need to special-case the no-tls-mode
|
||||
if mode == "disable" { |
||||
return nil, nil |
||||
} |
||||
|
||||
// for all the remaining cases we need to load
|
||||
// both the root-cert if exists, and the client-cert if exists.
|
||||
certBytes, err := loadCertificateBytes(dsInfo, readFile) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
certs, err := createCertificates(certBytes) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
switch mode { |
||||
// `disable` already handled
|
||||
case "": |
||||
// for backward-compatibility reasons this is the same as `require`
|
||||
return getTLSConfigRequire(certs) |
||||
case "require": |
||||
return getTLSConfigRequire(certs) |
||||
case "verify-ca": |
||||
return getTLSConfigVerifyCA(certs) |
||||
case "verify-full": |
||||
return getTLSConfigVerifyFull(certs, serverName) |
||||
default: |
||||
return nil, errors.New("tls: invalid mode " + mode) |
||||
} |
||||
} |
@ -0,0 +1,101 @@ |
||||
package tls |
||||
|
||||
import ( |
||||
"crypto/tls" |
||||
"crypto/x509" |
||||
"errors" |
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/sqleng" |
||||
) |
||||
|
||||
// this file deals with locating and loading the certificates,
|
||||
// from json-data or from disk.
|
||||
|
||||
type CertBytes struct { |
||||
rootCert []byte |
||||
clientKey []byte |
||||
clientCert []byte |
||||
} |
||||
|
||||
type ReadFileFunc = func(name string) ([]byte, error) |
||||
|
||||
var errPartialClientCertNoKey = errors.New("tls: client cert provided but client key missing") |
||||
var errPartialClientCertNoCert = errors.New("tls: client key provided but client cert missing") |
||||
|
||||
// certificates can be stored either as encrypted-json-data, or as file-path
|
||||
func loadCertificateBytes(dsInfo sqleng.DataSourceInfo, readFile ReadFileFunc) (*CertBytes, error) { |
||||
if dsInfo.JsonData.ConfigurationMethod == "file-content" { |
||||
return &CertBytes{ |
||||
rootCert: []byte(dsInfo.DecryptedSecureJSONData["tlsCACert"]), |
||||
clientKey: []byte(dsInfo.DecryptedSecureJSONData["tlsClientKey"]), |
||||
clientCert: []byte(dsInfo.DecryptedSecureJSONData["tlsClientCert"]), |
||||
}, nil |
||||
} else { |
||||
c := CertBytes{} |
||||
|
||||
if dsInfo.JsonData.RootCertFile != "" { |
||||
rootCert, err := readFile(dsInfo.JsonData.RootCertFile) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c.rootCert = rootCert |
||||
} |
||||
|
||||
if dsInfo.JsonData.CertKeyFile != "" { |
||||
clientKey, err := readFile(dsInfo.JsonData.CertKeyFile) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c.clientKey = clientKey |
||||
} |
||||
|
||||
if dsInfo.JsonData.CertFile != "" { |
||||
clientCert, err := readFile(dsInfo.JsonData.CertFile) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c.clientCert = clientCert |
||||
} |
||||
|
||||
return &c, nil |
||||
} |
||||
} |
||||
|
||||
type Certs struct { |
||||
clientCerts []tls.Certificate |
||||
rootCerts *x509.CertPool |
||||
} |
||||
|
||||
func createCertificates(certBytes *CertBytes) (*Certs, error) { |
||||
certs := Certs{} |
||||
|
||||
if len(certBytes.rootCert) > 0 { |
||||
pool := x509.NewCertPool() |
||||
ok := pool.AppendCertsFromPEM(certBytes.rootCert) |
||||
if !ok { |
||||
return nil, errors.New("tls: failed to add root certificate") |
||||
} |
||||
certs.rootCerts = pool |
||||
} |
||||
|
||||
hasClientKey := len(certBytes.clientKey) > 0 |
||||
hasClientCert := len(certBytes.clientCert) > 0 |
||||
|
||||
if hasClientKey && hasClientCert { |
||||
cert, err := tls.X509KeyPair(certBytes.clientCert, certBytes.clientKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
certs.clientCerts = []tls.Certificate{cert} |
||||
} |
||||
|
||||
if hasClientKey && (!hasClientCert) { |
||||
return nil, errPartialClientCertNoCert |
||||
} |
||||
|
||||
if hasClientCert && (!hasClientKey) { |
||||
return nil, errPartialClientCertNoKey |
||||
} |
||||
|
||||
return &certs, nil |
||||
} |
@ -0,0 +1,382 @@ |
||||
package tls |
||||
|
||||
import ( |
||||
"errors" |
||||
"os" |
||||
"testing" |
||||
|
||||
"github.com/grafana/grafana/pkg/tsdb/sqleng" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func noReadFile(path string) ([]byte, error) { |
||||
return nil, errors.New("not implemented") |
||||
} |
||||
|
||||
func TestTLSNoMode(t *testing.T) { |
||||
// for backward-compatibility reason,
|
||||
// when mode is unset, it defaults to `require`
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
ConfigurationMethod: "", |
||||
}, |
||||
} |
||||
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") |
||||
require.NoError(t, err) |
||||
require.NotNil(t, c) |
||||
require.True(t, c.InsecureSkipVerify) |
||||
} |
||||
|
||||
func TestTLSDisable(t *testing.T) { |
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "disable", |
||||
ConfigurationMethod: "", |
||||
}, |
||||
} |
||||
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") |
||||
require.NoError(t, err) |
||||
require.Nil(t, c) |
||||
} |
||||
|
||||
func TestTLSRequire(t *testing.T) { |
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "require", |
||||
ConfigurationMethod: "", |
||||
}, |
||||
} |
||||
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") |
||||
require.NoError(t, err) |
||||
require.NotNil(t, c) |
||||
require.True(t, c.InsecureSkipVerify) |
||||
require.Nil(t, c.RootCAs) |
||||
} |
||||
|
||||
func TestTLSRequireWithRootCert(t *testing.T) { |
||||
rootCertBytes, err := CreateRandomRootCertBytes() |
||||
require.NoError(t, err) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "require", |
||||
ConfigurationMethod: "file-content", |
||||
}, |
||||
DecryptedSecureJSONData: map[string]string{ |
||||
"tlsCACert": string(rootCertBytes), |
||||
}, |
||||
} |
||||
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") |
||||
require.NoError(t, err) |
||||
require.NotNil(t, c) |
||||
require.True(t, c.InsecureSkipVerify) |
||||
require.NotNil(t, c.VerifyConnection) |
||||
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
|
||||
} |
||||
|
||||
func TestTLSVerifyCA(t *testing.T) { |
||||
rootCertBytes, err := CreateRandomRootCertBytes() |
||||
require.NoError(t, err) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "verify-ca", |
||||
ConfigurationMethod: "file-content", |
||||
}, |
||||
DecryptedSecureJSONData: map[string]string{ |
||||
"tlsCACert": string(rootCertBytes), |
||||
}, |
||||
} |
||||
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") |
||||
require.NoError(t, err) |
||||
require.NotNil(t, c) |
||||
require.True(t, c.InsecureSkipVerify) |
||||
require.NotNil(t, c.VerifyConnection) |
||||
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
|
||||
} |
||||
|
||||
func TestTLSVerifyCAMisingRootCert(t *testing.T) { |
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "verify-ca", |
||||
ConfigurationMethod: "file-content", |
||||
}, |
||||
DecryptedSecureJSONData: map[string]string{}, |
||||
} |
||||
_, err := GetTLSConfig(dsInfo, noReadFile, "localhost") |
||||
require.ErrorIs(t, err, errNoRootCert) |
||||
} |
||||
|
||||
func TestTLSClientCert(t *testing.T) { |
||||
clientKey, clientCert, err := CreateRandomClientCert() |
||||
require.NoError(t, err) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "require", |
||||
ConfigurationMethod: "file-content", |
||||
}, |
||||
DecryptedSecureJSONData: map[string]string{ |
||||
"tlsClientCert": string(clientCert), |
||||
"tlsClientKey": string(clientKey), |
||||
}, |
||||
} |
||||
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") |
||||
require.NoError(t, err) |
||||
require.NotNil(t, c) |
||||
require.Len(t, c.Certificates, 1) |
||||
} |
||||
|
||||
func TestTLSMethodFileContentClientCertMissingKey(t *testing.T) { |
||||
_, clientCert, err := CreateRandomClientCert() |
||||
require.NoError(t, err) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "require", |
||||
ConfigurationMethod: "file-content", |
||||
}, |
||||
DecryptedSecureJSONData: map[string]string{ |
||||
"tlsClientCert": string(clientCert), |
||||
}, |
||||
} |
||||
_, err = GetTLSConfig(dsInfo, noReadFile, "localhost") |
||||
require.ErrorIs(t, err, errPartialClientCertNoKey) |
||||
} |
||||
|
||||
func TestTLSMethodFileContentClientCertMissingCert(t *testing.T) { |
||||
clientKey, _, err := CreateRandomClientCert() |
||||
require.NoError(t, err) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "require", |
||||
ConfigurationMethod: "file-content", |
||||
}, |
||||
DecryptedSecureJSONData: map[string]string{ |
||||
"tlsClientKey": string(clientKey), |
||||
}, |
||||
} |
||||
_, err = GetTLSConfig(dsInfo, noReadFile, "localhost") |
||||
require.ErrorIs(t, err, errPartialClientCertNoCert) |
||||
} |
||||
|
||||
func TestTLSMethodFilePathClientCertMissingKey(t *testing.T) { |
||||
clientKey, _, err := CreateRandomClientCert() |
||||
require.NoError(t, err) |
||||
|
||||
readFile := newMockReadFile(map[string]([]byte){ |
||||
"path1": clientKey, |
||||
}) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "require", |
||||
ConfigurationMethod: "file-path", |
||||
CertKeyFile: "path1", |
||||
}, |
||||
} |
||||
_, err = GetTLSConfig(dsInfo, readFile, "localhost") |
||||
require.ErrorIs(t, err, errPartialClientCertNoCert) |
||||
} |
||||
|
||||
func TestTLSMethodFilePathClientCertMissingCert(t *testing.T) { |
||||
_, clientCert, err := CreateRandomClientCert() |
||||
require.NoError(t, err) |
||||
|
||||
readFile := newMockReadFile(map[string]([]byte){ |
||||
"path1": clientCert, |
||||
}) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "require", |
||||
ConfigurationMethod: "file-path", |
||||
CertFile: "path1", |
||||
}, |
||||
} |
||||
_, err = GetTLSConfig(dsInfo, readFile, "localhost") |
||||
require.ErrorIs(t, err, errPartialClientCertNoKey) |
||||
} |
||||
|
||||
func TestTLSVerifyFull(t *testing.T) { |
||||
rootCertBytes, err := CreateRandomRootCertBytes() |
||||
require.NoError(t, err) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "verify-full", |
||||
ConfigurationMethod: "file-content", |
||||
}, |
||||
DecryptedSecureJSONData: map[string]string{ |
||||
"tlsCACert": string(rootCertBytes), |
||||
}, |
||||
} |
||||
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") |
||||
require.NoError(t, err) |
||||
require.NotNil(t, c) |
||||
require.False(t, c.InsecureSkipVerify) |
||||
require.Nil(t, c.VerifyConnection) |
||||
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
|
||||
} |
||||
|
||||
func TestTLSMethodFileContent(t *testing.T) { |
||||
rootCertBytes, err := CreateRandomRootCertBytes() |
||||
require.NoError(t, err) |
||||
|
||||
clientKey, clientCert, err := CreateRandomClientCert() |
||||
require.NoError(t, err) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "verify-full", |
||||
ConfigurationMethod: "file-content", |
||||
}, |
||||
DecryptedSecureJSONData: map[string]string{ |
||||
"tlsCACert": string(rootCertBytes), |
||||
"tlsClientCert": string(clientCert), |
||||
"tlsClientKey": string(clientKey), |
||||
}, |
||||
} |
||||
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") |
||||
require.NoError(t, err) |
||||
require.NotNil(t, c) |
||||
require.Len(t, c.Certificates, 1) |
||||
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
|
||||
} |
||||
|
||||
func TestTLSMethodFilePath(t *testing.T) { |
||||
rootCertBytes, err := CreateRandomRootCertBytes() |
||||
require.NoError(t, err) |
||||
|
||||
clientKey, clientCert, err := CreateRandomClientCert() |
||||
require.NoError(t, err) |
||||
|
||||
readFile := newMockReadFile(map[string]([]byte){ |
||||
"root-cert-path": rootCertBytes, |
||||
"client-key-path": clientKey, |
||||
"client-cert-path": clientCert, |
||||
}) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "verify-full", |
||||
ConfigurationMethod: "file-path", |
||||
RootCertFile: "root-cert-path", |
||||
CertKeyFile: "client-key-path", |
||||
CertFile: "client-cert-path", |
||||
}, |
||||
} |
||||
c, err := GetTLSConfig(dsInfo, readFile, "localhost") |
||||
require.NoError(t, err) |
||||
require.NotNil(t, c) |
||||
require.Len(t, c.Certificates, 1) |
||||
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
|
||||
} |
||||
|
||||
func TestTLSMethodFilePathRootCertDoesNotExist(t *testing.T) { |
||||
readFile := newMockReadFile(map[string]([]byte){}) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "verify-full", |
||||
ConfigurationMethod: "file-path", |
||||
RootCertFile: "path1", |
||||
}, |
||||
} |
||||
_, err := GetTLSConfig(dsInfo, readFile, "localhost") |
||||
require.ErrorIs(t, err, os.ErrNotExist) |
||||
} |
||||
|
||||
func TestTLSMethodFilePathClientCertKeyDoesNotExist(t *testing.T) { |
||||
_, clientCert, err := CreateRandomClientCert() |
||||
require.NoError(t, err) |
||||
|
||||
readFile := newMockReadFile(map[string]([]byte){ |
||||
"cert-path": clientCert, |
||||
}) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "require", |
||||
ConfigurationMethod: "file-path", |
||||
CertKeyFile: "key-path", |
||||
CertFile: "cert-path", |
||||
}, |
||||
} |
||||
_, err = GetTLSConfig(dsInfo, readFile, "localhost") |
||||
require.ErrorIs(t, err, os.ErrNotExist) |
||||
} |
||||
|
||||
func TestTLSMethodFilePathClientCertCertDoesNotExist(t *testing.T) { |
||||
clientKey, _, err := CreateRandomClientCert() |
||||
require.NoError(t, err) |
||||
|
||||
readFile := newMockReadFile(map[string]([]byte){ |
||||
"key-path": clientKey, |
||||
}) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "require", |
||||
ConfigurationMethod: "file-path", |
||||
CertKeyFile: "key-path", |
||||
CertFile: "cert-path", |
||||
}, |
||||
} |
||||
_, err = GetTLSConfig(dsInfo, readFile, "localhost") |
||||
require.ErrorIs(t, err, os.ErrNotExist) |
||||
} |
||||
|
||||
// method="" equals to method="file-path"
|
||||
func TestTLSMethodEmpty(t *testing.T) { |
||||
rootCertBytes, err := CreateRandomRootCertBytes() |
||||
require.NoError(t, err) |
||||
|
||||
clientKey, clientCert, err := CreateRandomClientCert() |
||||
require.NoError(t, err) |
||||
|
||||
readFile := newMockReadFile(map[string]([]byte){ |
||||
"root-cert-path": rootCertBytes, |
||||
"client-key-path": clientKey, |
||||
"client-cert-path": clientCert, |
||||
}) |
||||
|
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "verify-full", |
||||
ConfigurationMethod: "", |
||||
RootCertFile: "root-cert-path", |
||||
CertKeyFile: "client-key-path", |
||||
CertFile: "client-cert-path", |
||||
}, |
||||
} |
||||
c, err := GetTLSConfig(dsInfo, readFile, "localhost") |
||||
require.NoError(t, err) |
||||
require.NotNil(t, c) |
||||
require.Len(t, c.Certificates, 1) |
||||
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
|
||||
} |
||||
|
||||
func TestTLSVerifyFullMisingRootCert(t *testing.T) { |
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "verify-full", |
||||
ConfigurationMethod: "file-content", |
||||
}, |
||||
DecryptedSecureJSONData: map[string]string{}, |
||||
} |
||||
_, err := GetTLSConfig(dsInfo, noReadFile, "localhost") |
||||
require.ErrorIs(t, err, errNoRootCert) |
||||
} |
||||
|
||||
func TestTLSInvalidMode(t *testing.T) { |
||||
dsInfo := sqleng.DataSourceInfo{ |
||||
JsonData: sqleng.JsonData{ |
||||
Mode: "not-a-valid-mode", |
||||
}, |
||||
} |
||||
|
||||
_, err := GetTLSConfig(dsInfo, noReadFile, "localhost") |
||||
require.Error(t, err) |
||||
} |
@ -0,0 +1,105 @@ |
||||
package tls |
||||
|
||||
import ( |
||||
"crypto/rand" |
||||
"crypto/rsa" |
||||
"crypto/x509" |
||||
"crypto/x509/pkix" |
||||
"encoding/pem" |
||||
"math/big" |
||||
"os" |
||||
"time" |
||||
) |
||||
|
||||
func CreateRandomRootCertBytes() ([]byte, error) { |
||||
cert := x509.Certificate{ |
||||
SerialNumber: big.NewInt(42), |
||||
Subject: pkix.Name{ |
||||
CommonName: "test1", |
||||
}, |
||||
NotBefore: time.Now(), |
||||
NotAfter: time.Now().AddDate(10, 0, 0), |
||||
IsCA: true, |
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, |
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, |
||||
BasicConstraintsValid: true, |
||||
} |
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
bytes, err := x509.CreateCertificate(rand.Reader, &cert, &cert, &key.PublicKey, key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return pem.EncodeToMemory(&pem.Block{ |
||||
Type: "CERTIFICATE", |
||||
Bytes: bytes, |
||||
}), nil |
||||
} |
||||
|
||||
func CreateRandomClientCert() ([]byte, []byte, error) { |
||||
caKey, err := rsa.GenerateKey(rand.Reader, 2048) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
keyBytes := pem.EncodeToMemory(&pem.Block{ |
||||
Type: "RSA PRIVATE KEY", |
||||
Bytes: x509.MarshalPKCS1PrivateKey(key), |
||||
}) |
||||
|
||||
caCert := x509.Certificate{ |
||||
SerialNumber: big.NewInt(42), |
||||
Subject: pkix.Name{ |
||||
CommonName: "test1", |
||||
}, |
||||
NotBefore: time.Now(), |
||||
NotAfter: time.Now().AddDate(10, 0, 0), |
||||
IsCA: true, |
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, |
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, |
||||
BasicConstraintsValid: true, |
||||
} |
||||
|
||||
cert := x509.Certificate{ |
||||
SerialNumber: big.NewInt(2019), |
||||
Subject: pkix.Name{ |
||||
CommonName: "test1", |
||||
}, |
||||
NotBefore: time.Now(), |
||||
NotAfter: time.Now().AddDate(10, 0, 0), |
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, |
||||
KeyUsage: x509.KeyUsageDigitalSignature, |
||||
} |
||||
|
||||
certData, err := x509.CreateCertificate(rand.Reader, &cert, &caCert, &key.PublicKey, caKey) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
certBytes := pem.EncodeToMemory(&pem.Block{ |
||||
Type: "CERTIFICATE", |
||||
Bytes: certData, |
||||
}) |
||||
|
||||
return keyBytes, certBytes, nil |
||||
} |
||||
|
||||
func newMockReadFile(data map[string]([]byte)) ReadFileFunc { |
||||
return func(path string) ([]byte, error) { |
||||
bytes, ok := data[path] |
||||
if !ok { |
||||
return nil, os.ErrNotExist |
||||
} |
||||
return bytes, nil |
||||
} |
||||
} |
@ -1,249 +0,0 @@ |
||||
package postgres |
||||
|
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/log" |
||||
"github.com/grafana/grafana/pkg/tsdb/sqleng" |
||||
) |
||||
|
||||
var validateCertFunc = validateCertFilePaths |
||||
var writeCertFileFunc = writeCertFile |
||||
|
||||
type certFileType int |
||||
|
||||
const ( |
||||
rootCert = iota |
||||
clientCert |
||||
clientKey |
||||
) |
||||
|
||||
type tlsSettingsProvider interface { |
||||
getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) |
||||
} |
||||
|
||||
type datasourceCacheManager struct { |
||||
locker *locker |
||||
cache sync.Map |
||||
} |
||||
|
||||
type tlsManager struct { |
||||
logger log.Logger |
||||
dsCacheInstance datasourceCacheManager |
||||
dataPath string |
||||
} |
||||
|
||||
func newTLSManager(logger log.Logger, dataPath string) tlsSettingsProvider { |
||||
return &tlsManager{ |
||||
logger: logger, |
||||
dataPath: dataPath, |
||||
dsCacheInstance: datasourceCacheManager{locker: newLocker()}, |
||||
} |
||||
} |
||||
|
||||
type tlsSettings struct { |
||||
Mode string |
||||
ConfigurationMethod string |
||||
RootCertFile string |
||||
CertFile string |
||||
CertKeyFile string |
||||
} |
||||
|
||||
func (m *tlsManager) getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) { |
||||
tlsconfig := tlsSettings{ |
||||
Mode: dsInfo.JsonData.Mode, |
||||
} |
||||
|
||||
isTLSDisabled := (tlsconfig.Mode == "disable") |
||||
|
||||
if isTLSDisabled { |
||||
m.logger.Debug("Postgres TLS/SSL is disabled") |
||||
return tlsconfig, nil |
||||
} |
||||
|
||||
m.logger.Debug("Postgres TLS/SSL is enabled", "tlsMode", tlsconfig.Mode) |
||||
|
||||
tlsconfig.ConfigurationMethod = dsInfo.JsonData.ConfigurationMethod |
||||
tlsconfig.RootCertFile = dsInfo.JsonData.RootCertFile |
||||
tlsconfig.CertFile = dsInfo.JsonData.CertFile |
||||
tlsconfig.CertKeyFile = dsInfo.JsonData.CertKeyFile |
||||
|
||||
if tlsconfig.ConfigurationMethod == "file-content" { |
||||
if err := m.writeCertFiles(dsInfo, &tlsconfig); err != nil { |
||||
return tlsconfig, err |
||||
} |
||||
} else { |
||||
if err := validateCertFunc(tlsconfig.RootCertFile, tlsconfig.CertFile, tlsconfig.CertKeyFile); err != nil { |
||||
return tlsconfig, err |
||||
} |
||||
} |
||||
return tlsconfig, nil |
||||
} |
||||
|
||||
func (t certFileType) String() string { |
||||
switch t { |
||||
case rootCert: |
||||
return "root certificate" |
||||
case clientCert: |
||||
return "client certificate" |
||||
case clientKey: |
||||
return "client key" |
||||
default: |
||||
panic(fmt.Sprintf("Unrecognized certFileType %d", t)) |
||||
} |
||||
} |
||||
|
||||
func getFileName(dataDir string, fileType certFileType) string { |
||||
var filename string |
||||
switch fileType { |
||||
case rootCert: |
||||
filename = "root.crt" |
||||
case clientCert: |
||||
filename = "client.crt" |
||||
case clientKey: |
||||
filename = "client.key" |
||||
default: |
||||
panic(fmt.Sprintf("unrecognized certFileType %s", fileType.String())) |
||||
} |
||||
generatedFilePath := filepath.Join(dataDir, filename) |
||||
return generatedFilePath |
||||
} |
||||
|
||||
// writeCertFile writes a certificate file.
|
||||
func writeCertFile(logger log.Logger, fileContent string, generatedFilePath string) error { |
||||
fileContent = strings.TrimSpace(fileContent) |
||||
if fileContent != "" { |
||||
logger.Debug("Writing cert file", "path", generatedFilePath) |
||||
if err := os.WriteFile(generatedFilePath, []byte(fileContent), 0600); err != nil { |
||||
return err |
||||
} |
||||
// Make sure the file has the permissions expected by the Postgresql driver, otherwise it will bail
|
||||
if err := os.Chmod(generatedFilePath, 0600); err != nil { |
||||
return err |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
logger.Debug("Deleting cert file since no content is provided", "path", generatedFilePath) |
||||
exists, err := fileExists(generatedFilePath) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if exists { |
||||
if err := os.Remove(generatedFilePath); err != nil { |
||||
return fmt.Errorf("failed to remove %q: %w", generatedFilePath, err) |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (m *tlsManager) writeCertFiles(dsInfo sqleng.DataSourceInfo, tlsconfig *tlsSettings) error { |
||||
m.logger.Debug("Writing TLS certificate files to disk") |
||||
tlsRootCert := dsInfo.DecryptedSecureJSONData["tlsCACert"] |
||||
tlsClientCert := dsInfo.DecryptedSecureJSONData["tlsClientCert"] |
||||
tlsClientKey := dsInfo.DecryptedSecureJSONData["tlsClientKey"] |
||||
if tlsRootCert == "" && tlsClientCert == "" && tlsClientKey == "" { |
||||
m.logger.Debug("No TLS/SSL certificates provided") |
||||
} |
||||
|
||||
// Calculate all files path
|
||||
workDir := filepath.Join(m.dataPath, "tls", dsInfo.UID+"generatedTLSCerts") |
||||
tlsconfig.RootCertFile = getFileName(workDir, rootCert) |
||||
tlsconfig.CertFile = getFileName(workDir, clientCert) |
||||
tlsconfig.CertKeyFile = getFileName(workDir, clientKey) |
||||
|
||||
// Find datasource in the cache, if found, skip writing files
|
||||
cacheKey := strconv.Itoa(int(dsInfo.ID)) |
||||
m.dsCacheInstance.locker.RLock(cacheKey) |
||||
item, ok := m.dsCacheInstance.cache.Load(cacheKey) |
||||
m.dsCacheInstance.locker.RUnlock(cacheKey) |
||||
if ok { |
||||
if !item.(time.Time).Before(dsInfo.Updated) { |
||||
return nil |
||||
} |
||||
} |
||||
|
||||
m.dsCacheInstance.locker.Lock(cacheKey) |
||||
defer m.dsCacheInstance.locker.Unlock(cacheKey) |
||||
|
||||
item, ok = m.dsCacheInstance.cache.Load(cacheKey) |
||||
if ok { |
||||
if !item.(time.Time).Before(dsInfo.Updated) { |
||||
return nil |
||||
} |
||||
} |
||||
|
||||
// Write certification directory and files
|
||||
exists, err := fileExists(workDir) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if !exists { |
||||
if err := os.MkdirAll(workDir, 0700); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
if err = writeCertFileFunc(m.logger, tlsRootCert, tlsconfig.RootCertFile); err != nil { |
||||
return err |
||||
} |
||||
if err = writeCertFileFunc(m.logger, tlsClientCert, tlsconfig.CertFile); err != nil { |
||||
return err |
||||
} |
||||
if err = writeCertFileFunc(m.logger, tlsClientKey, tlsconfig.CertKeyFile); err != nil { |
||||
return err |
||||
} |
||||
|
||||
// we do not want to point to cert-files that do not exist
|
||||
if tlsRootCert == "" { |
||||
tlsconfig.RootCertFile = "" |
||||
} |
||||
|
||||
if tlsClientCert == "" { |
||||
tlsconfig.CertFile = "" |
||||
} |
||||
|
||||
if tlsClientKey == "" { |
||||
tlsconfig.CertKeyFile = "" |
||||
} |
||||
|
||||
// Update datasource cache
|
||||
m.dsCacheInstance.cache.Store(cacheKey, dsInfo.Updated) |
||||
return nil |
||||
} |
||||
|
||||
// validateCertFilePaths validates configured certificate file paths.
|
||||
func validateCertFilePaths(rootCert, clientCert, clientKey string) error { |
||||
for _, fpath := range []string{rootCert, clientCert, clientKey} { |
||||
if fpath == "" { |
||||
continue |
||||
} |
||||
exists, err := fileExists(fpath) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if !exists { |
||||
return fmt.Errorf("certificate file %q doesn't exist", fpath) |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Exists determines whether a file/directory exists or not.
|
||||
func fileExists(fpath string) (bool, error) { |
||||
_, err := os.Stat(fpath) |
||||
if err != nil { |
||||
if !os.IsNotExist(err) { |
||||
return false, err |
||||
} |
||||
return false, nil |
||||
} |
||||
|
||||
return true, nil |
||||
} |
@ -1,332 +0,0 @@ |
||||
package postgres |
||||
|
||||
import ( |
||||
"fmt" |
||||
"path/filepath" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend" |
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/log" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/grafana/grafana/pkg/tsdb/sqleng" |
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
|
||||
_ "github.com/lib/pq" |
||||
) |
||||
|
||||
var writeCertFileCallNum int |
||||
|
||||
// TestDataSourceCacheManager is to test the Cache manager
|
||||
func TestDataSourceCacheManager(t *testing.T) { |
||||
cfg := setting.NewCfg() |
||||
cfg.DataPath = t.TempDir() |
||||
mng := tlsManager{ |
||||
logger: backend.NewLoggerWith("logger", "tsdb.postgres"), |
||||
dsCacheInstance: datasourceCacheManager{locker: newLocker()}, |
||||
dataPath: cfg.DataPath, |
||||
} |
||||
jsonData := sqleng.JsonData{ |
||||
Mode: "verify-full", |
||||
ConfigurationMethod: "file-content", |
||||
} |
||||
secureJSONData := map[string]string{ |
||||
"tlsClientCert": "I am client certification", |
||||
"tlsClientKey": "I am client key", |
||||
"tlsCACert": "I am CA certification", |
||||
} |
||||
|
||||
updateTime := time.Now().Add(-5 * time.Minute) |
||||
|
||||
mockValidateCertFilePaths() |
||||
t.Cleanup(resetValidateCertFilePaths) |
||||
|
||||
t.Run("Check datasource cache creation", func(t *testing.T) { |
||||
var wg sync.WaitGroup |
||||
wg.Add(10) |
||||
for id := int64(1); id <= 10; id++ { |
||||
go func(id int64) { |
||||
ds := sqleng.DataSourceInfo{ |
||||
ID: id, |
||||
Updated: updateTime, |
||||
Database: "database", |
||||
JsonData: jsonData, |
||||
DecryptedSecureJSONData: secureJSONData, |
||||
UID: "testData", |
||||
} |
||||
s := tlsSettings{} |
||||
err := mng.writeCertFiles(ds, &s) |
||||
require.NoError(t, err) |
||||
wg.Done() |
||||
}(id) |
||||
} |
||||
wg.Wait() |
||||
|
||||
t.Run("check cache creation is succeed", func(t *testing.T) { |
||||
for id := int64(1); id <= 10; id++ { |
||||
updated, ok := mng.dsCacheInstance.cache.Load(strconv.Itoa(int(id))) |
||||
require.True(t, ok) |
||||
require.Equal(t, updateTime, updated) |
||||
} |
||||
}) |
||||
}) |
||||
|
||||
t.Run("Check datasource cache modification", func(t *testing.T) { |
||||
t.Run("check when version not changed, cache and files are not updated", func(t *testing.T) { |
||||
mockWriteCertFile() |
||||
t.Cleanup(resetWriteCertFile) |
||||
var wg1 sync.WaitGroup |
||||
wg1.Add(5) |
||||
for id := int64(1); id <= 5; id++ { |
||||
go func(id int64) { |
||||
ds := sqleng.DataSourceInfo{ |
||||
ID: 1, |
||||
Updated: updateTime, |
||||
Database: "database", |
||||
JsonData: jsonData, |
||||
DecryptedSecureJSONData: secureJSONData, |
||||
UID: "testData", |
||||
} |
||||
s := tlsSettings{} |
||||
err := mng.writeCertFiles(ds, &s) |
||||
require.NoError(t, err) |
||||
wg1.Done() |
||||
}(id) |
||||
} |
||||
wg1.Wait() |
||||
assert.Equal(t, writeCertFileCallNum, 0) |
||||
}) |
||||
|
||||
t.Run("cache is updated with the last datasource version", func(t *testing.T) { |
||||
dsV2 := sqleng.DataSourceInfo{ |
||||
ID: 1, |
||||
Updated: updateTime.Add(time.Minute), |
||||
Database: "database", |
||||
JsonData: jsonData, |
||||
DecryptedSecureJSONData: secureJSONData, |
||||
UID: "testData", |
||||
} |
||||
dsV3 := sqleng.DataSourceInfo{ |
||||
ID: 1, |
||||
Updated: updateTime.Add(2 * time.Minute), |
||||
Database: "database", |
||||
JsonData: jsonData, |
||||
DecryptedSecureJSONData: secureJSONData, |
||||
UID: "testData", |
||||
} |
||||
s := tlsSettings{} |
||||
err := mng.writeCertFiles(dsV2, &s) |
||||
require.NoError(t, err) |
||||
err = mng.writeCertFiles(dsV3, &s) |
||||
require.NoError(t, err) |
||||
version, ok := mng.dsCacheInstance.cache.Load("1") |
||||
require.True(t, ok) |
||||
require.Equal(t, updateTime.Add(2*time.Minute), version) |
||||
}) |
||||
}) |
||||
} |
||||
|
||||
// Test getFileName
|
||||
|
||||
func TestGetFileName(t *testing.T) { |
||||
testCases := []struct { |
||||
desc string |
||||
datadir string |
||||
fileType certFileType |
||||
expErr string |
||||
expectedGeneratedPath string |
||||
}{ |
||||
{ |
||||
desc: "Get File Name for root certification", |
||||
datadir: ".", |
||||
fileType: rootCert, |
||||
expectedGeneratedPath: "root.crt", |
||||
}, |
||||
{ |
||||
desc: "Get File Name for client certification", |
||||
datadir: ".", |
||||
fileType: clientCert, |
||||
expectedGeneratedPath: "client.crt", |
||||
}, |
||||
{ |
||||
desc: "Get File Name for client certification", |
||||
datadir: ".", |
||||
fileType: clientKey, |
||||
expectedGeneratedPath: "client.key", |
||||
}, |
||||
} |
||||
for _, tt := range testCases { |
||||
t.Run(tt.desc, func(t *testing.T) { |
||||
generatedPath := getFileName(tt.datadir, tt.fileType) |
||||
assert.Equal(t, tt.expectedGeneratedPath, generatedPath) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
// Test getTLSSettings.
|
||||
func TestGetTLSSettings(t *testing.T) { |
||||
cfg := setting.NewCfg() |
||||
cfg.DataPath = t.TempDir() |
||||
|
||||
mockValidateCertFilePaths() |
||||
t.Cleanup(resetValidateCertFilePaths) |
||||
|
||||
updatedTime := time.Now() |
||||
|
||||
testCases := []struct { |
||||
desc string |
||||
expErr string |
||||
jsonData sqleng.JsonData |
||||
secureJSONData map[string]string |
||||
uid string |
||||
tlsSettings tlsSettings |
||||
updated time.Time |
||||
}{ |
||||
{ |
||||
desc: "Custom TLS authentication disabled", |
||||
updated: updatedTime, |
||||
jsonData: sqleng.JsonData{ |
||||
Mode: "disable", |
||||
RootCertFile: "i/am/coding/ca.crt", |
||||
CertFile: "i/am/coding/client.crt", |
||||
CertKeyFile: "i/am/coding/client.key", |
||||
ConfigurationMethod: "file-path", |
||||
}, |
||||
tlsSettings: tlsSettings{Mode: "disable"}, |
||||
}, |
||||
{ |
||||
desc: "Custom TLS authentication with file path", |
||||
updated: updatedTime.Add(time.Minute), |
||||
jsonData: sqleng.JsonData{ |
||||
Mode: "verify-full", |
||||
ConfigurationMethod: "file-path", |
||||
RootCertFile: "i/am/coding/ca.crt", |
||||
CertFile: "i/am/coding/client.crt", |
||||
CertKeyFile: "i/am/coding/client.key", |
||||
}, |
||||
tlsSettings: tlsSettings{ |
||||
Mode: "verify-full", |
||||
ConfigurationMethod: "file-path", |
||||
RootCertFile: "i/am/coding/ca.crt", |
||||
CertFile: "i/am/coding/client.crt", |
||||
CertKeyFile: "i/am/coding/client.key", |
||||
}, |
||||
}, |
||||
{ |
||||
desc: "Custom TLS mode verify-full with certificate files content", |
||||
updated: updatedTime.Add(2 * time.Minute), |
||||
uid: "xxx", |
||||
jsonData: sqleng.JsonData{ |
||||
Mode: "verify-full", |
||||
ConfigurationMethod: "file-content", |
||||
}, |
||||
secureJSONData: map[string]string{ |
||||
"tlsCACert": "I am CA certification", |
||||
"tlsClientCert": "I am client certification", |
||||
"tlsClientKey": "I am client key", |
||||
}, |
||||
tlsSettings: tlsSettings{ |
||||
Mode: "verify-full", |
||||
ConfigurationMethod: "file-content", |
||||
RootCertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"), |
||||
CertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"), |
||||
CertKeyFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"), |
||||
}, |
||||
}, |
||||
{ |
||||
desc: "Custom TLS mode verify-ca with no client certificates with certificate files content", |
||||
updated: updatedTime.Add(3 * time.Minute), |
||||
uid: "xxx", |
||||
jsonData: sqleng.JsonData{ |
||||
Mode: "verify-ca", |
||||
ConfigurationMethod: "file-content", |
||||
}, |
||||
secureJSONData: map[string]string{ |
||||
"tlsCACert": "I am CA certification", |
||||
}, |
||||
tlsSettings: tlsSettings{ |
||||
Mode: "verify-ca", |
||||
ConfigurationMethod: "file-content", |
||||
RootCertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"), |
||||
CertFile: "", |
||||
CertKeyFile: "", |
||||
}, |
||||
}, |
||||
{ |
||||
desc: "Custom TLS mode require with client certificates and no root certificate with certificate files content", |
||||
updated: updatedTime.Add(4 * time.Minute), |
||||
uid: "xxx", |
||||
jsonData: sqleng.JsonData{ |
||||
Mode: "require", |
||||
ConfigurationMethod: "file-content", |
||||
}, |
||||
secureJSONData: map[string]string{ |
||||
"tlsClientCert": "I am client certification", |
||||
"tlsClientKey": "I am client key", |
||||
}, |
||||
tlsSettings: tlsSettings{ |
||||
Mode: "require", |
||||
ConfigurationMethod: "file-content", |
||||
RootCertFile: "", |
||||
CertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"), |
||||
CertKeyFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"), |
||||
}, |
||||
}, |
||||
} |
||||
for _, tt := range testCases { |
||||
t.Run(tt.desc, func(t *testing.T) { |
||||
var settings tlsSettings |
||||
var err error |
||||
mng := tlsManager{ |
||||
logger: backend.NewLoggerWith("logger", "tsdb.postgres"), |
||||
dsCacheInstance: datasourceCacheManager{locker: newLocker()}, |
||||
dataPath: cfg.DataPath, |
||||
} |
||||
|
||||
ds := sqleng.DataSourceInfo{ |
||||
JsonData: tt.jsonData, |
||||
DecryptedSecureJSONData: tt.secureJSONData, |
||||
UID: tt.uid, |
||||
Updated: tt.updated, |
||||
} |
||||
|
||||
settings, err = mng.getTLSSettings(ds) |
||||
|
||||
if tt.expErr == "" { |
||||
require.NoError(t, err, tt.desc) |
||||
assert.Equal(t, tt.tlsSettings, settings) |
||||
} else { |
||||
require.Error(t, err, tt.desc) |
||||
assert.True(t, strings.HasPrefix(err.Error(), tt.expErr), |
||||
fmt.Sprintf("%s: %q doesn't start with %q", tt.desc, err, tt.expErr)) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func mockValidateCertFilePaths() { |
||||
validateCertFunc = func(rootCert, clientCert, clientKey string) error { |
||||
return nil |
||||
} |
||||
} |
||||
|
||||
func resetValidateCertFilePaths() { |
||||
validateCertFunc = validateCertFilePaths |
||||
} |
||||
|
||||
func mockWriteCertFile() { |
||||
writeCertFileCallNum = 0 |
||||
writeCertFileFunc = func(logger log.Logger, fileContent string, generatedFilePath string) error { |
||||
writeCertFileCallNum++ |
||||
return nil |
||||
} |
||||
} |
||||
|
||||
func resetWriteCertFile() { |
||||
writeCertFileCallNum = 0 |
||||
writeCertFileFunc = writeCertFile |
||||
} |
Loading…
Reference in new issue