diff --git a/pkg/services/sqlstore/sqlstore.go b/pkg/services/sqlstore/sqlstore.go index d0e93177d8b..7673d80ea91 100644 --- a/pkg/services/sqlstore/sqlstore.go +++ b/pkg/services/sqlstore/sqlstore.go @@ -21,6 +21,7 @@ import ( "github.com/grafana/grafana/pkg/services/sqlstore/migrator" "github.com/grafana/grafana/pkg/services/sqlstore/sqlutil" "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/util" "github.com/go-sql-driver/mysql" "github.com/go-xorm/xorm" @@ -222,13 +223,9 @@ func (ss *SqlStore) buildConnectionString() (string, error) { cnnstr += "&tls=custom" } case migrator.POSTGRES: - var host, port = "127.0.0.1", "5432" - fields := strings.Split(ss.dbCfg.Host, ":") - if len(fields) > 0 && len(strings.TrimSpace(fields[0])) > 0 { - host = fields[0] - } - if len(fields) > 1 && len(strings.TrimSpace(fields[1])) > 0 { - port = fields[1] + host, port, err := util.SplitIpPort(ss.dbCfg.Host, "5432") + if err != nil { + return "", err } if ss.dbCfg.Pwd == "" { ss.dbCfg.Pwd = "''" diff --git a/pkg/services/sqlstore/sqlstore_test.go b/pkg/services/sqlstore/sqlstore_test.go new file mode 100644 index 00000000000..76402d6a50d --- /dev/null +++ b/pkg/services/sqlstore/sqlstore_test.go @@ -0,0 +1,101 @@ +package sqlstore + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" + + "github.com/grafana/grafana/pkg/setting" +) + +type sqlStoreTest struct { + name string + dbType string + dbHost string + connStrValues []string +} + +var sqlStoreTestCases = []sqlStoreTest{ + { + name: "MySQL IPv4", + dbType: "mysql", + dbHost: "1.2.3.4:5678", + connStrValues: []string{"tcp(1.2.3.4:5678)"}, + }, + { + name: "Postgres IPv4", + dbType: "postgres", + dbHost: "1.2.3.4:5678", + connStrValues: []string{"host=1.2.3.4", "port=5678"}, + }, + { + name: "Postgres IPv4 (Default Port)", + dbType: "postgres", + dbHost: "1.2.3.4", + connStrValues: []string{"host=1.2.3.4", "port=5432"}, + }, + { + name: "MySQL IPv4 (Default Port)", + dbType: "mysql", + dbHost: "1.2.3.4", + connStrValues: []string{"tcp(1.2.3.4)"}, + }, + { + name: "MySQL IPv6", + dbType: "mysql", + dbHost: "[fe80::24e8:31b2:91df:b177]:1234", + connStrValues: []string{"tcp([fe80::24e8:31b2:91df:b177]:1234)"}, + }, + { + name: "Postgres IPv6", + dbType: "postgres", + dbHost: "[fe80::24e8:31b2:91df:b177]:1234", + connStrValues: []string{"host=fe80::24e8:31b2:91df:b177", "port=1234"}, + }, + { + name: "MySQL IPv6 (Default Port)", + dbType: "mysql", + dbHost: "::1", + connStrValues: []string{"tcp(::1)"}, + }, + { + name: "Postgres IPv6 (Default Port)", + dbType: "postgres", + dbHost: "::1", + connStrValues: []string{"host=::1", "port=5432"}, + }, +} + +func TestSqlConnectionString(t *testing.T) { + Convey("Testing SQL Connection Strings", t, func() { + t.Helper() + + for _, testCase := range sqlStoreTestCases { + Convey(testCase.name, func() { + sqlstore := &SqlStore{} + sqlstore.Cfg = makeSqlStoreTestConfig(testCase.dbType, testCase.dbHost) + sqlstore.readConfig() + + connStr, err := sqlstore.buildConnectionString() + + So(err, ShouldBeNil) + for _, connSubStr := range testCase.connStrValues { + So(connStr, ShouldContainSubstring, connSubStr) + } + }) + } + }) +} + +func makeSqlStoreTestConfig(dbType string, host string) *setting.Cfg { + cfg := setting.NewCfg() + + sec, _ := cfg.Raw.NewSection("database") + sec.NewKey("type", dbType) + sec.NewKey("host", host) + sec.NewKey("user", "user") + sec.NewKey("name", "test_db") + sec.NewKey("password", "pass") + + return cfg +} diff --git a/pkg/tsdb/mssql/mssql.go b/pkg/tsdb/mssql/mssql.go index 469d6baa5de..bb2e06ed673 100644 --- a/pkg/tsdb/mssql/mssql.go +++ b/pkg/tsdb/mssql/mssql.go @@ -4,13 +4,13 @@ import ( "database/sql" "fmt" "strconv" - "strings" _ "github.com/denisenkom/go-mssqldb" "github.com/go-xorm/core" "github.com/grafana/grafana/pkg/log" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/tsdb" + "github.com/grafana/grafana/pkg/util" ) func init() { @@ -20,7 +20,10 @@ func init() { func newMssqlQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoint, error) { logger := log.New("tsdb.mssql") - cnnstr := generateConnectionString(datasource) + cnnstr, err := generateConnectionString(datasource) + if err != nil { + return nil, err + } logger.Debug("getEngine", "connection", cnnstr) config := tsdb.SqlQueryEndpointConfiguration{ @@ -37,7 +40,7 @@ func newMssqlQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoin return tsdb.NewSqlQueryEndpoint(&config, &rowTransformer, newMssqlMacroEngine(), logger) } -func generateConnectionString(datasource *models.DataSource) string { +func generateConnectionString(datasource *models.DataSource) (string, error) { password := "" for key, value := range datasource.SecureJsonData.Decrypt() { if key == "password" { @@ -46,12 +49,11 @@ func generateConnectionString(datasource *models.DataSource) string { } } - hostParts := strings.Split(datasource.Url, ":") - if len(hostParts) < 2 { - hostParts = append(hostParts, "1433") + server, port, err := util.SplitIpPort(datasource.Url, "1433") + if err != nil { + return "", err } - server, port := hostParts[0], hostParts[1] encrypt := datasource.JsonData.Get("encrypt").MustString("false") connStr := fmt.Sprintf("server=%s;port=%s;database=%s;user id=%s;password=%s;", server, @@ -63,7 +65,7 @@ func generateConnectionString(datasource *models.DataSource) string { if encrypt != "false" { connStr += fmt.Sprintf("encrypt=%s;", encrypt) } - return connStr + return connStr, nil } type mssqlRowTransformer struct { diff --git a/pkg/util/ip.go b/pkg/util/ip.go new file mode 100644 index 00000000000..351abd7a03b --- /dev/null +++ b/pkg/util/ip.go @@ -0,0 +1,24 @@ +package util + +import ( + "net" +) + +func SplitIpPort(ipStr string, portDefault string) (ip string, port string, err error) { + ipAddr := net.ParseIP(ipStr) + + if ipAddr == nil { + // Port was included + ip, port, err = net.SplitHostPort(ipStr) + + if err != nil { + return "", "", err + } + } else { + // No port was included + ip = ipAddr.String() + port = portDefault + } + + return ip, port, nil +} diff --git a/pkg/util/ip_test.go b/pkg/util/ip_test.go new file mode 100644 index 00000000000..f938c182b02 --- /dev/null +++ b/pkg/util/ip_test.go @@ -0,0 +1,43 @@ +package util + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestSplitIpPort(t *testing.T) { + + Convey("When parsing an IPv4 without explicit port", t, func() { + ip, port, err := SplitIpPort("1.2.3.4", "5678") + + So(err, ShouldEqual, nil) + So(ip, ShouldEqual, "1.2.3.4") + So(port, ShouldEqual, "5678") + }) + + Convey("When parsing an IPv6 without explicit port", t, func() { + ip, port, err := SplitIpPort("::1", "5678") + + So(err, ShouldEqual, nil) + So(ip, ShouldEqual, "::1") + So(port, ShouldEqual, "5678") + }) + + Convey("When parsing an IPv4 with explicit port", t, func() { + ip, port, err := SplitIpPort("1.2.3.4:56", "78") + + So(err, ShouldEqual, nil) + So(ip, ShouldEqual, "1.2.3.4") + So(port, ShouldEqual, "56") + }) + + Convey("When parsing an IPv6 with explicit port", t, func() { + ip, port, err := SplitIpPort("[::1]:56", "78") + + So(err, ShouldEqual, nil) + So(ip, ShouldEqual, "::1") + So(port, ShouldEqual, "56") + }) + +}