Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add context support to router #285

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions autopaho/examples/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ func main() {
}

router := paho.NewStandardRouter()
router.DefaultHandler(func(p *paho.Publish) { fmt.Printf("defaulthandler received message with topic: %s\n", p.Topic) })
router.DefaultHandler(func(c context.Context, p *paho.Publish) {
fmt.Printf("defaulthandler received message with topic: %s\n", p.Topic)
})

cliCfg := autopaho.ClientConfig{
ServerUrls: []*url.URL{u},
Expand All @@ -61,7 +63,7 @@ func main() {
// You can write the function(s) yourself or use the supplied Router
OnPublishReceived: []func(paho.PublishReceived) (bool, error){
func(pr paho.PublishReceived) (bool, error) {
router.Route(pr.Packet.Packet())
router.Route(context.TODO(), pr.Packet.Packet())
return true, nil // we assume that the router handles all messages (todo: amend router API)
}},
OnClientError: func(err error) { fmt.Printf("client error: %s\n", err) },
Expand Down Expand Up @@ -96,10 +98,16 @@ func main() {

// Handlers can be registered/deregistered at any time. It's important to note that you need to subscribe AND create
// a handler
router.RegisterHandler("test/test/#", func(p *paho.Publish) { fmt.Printf("test/test/# received message with topic: %s\n", p.Topic) })
router.RegisterHandler("test/test/foo", func(p *paho.Publish) { fmt.Printf("test/test/foo received message with topic: %s\n", p.Topic) })
router.RegisterHandler("test/nomatch", func(p *paho.Publish) { fmt.Printf("test/nomatch received message with topic: %s\n", p.Topic) })
router.RegisterHandler("test/quit", func(p *paho.Publish) { stop() }) // Context will be cancelled if we receive a matching message
router.RegisterHandler("test/test/#", func(c context.Context, p *paho.Publish) {
fmt.Printf("test/test/# received message with topic: %s\n", p.Topic)
})
router.RegisterHandler("test/test/foo", func(c context.Context, p *paho.Publish) {
fmt.Printf("test/test/foo received message with topic: %s\n", p.Topic)
})
router.RegisterHandler("test/nomatch", func(c context.Context, p *paho.Publish) {
fmt.Printf("test/nomatch received message with topic: %s\n", p.Topic)
})
router.RegisterHandler("test/quit", func(c context.Context, p *paho.Publish) { stop() }) // Context will be cancelled if we receive a matching message

// We publish three messages to test out the various route handlers
topics := []string{"test/test", "test/test/foo", "test/xxNoMatch", "test/quit"}
Expand Down
2 changes: 1 addition & 1 deletion autopaho/examples/rpc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func main() {
router := paho.NewStandardRouter()
cliCfg.OnPublishReceived = []func(paho.PublishReceived) (bool, error){
func(p paho.PublishReceived) (bool, error) {
router.Route(p.Packet.Packet())
router.Route(context.TODO(), p.Packet.Packet())
return false, nil
}}

Expand Down
2 changes: 1 addition & 1 deletion autopaho/extensions/rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (h *Handler) Request(ctx context.Context, pb *paho.Publish) (resp *paho.Pub
}
}

func (h *Handler) responseHandler(pb *paho.Publish) {
func (h *Handler) responseHandler(ctx context.Context, pb *paho.Publish) {
if pb.Properties == nil || pb.Properties.CorrelationData == nil {
return
}
Expand Down
10 changes: 4 additions & 6 deletions paho/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type (
PingHandler Pinger
defaultPinger bool

// Router - new inbound messages will be passed to the `Route(*packets.Publish)` function.
// Router - new inbound messages will be passed to the `Route(context.Context, *packets.Publish)` function.
//
// Depreciated: If a router is provided, it will now be added to the end of the OnPublishReceived
// slice (which provides a more flexible approach to handling incoming messages).
Expand Down Expand Up @@ -194,7 +194,7 @@ func NewClient(conf ClientConfig) *Client {
r := c.config.Router
c.onPublishReceived = append(c.onPublishReceived,
func(p PublishReceived) (bool, error) {
r.Route(p.Packet.Packet())
r.Route(context.TODO(), p.Packet.Packet())
return false, nil
})
}
Expand Down Expand Up @@ -443,9 +443,7 @@ func (c *Client) routePublishPackets() {
// Copy onPublishReceived so lock is only held briefly
c.onPublishReceivedMu.Lock()
handlers := make([]func(PublishReceived) (bool, error), len(c.onPublishReceived))
for i := range c.onPublishReceived {
handlers[i] = c.onPublishReceived[i]
}
copy(handlers, c.onPublishReceived)
c.onPublishReceivedMu.Unlock()

if c.config.EnableManualAcknowledgment && pb.QoS != 0 {
Expand Down Expand Up @@ -887,7 +885,7 @@ func (c *Client) publishQoS12(ctx context.Context, pb *packets.Publish, o Publis
}

// From this point on the message is in store, and ret will receive something regardless of whether we succeed in
// writing the packet to the connection
// writing the packet to the connection or not
if _, err := pb.WriteTo(c.config.Conn); err != nil {
c.debug.Printf("failed to write packet %d to connection: %s", pb.PacketID, err)
if o.Method == PublishMethod_AsyncSend {
Expand Down
11 changes: 6 additions & 5 deletions paho/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package paho

import (
"context"
"strings"
"sync"

Expand All @@ -28,7 +29,7 @@ import (
// MessageHandlers should complete quickly (start a go routine for
// long-running processes) and should not call functions within the
// paho instance that triggered them (due to potential deadlocks).
type MessageHandler func(*Publish)
type MessageHandler func(context.Context, *Publish)

// Router is an interface of the functions for a struct that is
// used to handle invoking MessageHandlers depending on the
Expand All @@ -42,7 +43,7 @@ type MessageHandler func(*Publish)
type Router interface {
RegisterHandler(string, MessageHandler)
UnregisterHandler(string)
Route(*packets.Publish)
Route(context.Context, *packets.Publish)
SetDebugLogger(log.Logger)
}

Expand Down Expand Up @@ -96,7 +97,7 @@ func (r *StandardRouter) UnregisterHandler(topic string) {

// Route is the library provided StandardRouter's implementation
// of the required interface function()
func (r *StandardRouter) Route(pb *packets.Publish) {
func (r *StandardRouter) Route(ctx context.Context, pb *packets.Publish) {
r.debug.Println("routing message for:", pb.Topic)
r.RLock()
defer r.RUnlock()
Expand Down Expand Up @@ -124,14 +125,14 @@ func (r *StandardRouter) Route(pb *packets.Publish) {
if match(route, topic) {
r.debug.Println("found handler for:", route)
for _, handler := range handlers {
handler(m)
handler(ctx, m)
handlerCalled = true
}
}
}

if !handlerCalled && r.defaultHandler != nil {
r.defaultHandler(m)
r.defaultHandler(ctx, m)
}
}

Expand Down
67 changes: 60 additions & 7 deletions paho/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package paho

import (
"context"
"reflect"
"testing"

Expand Down Expand Up @@ -94,36 +95,88 @@ func Test_routeSplit(t *testing.T) {
func Test_routeDefault(t *testing.T) {
var r1Count, r2Count int

r1 := func(p *Publish) { r1Count++ }
r2 := func(p *Publish) { r2Count++ }
ctx := context.Background()
r1 := func(c context.Context, p *Publish) { r1Count++ }
r2 := func(c context.Context, p *Publish) { r2Count++ }

r := NewStandardRouter()
r.RegisterHandler("test", r1)

r.Route(&packets.Publish{Topic: "test", Properties: &packets.Properties{}})
r.Route(ctx, &packets.Publish{Topic: "test", Properties: &packets.Properties{}})
if r1Count != 1 {
t.Errorf("router1 should have been called r1: %d, r2: %d", r1Count, r2Count)
}
// Confirm that unset default does not cause issue
r.Route(&packets.Publish{Topic: "xxyy", Properties: &packets.Properties{}})
r.Route(ctx, &packets.Publish{Topic: "xxyy", Properties: &packets.Properties{}})
if r1Count != 1 {
t.Errorf("router1 should not have been called r1: %d, r2: %d", r1Count, r2Count)
}

r.DefaultHandler(r2)
r.Route(&packets.Publish{Topic: "test", Properties: &packets.Properties{}})
r.Route(ctx, &packets.Publish{Topic: "test", Properties: &packets.Properties{}})
if r1Count != 2 || r2Count != 0 {
t.Errorf("router1 should been called r1: %d, r2: %d", r1Count, r2Count)
}
r.Route(&packets.Publish{Topic: "xxyy", Properties: &packets.Properties{}})
r.Route(ctx, &packets.Publish{Topic: "xxyy", Properties: &packets.Properties{}})
if r1Count != 2 || r2Count != 1 {
t.Errorf("router2 should have been called r1: %d, r2: %d", r1Count, r2Count)
}

r.DefaultHandler(nil)
r.Route(&packets.Publish{Topic: "xxyy", Properties: &packets.Properties{}})
r.Route(ctx, &packets.Publish{Topic: "xxyy", Properties: &packets.Properties{}})
if r1Count != 2 || r2Count != 1 {
t.Errorf("no router should have been called r1: %d, r2: %d", r1Count, r2Count)
}

}

func Test_routeContextPropagation(t *testing.T) {
type ctxKey string
testKey := ctxKey("test-key")
testValue := "test-value"

var receivedValue string
handler := func(ctx context.Context, p *Publish) {
if v, ok := ctx.Value(testKey).(string); ok {
receivedValue = v
}
}

r := NewStandardRouter()
r.RegisterHandler("test/topic", handler)

// Create a context with a test value
ctx := context.WithValue(context.Background(), testKey, testValue)

// Route a message with the context
r.Route(ctx, &packets.Publish{
Topic: "test/topic",
Properties: &packets.Properties{},
})

// Verify the context value was correctly propagated
if receivedValue != testValue {
t.Errorf("context value not propagated correctly, got: %v, want: %v", receivedValue, testValue)
}

// Test with a cancelled context
cancelCtx, cancel := context.WithCancel(context.Background())
cancel()

var contextCancelled bool
cancelHandler := func(ctx context.Context, p *Publish) {
if ctx.Err() == context.Canceled {
contextCancelled = true
}
}

r.RegisterHandler("test/cancel", cancelHandler)
r.Route(cancelCtx, &packets.Publish{
Topic: "test/cancel",
Properties: &packets.Properties{},
})

if !contextCancelled {
t.Error("cancelled context was not properly propagated to handler")
}
}