Skip to content

Commit

Permalink
Add ping and pong received callbacks
Browse files Browse the repository at this point in the history
This change adds two optional callbacks to both `DialOptions` and
`AcceptOptions`. These callbacks are invoked synchronously when a ping
or pong frame is received, allowing advanced users to log or inspect
payloads for metrics or debugging. If the callback needs to perform more
complex work or reuse the payload outside the callback, it is
recommended to perform processing in a separate goroutine.

The boolean return value of `OnPingReceived` is used to determine if the
subsequent pong frame should be sent. If `false` is returned, the pong
frame is not sent.

Tests confirm that the ping/pong callbacks are invoked as expected.

References #246
  • Loading branch information
igolaizola committed Jan 29, 2025
1 parent aec630d commit 85e8670
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 5 deletions.
19 changes: 19 additions & 0 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package websocket

import (
"bytes"
"context"
"crypto/sha1"
"encoding/base64"
"errors"
Expand Down Expand Up @@ -62,6 +63,22 @@ type AcceptOptions struct {
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int

// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
//
// The payload contains the application data of the ping frame.
// If the callback returns false, the subsequent pong frame will not be sent.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
OnPingReceived func(ctx context.Context, payload []byte) bool

// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
//
// The payload contains the application data of the pong frame.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
//
// Unlike OnPingReceived, this callback does not return a value because a pong frame
// is a response to a ping and does not trigger any further frame transmission.
OnPongReceived func(ctx context.Context, payload []byte)
}

func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
Expand Down Expand Up @@ -156,6 +173,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
client: false,
copts: copts,
flateThreshold: opts.CompressionThreshold,
onPingReceived: opts.OnPingReceived,
onPongReceived: opts.OnPongReceived,

br: brw.Reader,
bw: brw.Writer,
Expand Down
16 changes: 11 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ type Conn struct {
closeMu sync.Mutex // Protects following.
closed chan struct{}

pingCounter atomic.Int64
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
pingCounter atomic.Int64
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
}

type connConfig struct {
Expand All @@ -94,6 +96,8 @@ type connConfig struct {
client bool
copts *compressionOptions
flateThreshold int
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)

br *bufio.Reader
bw *bufio.Writer
Expand All @@ -114,8 +118,10 @@ func newConn(cfg connConfig) *Conn {
writeTimeout: make(chan context.Context),
timeoutLoopDone: make(chan struct{}),

closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
onPingReceived: cfg.onPingReceived,
onPongReceived: cfg.onPongReceived,
}

c.readMu = newMu(c)
Expand Down
79 changes: 79 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,85 @@ func TestConn(t *testing.T) {
assert.Contains(t, err, "failed to wait for pong")
})

t.Run("pingReceivedPongReceived", func(t *testing.T) {
var pingReceived1, pongReceived1 bool
var pingReceived2, pongReceived2 bool
tt, c1, c2 := newConnTest(t,
&websocket.DialOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived1 = true
return true
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived1 = true
},
}, &websocket.AcceptOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived2 = true
return true
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived2 = true
},
},
)

c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)

ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
defer cancel()

err := c1.Ping(ctx)
assert.Success(t, err)

c1.CloseNow()
c2.CloseNow()

assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2)
assert.Equal(t, "only one side receives the pong", false, pongReceived1 && pongReceived2)
assert.Equal(t, "ping and pong received", true, (pingReceived1 && pongReceived2) || (pingReceived2 && pongReceived1))
})

t.Run("pingReceivedPongNotReceived", func(t *testing.T) {
var pingReceived1, pongReceived1 bool
var pingReceived2, pongReceived2 bool
tt, c1, c2 := newConnTest(t,
&websocket.DialOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived1 = true
return false
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived1 = true
},
}, &websocket.AcceptOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived2 = true
return false
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived2 = true
},
},
)

c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)

ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
defer cancel()

err := c1.Ping(ctx)
assert.Contains(t, err, "failed to wait for pong")

c1.CloseNow()
c2.CloseNow()

assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2)
assert.Equal(t, "ping received and pong not received", true, (pingReceived1 && !pongReceived2) || (pingReceived2 && !pongReceived1))
})

t.Run("concurrentWrite", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)

Expand Down
18 changes: 18 additions & 0 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@ type DialOptions struct {
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int

// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
//
// The payload contains the application data of the ping frame.
// If the callback returns false, the subsequent pong frame will not be sent.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
OnPingReceived func(ctx context.Context, payload []byte) bool

// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
//
// The payload contains the application data of the pong frame.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
//
// Unlike OnPingReceived, this callback does not return a value because a pong frame
// is a response to a ping and does not trigger any further frame transmission.
OnPongReceived func(ctx context.Context, payload []byte)
}

func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
Expand Down Expand Up @@ -163,6 +179,8 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
client: true,
copts: copts,
flateThreshold: opts.CompressionThreshold,
onPingReceived: opts.OnPingReceived,
onPongReceived: opts.OnPongReceived,
br: getBufioReader(rwc),
bw: getBufioWriter(rwc),
}), resp, nil
Expand Down
8 changes: 8 additions & 0 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,16 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {

switch h.opcode {
case opPing:
if c.onPingReceived != nil {
if !c.onPingReceived(ctx, b) {
return nil
}
}
return c.writeControl(ctx, opPong, b)
case opPong:
if c.onPongReceived != nil {
c.onPongReceived(ctx, b)
}
c.activePingsMu.Lock()
pong, ok := c.activePings[string(b)]
c.activePingsMu.Unlock()
Expand Down

0 comments on commit 85e8670

Please sign in to comment.