Skip to content

Commit

Permalink
ensure websockets persists until done on drain
Browse files Browse the repository at this point in the history
  • Loading branch information
elijah-rou committed Feb 7, 2025
1 parent 6265a8e commit d6da0ea
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
17 changes: 17 additions & 0 deletions pkg/queue/sharedmain/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ import (
"context"
"net"
"net/http"
"strings"
"sync/atomic"
"time"

"go.uber.org/zap"
netheader "knative.dev/networking/pkg/http/header"
netproxy "knative.dev/networking/pkg/http/proxy"
netstats "knative.dev/networking/pkg/http/stats"
"knative.dev/pkg/network"
pkghandler "knative.dev/pkg/network/handlers"
"knative.dev/pkg/tracing"
tracingconfig "knative.dev/pkg/tracing/config"
Expand All @@ -43,6 +46,7 @@ func mainHandler(
prober func() bool,
stats *netstats.RequestStats,
logger *zap.SugaredLogger,
pendingRequests *atomic.Int32,
) (http.Handler, *pkghandler.Drainer) {
target := net.JoinHostPort("127.0.0.1", env.UserPort)

Expand Down Expand Up @@ -86,6 +90,8 @@ func mainHandler(

composedHandler = withFullDuplex(composedHandler, env.EnableHTTPFullDuplex, logger)

composedHandler = withRequestCounter(composedHandler, pendingRequests)

drainer := &pkghandler.Drainer{
QuietPeriod: drainSleepDuration,
// Add Activator probe header to the drainer so it can handle probes directly from activator
Expand All @@ -100,6 +106,7 @@ func mainHandler(
// Hence we need to have RequestLogHandler be the first one.
composedHandler = requestLogHandler(logger, composedHandler, env)
}

return composedHandler, drainer
}

Expand Down Expand Up @@ -139,3 +146,13 @@ func withFullDuplex(h http.Handler, enableFullDuplex bool, logger *zap.SugaredLo
h.ServeHTTP(w, r)
})
}

func withRequestCounter(h http.Handler, pendingRequests *atomic.Int32) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get(network.ProbeHeaderName) != network.ProbeHeaderValue && !strings.HasPrefix(r.Header.Get("User-Agent"), "kube-probe/") {
pendingRequests.Add(1)
defer pendingRequests.Add(-1)
}
h.ServeHTTP(w, r)
})
}
32 changes: 28 additions & 4 deletions pkg/queue/sharedmain/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"net/http"
"os"
"strconv"
"sync/atomic"
"time"

"github.com/kelseyhightower/envconfig"
Expand Down Expand Up @@ -59,7 +60,7 @@ const (
// Duration the /wait-for-drain handler should wait before returning.
// This is to give networking a little bit more time to remove the pod
// from its configuration and propagate that to all loadbalancers and nodes.
drainSleepDuration = 30 * time.Second
drainSleepDuration = 15 * time.Second

// certPath is the path for the server certificate mounted by queue-proxy.
certPath = queue.CertDirectory + "/" + certificates.CertName
Expand Down Expand Up @@ -169,6 +170,8 @@ func Main(opts ...Option) error {
d := Defaults{
Ctx: signals.NewContext(),
}
pendingRequests := atomic.Int32{}
pendingRequests.Store(0)

// Parse the environment.
var env config
Expand Down Expand Up @@ -234,7 +237,7 @@ func Main(opts ...Option) error {
// Enable TLS when certificate is mounted.
tlsEnabled := exists(logger, certPath) && exists(logger, keyPath)

mainHandler, drainer := mainHandler(d.Ctx, env, d.Transport, probe, stats, logger)
mainHandler, drainer := mainHandler(d.Ctx, env, d.Transport, probe, stats, logger, &pendingRequests)
adminHandler := adminHandler(d.Ctx, logger, drainer)

// Enable TLS server when activator server certs are mounted.
Expand Down Expand Up @@ -306,15 +309,36 @@ func Main(opts ...Option) error {
logger.Infof("Sleeping %v to allow K8s propagation of non-ready state", drainSleepDuration)
drainer.Drain()

ctx, cancel := context.WithTimeout(context.Background(), time.Duration(env.RevisionTimeoutSeconds)*time.Second)
defer cancel()
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()

logger.Infof("Drain: waiting for %d pending requests to complete", pendingRequests.Load())
WaitOnPendingRequests:
for {
select {
case <-ctx.Done():
logger.Infof("Drain: timeout waiting for pending requests to complete")
break WaitOnPendingRequests
case <-ticker.C:
if pendingRequests.Load() <= 0 {
logger.Infof("Drain: all pending requests completed")
break WaitOnPendingRequests
}
}
}
time.Sleep(drainSleepDuration)

for name, srv := range httpServers {
logger.Info("Shutting down server: ", name)
if err := srv.Shutdown(context.Background()); err != nil {
if err := srv.Shutdown(ctx); err != nil {
logger.Errorw("Failed to shutdown server", zap.String("server", name), zap.Error(err))
}
}
for name, srv := range tlsServers {
logger.Info("Shutting down server: ", name)
if err := srv.Shutdown(context.Background()); err != nil {
if err := srv.Shutdown(ctx); err != nil {
logger.Errorw("Failed to shutdown server", zap.String("server", name), zap.Error(err))
}
}
Expand Down

0 comments on commit d6da0ea

Please sign in to comment.