Skip to content

Commit

Permalink
clean up legacy nats bus features
Browse files Browse the repository at this point in the history
  • Loading branch information
paulwe committed Mar 26, 2024
1 parent a39fe33 commit c99daa1
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 66 deletions.
2 changes: 1 addition & 1 deletion internal/bus/bus.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ const (
)

type Channel struct {
Legacy, Server, Local string
Legacy, Server, Server2, Local string
}

type MessageBus interface {
Expand Down
3 changes: 0 additions & 3 deletions internal/bus/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ package bus

import (
"google.golang.org/protobuf/proto"

"github.com/livekit/psrpc/internal/logger"
)

type Subscription[MessageType proto.Message] interface {
Expand All @@ -42,7 +40,6 @@ func newSubscription[MessageType proto.Message](sub Reader, size int) Subscripti

p, err := deserialize(b)
if err != nil {
logger.Error(err, "failed to deserialize message")
continue
}
msgChan <- p.(MessageType)
Expand Down
93 changes: 42 additions & 51 deletions internal/test/my_service/my_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,38 @@ import (

func TestGeneratedService(t *testing.T) {
t.Run("Local", func(t *testing.T) {
testGeneratedService(t, psrpc.NewLocalMessageBus())
testGeneratedService(t, (func() func() psrpc.MessageBus {
bus := psrpc.NewLocalMessageBus()
return func() psrpc.MessageBus { return bus }
})())
})

t.Run("Redis", func(t *testing.T) {
rc := redis.NewUniversalClient(&redis.UniversalOptions{Addrs: []string{"localhost:6379"}})
testGeneratedService(t, psrpc.NewRedisMessageBus(rc))
testGeneratedService(t, func() psrpc.MessageBus {
rc := redis.NewUniversalClient(&redis.UniversalOptions{Addrs: []string{"localhost:6379"}})
return psrpc.NewRedisMessageBus(rc)
})
})

t.Run("Nats", func(t *testing.T) {
nc, _ := nats.Connect(nats.DefaultURL)
testGeneratedService(t, psrpc.NewNatsMessageBus(nc))
testGeneratedService(t, func() psrpc.MessageBus {
nc, _ := nats.Connect(nats.DefaultURL)
return psrpc.NewNatsMessageBus(nc)
})
})
}

func testGeneratedService(t *testing.T, bus psrpc.MessageBus) {
func testGeneratedService(t *testing.T, bus func() psrpc.MessageBus) {
ctx := context.Background()
req := &MyRequest{}
update := &MyUpdate{}
sA := createServer(t, bus)
sB := createServer(t, bus)
sA := createServer(t, bus())
sB := createServer(t, bus())

t.Cleanup(func() {
shutdown(t, sA)
shutdown(t, sB)
})

requestCount := 0
requestHook := func(ctx context.Context, req proto.Message, rpcInfo psrpc.RPCInfo) {
Expand All @@ -60,8 +72,8 @@ func testGeneratedService(t *testing.T, bus psrpc.MessageBus) {
responseHook := func(ctx context.Context, req proto.Message, rpcInfo psrpc.RPCInfo, res proto.Message, err error) {
responseCount++
}
cA := createClient(t, bus, psrpc.WithClientRequestHooks(requestHook), psrpc.WithClientResponseHooks(responseHook))
cB := createClient(t, bus)
cA := createClient(t, bus(), psrpc.WithClientRequestHooks(requestHook), psrpc.WithClientResponseHooks(responseHook))
cB := createClient(t, bus())

// rpc NormalRPC(MyRequest) returns (MyResponse);
_, err := cA.NormalRPC(ctx, req)
Expand Down Expand Up @@ -101,7 +113,7 @@ func testGeneratedService(t *testing.T, bus psrpc.MessageBus) {
require.NotNil(t, res)
require.NoError(t, res.Err)
case <-time.After(time.Second * 3):
t.Fatalf("timed out")
require.FailNow(t, "timed out")
}
}

Expand All @@ -121,7 +133,7 @@ func testGeneratedService(t *testing.T, bus psrpc.MessageBus) {
require.NoError(t, stream.Close(nil))

// let the service goroutine run
time.Sleep(time.Millisecond * 100)
time.Sleep(time.Second)

sA.Lock()
sB.Lock()
Expand All @@ -138,7 +150,7 @@ func testGeneratedService(t *testing.T, bus psrpc.MessageBus) {
require.NoError(t, sA.server.RegisterGetRegionStatsTopic("regionB"))
sA.server.DeregisterGetRegionStatsTopic("regionB")
require.NoError(t, sB.server.RegisterGetRegionStatsTopic("regionB"))
time.Sleep(time.Millisecond * 100)
time.Sleep(time.Second)

respChan, err = cB.GetRegionStats(ctx, "regionB", req)
require.NoError(t, err)
Expand All @@ -147,7 +159,7 @@ func testGeneratedService(t *testing.T, bus psrpc.MessageBus) {
require.NotNil(t, res)
require.NoError(t, res.Err)
case <-time.After(time.Second):
t.Fatalf("timed out")
require.FailNow(t, "timed out")
}

sA.Lock()
Expand All @@ -157,56 +169,35 @@ func testGeneratedService(t *testing.T, bus psrpc.MessageBus) {
sA.Unlock()
sB.Unlock()

// rpc ProcessUpdate(Ignored) returns (MyUpdate) {
// option (psrpc.options).subscription = true;
subA, err := cA.SubscribeProcessUpdate(ctx)
require.NoError(t, err)
subB, err := cB.SubscribeProcessUpdate(ctx)
require.NoError(t, err)
time.Sleep(time.Millisecond * 100)

require.NoError(t, sA.server.PublishProcessUpdate(ctx, update))
requireOne(t, subA, subB)
require.NoError(t, subA.Close())
require.NoError(t, subB.Close())

// rpc UpdateRegionState(Ignored) returns (MyUpdate) {
// option (psrpc.options).subscription = true;
// option (psrpc.options).topics = true;
// option (psrpc.options).type = MULTI;
subA, err = cA.SubscribeUpdateRegionState(ctx, "regionA")
subA, err := cA.SubscribeUpdateRegionState(ctx, "regionA")
require.NoError(t, err)
subB, err = cB.SubscribeUpdateRegionState(ctx, "regionA")
subB, err := cB.SubscribeUpdateRegionState(ctx, "regionA")
require.NoError(t, err)
time.Sleep(time.Millisecond * 100)
time.Sleep(time.Second)

require.NoError(t, sB.server.PublishUpdateRegionState(ctx, "regionA", update))
requireTwo(t, subA, subB)
require.NoError(t, subA.Close())
require.NoError(t, subB.Close())

shutdown(t, sA)
shutdown(t, sB)
}

func requireOne(t *testing.T, subA, subB psrpc.Subscription[*MyUpdate]) {
for i := 0; i < 2; i++ {
select {
case <-subA.Channel():
if i == 0 {
continue
}
case <-subB.Channel():
if i == 0 {
continue
}
case <-time.After(time.Second):
if i == 1 {
continue
}
}
t.Fatalf("%d responses received", i*2)
var n int
select {
case <-subA.Channel():
n++
case <-time.After(time.Second):
}
select {
case <-subB.Channel():
n++
case <-time.After(time.Second):
}
require.Equal(t, 1, n, "expected one response")
}

func requireTwo(t *testing.T, subA, subB psrpc.Subscription[*MyUpdate]) {
Expand All @@ -215,7 +206,7 @@ func requireTwo(t *testing.T, subA, subB psrpc.Subscription[*MyUpdate]) {
case <-subA.Channel():
case <-subB.Channel():
case <-time.After(time.Second):
t.Fatalf("timed out")
require.FailNow(t, "timed out")
}
}
}
Expand Down Expand Up @@ -246,7 +237,7 @@ func shutdown(t *testing.T, s *MyService) {
case <-done:
// continue
case <-time.After(time.Second * 3):
t.Fatalf("shutdown not returning")
require.FailNow(t, "shutdown not returning")
}
}

Expand Down
27 changes: 16 additions & 11 deletions internal/test/psrpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ func TestRPC(t *testing.T) {
}{
{
label: "Local",
bus: func() psrpc.MessageBus { return psrpc.NewLocalMessageBus() },
bus: (func() func() psrpc.MessageBus {
bus := psrpc.NewLocalMessageBus()
return func() psrpc.MessageBus { return bus }
})(),
},
{
label: "Redis",
Expand All @@ -61,29 +64,29 @@ func TestRPC(t *testing.T) {
for _, c := range cases {
c := c
t.Run(fmt.Sprintf("RPC/%s", c.label), func(t *testing.T) {
testRPC(t, c.bus())
testRPC(t, c.bus)
})
t.Run(fmt.Sprintf("Stream/%s", c.label), func(t *testing.T) {
testStream(t, c.bus())
testStream(t, c.bus)
})
}
}

func testRPC(t *testing.T, bus psrpc.MessageBus) {
func testRPC(t *testing.T, bus func() psrpc.MessageBus) {
serviceName := "test"

serverA := server.NewRPCServer(&info.ServiceDefinition{
Name: serviceName,
ID: rand.NewString(),
}, bus)
}, bus())
serverB := server.NewRPCServer(&info.ServiceDefinition{
Name: serviceName,
ID: rand.NewString(),
}, bus)
}, bus())
serverC := server.NewRPCServer(&info.ServiceDefinition{
Name: serviceName,
ID: rand.NewString(),
}, bus)
}, bus())

t.Cleanup(func() {
serverA.Close(true)
Expand All @@ -94,7 +97,7 @@ func testRPC(t *testing.T, bus psrpc.MessageBus) {
c, err := client.NewRPCClient(&info.ServiceDefinition{
Name: serviceName,
ID: rand.NewString(),
}, bus)
}, bus())
require.NoError(t, err)

retErr := psrpc.NewErrorf(psrpc.Internal, "foo")
Expand All @@ -119,6 +122,7 @@ func testRPC(t *testing.T, bus psrpc.MessageBus) {
require.NoError(t, err)
err = server.RegisterHandler[*internal.Request, *internal.Response](serverB, rpc, nil, addOne, nil)
require.NoError(t, err)
time.Sleep(time.Second)

ctx := context.Background()
requestID := rand.NewRequestID()
Expand All @@ -141,6 +145,7 @@ func testRPC(t *testing.T, bus psrpc.MessageBus) {
require.NoError(t, err)
err = server.RegisterHandler[*internal.Request, *internal.Response](serverC, multiRpc, nil, returnError, nil)
require.NoError(t, err)
time.Sleep(time.Second)

requestID = rand.NewRequestID()
resChan, err := client.RequestMulti[*internal.Response](
Expand Down Expand Up @@ -168,13 +173,13 @@ func testRPC(t *testing.T, bus psrpc.MessageBus) {
}
}

func testStream(t *testing.T, bus psrpc.MessageBus) {
func testStream(t *testing.T, bus func() psrpc.MessageBus) {
serviceName := "test_stream"

serverA := server.NewRPCServer(&info.ServiceDefinition{
Name: serviceName,
ID: rand.NewString(),
}, bus)
}, bus())

t.Cleanup(func() {
serverA.Close(true)
Expand All @@ -183,7 +188,7 @@ func testStream(t *testing.T, bus psrpc.MessageBus) {
c, err := client.NewRPCClientWithStreams(&info.ServiceDefinition{
Name: serviceName,
ID: rand.NewString(),
}, bus)
}, bus())
require.NoError(t, err)

serverClose := make(chan struct{})
Expand Down

0 comments on commit c99daa1

Please sign in to comment.