Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support to config the version white list #300

Merged
merged 4 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func logoutAndClose(conn *connection, sessionID int64) {
func TestConnection(t *testing.T) {
hostAddress := HostAddress{Host: address, Port: port}
conn := newConnection(hostAddress)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "")
if err != nil {
t.Fatalf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error())
}
Expand Down Expand Up @@ -122,7 +122,7 @@ func TestConnection(t *testing.T) {
func TestConnectionIPv6(t *testing.T) {
hostAddress := HostAddress{Host: addressIPv6, Port: port}
conn := newConnection(hostAddress)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "")
if err != nil {
t.Fatalf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error())
}
Expand Down Expand Up @@ -245,6 +245,22 @@ func TestConfigs(t *testing.T) {
}
}

func TestVersionVerify(t *testing.T) {
const (
username = "root"
password = "nebula"
)

hostAddress := HostAddress{Host: address, Port: port}

conn := newConnection(hostAddress)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "INVALID_VERSION")
if err != nil {
assert.Contains(t, err.Error(), "incompatible version between client and server")
}
defer conn.close()
}

func TestAuthentication(t *testing.T) {
const (
username = "dummy"
Expand All @@ -254,7 +270,7 @@ func TestAuthentication(t *testing.T) {
hostAddress := HostAddress{Host: address, Port: port}

conn := newConnection(hostAddress)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "")
if err != nil {
t.Fatalf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error())
}
Expand Down Expand Up @@ -1421,7 +1437,7 @@ func prepareSpace(spaceName string) error {
conn := newConnection(hostAddress)
testPoolConfig := GetDefaultConf()

err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "")
if err != nil {
return fmt.Errorf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error())
}
Expand Down Expand Up @@ -1458,7 +1474,7 @@ func dropSpace(spaceName string) error {
conn := newConnection(hostAddress)
testPoolConfig := GetDefaultConf()

err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil)
err := conn.open(hostAddress, testPoolConfig.TimeOut, nil, false, nil, "")
if err != nil {
return fmt.Errorf("fail to open connection, address: %s, port: %d, %s", address, port, err.Error())
}
Expand Down
11 changes: 11 additions & 0 deletions configs.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ type PoolConfig struct {
UseHTTP2 bool
// HttpHeader is the http headers for the connection when using HTTP2
HttpHeader http.Header
// client version, make sure the client version is in the white list of NebulaGraph server
Version string
}

// validateConf validates config
Expand Down Expand Up @@ -64,6 +66,7 @@ func GetDefaultConf() PoolConfig {
MaxConnPoolSize: 10,
MinConnPoolSize: 0,
UseHTTP2: false,
Version: "",
}
}

Expand Down Expand Up @@ -138,6 +141,8 @@ type SessionPoolConf struct {
useHTTP2 bool
// httpHeader is the http headers for the connection
httpHeader http.Header
// client version, make sure the client version is in the white list of NebulaGraph server
version string
}

type SessionPoolConfOption func(*SessionPoolConf)
Expand Down Expand Up @@ -214,6 +219,12 @@ func WithHttpHeader(header http.Header) SessionPoolConfOption {
}
}

func WithVersion(version string) SessionPoolConfOption {
return func(conf *SessionPoolConf) {
conf.version = version
}
}

