Skip to content

Commit

Permalink
Add a weighted semaphore to protect the concurrency of requests being…
Browse files Browse the repository at this point in the history
… handled by coordinator. Hopes to address #22, values for concurrency taken from Koenkk/zigbee-herdsman.
  • Loading branch information
pwood committed Aug 7, 2021
1 parent f3b4c8f commit 338b576
Show file tree
Hide file tree
Showing 28 changed files with 135 additions and 7 deletions.
6 changes: 6 additions & 0 deletions adapter_endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@ package zstack

import (
"context"
"fmt"
"github.com/shimmeringbee/zigbee"
)

func (z *ZStack) RegisterAdapterEndpoint(ctx context.Context, endpoint zigbee.Endpoint, appProfileId zigbee.ProfileID, appDeviceId uint16, appDeviceVersion uint8, inClusters []zigbee.ClusterID, outClusters []zigbee.ClusterID) error {
if err := z.sem.Acquire(ctx, 1); err != nil {
return fmt.Errorf("failed to acquire semaphore: %w", err)
}
defer z.sem.Release(1)

request := AFRegister{
Endpoint: endpoint,
AppProfileId: appProfileId,
Expand Down
3 changes: 3 additions & 0 deletions adapter_endpoints_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
unpiTest "github.com/shimmeringbee/unpi/testing"
"github.com/shimmeringbee/zigbee"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/semaphore"
"testing"
"time"
)
Expand All @@ -18,6 +19,7 @@ func Test_RegisterAdapterEndpoint(t *testing.T) {

unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

c := unpiMock.On(SREQ, AF, AFRegisterID).Return(Frame{
Expand All @@ -43,6 +45,7 @@ func Test_RegisterAdapterEndpoint(t *testing.T) {

unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

unpiMock.On(SREQ, AF, AFRegisterID).Return(Frame{
Expand Down
6 changes: 6 additions & 0 deletions adapter_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package zstack

import (
"context"
"fmt"
"github.com/shimmeringbee/zigbee"
)

Expand Down Expand Up @@ -30,6 +31,11 @@ func (z *ZStack) GetAdapterNetworkAddress(ctx context.Context) (zigbee.NetworkAd
func (z *ZStack) getAddressInfo(ctx context.Context) (UtilGetDeviceInfoRequestReply, error) {
resp := UtilGetDeviceInfoRequestReply{}

if err := z.sem.Acquire(ctx, 1); err != nil {
return resp, fmt.Errorf("failed to acquire semaphore: %w", err)
}
defer z.sem.Release(1)

err := z.requestResponder.RequestResponse(ctx, UtilGetDeviceInfoRequest{}, &resp)
return resp, err
}
Expand Down
3 changes: 3 additions & 0 deletions adapter_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
unpiTest "github.com/shimmeringbee/unpi/testing"
"github.com/shimmeringbee/zigbee"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/semaphore"
"testing"
"time"
)
Expand All @@ -18,6 +19,7 @@ func Test_GetAdapterIEEEAddress(t *testing.T) {

unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

unpiMock.On(SREQ, UTIL, UtilGetDeviceInfoRequestID).Return(Frame{
Expand All @@ -42,6 +44,7 @@ func Test_GetAdapterNetworkAddress(t *testing.T) {

unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

unpiMock.On(SREQ, UTIL, UtilGetDeviceInfoRequestID).Return(Frame{
Expand Down
7 changes: 7 additions & 0 deletions adapter_initialise.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"github.com/shimmeringbee/retry"
"github.com/shimmeringbee/zigbee"
"golang.org/x/sync/semaphore"
"reflect"
)

Expand All @@ -24,6 +25,12 @@ func (z *ZStack) Initialise(pctx context.Context, nc zigbee.NetworkConfiguration
return err
}

if version.IsV3() {
z.sem = semaphore.NewWeighted(16)
} else {
z.sem = semaphore.NewWeighted(2)
}

z.logger.LogInfo(ctx, "Verifying existing network configuration.")
if valid, err := z.verifyAdapterNetworkConfig(ctx, version); err != nil {
return err
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ require (
github.com/shimmeringbee/unpi v0.0.0-20210525151328-7ede275a1033
github.com/shimmeringbee/zigbee v0.0.0-20210427191220-76676a734066
github.com/stretchr/testify v1.7.0
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
5 changes: 5 additions & 0 deletions joining.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ func (z *ZStack) DenyJoin(ctx context.Context) error {
}

func (z *ZStack) sendJoin(ctx context.Context, address zigbee.NetworkAddress, timeout uint8, newState JoinState) error {
if err := z.sem.Acquire(ctx, 1); err != nil {
return fmt.Errorf("failed to acquire semaphore: %w", err)
}
defer z.sem.Release(1)

response := ZDOMgmtPermitJoinRequestReply{}

if err := z.requestResponder.RequestResponse(ctx, ZDOMgmtPermitJoinRequest{
Expand Down
6 changes: 6 additions & 0 deletions joining_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
unpiTest "github.com/shimmeringbee/unpi/testing"
"github.com/shimmeringbee/zigbee"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/semaphore"
"testing"
"time"
)
Expand All @@ -18,6 +19,7 @@ func Test_PermitJoin(t *testing.T) {

unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

c := unpiMock.On(SREQ, ZDO, ZDOMgmtPermitJoinRequestID).Return(Frame{
Expand All @@ -42,6 +44,7 @@ func Test_PermitJoin(t *testing.T) {

unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()
zstack.NetworkProperties.NetworkAddress = zigbee.NetworkAddress(0x0102)

Expand All @@ -67,6 +70,7 @@ func Test_PermitJoin(t *testing.T) {

unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

unpiMock.On(SREQ, ZDO, ZDOMgmtPermitJoinRequestID).Return(Frame{
Expand All @@ -90,6 +94,7 @@ func Test_DenyJoin(t *testing.T) {

unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

c := unpiMock.On(SREQ, ZDO, ZDOMgmtPermitJoinRequestID).Return(Frame{
Expand All @@ -115,6 +120,7 @@ func Test_DenyJoin(t *testing.T) {

unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

unpiMock.On(SREQ, ZDO, ZDOMgmtPermitJoinRequestID).Return(Frame{
Expand Down
6 changes: 6 additions & 0 deletions network_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ func (z *ZStack) requestLQITable(node zigbee.Node, startIndex uint8) {
ctx, cancel := context.WithTimeout(context.Background(), DefaultZStackTimeout)
defer cancel()

if err := z.sem.Acquire(ctx, 1); err != nil {
z.logger.LogError(ctx, "Failed to request LQI table, failed to acquire semaphore ", logwrap.Datum("IEEEAddress", node.IEEEAddress.String()), logwrap.Datum("NetworkAddress", node.NetworkAddress), logwrap.Err(err))
return
}
defer z.sem.Release(1)

resp := ZdoMGMTLQIReqReply{}
z.logger.LogDebug(ctx, "Requesting LQI table from device.", logwrap.Datum("IEEEAddress", node.IEEEAddress.String()), logwrap.Datum("NetworkAddress", node.NetworkAddress), logwrap.Datum("StartIndex", startIndex))
if err := z.requestResponder.RequestResponse(ctx, ZdoMGMTLQIReq{DestinationAddress: node.NetworkAddress, StartIndex: startIndex}, &resp); err != nil {
Expand Down
13 changes: 13 additions & 0 deletions network_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
unpiTest "github.com/shimmeringbee/unpi/testing"
"github.com/shimmeringbee/zigbee"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/semaphore"
"testing"
"time"
)
Expand All @@ -15,6 +16,7 @@ func Test_NetworkManager(t *testing.T) {
t.Run("issues a lqi poll request only for coordinators or routers", func(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()
defer zstack.Stop()

Expand All @@ -39,6 +41,7 @@ func Test_NetworkManager(t *testing.T) {
t.Run("the coordinator is added to the node list as a coordinator", func(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()
defer zstack.Stop()

Expand Down Expand Up @@ -72,6 +75,7 @@ func Test_NetworkManager(t *testing.T) {
t.Run("a node is added to the node table when an ZdoIEEEAddrRsp messages are received", func(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()
defer unpiMock.AssertCalls(t)

Expand Down Expand Up @@ -106,6 +110,7 @@ func Test_NetworkManager(t *testing.T) {
t.Run("a node is added to the node table when an ZdoNWKAddrRsp messages are received", func(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()
defer unpiMock.AssertCalls(t)

Expand Down Expand Up @@ -143,6 +148,7 @@ func Test_NetworkManager(t *testing.T) {

unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()
defer unpiMock.AssertCalls(t)

Expand Down Expand Up @@ -209,6 +215,7 @@ func Test_NetworkManager(t *testing.T) {

unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()
defer unpiMock.AssertCalls(t)

Expand Down Expand Up @@ -257,6 +264,7 @@ func Test_NetworkManager(t *testing.T) {
t.Run("a new router will be queried for network state", func(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()
defer unpiMock.AssertCalls(t)

Expand Down Expand Up @@ -311,6 +319,7 @@ func Test_NetworkManager(t *testing.T) {
t.Run("nodes in lqi query are added to network manager", func(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()
defer unpiMock.AssertCalls(t)

Expand Down Expand Up @@ -373,6 +382,7 @@ func Test_NetworkManager(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
nt := NewNodeTable()
zstack := New(unpiMock, nt)
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()
defer unpiMock.AssertCalls(t)

Expand Down Expand Up @@ -416,6 +426,7 @@ func Test_NetworkManager(t *testing.T) {
t.Run("nodes in lqi query are not added if Ext PANID does not match", func(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()
defer unpiMock.AssertCalls(t)

Expand Down Expand Up @@ -472,6 +483,7 @@ func Test_NetworkManager(t *testing.T) {
t.Run("nodes in lqi query are not added if it has an invalid IEEE address", func(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
zstack.NetworkProperties.IEEEAddress = zigbee.IEEEAddress(1)

defer unpiMock.Stop()
Expand Down Expand Up @@ -533,6 +545,7 @@ func Test_NetworkManager(t *testing.T) {

unpiMock := unpiTest.NewMockAdapter()
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()
defer unpiMock.AssertCalls(t)

Expand Down
11 changes: 11 additions & 0 deletions node_address.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package zstack

import (
"context"
"fmt"
"github.com/shimmeringbee/logwrap"
"github.com/shimmeringbee/zigbee"
)
Expand All @@ -25,6 +26,11 @@ func (z *ZStack) ResolveNodeNWKAddress(ctx context.Context, address zigbee.IEEEA
}

func (z *ZStack) QueryNodeIEEEAddress(ctx context.Context, address zigbee.NetworkAddress) (zigbee.IEEEAddress, error) {
if err := z.sem.Acquire(ctx, 1); err != nil {
return zigbee.EmptyIEEEAddress, fmt.Errorf("failed to acquire semaphore: %w", err)
}
defer z.sem.Release(1)

request := ZdoIEEEAddrReq{
NetworkAddress: address,
ReqType: 0x00,
Expand All @@ -47,6 +53,11 @@ func (z *ZStack) QueryNodeIEEEAddress(ctx context.Context, address zigbee.Networ
}

func (z *ZStack) QueryNodeNWKAddress(ctx context.Context, address zigbee.IEEEAddress) (zigbee.NetworkAddress, error) {
if err := z.sem.Acquire(ctx, 1); err != nil {
return zigbee.NetworkAddress(0x0), fmt.Errorf("failed to acquire semaphore: %w", err)
}
defer z.sem.Release(1)

request := ZdoNWKAddrReq{
IEEEAddress: address,
ReqType: 0x00,
Expand Down
7 changes: 7 additions & 0 deletions node_address_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
unpiTest "github.com/shimmeringbee/unpi/testing"
"github.com/shimmeringbee/zigbee"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/semaphore"
"testing"
"time"
)
Expand All @@ -19,6 +20,7 @@ func Test_ResolveNodeIEEEAddress(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
defer unpiMock.AssertCalls(t)
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

zstack.nodeTable.addOrUpdate(0x1122334455667788, 0xaabb)
Expand All @@ -35,6 +37,7 @@ func Test_ResolveNodeIEEEAddress(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
defer unpiMock.AssertCalls(t)
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

call := unpiMock.On(SREQ, ZDO, ZdoIEEEAddrReqID).Return(Frame{
Expand Down Expand Up @@ -75,6 +78,7 @@ func Test_QueryNodeIEEEAddress(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
defer unpiMock.AssertCalls(t)
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

call := unpiMock.On(SREQ, ZDO, ZdoIEEEAddrReqID).Return(Frame{
Expand Down Expand Up @@ -115,6 +119,7 @@ func Test_ResolveNodeNWKAddress(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
defer unpiMock.AssertCalls(t)
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

zstack.nodeTable.addOrUpdate(0x1122334455667788, 0xaabb)
Expand All @@ -131,6 +136,7 @@ func Test_ResolveNodeNWKAddress(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
defer unpiMock.AssertCalls(t)
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

call := unpiMock.On(SREQ, ZDO, ZdoNWKAddrReqID).Return(Frame{
Expand Down Expand Up @@ -171,6 +177,7 @@ func Test_QueryNodeNWKAddress(t *testing.T) {
unpiMock := unpiTest.NewMockAdapter()
defer unpiMock.AssertCalls(t)
zstack := New(unpiMock, NewNodeTable())
zstack.sem = semaphore.NewWeighted(8)
defer unpiMock.Stop()

call := unpiMock.On(SREQ, ZDO, ZdoNWKAddrReqID).Return(Frame{
Expand Down
Loading

0 comments on commit 338b576

Please sign in to comment.