diff --git a/pgconn/config.go b/pgconn/config.go index 13def1e80..95b7c912f 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -19,6 +19,7 @@ import ( "github.com/jackc/pgpassfile" "github.com/jackc/pgservicefile" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" ) @@ -39,7 +40,12 @@ type Config struct { DialFunc DialFunc // e.g. net.Dialer.DialContext LookupFunc LookupFunc // e.g. net.Resolver.LookupHost BuildFrontend BuildFrontendFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + + // BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called + // when a context passed to a PgConn method is canceled. + BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler + + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) KerberosSrvName string KerberosSpn string @@ -266,6 +272,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { return pgproto3.NewFrontend(r, w) }, + BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler { + return &DeadlineContextWatcherHandler{Conn: pgConn.conn} + }, OnPgError: func(_ *PgConn, pgErr *PgError) bool { // we want to automatically close any fatal errors if strings.EqualFold(pgErr.Severity, "FATAL") { diff --git a/pgconn/internal/ctxwatch/context_watcher.go b/pgconn/ctxwatch/context_watcher.go similarity index 71% rename from pgconn/internal/ctxwatch/context_watcher.go rename to pgconn/ctxwatch/context_watcher.go index b39cb3ee5..db8884eb8 100644 --- a/pgconn/internal/ctxwatch/context_watcher.go +++ b/pgconn/ctxwatch/context_watcher.go @@ -8,9 +8,8 @@ import ( // ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a // time. type ContextWatcher struct { - onCancel func() - onUnwatchAfterCancel func() - unwatchChan chan struct{} + handler Handler + unwatchChan chan struct{} lock sync.Mutex watchInProgress bool @@ -20,11 +19,10 @@ type ContextWatcher struct { // NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. // OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and // onCancel called. -func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher { +func NewContextWatcher(handler Handler) *ContextWatcher { cw := &ContextWatcher{ - onCancel: onCancel, - onUnwatchAfterCancel: onUnwatchAfterCancel, - unwatchChan: make(chan struct{}), + handler: handler, + unwatchChan: make(chan struct{}), } return cw @@ -46,7 +44,7 @@ func (cw *ContextWatcher) Watch(ctx context.Context) { go func() { select { case <-ctx.Done(): - cw.onCancel() + cw.handler.HandleCancel(ctx) cw.onCancelWasCalled = true <-cw.unwatchChan case <-cw.unwatchChan: @@ -66,8 +64,17 @@ func (cw *ContextWatcher) Unwatch() { if cw.watchInProgress { cw.unwatchChan <- struct{}{} if cw.onCancelWasCalled { - cw.onUnwatchAfterCancel() + cw.handler.HandleUnwatchAfterCancel() } cw.watchInProgress = false } } + +type Handler interface { + // HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the + // context that was canceled. + HandleCancel(canceledCtx context.Context) + + // HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched. + HandleUnwatchAfterCancel() +} diff --git a/pgconn/internal/ctxwatch/context_watcher_test.go b/pgconn/ctxwatch/context_watcher_test.go similarity index 66% rename from pgconn/internal/ctxwatch/context_watcher_test.go rename to pgconn/ctxwatch/context_watcher_test.go index d62b97d3e..302aabe3b 100644 --- a/pgconn/internal/ctxwatch/context_watcher_test.go +++ b/pgconn/ctxwatch/context_watcher_test.go @@ -6,17 +6,32 @@ import ( "testing" "time" - "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" "github.com/stretchr/testify/require" ) +type testHandler struct { + handleCancel func(context.Context) + handleUnwatchAfterCancel func() +} + +func (h *testHandler) HandleCancel(ctx context.Context) { + h.handleCancel(ctx) +} + +func (h *testHandler) HandleUnwatchAfterCancel() { + h.handleUnwatchAfterCancel() +} + func TestContextWatcherContextCancelled(t *testing.T) { canceledChan := make(chan struct{}) cleanupCalled := false - cw := ctxwatch.NewContextWatcher(func() { - canceledChan <- struct{}{} - }, func() { - cleanupCalled = true + cw := ctxwatch.NewContextWatcher(&testHandler{ + handleCancel: func(context.Context) { + canceledChan <- struct{}{} + }, handleUnwatchAfterCancel: func() { + cleanupCalled = true + }, }) ctx, cancel := context.WithCancel(context.Background()) @@ -35,10 +50,12 @@ func TestContextWatcherContextCancelled(t *testing.T) { } func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) { - cw := ctxwatch.NewContextWatcher(func() { - t.Error("cancel func should not have been called") - }, func() { - t.Error("cleanup func should not have been called") + cw := ctxwatch.NewContextWatcher(&testHandler{ + handleCancel: func(context.Context) { + t.Error("cancel func should not have been called") + }, handleUnwatchAfterCancel: func() { + t.Error("cleanup func should not have been called") + }, }) ctx, cancel := context.WithCancel(context.Background()) @@ -48,7 +65,7 @@ func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) { } func TestContextWatcherMultipleWatchPanics(t *testing.T) { - cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -61,7 +78,7 @@ func TestContextWatcherMultipleWatchPanics(t *testing.T) { } func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) { - cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) cw.Unwatch() // unwatch when not / never watching ctx, cancel := context.WithCancel(context.Background()) @@ -72,7 +89,7 @@ func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) { } func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) { - cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() @@ -88,10 +105,12 @@ func TestContextWatcherStress(t *testing.T) { var cancelFuncCalls int64 var cleanupFuncCalls int64 - cw := ctxwatch.NewContextWatcher(func() { - atomic.AddInt64(&cancelFuncCalls, 1) - }, func() { - atomic.AddInt64(&cleanupFuncCalls, 1) + cw := ctxwatch.NewContextWatcher(&testHandler{ + handleCancel: func(context.Context) { + atomic.AddInt64(&cancelFuncCalls, 1) + }, handleUnwatchAfterCancel: func() { + atomic.AddInt64(&cleanupFuncCalls, 1) + }, }) cycleCount := 100000 @@ -134,7 +153,7 @@ func TestContextWatcherStress(t *testing.T) { } func BenchmarkContextWatcherUncancellable(b *testing.B) { - cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) for i := 0; i < b.N; i++ { cw.Watch(context.Background()) @@ -143,7 +162,7 @@ func BenchmarkContextWatcherUncancellable(b *testing.B) { } func BenchmarkContextWatcherCancelled(b *testing.B) { - cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) for i := 0; i < b.N; i++ { ctx, cancel := context.WithCancel(context.Background()) @@ -154,7 +173,7 @@ func BenchmarkContextWatcherCancelled(b *testing.B) { } func BenchmarkContextWatcherCancellable(b *testing.B) { - cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 170f18db6..d425b5ad8 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -18,8 +18,8 @@ import ( "github.com/jackc/pgx/v5/internal/iobufpool" "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" "github.com/jackc/pgx/v5/pgconn/internal/bgreader" - "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" ) @@ -281,28 +281,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) - netConn, err := config.DialFunc(ctx, network, address) + pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} } - pgConn.conn = netConn - pgConn.contextWatcher = newContextWatcher(netConn) - pgConn.contextWatcher.Watch(ctx) - if fallbackConfig.TLSConfig != nil { - nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig) + pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn}) + pgConn.contextWatcher.Watch(ctx) + tlsConn, err := startTLS(pgConn.conn, fallbackConfig.TLSConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { - netConn.Close() + pgConn.conn.Close() return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)} } - pgConn.conn = nbTLSConn - pgConn.contextWatcher = newContextWatcher(nbTLSConn) - pgConn.contextWatcher.Watch(ctx) + pgConn.conn = tlsConn } + pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn)) + pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() pgConn.parameterStatuses = make(map[string]string) @@ -412,13 +410,6 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } -func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { - return ctxwatch.NewContextWatcher( - func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { conn.SetDeadline(time.Time{}) }, - ) -} - func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { @@ -988,10 +979,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { defer cancelConn.Close() if ctx != context.Background() { - contextWatcher := ctxwatch.NewContextWatcher( - func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { cancelConn.SetDeadline(time.Time{}) }, - ) + contextWatcher := ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: cancelConn}) contextWatcher.Watch(ctx) defer contextWatcher.Unwatch() } @@ -1939,7 +1927,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) { cleanupDone: make(chan struct{}), } - pgConn.contextWatcher = newContextWatcher(pgConn.conn) + pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn)) pgConn.bgReader = bgreader.New(pgConn.conn) pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), func() { @@ -2246,3 +2234,19 @@ func (p *Pipeline) Close() error { return p.err } + +// DeadlineContextWatcherHandler handles canceled contexts by setting a deadline on a net.Conn. +type DeadlineContextWatcherHandler struct { + Conn net.Conn + + // DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled. + DeadlineDelay time.Duration +} + +func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) { + h.Conn.SetDeadline(time.Now().Add(h.DeadlineDelay)) +} + +func (h *DeadlineContextWatcherHandler) HandleUnwatchAfterCancel() { + h.Conn.SetDeadline(time.Time{}) +} diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 604aa9e61..704abe1cd 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -24,6 +24,7 @@ import ( "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgmock" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" ) @@ -3480,3 +3481,49 @@ func mustEncode(buf []byte, err error) []byte { } return buf } + +func TestDeadlineContextWatcherHandler(t *testing.T) { + t.Parallel() + + t.Run("DeadlineExceeded with zero DeadlineDelay", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn()} + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(1)").ReadAll() + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.True(t, pgConn.IsClosed()) + }) + + t.Run("DeadlineExceeded with DeadlineDelay", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn(), DeadlineDelay: 500 * time.Millisecond} + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(0.250)").ReadAll() + require.NoError(t, err) + + ensureConnValid(t, pgConn) + }) +}