func (conf *SessionPoolConf) checkMandatoryFields() error {
// Check mandatory fields
if conf.username == "" {
Expand Down
10 changes: 8 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type connection struct {
sslConfig *tls.Config
useHTTP2 bool
httpHeader http.Header
version string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please run go fmt to make aligned

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

graph *graph.GraphServiceClient
}

Expand All @@ -39,19 +40,21 @@ func newConnection(severAddress HostAddress) *connection {
timeout: 0 * time.Millisecond,
returnedAt: time.Now(),
sslConfig: nil,
version: "",
graph: nil,
}
}

// open opens a transport for the connection
// if sslConfig is not nil, an SSL transport will be created
func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslConfig *tls.Config,
useHTTP2 bool, httpHeader http.Header) error {
useHTTP2 bool, httpHeader http.Header, version string) error {
ip := hostAddress.Host
port := hostAddress.Port
newAdd := net.JoinHostPort(ip, strconv.Itoa(port))
cn.timeout = timeout
cn.useHTTP2 = useHTTP2
cn.version = version

var (
err error
Expand Down Expand Up @@ -133,6 +136,9 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo

func (cn *connection) verifyClientVersion() error {
req := graph.NewVerifyClientVersionReq()
if cn.version != "" {
req.SetVersion([]byte(cn.version))
}
resp, err := cn.graph.VerifyClientVersion(req)
if err != nil {
cn.close()
Expand All @@ -150,7 +156,7 @@ func (cn *connection) verifyClientVersion() error {
// When the timeout occurs, the connection will be reopened to avoid the impact of the message.
func (cn *connection) reopen() error {
cn.close()
return cn.open(cn.severAddress, cn.timeout, cn.sslConfig, cn.useHTTP2, cn.httpHeader)
return cn.open(cn.severAddress, cn.timeout, cn.sslConfig, cn.useHTTP2, cn.httpHeader, cn.version)
}

// Authenticate
Expand Down
16 changes: 8 additions & 8 deletions connection_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func NewSslConnectionPool(addresses []HostAddress, conf PoolConfig, sslConfig *t
// initPool initializes the connection pool
func (pool *ConnectionPool) initPool() error {
if err := checkAddresses(pool.conf.TimeOut, pool.addresses, pool.sslConfig,
pool.conf.UseHTTP2, pool.conf.HttpHeader); err != nil {
pool.conf.UseHTTP2, pool.conf.HttpHeader, pool.conf.Version); err != nil {
return fmt.Errorf("failed to open connection, error: %s ", err.Error())
}

Expand All @@ -76,7 +76,7 @@ func (pool *ConnectionPool) initPool() error {

// Open connection to host
if err := newConn.open(newConn.severAddress, pool.conf.TimeOut, pool.sslConfig,
pool.conf.UseHTTP2, pool.conf.HttpHeader); err != nil {
pool.conf.UseHTTP2, pool.conf.HttpHeader, pool.conf.Version); err != nil {
// If initialization failed, clean idle queue
idleLen := pool.idleConnectionQueue.Len()
for i := 0; i < idleLen; i++ {
Expand Down Expand Up @@ -194,7 +194,7 @@ func (pool *ConnectionPool) releaseAndBack(conn *connection, pushBack bool) {

// Ping checks availability of host
func (pool *ConnectionPool) Ping(host HostAddress, timeout time.Duration) error {
return pingAddress(host, timeout, pool.sslConfig, pool.conf.UseHTTP2, pool.conf.HttpHeader)
return pingAddress(host, timeout, pool.sslConfig, pool.conf.UseHTTP2, pool.conf.HttpHeader, pool.conf.Version)
}

// Close closes all connection
Expand Down Expand Up @@ -246,7 +246,7 @@ func (pool *ConnectionPool) newConnToHost() (*connection, error) {
newConn := newConnection(host)
// Open connection to host
if err := newConn.open(newConn.severAddress, pool.conf.TimeOut, pool.sslConfig,
pool.conf.UseHTTP2, pool.conf.HttpHeader); err != nil {
pool.conf.UseHTTP2, pool.conf.HttpHeader, pool.conf.Version); err != nil {
return nil, err
}
// Add connection to active queue
Expand Down Expand Up @@ -354,24 +354,24 @@ func (pool *ConnectionPool) timeoutConnectionList() (closing []*connection) {
// It opens a temporary connection to each address and closes it immediately.
// If no error is returned, the addresses are available.
func checkAddresses(confTimeout time.Duration, addresses []HostAddress, sslConfig *tls.Config,
useHTTP2 bool, httpHeader http.Header) error {
useHTTP2 bool, httpHeader http.Header, version string) error {
var timeout = 3 * time.Second
if confTimeout != 0 && confTimeout < timeout {
timeout = confTimeout
}
for _, address := range addresses {
if err := pingAddress(address, timeout, sslConfig, useHTTP2, httpHeader); err != nil {
if err := pingAddress(address, timeout, sslConfig, useHTTP2, httpHeader, version); err != nil {
return err
}
}
return nil
}

func pingAddress(address HostAddress, timeout time.Duration, sslConfig *tls.Config,
useHTTP2 bool, httpHeader http.Header) error {
useHTTP2 bool, httpHeader http.Header, version string) error {
newConn := newConnection(address)
// Open connection to host
if err := newConn.open(newConn.severAddress, timeout, sslConfig, useHTTP2, httpHeader); err != nil {
if err := newConn.open(newConn.severAddress, timeout, sslConfig, useHTTP2, httpHeader, version); err != nil {
return err
}
defer newConn.close()
Expand Down
1 change: 1 addition & 0 deletions examples/basic_example/graph_client_basic_example.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func main() {
// Create configs for connection pool using default values
testPoolConfig := nebula.GetDefaultConf()
testPoolConfig.UseHTTP2 = useHTTP2
testPoolConfig.Version = "3.0.0"

// Initialize connection pool
pool, err := nebula.NewConnectionPool(hostList, testPoolConfig, log)
Expand Down
6 changes: 6 additions & 0 deletions nebula-docker-compose/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ services:
- --system_memory_high_watermark_ratio=0.99
- --heartbeat_interval_secs=1
- --max_sessions_per_ip_per_user=1200
- --enable_client_white_list=true
- --client_white_list=3.0.0:test
# ssl
- --ca_path=${ca_path}
- --cert_path=${cert_path}
Expand Down Expand Up @@ -328,6 +330,8 @@ services:
- --system_memory_high_watermark_ratio=0.99
- --heartbeat_interval_secs=1
- --max_sessions_per_ip_per_user=1200
- --enable_client_white_list=true
- --client_white_list=3.0.0:test
# ssl
- --ca_path=${ca_path}
- --cert_path=${cert_path}
Expand Down Expand Up @@ -374,6 +378,8 @@ services:
- --system_memory_high_watermark_ratio=0.99
- --heartbeat_interval_secs=1
- --max_sessions_per_ip_per_user=1200
- --enable_client_white_list=true
- --client_white_list=3.0.0:test
# ssl
- --ca_path=${ca_path}
- --cert_path=${cert_path}
Expand Down
4 changes: 2 additions & 2 deletions session_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func NewSessionPool(conf SessionPoolConf, log Logger) (*SessionPool, error) {
func (pool *SessionPool) init() error {
// check the hosts status
if err := checkAddresses(pool.conf.timeOut, pool.conf.serviceAddrs, pool.conf.sslConfig,
pool.conf.useHTTP2, pool.conf.httpHeader); err != nil {
pool.conf.useHTTP2, pool.conf.httpHeader, pool.conf.version); err != nil {
return fmt.Errorf("failed to initialize the session pool, %s", err.Error())
}

Expand Down Expand Up @@ -287,7 +287,7 @@ func (pool *SessionPool) newSession() (*pureSession, error) {

// open a new connection
if err := cn.open(cn.severAddress, pool.conf.timeOut, pool.conf.sslConfig,
pool.conf.useHTTP2, pool.conf.httpHeader); err != nil {
pool.conf.useHTTP2, pool.conf.httpHeader, pool.conf.version); err != nil {
return nil, fmt.Errorf("failed to create a net.Conn-backed Transport,: %s", err.Error())
}

Expand Down
22 changes: 22 additions & 0 deletions session_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,28 @@ func TestSessionPoolServerCheck(t *testing.T) {
}
}

func TestSessionPoolInvalidVersion(t *testing.T) {
prepareSpace("client_test")
defer dropSpace("client_test")
hostAddress := HostAddress{Host: address, Port: port}

// wrong version info
versionConfig, err := NewSessionPoolConf(
"root",
"nebula",
[]HostAddress{hostAddress},
"client_test",
)
versionConfig.version = "INVALID_VERSION"
versionConfig.minSize = 1

// create session pool
_, err = NewSessionPool(*versionConfig, DefaultLogger{})
if err != nil {
assert.Contains(t, err.Error(), "incompatible version between client and server")
}
}

func TestSessionPoolBasic(t *testing.T) {
prepareSpace("client_test")
defer dropSpace("client_test")
Expand Down
Loading