Skip to content

Commit

Permalink
Allow customizing context canceled behavior for pgconn
Browse files Browse the repository at this point in the history
This feature made the ctxwatch package public.
  • Loading branch information
jackc committed May 8, 2024
1 parent 60a01d0 commit 42c9e90
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 52 deletions.
11 changes: 10 additions & 1 deletion pgconn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

"github.com/jackc/pgpassfile"
"github.com/jackc/pgservicefile"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
)

Expand All @@ -39,7 +40,12 @@ type Config struct {
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)

// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
// when a context passed to a PgConn method is canceled.
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler

RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)

KerberosSrvName string
KerberosSpn string
Expand Down Expand Up @@ -266,6 +272,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w)
},
BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler {
return &DeadlineContextWatcherHandler{Conn: pgConn.conn}
},
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
// we want to automatically close any fatal errors
if strings.EqualFold(pgErr.Severity, "FATAL") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ import (
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
// time.
type ContextWatcher struct {
onCancel func()
onUnwatchAfterCancel func()
unwatchChan chan struct{}
handler Handler
unwatchChan chan struct{}

lock sync.Mutex
watchInProgress bool
Expand All @@ -20,11 +19,10 @@ type ContextWatcher struct {
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
// onCancel called.
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
func NewContextWatcher(handler Handler) *ContextWatcher {
cw := &ContextWatcher{
onCancel: onCancel,
onUnwatchAfterCancel: onUnwatchAfterCancel,
unwatchChan: make(chan struct{}),
handler: handler,
unwatchChan: make(chan struct{}),
}

return cw
Expand All @@ -46,7 +44,7 @@ func (cw *ContextWatcher) Watch(ctx context.Context) {
go func() {
select {
case <-ctx.Done():
cw.onCancel()
cw.handler.HandleCancel(ctx)
cw.onCancelWasCalled = true
<-cw.unwatchChan
case <-cw.unwatchChan:
Expand All @@ -66,8 +64,17 @@ func (cw *ContextWatcher) Unwatch() {
if cw.watchInProgress {
cw.unwatchChan <- struct{}{}
if cw.onCancelWasCalled {
cw.onUnwatchAfterCancel()
cw.handler.HandleUnwatchAfterCancel()
}
cw.watchInProgress = false
}
}

type Handler interface {
// HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the
// context that was canceled.
HandleCancel(canceledCtx context.Context)

// HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched.
HandleUnwatchAfterCancel()
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,32 @@ import (
"testing"
"time"

"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/stretchr/testify/require"
)

type testHandler struct {
handleCancel func(context.Context)
handleUnwatchAfterCancel func()
}

func (h *testHandler) HandleCancel(ctx context.Context) {
h.handleCancel(ctx)
}

func (h *testHandler) HandleUnwatchAfterCancel() {
h.handleUnwatchAfterCancel()
}

func TestContextWatcherContextCancelled(t *testing.T) {
canceledChan := make(chan struct{})
cleanupCalled := false
cw := ctxwatch.NewContextWatcher(func() {
canceledChan <- struct{}{}
}, func() {
cleanupCalled = true
cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
canceledChan <- struct{}{}
}, handleUnwatchAfterCancel: func() {
cleanupCalled = true
},
})

ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -35,10 +50,12 @@ func TestContextWatcherContextCancelled(t *testing.T) {
}

func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {
t.Error("cancel func should not have been called")
}, func() {
t.Error("cleanup func should not have been called")
cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
t.Error("cancel func should not have been called")
}, handleUnwatchAfterCancel: func() {
t.Error("cleanup func should not have been called")
},
})

ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -48,7 +65,7 @@ func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
}

func TestContextWatcherMultipleWatchPanics(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -61,7 +78,7 @@ func TestContextWatcherMultipleWatchPanics(t *testing.T) {
}

func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
cw.Unwatch() // unwatch when not / never watching

ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -72,7 +89,7 @@ func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
}

func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
Expand All @@ -88,10 +105,12 @@ func TestContextWatcherStress(t *testing.T) {
var cancelFuncCalls int64
var cleanupFuncCalls int64

cw := ctxwatch.NewContextWatcher(func() {
atomic.AddInt64(&cancelFuncCalls, 1)
}, func() {
atomic.AddInt64(&cleanupFuncCalls, 1)
cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
atomic.AddInt64(&cancelFuncCalls, 1)
}, handleUnwatchAfterCancel: func() {
atomic.AddInt64(&cleanupFuncCalls, 1)
},
})

cycleCount := 100000
Expand Down Expand Up @@ -134,7 +153,7 @@ func TestContextWatcherStress(t *testing.T) {
}

func BenchmarkContextWatcherUncancellable(b *testing.B) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})

for i := 0; i < b.N; i++ {
cw.Watch(context.Background())
Expand All @@ -143,7 +162,7 @@ func BenchmarkContextWatcherUncancellable(b *testing.B) {
}

func BenchmarkContextWatcherCancelled(b *testing.B) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})

for i := 0; i < b.N; i++ {
ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -154,7 +173,7 @@ func BenchmarkContextWatcherCancelled(b *testing.B) {
}

func BenchmarkContextWatcherCancellable(b *testing.B) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down
50 changes: 27 additions & 23 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ import (

"github.com/jackc/pgx/v5/internal/iobufpool"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
)

Expand Down Expand Up @@ -281,28 +281,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig

var err error
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
netConn, err := config.DialFunc(ctx, network, address)
pgConn.conn, err = config.DialFunc(ctx, network, address)
if err != nil {
return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
}

pgConn.conn = netConn
pgConn.contextWatcher = newContextWatcher(netConn)
pgConn.contextWatcher.Watch(ctx)

if fallbackConfig.TLSConfig != nil {
nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig)
pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn})
pgConn.contextWatcher.Watch(ctx)
tlsConn, err := startTLS(pgConn.conn, fallbackConfig.TLSConfig)
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil {
netConn.Close()
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
}

pgConn.conn = nbTLSConn
pgConn.contextWatcher = newContextWatcher(nbTLSConn)
pgConn.contextWatcher.Watch(ctx)
pgConn.conn = tlsConn
}

pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn))
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()

pgConn.parameterStatuses = make(map[string]string)
Expand Down Expand Up @@ -412,13 +410,6 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
}
}

func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
return ctxwatch.NewContextWatcher(
func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
func() { conn.SetDeadline(time.Time{}) },
)
}

func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
if err != nil {
Expand Down Expand Up @@ -988,10 +979,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
defer cancelConn.Close()

if ctx != context.Background() {
contextWatcher := ctxwatch.NewContextWatcher(
func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
func() { cancelConn.SetDeadline(time.Time{}) },
)
contextWatcher := ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: cancelConn})
contextWatcher.Watch(ctx)
defer contextWatcher.Unwatch()
}
Expand Down Expand Up @@ -1939,7 +1927,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
cleanupDone: make(chan struct{}),
}

pgConn.contextWatcher = newContextWatcher(pgConn.conn)
pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn))
pgConn.bgReader = bgreader.New(pgConn.conn)
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
func() {
Expand Down Expand Up @@ -2246,3 +2234,19 @@ func (p *Pipeline) Close() error {

return p.err
}

// DeadlineContextWatcherHandler handles canceled contexts by setting a deadline on a net.Conn.
type DeadlineContextWatcherHandler struct {
Conn net.Conn

// DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled.
DeadlineDelay time.Duration
}

func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) {
h.Conn.SetDeadline(time.Now().Add(h.DeadlineDelay))
}

func (h *DeadlineContextWatcherHandler) HandleUnwatchAfterCancel() {
h.Conn.SetDeadline(time.Time{})
}
47 changes: 47 additions & 0 deletions pgconn/pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/internal/pgmock"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/jackc/pgx/v5/pgtype"
)
Expand Down Expand Up @@ -3480,3 +3481,49 @@ func mustEncode(buf []byte, err error) []byte {
}
return buf
}

func TestDeadlineContextWatcherHandler(t *testing.T) {
t.Parallel()

t.Run("DeadlineExceeded with zero DeadlineDelay", func(t *testing.T) {
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn()}
}
config.ConnectTimeout = 5 * time.Second

pgConn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, pgConn)

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

_, err = pgConn.Exec(ctx, "select 1, pg_sleep(1)").ReadAll()
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.True(t, pgConn.IsClosed())
})

t.Run("DeadlineExceeded with DeadlineDelay", func(t *testing.T) {
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn(), DeadlineDelay: 500 * time.Millisecond}
}
config.ConnectTimeout = 5 * time.Second

pgConn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, pgConn)

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

_, err = pgConn.Exec(ctx, "select 1, pg_sleep(0.250)").ReadAll()
require.NoError(t, err)

ensureConnValid(t, pgConn)
})
}

0 comments on commit 42c9e90

Please sign in to comment.