Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add context for connect and close functions #77

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ This section explains the various stages at which different handler functions ar

#### On connection establishment:

- **`OnConnect`**: This handler is invoked immediately after the TCP connection is made. It can be utilized for operations that should be performed before the connection is officially considered established (e.g., sending `SignOn` message and receiving its response).
- **`OnConnect`** or **`OnConnectCtx`**: This handler is invoked immediately after the TCP connection is made. It can be utilized for operations that should be performed before the connection is officially considered established (e.g., sending `SignOn` message and receiving its response). **NOTE** If both `OnConnect` and `OnConnectCtx` are defined, `OnConnectCtx` will be used.

- **`ConnectionEstablishedHandler (async)`**: This asynchronous handler is triggered when the connection is logically considered established.

Expand All @@ -100,7 +100,7 @@ This section explains the various stages at which different handler functions ar

- **`ConnectionClosedHandlers (async)`**: These asynchronous handlers are invoked when a connection is closed, either by the server or due to a connection error.

- **`OnClose`**: This handler is activated when we manually close the connection.
- **`OnClose`** or **`OnCloseCtx`**: This handler is activated when we manually close the connection. **NOTE** If both `OnClose` and `OnCloseCtx` are defined, `OnCloseCtx` will be used.


### (m)TLS connection
Expand Down Expand Up @@ -269,6 +269,30 @@ Following options are supported:
* `MinConnections` is the number of connections required to be established when we connect the pool
* `ConnectionsFilter` is a function to filter connections in the pool for `Get`, `IsDegraded` or `IsUp` methods

## Context
You can provide context to the Connect and Close functions in addition to defining `OnConnectCtx` and `OnCloseCtx` in the connection options. This will allow you to pass along telemetry or any other information on contexts through from the Connect/Close calls to your handler functions:

```go
c, err := connection.New("127.0.0.1:9999", brandSpec, readMessageLength, writeMessageLength,
connection.SendTimeout(100*time.Millisecond),
connection.IdleTime(50*time.Millisecond),
connect.OnConnectCtx(func(ctx context.Context, c *connection.Connection){
return signOnFunc(ctx, c)
}),
connect.OnCloseCtx(func(ctx context.Context, c *connection.Connection){
return signOffFunc(ctx, c)
}),
)

ctx := context.Background()
c.ConnectCtx(ctx)

...

c.CloseCtx(ctx)

```

## Benchmark

To benchmark the connection, run:
Expand Down
37 changes: 31 additions & 6 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package connection

