diff --git a/connection_producer.go b/connection_producer.go index 8f7363a..4675f63 100644 --- a/connection_producer.go +++ b/connection_producer.go @@ -40,6 +40,12 @@ type YugabyteDBConnectionProducer struct { LoadBalance bool `json:"load_balance" mapstructure:"load_balance" structs:"load_balance"` YbServersRefreshInterval int `json:"yb_servers_refresh_interval" mapstructure:"yb_servers_refresh_interval" structs:"yb_servers_refresh_interval"` TopologyKeys string `json:"topology_keys" mapstructure:"topology_keys" structs:"topology_keys"` + SslMode string `json:"sslmode" mapstructure:"sslmode" structs:"sslmode"` + SslRootCert string `json:"sslrootcert" mapstructure:"sslrootcert" structs:"sslrootcert"` + SslSni string `json:"sslsni" mapstructure:"sslsni" structs:"sslsni"` + SslKey string `json:"sslkey" mapstructure:"sslkey" structs:"sslkey"` + SslCert string `json:"sslcert" mapstructure:"sslcert" structs:"sslcert"` + SslPassword string `json:"sslpassword" mapstructure:"sslpassword" structs:"sslpassword"` Type string RawConfig map[string]interface{} @@ -134,13 +140,37 @@ func (c *YugabyteDBConnectionProducer) Connection(ctx context.Context) (interfac c.db.Close() } + if c.SslMode == "" { + c.SslMode = "prefer" //default sslmode + } + var conn string if c.TopologyKeys != "" { conn = fmt.Sprintf("host=%s port=%d user=%s "+ - "password=%s dbname=%s sslmode=disable load_balance=%v yb_servers_refresh_interval=%d topology_keys=%s ", c.Host, c.Port, c.Username, c.Password, c.DbName, c.LoadBalance, c.YbServersRefreshInterval, c.TopologyKeys) + "password=%s dbname=%s sslmode=%s load_balance=%v yb_servers_refresh_interval=%d topology_keys=%s ", c.Host, c.Port, c.Username, c.Password, c.DbName, c.SslMode, c.LoadBalance, c.YbServersRefreshInterval, c.TopologyKeys) } else { conn = fmt.Sprintf("host=%s port=%d user=%s "+ - "password=%s dbname=%s sslmode=disable load_balance=%v yb_servers_refresh_interval=%d ", c.Host, c.Port, c.Username, c.Password, c.DbName, c.LoadBalance, c.YbServersRefreshInterval) + "password=%s dbname=%s sslmode=%s load_balance=%v yb_servers_refresh_interval=%d ", c.Host, c.Port, c.Username, c.Password, c.DbName, c.SslMode, c.LoadBalance, c.YbServersRefreshInterval) + } + + if c.SslRootCert != "" { + conn = fmt.Sprintf(conn + fmt.Sprintf("sslrootcert=%s ", c.SslRootCert)) + } + + if c.SslCert != "" { + conn = fmt.Sprintf(conn + fmt.Sprintf("sslcert=%s ", c.SslCert)) + } + + if c.SslKey != "" { + conn = fmt.Sprintf(conn + fmt.Sprintf("sslkey=%s ", c.SslKey)) + } + + if c.SslPassword != "" { + conn = fmt.Sprintf(conn + fmt.Sprintf("sslpassword=%s ", c.SslPassword)) + } + + if c.SslSni != "" { + conn = fmt.Sprintf(conn + fmt.Sprintf("sslsni=%s", c.SslSni)) } if len(c.ConnectionURL) != 0 { diff --git a/ysql_test.go b/ysql_test.go index 7cd769e..c4f2c8d 100644 --- a/ysql_test.go +++ b/ysql_test.go @@ -440,6 +440,7 @@ func TestUpdateUser_Expiration(t *testing.T) { // Shared test container for speed - there should not be any overlap between the tests db, cleanup := getysql(t, nil) defer cleanup() + db.db.SetMaxOpenConns(1) for name, test := range tests { t.Run(name, func(t *testing.T) { diff --git a/ysqlhelper/ysqlhelper.go b/ysqlhelper/ysqlhelper.go index 64c1ab4..f7d4ce4 100644 --- a/ysqlhelper/ysqlhelper.go +++ b/ysqlhelper/ysqlhelper.go @@ -18,6 +18,7 @@ import ( "fmt" "net/url" "testing" + "time" "github.com/hashicorp/vault/sdk/helper/docker" @@ -51,6 +52,8 @@ func PrepareTestContainer(t *testing.T, version string) (func(), string) { } func connectYugabyteDB(ctx context.Context, host string, port int) (docker.ServiceConfig, error) { + fmt.Println("Waiting for container to be up...") + time.Sleep(30 * time.Second) u := url.URL{ Scheme: "postgres", User: url.UserPassword("yugabyte", "testsecret"),