diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index d425b5ad8..819380e3f 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -2250,3 +2250,55 @@ func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) { func (h *DeadlineContextWatcherHandler) HandleUnwatchAfterCancel() { h.Conn.SetDeadline(time.Time{}) } + +// CancelRequestContextWatcherHandler handles canceled contexts by sending a cancel request to the server. It also sets +// a deadline on a net.Conn as a fallback. +type CancelRequestContextWatcherHandler struct { + Conn *PgConn + + // CancelRequestDelay is the delay before sending the cancel request to the server. + CancelRequestDelay time.Duration + + // DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled. + DeadlineDelay time.Duration + + cancelFinishedChan chan struct{} + handleUnwatchAfterCancelCalled func() +} + +func (h *CancelRequestContextWatcherHandler) HandleCancel(context.Context) { + h.cancelFinishedChan = make(chan struct{}) + var handleUnwatchedAfterCancelCalledCtx context.Context + handleUnwatchedAfterCancelCalledCtx, h.handleUnwatchAfterCancelCalled = context.WithCancel(context.Background()) + + deadline := time.Now().Add(h.DeadlineDelay) + h.Conn.conn.SetDeadline(deadline) + + go func() { + defer close(h.cancelFinishedChan) + + select { + case <-handleUnwatchedAfterCancelCalledCtx.Done(): + return + case <-time.After(h.CancelRequestDelay): + } + + cancelRequestCtx, cancel := context.WithDeadline(handleUnwatchedAfterCancelCalledCtx, deadline) + defer cancel() + h.Conn.CancelRequest(cancelRequestCtx) + + // CancelRequest is inherently racy. Even though the cancel request has been received by the server at this point, + // it hasn't necessarily been delivered to the other connection. If we immediately return and the connection is + // immediately used then it is possible the CancelRequest will actually cancel our next query. The + // TestCancelRequestContextWatcherHandler Stress test can produce this error without the sleep below. The sleep time + // is arbitrary, but should be sufficient to prevent this error case. + time.Sleep(100 * time.Millisecond) + }() +} + +func (h *CancelRequestContextWatcherHandler) HandleUnwatchAfterCancel() { + h.handleUnwatchAfterCancelCalled() + <-h.cancelFinishedChan + + h.Conn.conn.SetDeadline(time.Time{}) +} diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 704abe1cd..e91f699a4 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -3527,3 +3527,114 @@ func TestDeadlineContextWatcherHandler(t *testing.T) { ensureConnValid(t, pgConn) }) } + +func TestCancelRequestContextWatcherHandler(t *testing.T) { + t.Parallel() + + t.Run("DeadlineExceeded cancels request after CancelRequestDelay", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: conn, + CancelRequestDelay: 250 * time.Millisecond, + DeadlineDelay: 5000 * time.Millisecond, + } + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(3)").ReadAll() + require.Error(t, err) + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + + ensureConnValid(t, pgConn) + }) + + t.Run("DeadlineExceeded - do not send cancel request when query finishes in grace period", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: conn, + CancelRequestDelay: 1000 * time.Millisecond, + DeadlineDelay: 5000 * time.Millisecond, + } + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(0.250)").ReadAll() + require.NoError(t, err) + + ensureConnValid(t, pgConn) + }) + + t.Run("DeadlineExceeded sets conn deadline with DeadlineDelay", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: conn, + CancelRequestDelay: 5000 * time.Millisecond, // purposely setting this higher than DeadlineDelay to ensure the cancel request never happens. + DeadlineDelay: 250 * time.Millisecond, + } + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(1)").ReadAll() + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.True(t, pgConn.IsClosed()) + }) + + for i := 0; i < 10; i++ { + t.Run(fmt.Sprintf("Stress %d", i), func(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: conn, + CancelRequestDelay: 5 * time.Millisecond, + DeadlineDelay: 1000 * time.Millisecond, + } + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + for i := 0; i < 20; i++ { + func() { + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond) + defer cancel() + pgConn.Exec(ctx, "select 1, pg_sleep(0.010)").ReadAll() + ensureConnValid(t, pgConn) + }() + } + }) + } +}