import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
Expand Down Expand Up @@ -147,6 +148,11 @@ func (c *Connection) SetOptions(options ...Option) error {

// Connect establishes the connection to the server using configured Addr
func (c *Connection) Connect() error {
return c.ConnectCtx(context.Background())
}

// ConnectCtx establishes the connection to the server using configured Addr
func (c *Connection) ConnectCtx(ctx context.Context) error {
var conn net.Conn
var err error

Expand All @@ -170,12 +176,19 @@ func (c *Connection) Connect() error {

c.run()

if c.Opts.OnConnect != nil {
if err := c.Opts.OnConnect(c); err != nil {
onConnect := c.Opts.OnConnectCtx
if onConnect == nil && c.Opts.OnConnect != nil {
onConnect = func(_ context.Context, c *Connection) error {
return c.Opts.OnConnect(c)
}
}

if onConnect != nil {
if err := onConnect(ctx, c); err != nil {
// close connection if OnConnect failed
// but ignore the potential error from Close()
// as it's a rare case
_ = c.Close()
_ = c.CloseCtx(ctx)

return fmt.Errorf("on connect callback %s: %w", c.addr, err)
}
Expand Down Expand Up @@ -261,7 +274,6 @@ func (c *Connection) handleConnectionError(err error) {
case <-done:
return
}

}
}()

Expand Down Expand Up @@ -299,8 +311,21 @@ func (c *Connection) close() error {
// Close waits for pending requests to complete and then closes network
// connection with ISO 8583 server
func (c *Connection) Close() error {
if c.Opts.OnClose != nil {
if err := c.Opts.OnClose(c); err != nil {
return c.CloseCtx(context.Background())
}

// CloseCtx waits for pending requests to complete and then closes network
// connection with ISO 8583 server
func (c *Connection) CloseCtx(ctx context.Context) error {
onClose := c.Opts.OnCloseCtx
if onClose == nil && c.Opts.OnClose != nil {
onClose = func(_ context.Context, c *Connection) error {
return c.Opts.OnClose(c)
}
}

if onClose != nil {
if err := onClose(ctx, c); err != nil {
return fmt.Errorf("on close callback: %w", err)
}
}
Expand Down
64 changes: 56 additions & 8 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package connection_test

import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
Expand All @@ -27,8 +28,10 @@ type baseFields struct {
STAN *field.String `index:"11"`
}

var stan int
var stanMu sync.Mutex
var (
stan int
stanMu sync.Mutex
)

func getSTAN() string {
stanMu.Lock()
Expand Down Expand Up @@ -145,6 +148,31 @@ func TestClient_Connect(t *testing.T) {
}, 100*time.Millisecond, 20*time.Millisecond, "onConnect should be called")
})

t.Run("OnConnectCtx is called", func(t *testing.T) {
server, err := NewTestServer()
require.NoError(t, err)
defer server.Close()

var onConnectCalled int32
onConnectCtx := func(ctx context.Context, c *connection.Connection) error {
// increase the counter
atomic.AddInt32(&onConnectCalled, 1)
return nil
}

c, err := connection.New(server.Addr, testSpec, readMessageLength, writeMessageLength, connection.OnConnectCtx(onConnectCtx))
require.NoError(t, err)

err = c.ConnectCtx(context.Background())
require.NoError(t, err)
defer c.Close()

// eventually the onConnectCounter should be 1
require.Eventually(t, func() bool {
return atomic.LoadInt32(&onConnectCalled) == 1
}, 100*time.Millisecond, 20*time.Millisecond, "onConnect should be called")
})

t.Run("OnClose is called", func(t *testing.T) {
server, err := NewTestServer()
require.NoError(t, err)
Expand All @@ -160,9 +188,6 @@ func TestClient_Connect(t *testing.T) {
c, err := connection.New(server.Addr, testSpec, readMessageLength, writeMessageLength, connection.OnClose(onClose))
require.NoError(t, err)

// err = c.Connect()
// require.NoError(t, err)

err = c.Close()
require.NoError(t, err)

Expand All @@ -171,6 +196,30 @@ func TestClient_Connect(t *testing.T) {
return atomic.LoadInt32(&onClosedCalled) == 1
}, 100*time.Millisecond, 20*time.Millisecond, "onClose should be called")
})

t.Run("OnCloseCtx is called", func(t *testing.T) {
server, err := NewTestServer()
require.NoError(t, err)
defer server.Close()

var onClosedCalled int32
onCloseCtx := func(ctx context.Context, c *connection.Connection) error {
// increase the counter
atomic.AddInt32(&onClosedCalled, 1)
return nil
}

c, err := connection.New(server.Addr, testSpec, readMessageLength, writeMessageLength, connection.OnCloseCtx(onCloseCtx))
require.NoError(t, err)

err = c.CloseCtx(context.Background())
require.NoError(t, err)

// eventually the onClosedCalled should be 1
require.Eventually(t, func() bool {
return atomic.LoadInt32(&onClosedCalled) == 1
}, 100*time.Millisecond, 20*time.Millisecond, "onClose should be called")
})
}

func TestClient_Write(t *testing.T) {
Expand Down Expand Up @@ -682,7 +731,6 @@ func TestClient_Send(t *testing.T) {

// and that response for the first message was received second
require.Equal(t, receivedSTANs[1], stan1)

})

t.Run("automatically sends ping messages after ping interval", func(t *testing.T) {
Expand Down Expand Up @@ -985,7 +1033,6 @@ func TestClient_Send(t *testing.T) {
return server.ReceivedPings() > 0
}, 200*time.Millisecond, 50*time.Millisecond, "no ping messages were sent after read timeout")
})

}

func TestClient_Options(t *testing.T) {
Expand Down Expand Up @@ -1019,7 +1066,6 @@ func TestClient_Options(t *testing.T) {
require.Eventually(t, func() bool {
return atomic.LoadInt32(&callsCounter) > 0
}, 500*time.Millisecond, 50*time.Millisecond, "error handler was never called")

})

t.Run("ClosedHandler is called when connection is closed", func(t *testing.T) {
Expand Down Expand Up @@ -1319,9 +1365,11 @@ func (m *TrackingRWCloser) Write(p []byte) (n int, err error) {

return 0, nil
}

func (m *TrackingRWCloser) Read(p []byte) (n int, err error) {
return 0, nil
}

func (m *TrackingRWCloser) Close() error {
return nil
}
Expand Down
25 changes: 25 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package connection

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand Down Expand Up @@ -60,12 +61,20 @@ type Options struct {
// returned to the caller
ErrorHandler func(err error)

// If both OnConnect and OnConnectCtx are set, OnConnectCtx will be used
// OnConnect is called synchronously when a connection is established
OnConnect func(c *Connection) error

// OnConnectCtx is called synchronously when a connection is established
OnConnectCtx func(ctx context.Context, c *Connection) error

// If both OnClose and OnCloseCtx are set, OnCloseCtx will be used
// OnClose is called synchronously before a connection is closed
OnClose func(c *Connection) error

// OnCloseCtx is called synchronously before a connection is closed
OnCloseCtx func(ctx context.Context, c *Connection) error

// RequestIDGenerator is used to generate a unique identifier for a request
// so that responses from the server can be matched to the original request.
RequestIDGenerator RequestIDGenerator
Expand Down Expand Up @@ -191,13 +200,29 @@ func OnConnect(h func(c *Connection) error) Option {
}
}

// OnConnectCtx sets a callback that will be synchronously called when connection is established.
// If it returns error, then connections will be closed and re-connect will be attempted
func OnConnectCtx(h func(ctx context.Context, c *Connection) error) Option {
return func(opts *Options) error {
opts.OnConnectCtx = h
return nil
}
}

func OnClose(h func(c *Connection) error) Option {
return func(opts *Options) error {
opts.OnClose = h
return nil
}
}

func OnCloseCtx(h func(ctx context.Context, c *Connection) error) Option {
return func(opts *Options) error {
opts.OnCloseCtx = h
return nil
}
}

func defaultTLSConfig() *tls.Config {
return &tls.Config{
MinVersion: tls.VersionTLS12,
Expand Down
20 changes: 15 additions & 5 deletions pool.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package connection

import (
"context"
"errors"
"fmt"
"sync"
Expand Down Expand Up @@ -62,6 +63,11 @@ func (p *Pool) handleError(err error) {

// Connect creates poll of connections by calling Factory method and connect them all
func (p *Pool) Connect() error {
return p.ConnectCtx(context.Background())
}

// Connect creates poll of connections by calling Factory method and connect them all
func (p *Pool) ConnectCtx(ctx context.Context) error {
// We need to close pool (with all potentially running goroutines) if
// connection creation fails. Example of such situation is when we
// successfully created 2 connections, but 3rd failed and minimum
Expand All @@ -72,7 +78,7 @@ func (p *Pool) Connect() error {
var connectErr error
defer func() {
if connectErr != nil {
p.Close()
p.CloseCtx(ctx)
}
}()

Expand All @@ -95,7 +101,7 @@ func (p *Pool) Connect() error {
// set own handler when connection is closed
conn.SetOptions(ConnectionClosedHandler(p.handleClosedConnection))

err = conn.Connect()
err = conn.ConnectCtx(ctx)
if err != nil {
errs = append(errs, fmt.Errorf("connecting to %s: %w", addr, err))
p.handleError(fmt.Errorf("failed to connect to %s: %w", conn.addr, err))
Expand Down Expand Up @@ -165,7 +171,7 @@ func (p *Pool) handleClosedConnection(closedConn *Connection) {
return
}

var connIndex = -1
connIndex := -1
for i, conn := range p.connections {
if conn == closedConn {
connIndex = i
Expand Down Expand Up @@ -234,6 +240,11 @@ func (p *Pool) recreateConnection(closedConn *Connection) {

// Close closes all connections in the pool
func (p *Pool) Close() error {
return p.CloseCtx(context.Background())
}

// CloseCtx closes all connections in the pool
func (p *Pool) CloseCtx(ctx context.Context) error {
p.mu.Lock()
if p.isClosed {
p.mu.Unlock()
Expand All @@ -259,12 +270,11 @@ func (p *Pool) Close() error {
for _, conn := range p.connections {
go func(conn *Connection) {
defer wg.Done()
err := conn.Close()
err := conn.CloseCtx(ctx)
if err != nil {
p.handleError(fmt.Errorf("closing connection on pool close: %w", err))
}
}(conn)

}
wg.Wait()

Expand Down
Loading