Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure websocket conections persist until done on queue-proxy drain #15759

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pkg/queue/sharedmain/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"net"
"net/http"
"sync/atomic"
"time"

"go.uber.org/zap"
Expand All @@ -43,6 +44,7 @@ func mainHandler(
prober func() bool,
stats *netstats.RequestStats,
logger *zap.SugaredLogger,
pendingRequests *atomic.Int32,
) (http.Handler, *pkghandler.Drainer) {
target := net.JoinHostPort("127.0.0.1", env.UserPort)

Expand Down Expand Up @@ -86,6 +88,8 @@ func mainHandler(

composedHandler = withFullDuplex(composedHandler, env.EnableHTTPFullDuplex, logger)

composedHandler = withRequestCounter(composedHandler, pendingRequests)

drainer := &pkghandler.Drainer{
QuietPeriod: drainSleepDuration,
// Add Activator probe header to the drainer so it can handle probes directly from activator
Expand All @@ -100,6 +104,7 @@ func mainHandler(
// Hence we need to have RequestLogHandler be the first one.
composedHandler = requestLogHandler(logger, composedHandler, env)
}

return composedHandler, drainer
}

Expand Down Expand Up @@ -139,3 +144,11 @@ func withFullDuplex(h http.Handler, enableFullDuplex bool, logger *zap.SugaredLo
h.ServeHTTP(w, r)
})
}

func withRequestCounter(h http.Handler, pendingRequests *atomic.Int32) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pendingRequests.Add(1)
defer pendingRequests.Add(-1)
h.ServeHTTP(w, r)
})
}
21 changes: 20 additions & 1 deletion pkg/queue/sharedmain/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"net/http"
"os"
"strconv"
"sync/atomic"
"time"

"github.com/kelseyhightower/envconfig"
Expand Down Expand Up @@ -169,6 +170,8 @@ func Main(opts ...Option) error {
d := Defaults{
Ctx: signals.NewContext(),
}
pendingRequests := atomic.Int32{}
pendingRequests.Store(0)

// Parse the environment.
var env config
Expand Down Expand Up @@ -234,7 +237,7 @@ func Main(opts ...Option) error {
// Enable TLS when certificate is mounted.
tlsEnabled := exists(logger, certPath) && exists(logger, keyPath)

mainHandler, drainer := mainHandler(d.Ctx, env, d.Transport, probe, stats, logger)
mainHandler, drainer := mainHandler(d.Ctx, env, d.Transport, probe, stats, logger, &pendingRequests)
adminHandler := adminHandler(d.Ctx, logger, drainer)

// Enable TLS server when activator server certs are mounted.
Expand Down Expand Up @@ -304,8 +307,24 @@ func Main(opts ...Option) error {
case <-d.Ctx.Done():
logger.Info("Received TERM signal, attempting to gracefully shutdown servers.")
logger.Infof("Sleeping %v to allow K8s propagation of non-ready state", drainSleepDuration)
time.Sleep(drainSleepDuration)
drainer.Drain()

// Wait on active requests to complete. This is done explictly
// to avoid closing any connections which have been highjacked,
// as in net/http `.Shutdown` would do so ungracefully.
// See https://github.com/golang/go/issues/17721
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
logger.Infof("Drain: waiting for %d pending requests to complete", pendingRequests.Load())
WaitOnPendingRequests:
for range ticker.C {
if pendingRequests.Load() <= 0 {
logger.Infof("Drain: all pending requests completed")
break WaitOnPendingRequests
}
}

for name, srv := range httpServers {
logger.Info("Shutting down server: ", name)
if err := srv.Shutdown(context.Background()); err != nil {
Expand Down
5 changes: 5 additions & 0 deletions test/e2e/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,11 @@ func TestWebSocketWithTimeout(t *testing.T) {
idleTimeoutSeconds: 10,
delay: "20",
expectError: true,
}, {
name: "websocket does not drop after queue drain is called at 30s",
timeoutSeconds: 60,
delay: "45",
expectError: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
Expand Down