diff --git a/pkg/components/simplejson/simplejson.go b/pkg/components/simplejson/simplejson.go index 4e2a4f36d6c..82b61af71ff 100644 --- a/pkg/components/simplejson/simplejson.go +++ b/pkg/components/simplejson/simplejson.go @@ -52,7 +52,7 @@ func New() *Json { } } -// New returns a pointer to a new, empty `Json` object +// NewFromAny returns a pointer to a new `Json` object with provided data. func NewFromAny(data interface{}) *Json { return &Json{data: data} } diff --git a/pkg/tsdb/postgres/postgres.go b/pkg/tsdb/postgres/postgres.go index 36c53150dfe..de2628e6e44 100644 --- a/pkg/tsdb/postgres/postgres.go +++ b/pkg/tsdb/postgres/postgres.go @@ -3,11 +3,11 @@ package postgres import ( "database/sql" "fmt" - "net/url" "strconv" "strings" "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/util/errutil" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" @@ -47,20 +47,51 @@ func newPostgresQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndp timescaledb := datasource.JsonData.Get("timescaledb").MustBool(false) endpoint, err := sqleng.NewSqlQueryEndpoint(&config, &queryResultTransformer, newPostgresMacroEngine(timescaledb), logger) - if err == nil { - logger.Debug("Successfully connected to Postgres") - } else { + if err != nil { logger.Debug("Failed connecting to Postgres", "err", err) + return nil, err } + + logger.Debug("Successfully connected to Postgres") return endpoint, err } +// escape single quotes and backslashes in Postgres connection string parameters. +func escape(input string) string { + return strings.Replace(strings.Replace(input, `\`, `\\`, -1), "'", `\'`, -1) +} + func generateConnectionString(datasource *models.DataSource, logger log.Logger) (string, error) { sslMode := strings.TrimSpace(strings.ToLower(datasource.JsonData.Get("sslmode").MustString("verify-full"))) isSSLDisabled := sslMode == "disable" - // Always pass SSL mode - sslOpts := fmt.Sprintf("sslmode=%s", url.QueryEscape(sslMode)) + var host string + var port int + if strings.HasPrefix(datasource.Url, "/") { + host = datasource.Url + logger.Debug("Generating connection string with Unix socket specifier", "socket", host) + } else { + sp := strings.SplitN(datasource.Url, ":", 2) + host = sp[0] + if len(sp) > 1 { + var err error + port, err = strconv.Atoi(sp[1]) + if err != nil { + return "", errutil.Wrapf(err, "invalid port in host specifier %q", sp[1]) + } + + logger.Debug("Generating connection string with network host/port pair", "host", host, "port", port) + } else { + logger.Debug("Generating connection string with network host", "host", host) + } + } + + connStr := fmt.Sprintf("user='%s' password='%s' host='%s' dbname='%s' sslmode='%s'", + escape(datasource.User), escape(datasource.DecryptedPassword()), escape(host), escape(datasource.Database), + escape(sslMode)) + if port > 0 { + connStr += fmt.Sprintf(" port=%d", port) + } if isSSLDisabled { logger.Debug("Postgres SSL is disabled") } else { @@ -69,7 +100,7 @@ func generateConnectionString(datasource *models.DataSource, logger log.Logger) // Attach root certificate if provided if sslRootCert := datasource.JsonData.Get("sslRootCertFile").MustString(""); sslRootCert != "" { logger.Debug("Setting server root certificate", "sslRootCert", sslRootCert) - sslOpts = fmt.Sprintf("%s&sslrootcert=%s", sslOpts, url.QueryEscape(sslRootCert)) + connStr += fmt.Sprintf(" sslrootcert='%s'", sslRootCert) } // Attach client certificate and key if both are provided @@ -77,20 +108,14 @@ func generateConnectionString(datasource *models.DataSource, logger log.Logger) sslKey := datasource.JsonData.Get("sslKeyFile").MustString("") if sslCert != "" && sslKey != "" { logger.Debug("Setting SSL client auth", "sslCert", sslCert, "sslKey", sslKey) - sslOpts = fmt.Sprintf("%s&sslcert=%s&sslkey=%s", sslOpts, url.QueryEscape(sslCert), url.QueryEscape(sslKey)) + connStr += fmt.Sprintf(" sslcert='%s' sslkey='%s'", sslCert, sslKey) } else if sslCert != "" || sslKey != "" { return "", fmt.Errorf("SSL client certificate and key must both be specified") } } - u := &url.URL{ - Scheme: "postgres", - User: url.UserPassword(datasource.User, datasource.DecryptedPassword()), - Host: datasource.Url, Path: datasource.Database, - RawQuery: sslOpts, - } - - return u.String(), nil + logger.Debug("Generated Postgres connection string successfully") + return connStr, nil } type postgresQueryResultTransformer struct { diff --git a/pkg/tsdb/postgres/postgres_test.go b/pkg/tsdb/postgres/postgres_test.go index 8ccbaf1ab5d..bff66a9dd88 100644 --- a/pkg/tsdb/postgres/postgres_test.go +++ b/pkg/tsdb/postgres/postgres_test.go @@ -10,17 +10,109 @@ import ( "github.com/grafana/grafana/pkg/components/securejsondata" "github.com/grafana/grafana/pkg/components/simplejson" + "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore/sqlutil" "github.com/grafana/grafana/pkg/tsdb" "github.com/grafana/grafana/pkg/tsdb/sqleng" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "xorm.io/xorm" _ "github.com/lib/pq" . "github.com/smartystreets/goconvey/convey" ) +// Test generateConnectionString. +func TestGenerateConnectionString(t *testing.T) { + logger := log.New("tsdb.postgres") + + testCases := []struct { + desc string + host string + user string + password string + database string + sslMode string + expConnStr string + expErr string + }{ + { + desc: "Unix socket host", + host: "/var/run/postgresql", + user: "user", + password: "password", + database: "database", + expConnStr: "user='user' password='password' host='/var/run/postgresql' dbname='database' sslmode='verify-full'", + }, + { + desc: "TCP host", + host: "host", + user: "user", + password: "password", + database: "database", + expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full'", + }, + { + desc: "TCP/port host", + host: "host:1234", + user: "user", + password: "password", + database: "database", + expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full' port=1234", + }, + { + desc: "Invalid port", + host: "host:invalid", + user: "user", + database: "database", + expErr: "invalid port in host specifier", + }, + { + desc: "Password with single quote and backslash", + host: "host", + user: "user", + password: `p'\assword`, + database: "database", + expConnStr: `user='user' password='p\'\\assword' host='host' dbname='database' sslmode='verify-full'`, + }, + { + desc: "Custom SSL mode", + host: "host", + user: "user", + password: "password", + database: "database", + sslMode: "disable", + expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='disable'", + }, + } + for _, tt := range testCases { + t.Run(tt.desc, func(t *testing.T) { + data := map[string]interface{}{} + if tt.sslMode != "" { + data["sslmode"] = tt.sslMode + } + ds := &models.DataSource{ + Url: tt.host, + User: tt.user, + Password: tt.password, + Database: tt.database, + JsonData: simplejson.NewFromAny(data), + } + connStr, err := generateConnectionString(ds, logger) + if tt.expErr == "" { + require.NoError(t, err, tt.desc) + assert.Equal(t, tt.expConnStr, connStr, tt.desc) + } else { + require.Error(t, err, tt.desc) + assert.True(t, strings.HasPrefix(err.Error(), tt.expErr), + fmt.Sprintf("%s: %q doesn't start with %q", tt.desc, err, tt.expErr)) + } + }) + } +} + // To run this test, set runPostgresTests=true // Or from the commandline: GRAFANA_TEST_DB=postgres go test -v ./pkg/tsdb/postgres // The tests require a PostgreSQL db named grafanadstest and a user/password grafanatest/grafanatest!