Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add context for connect and close functions #77

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package connection

import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
Expand Down Expand Up @@ -107,6 +108,10 @@ func New(addr string, spec *iso8583.MessageSpec, mlReader MessageLengthReader, m
}
}

if err := opts.Validate(); err != nil {
return nil, fmt.Errorf("validating options: %w", err)
}
adamdecaf marked this conversation as resolved.
Show resolved Hide resolved

return &Connection{
addr: addr,
Opts: opts,
Expand Down Expand Up @@ -188,6 +193,49 @@ func (c *Connection) Connect() error {
return nil
}

// ConnectCtx establishes the connection to the server using configured Addr
func (c *Connection) ConnectCtx(ctx context.Context) error {
var conn net.Conn
var err error

if c.conn != nil {
return nil
}
bpross marked this conversation as resolved.
Show resolved Hide resolved

d := &net.Dialer{Timeout: c.Opts.ConnectTimeout}

if c.Opts.TLSConfig != nil {
conn, err = tls.DialWithDialer(d, "tcp", c.addr, c.Opts.TLSConfig)
} else {
conn, err = d.Dial("tcp", c.addr)
}

if err != nil {
return fmt.Errorf("connecting to server %s: %w", c.addr, err)
}

c.conn = conn

c.run()

if c.Opts.OnConnectCtx != nil {
if err := c.Opts.OnConnectCtx(ctx, c); err != nil {
// close connection if OnConnect failed
// but ignore the potential error from Close()
// as it's a rare case
_ = c.Close()

return fmt.Errorf("on connect callback %s: %w", c.addr, err)
}
}

if c.Opts.ConnectionEstablishedHandler != nil {
go c.Opts.ConnectionEstablishedHandler(c)
}

return nil
}

// Write writes data directly to the connection. It is crucial to note that the
// Write operation is atomic in nature, meaning it completes in a single
// uninterrupted step.
Expand Down Expand Up @@ -261,7 +309,6 @@ func (c *Connection) handleConnectionError(err error) {
case <-done:
return
}

}
}()

Expand Down Expand Up @@ -317,6 +364,27 @@ func (c *Connection) Close() error {
return c.close()
}

// CloseCtx waits for pending requests to complete and then closes network
// connection with ISO 8583 server
func (c *Connection) CloseCtx(ctx context.Context) error {
if c.Opts.OnCloseCtx != nil {
if err := c.Opts.OnCloseCtx(ctx, c); err != nil {
return fmt.Errorf("on close callback: %w", err)
}
}

c.mutex.Lock()
// if we are closing already, just return
if c.closing {
c.mutex.Unlock()
return nil
}
c.closing = true
c.mutex.Unlock()

return c.close()
}

func (c *Connection) Done() <-chan struct{} {
return c.done
}
Expand Down
94 changes: 89 additions & 5 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package connection_test

import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
Expand All @@ -27,8 +28,10 @@ type baseFields struct {
STAN *field.String `index:"11"`
}

var stan int
var stanMu sync.Mutex
var (
stan int
stanMu sync.Mutex
)

func getSTAN() string {
stanMu.Lock()
Expand Down Expand Up @@ -145,6 +148,31 @@ func TestClient_Connect(t *testing.T) {
}, 100*time.Millisecond, 20*time.Millisecond, "onConnect should be called")
})

t.Run("OnConnectCtx is called", func(t *testing.T) {
server, err := NewTestServer()
require.NoError(t, err)
defer server.Close()

var onConnectCalled int32
onConnectCtx := func(ctx context.Context, c *connection.Connection) error {
// increase the counter
atomic.AddInt32(&onConnectCalled, 1)
return nil
}

c, err := connection.New(server.Addr, testSpec, readMessageLength, writeMessageLength, connection.OnConnectCtx(onConnectCtx))
require.NoError(t, err)

err = c.ConnectCtx(context.Background())
require.NoError(t, err)
defer c.Close()

// eventually the onConnectCounter should be 1
require.Eventually(t, func() bool {
return atomic.LoadInt32(&onConnectCalled) == 1
}, 100*time.Millisecond, 20*time.Millisecond, "onConnect should be called")
})

t.Run("OnClose is called", func(t *testing.T) {
server, err := NewTestServer()
require.NoError(t, err)
Expand All @@ -171,6 +199,33 @@ func TestClient_Connect(t *testing.T) {
return atomic.LoadInt32(&onClosedCalled) == 1
}, 100*time.Millisecond, 20*time.Millisecond, "onClose should be called")
})

t.Run("OnCloseCtx is called", func(t *testing.T) {
server, err := NewTestServer()
require.NoError(t, err)
defer server.Close()

var onClosedCalled int32
onCloseCtx := func(ctx context.Context, c *connection.Connection) error {
// increase the counter
atomic.AddInt32(&onClosedCalled, 1)
return nil
}

c, err := connection.New(server.Addr, testSpec, readMessageLength, writeMessageLength, connection.OnCloseCtx(onCloseCtx))
require.NoError(t, err)

// err = c.Connect()
// require.NoError(t, err)

err = c.CloseCtx(context.Background())
require.NoError(t, err)

// eventually the onClosedCalled should be 1
require.Eventually(t, func() bool {
return atomic.LoadInt32(&onClosedCalled) == 1
}, 100*time.Millisecond, 20*time.Millisecond, "onClose should be called")
})
}

