diff --git a/autopaho/examples/router/router.go b/autopaho/examples/router/router.go index 4bb8a9a..c699b3c 100644 --- a/autopaho/examples/router/router.go +++ b/autopaho/examples/router/router.go @@ -44,7 +44,9 @@ func main() { } router := paho.NewStandardRouter() - router.DefaultHandler(func(p *paho.Publish) { fmt.Printf("defaulthandler received message with topic: %s\n", p.Topic) }) + router.DefaultHandler(func(c context.Context, p *paho.Publish) { + fmt.Printf("defaulthandler received message with topic: %s\n", p.Topic) + }) cliCfg := autopaho.ClientConfig{ ServerUrls: []*url.URL{u}, @@ -61,7 +63,7 @@ func main() { // You can write the function(s) yourself or use the supplied Router OnPublishReceived: []func(paho.PublishReceived) (bool, error){ func(pr paho.PublishReceived) (bool, error) { - router.Route(pr.Packet.Packet()) + router.Route(context.TODO(), pr.Packet.Packet()) return true, nil // we assume that the router handles all messages (todo: amend router API) }}, OnClientError: func(err error) { fmt.Printf("client error: %s\n", err) }, @@ -96,10 +98,16 @@ func main() { // Handlers can be registered/deregistered at any time. It's important to note that you need to subscribe AND create // a handler - router.RegisterHandler("test/test/#", func(p *paho.Publish) { fmt.Printf("test/test/# received message with topic: %s\n", p.Topic) }) - router.RegisterHandler("test/test/foo", func(p *paho.Publish) { fmt.Printf("test/test/foo received message with topic: %s\n", p.Topic) }) - router.RegisterHandler("test/nomatch", func(p *paho.Publish) { fmt.Printf("test/nomatch received message with topic: %s\n", p.Topic) }) - router.RegisterHandler("test/quit", func(p *paho.Publish) { stop() }) // Context will be cancelled if we receive a matching message + router.RegisterHandler("test/test/#", func(c context.Context, p *paho.Publish) { + fmt.Printf("test/test/# received message with topic: %s\n", p.Topic) + }) + router.RegisterHandler("test/test/foo", func(c context.Context, p *paho.Publish) { + fmt.Printf("test/test/foo received message with topic: %s\n", p.Topic) + }) + router.RegisterHandler("test/nomatch", func(c context.Context, p *paho.Publish) { + fmt.Printf("test/nomatch received message with topic: %s\n", p.Topic) + }) + router.RegisterHandler("test/quit", func(c context.Context, p *paho.Publish) { stop() }) // Context will be cancelled if we receive a matching message // We publish three messages to test out the various route handlers topics := []string{"test/test", "test/test/foo", "test/xxNoMatch", "test/quit"} diff --git a/autopaho/examples/rpc/main.go b/autopaho/examples/rpc/main.go index 00a1801..a594607 100644 --- a/autopaho/examples/rpc/main.go +++ b/autopaho/examples/rpc/main.go @@ -185,7 +185,7 @@ func main() { router := paho.NewStandardRouter() cliCfg.OnPublishReceived = []func(paho.PublishReceived) (bool, error){ func(p paho.PublishReceived) (bool, error) { - router.Route(p.Packet.Packet()) + router.Route(context.TODO(), p.Packet.Packet()) return false, nil }} diff --git a/autopaho/extensions/rpc/rpc.go b/autopaho/extensions/rpc/rpc.go index c702d05..82fadbb 100644 --- a/autopaho/extensions/rpc/rpc.go +++ b/autopaho/extensions/rpc/rpc.go @@ -107,7 +107,7 @@ func (h *Handler) Request(ctx context.Context, pb *paho.Publish) (resp *paho.Pub } } -func (h *Handler) responseHandler(pb *paho.Publish) { +func (h *Handler) responseHandler(ctx context.Context, pb *paho.Publish) { if pb.Properties == nil || pb.Properties.CorrelationData == nil { return } diff --git a/paho/client.go b/paho/client.go index a772da8..8187a38 100644 --- a/paho/client.go +++ b/paho/client.go @@ -69,7 +69,7 @@ type ( PingHandler Pinger defaultPinger bool - // Router - new inbound messages will be passed to the `Route(*packets.Publish)` function. + // Router - new inbound messages will be passed to the `Route(context.Context, *packets.Publish)` function. // // Depreciated: If a router is provided, it will now be added to the end of the OnPublishReceived // slice (which provides a more flexible approach to handling incoming messages). @@ -194,7 +194,7 @@ func NewClient(conf ClientConfig) *Client { r := c.config.Router c.onPublishReceived = append(c.onPublishReceived, func(p PublishReceived) (bool, error) { - r.Route(p.Packet.Packet()) + r.Route(context.TODO(), p.Packet.Packet()) return false, nil }) } @@ -443,9 +443,7 @@ func (c *Client) routePublishPackets() { // Copy onPublishReceived so lock is only held briefly c.onPublishReceivedMu.Lock() handlers := make([]func(PublishReceived) (bool, error), len(c.onPublishReceived)) - for i := range c.onPublishReceived { - handlers[i] = c.onPublishReceived[i] - } + copy(handlers, c.onPublishReceived) c.onPublishReceivedMu.Unlock() if c.config.EnableManualAcknowledgment && pb.QoS != 0 { @@ -887,7 +885,7 @@ func (c *Client) publishQoS12(ctx context.Context, pb *packets.Publish, o Publis } // From this point on the message is in store, and ret will receive something regardless of whether we succeed in - // writing the packet to the connection + // writing the packet to the connection or not if _, err := pb.WriteTo(c.config.Conn); err != nil { c.debug.Printf("failed to write packet %d to connection: %s", pb.PacketID, err) if o.Method == PublishMethod_AsyncSend { diff --git a/paho/router.go b/paho/router.go index dc6fa70..cdd870f 100644 --- a/paho/router.go +++ b/paho/router.go @@ -16,6 +16,7 @@ package paho import ( + "context" "strings" "sync" @@ -28,7 +29,7 @@ import ( // MessageHandlers should complete quickly (start a go routine for // long-running processes) and should not call functions within the // paho instance that triggered them (due to potential deadlocks). -type MessageHandler func(*Publish) +type MessageHandler func(context.Context, *Publish) // Router is an interface of the functions for a struct that is // used to handle invoking MessageHandlers depending on the @@ -42,7 +43,7 @@ type MessageHandler func(*Publish) type Router interface { RegisterHandler(string, MessageHandler) UnregisterHandler(string) - Route(*packets.Publish) + Route(context.Context, *packets.Publish) SetDebugLogger(log.Logger) } @@ -96,7 +97,7 @@ func (r *StandardRouter) UnregisterHandler(topic string) { // Route is the library provided StandardRouter's implementation // of the required interface function() -func (r *StandardRouter) Route(pb *packets.Publish) { +func (r *StandardRouter) Route(ctx context.Context, pb *packets.Publish) { r.debug.Println("routing message for:", pb.Topic) r.RLock() defer r.RUnlock() @@ -124,14 +125,14 @@ func (r *StandardRouter) Route(pb *packets.Publish) { if match(route, topic) { r.debug.Println("found handler for:", route) for _, handler := range handlers { - handler(m) + handler(ctx, m) handlerCalled = true } } } if !handlerCalled && r.defaultHandler != nil { - r.defaultHandler(m) + r.defaultHandler(ctx, m) } } diff --git a/paho/router_test.go b/paho/router_test.go index 2001ab1..64e1827 100644 --- a/paho/router_test.go +++ b/paho/router_test.go @@ -16,6 +16,7 @@ package paho import ( + "context" "reflect" "testing" @@ -94,36 +95,88 @@ func Test_routeSplit(t *testing.T) { func Test_routeDefault(t *testing.T) { var r1Count, r2Count int - r1 := func(p *Publish) { r1Count++ } - r2 := func(p *Publish) { r2Count++ } + ctx := context.Background() + r1 := func(c context.Context, p *Publish) { r1Count++ } + r2 := func(c context.Context, p *Publish) { r2Count++ } r := NewStandardRouter() r.RegisterHandler("test", r1) - r.Route(&packets.Publish{Topic: "test", Properties: &packets.Properties{}}) + r.Route(ctx, &packets.Publish{Topic: "test", Properties: &packets.Properties{}}) if r1Count != 1 { t.Errorf("router1 should have been called r1: %d, r2: %d", r1Count, r2Count) } // Confirm that unset default does not cause issue - r.Route(&packets.Publish{Topic: "xxyy", Properties: &packets.Properties{}}) + r.Route(ctx, &packets.Publish{Topic: "xxyy", Properties: &packets.Properties{}}) if r1Count != 1 { t.Errorf("router1 should not have been called r1: %d, r2: %d", r1Count, r2Count) } r.DefaultHandler(r2) - r.Route(&packets.Publish{Topic: "test", Properties: &packets.Properties{}}) + r.Route(ctx, &packets.Publish{Topic: "test", Properties: &packets.Properties{}}) if r1Count != 2 || r2Count != 0 { t.Errorf("router1 should been called r1: %d, r2: %d", r1Count, r2Count) } - r.Route(&packets.Publish{Topic: "xxyy", Properties: &packets.Properties{}}) + r.Route(ctx, &packets.Publish{Topic: "xxyy", Properties: &packets.Properties{}}) if r1Count != 2 || r2Count != 1 { t.Errorf("router2 should have been called r1: %d, r2: %d", r1Count, r2Count) } r.DefaultHandler(nil) - r.Route(&packets.Publish{Topic: "xxyy", Properties: &packets.Properties{}}) + r.Route(ctx, &packets.Publish{Topic: "xxyy", Properties: &packets.Properties{}}) if r1Count != 2 || r2Count != 1 { t.Errorf("no router should have been called r1: %d, r2: %d", r1Count, r2Count) } } + +func Test_routeContextPropagation(t *testing.T) { + type ctxKey string + testKey := ctxKey("test-key") + testValue := "test-value" + + var receivedValue string + handler := func(ctx context.Context, p *Publish) { + if v, ok := ctx.Value(testKey).(string); ok { + receivedValue = v + } + } + + r := NewStandardRouter() + r.RegisterHandler("test/topic", handler) + + // Create a context with a test value + ctx := context.WithValue(context.Background(), testKey, testValue) + + // Route a message with the context + r.Route(ctx, &packets.Publish{ + Topic: "test/topic", + Properties: &packets.Properties{}, + }) + + // Verify the context value was correctly propagated + if receivedValue != testValue { + t.Errorf("context value not propagated correctly, got: %v, want: %v", receivedValue, testValue) + } + + // Test with a cancelled context + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + + var contextCancelled bool + cancelHandler := func(ctx context.Context, p *Publish) { + if ctx.Err() == context.Canceled { + contextCancelled = true + } + } + + r.RegisterHandler("test/cancel", cancelHandler) + r.Route(cancelCtx, &packets.Publish{ + Topic: "test/cancel", + Properties: &packets.Properties{}, + }) + + if !contextCancelled { + t.Error("cancelled context was not properly propagated to handler") + } +}