mirror of https://github.com/grafana/grafana
Unified Storage: Fixes bug with postgres connection string and adds tests (#87656)
parent
8c585c4a79
commit
3bf39d6d9a
@ -0,0 +1,71 @@ |
||||
package dbimpl |
||||
|
||||
import ( |
||||
"fmt" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/grafana/grafana/pkg/util" |
||||
"xorm.io/xorm" |
||||
) |
||||
|
||||
func getEngineMySQL(cfgSection *setting.DynamicSection) (*xorm.Engine, error) { |
||||
dbHost := cfgSection.Key("db_host").MustString("") |
||||
dbName := cfgSection.Key("db_name").MustString("") |
||||
dbUser := cfgSection.Key("db_user").MustString("") |
||||
dbPass := cfgSection.Key("db_pass").MustString("") |
||||
|
||||
// TODO: support all mysql connection options
|
||||
protocol := "tcp" |
||||
if strings.HasPrefix(dbHost, "/") { |
||||
protocol = "unix" |
||||
} |
||||
|
||||
connectionString := connectionStringMySQL(dbUser, dbPass, protocol, dbHost, dbName) |
||||
|
||||
engine, err := xorm.NewEngine("mysql", connectionString) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
engine.SetMaxOpenConns(0) |
||||
engine.SetMaxIdleConns(2) |
||||
engine.SetConnMaxLifetime(time.Second * time.Duration(14400)) |
||||
|
||||
return engine, nil |
||||
} |
||||
|
||||
func getEnginePostgres(cfgSection *setting.DynamicSection) (*xorm.Engine, error) { |
||||
dbHost := cfgSection.Key("db_host").MustString("") |
||||
dbName := cfgSection.Key("db_name").MustString("") |
||||
dbUser := cfgSection.Key("db_user").MustString("") |
||||
dbPass := cfgSection.Key("db_pass").MustString("") |
||||
|
||||
// TODO: support all postgres connection options
|
||||
dbSslMode := cfgSection.Key("db_sslmode").MustString("disable") |
||||
|
||||
addr, err := util.SplitHostPortDefault(dbHost, "127.0.0.1", "5432") |
||||
if err != nil { |
||||
return nil, fmt.Errorf("invalid host specifier '%s': %w", dbHost, err) |
||||
} |
||||
|
||||
connectionString := connectionStringPostgres(dbUser, dbPass, addr.Host, addr.Port, dbName, dbSslMode) |
||||
|
||||
engine, err := xorm.NewEngine("postgres", connectionString) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return engine, nil |
||||
} |
||||
|
||||
func connectionStringMySQL(user, password, protocol, host, dbName string) string { |
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true", user, password, protocol, host, dbName) |
||||
} |
||||
|
||||
func connectionStringPostgres(user, password, host, port, dbName, sslMode string) string { |
||||
return fmt.Sprintf( |
||||
"user=%s password=%s host=%s port=%s dbname=%s sslmode=%s", // sslcert='%s' sslkey='%s' sslrootcert='%s'",
|
||||
user, password, host, port, dbName, sslMode, // ss.dbCfg.ClientCertPath, ss.dbCfg.ClientKeyPath, ss.dbCfg.CaCertPath
|
||||
) |
||||
} |
@ -0,0 +1,54 @@ |
||||
package dbimpl |
||||
|
||||
import ( |
||||
"strings" |
||||
"testing" |
||||
|
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestGetEnginePostgresFromConfig(t *testing.T) { |
||||
cfg := setting.NewCfg() |
||||
s, err := cfg.Raw.NewSection("entity_api") |
||||
require.NoError(t, err) |
||||
s.Key("db_type").SetValue("mysql") |
||||
s.Key("db_host").SetValue("localhost") |
||||
s.Key("db_name").SetValue("grafana") |
||||
s.Key("db_user").SetValue("user") |
||||
s.Key("db_password").SetValue("password") |
||||
|
||||
engine, err := getEnginePostgres(cfg.SectionWithEnvOverrides("entity_api")) |
||||
|
||||
assert.NotNil(t, engine) |
||||
assert.NoError(t, err) |
||||
assert.True(t, strings.Contains(engine.DataSourceName(), "dbname=grafana")) |
||||
} |
||||
|
||||
func TestGetEngineMySQLFromConfig(t *testing.T) { |
||||
cfg := setting.NewCfg() |
||||
s, err := cfg.Raw.NewSection("entity_api") |
||||
require.NoError(t, err) |
||||
s.Key("db_type").SetValue("mysql") |
||||
s.Key("db_host").SetValue("localhost") |
||||
s.Key("db_name").SetValue("grafana") |
||||
s.Key("db_user").SetValue("user") |
||||
s.Key("db_password").SetValue("password") |
||||
|
||||
engine, err := getEngineMySQL(cfg.SectionWithEnvOverrides("entity_api")) |
||||
|
||||
assert.NotNil(t, engine) |
||||
assert.NoError(t, err) |
||||
} |
||||
|
||||
func TestGetConnectionStrings(t *testing.T) { |
||||
t.Run("generate mysql connection string", func(t *testing.T) { |
||||
expected := "user:password@tcp(localhost)/grafana?collation=utf8mb4_unicode_ci&allowNativePasswords=true&clientFoundRows=true" |
||||
assert.Equal(t, expected, connectionStringMySQL("user", "password", "tcp", "localhost", "grafana")) |
||||
}) |
||||
t.Run("generate postgres connection string", func(t *testing.T) { |
||||
expected := "user=user password=password host=localhost port=5432 dbname=grafana sslmode=disable" |
||||
assert.Equal(t, expected, connectionStringPostgres("user", "password", "localhost", "5432", "grafana", "disable")) |
||||
}) |
||||
} |
Loading…
Reference in new issue