func TestClient_Write(t *testing.T) {
Expand Down Expand Up @@ -682,7 +737,6 @@ func TestClient_Send(t *testing.T) {

// and that response for the first message was received second
require.Equal(t, receivedSTANs[1], stan1)

})

t.Run("automatically sends ping messages after ping interval", func(t *testing.T) {
Expand Down Expand Up @@ -985,7 +1039,6 @@ func TestClient_Send(t *testing.T) {
return server.ReceivedPings() > 0
}, 200*time.Millisecond, 50*time.Millisecond, "no ping messages were sent after read timeout")
})

}

func TestClient_Options(t *testing.T) {
Expand Down Expand Up @@ -1019,7 +1072,6 @@ func TestClient_Options(t *testing.T) {
require.Eventually(t, func() bool {
return atomic.LoadInt32(&callsCounter) > 0
}, 500*time.Millisecond, 50*time.Millisecond, "error handler was never called")

})

t.Run("ClosedHandler is called when connection is closed", func(t *testing.T) {
Expand Down Expand Up @@ -1108,6 +1160,36 @@ func TestClient_Options(t *testing.T) {

require.Equal(t, 1, callsCounter)
})

t.Run("Cannot configure OnConnect and OnConnectCtx", func(t *testing.T) {
c, err := connection.New(
"localhost:0",
testSpec,
readMessageLength,
writeMessageLength,
connection.SendTimeout(500*time.Millisecond),
connection.OnConnect(func(c *connection.Connection) error { return nil }),
connection.OnConnectCtx(func(ctx context.Context, c *connection.Connection) error { return nil }),
)

require.Error(t, err)
require.Nil(t, c)
})

t.Run("Cannot configure OnClose and OnCloseCtx", func(t *testing.T) {
c, err := connection.New(
"localhost:0",
testSpec,
readMessageLength,
writeMessageLength,
connection.SendTimeout(500*time.Millisecond),
connection.OnClose(func(c *connection.Connection) error { return nil }),
connection.OnCloseCtx(func(ctx context.Context, c *connection.Connection) error { return nil }),
)

require.Error(t, err)
require.Nil(t, c)
})
}

func TestClientWithMessageReaderAndWriter(t *testing.T) {
Expand Down Expand Up @@ -1319,9 +1401,11 @@ func (m *TrackingRWCloser) Write(p []byte) (n int, err error) {

return 0, nil
}

func (m *TrackingRWCloser) Read(p []byte) (n int, err error) {
return 0, nil
}

func (m *TrackingRWCloser) Close() error {
return nil
}
Expand Down
36 changes: 36 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package connection

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand Down Expand Up @@ -60,12 +61,21 @@ type Options struct {
// returned to the caller
ErrorHandler func(err error)

// Only define a single OnConnect and OnClose functions
// the config will error out if both are set

// OnConnect is called synchronously when a connection is established
OnConnect func(c *Connection) error

// OnConnectCtx is called synchronously when a connection is established
OnConnectCtx func(ctx context.Context, c *Connection) error

// OnClose is called synchronously before a connection is closed
OnClose func(c *Connection) error

// OnCloseCtx is called synchronously before a connection is closed
OnCloseCtx func(ctx context.Context, c *Connection) error

// RequestIDGenerator is used to generate a unique identifier for a request
// so that responses from the server can be matched to the original request.
RequestIDGenerator RequestIDGenerator
Expand Down Expand Up @@ -101,6 +111,16 @@ func GetDefaultOptions() Options {
}
}

func (o *Options) Validate() error {
if o.OnConnect != nil && o.OnConnectCtx != nil {
return fmt.Errorf("OnConnect and OnConnectCtx are mutually exclusive")
}
if o.OnClose != nil && o.OnCloseCtx != nil {
return fmt.Errorf("OnClose and OnCloseCtx are mutually exclusive")
}
return nil
}

// IdleTime sets an IdleTime option
func IdleTime(d time.Duration) Option {
return func(o *Options) error {
Expand Down Expand Up @@ -191,13 +211,29 @@ func OnConnect(h func(c *Connection) error) Option {
}
}

// OnConnectCtx sets a callback that will be synchronously called when connection is established.
// If it returns error, then connections will be closed and re-connect will be attempted
func OnConnectCtx(h func(ctx context.Context, c *Connection) error) Option {
return func(opts *Options) error {
opts.OnConnectCtx = h
return nil
}
}

func OnClose(h func(c *Connection) error) Option {
return func(opts *Options) error {
opts.OnClose = h
return nil
}
}

func OnCloseCtx(h func(ctx context.Context, c *Connection) error) Option {
return func(opts *Options) error {
opts.OnCloseCtx = h
return nil
}
}

func defaultTLSConfig() *tls.Config {
return &tls.Config{
MinVersion: tls.VersionTLS12,
Expand Down
3 changes: 1 addition & 2 deletions pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (p *Pool) handleClosedConnection(closedConn *Connection) {
return
}

var connIndex = -1
connIndex := -1
for i, conn := range p.connections {
if conn == closedConn {
connIndex = i
Expand Down Expand Up @@ -264,7 +264,6 @@ func (p *Pool) Close() error {
p.handleError(fmt.Errorf("closing connection on pool close: %w", err))
}
}(conn)

}
wg.Wait()

Expand Down