diff --git a/adapter_endpoints.go b/adapter_endpoints.go index 6576bb7..68f0069 100644 --- a/adapter_endpoints.go +++ b/adapter_endpoints.go @@ -2,10 +2,16 @@ package zstack import ( "context" + "fmt" "github.com/shimmeringbee/zigbee" ) func (z *ZStack) RegisterAdapterEndpoint(ctx context.Context, endpoint zigbee.Endpoint, appProfileId zigbee.ProfileID, appDeviceId uint16, appDeviceVersion uint8, inClusters []zigbee.ClusterID, outClusters []zigbee.ClusterID) error { + if err := z.sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire semaphore: %w", err) + } + defer z.sem.Release(1) + request := AFRegister{ Endpoint: endpoint, AppProfileId: appProfileId, diff --git a/adapter_endpoints_test.go b/adapter_endpoints_test.go index 7d84f20..e2279c9 100644 --- a/adapter_endpoints_test.go +++ b/adapter_endpoints_test.go @@ -7,6 +7,7 @@ import ( unpiTest "github.com/shimmeringbee/unpi/testing" "github.com/shimmeringbee/zigbee" "github.com/stretchr/testify/assert" + "golang.org/x/sync/semaphore" "testing" "time" ) @@ -18,6 +19,7 @@ func Test_RegisterAdapterEndpoint(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() c := unpiMock.On(SREQ, AF, AFRegisterID).Return(Frame{ @@ -43,6 +45,7 @@ func Test_RegisterAdapterEndpoint(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() unpiMock.On(SREQ, AF, AFRegisterID).Return(Frame{ diff --git a/adapter_info.go b/adapter_info.go index 3463f57..b3069fa 100644 --- a/adapter_info.go +++ b/adapter_info.go @@ -2,6 +2,7 @@ package zstack import ( "context" + "fmt" "github.com/shimmeringbee/zigbee" ) @@ -30,6 +31,11 @@ func (z *ZStack) GetAdapterNetworkAddress(ctx context.Context) (zigbee.NetworkAd func (z *ZStack) getAddressInfo(ctx context.Context) (UtilGetDeviceInfoRequestReply, error) { resp := UtilGetDeviceInfoRequestReply{} + if err := z.sem.Acquire(ctx, 1); err != nil { + return resp, fmt.Errorf("failed to acquire semaphore: %w", err) + } + defer z.sem.Release(1) + err := z.requestResponder.RequestResponse(ctx, UtilGetDeviceInfoRequest{}, &resp) return resp, err } diff --git a/adapter_info_test.go b/adapter_info_test.go index e945068..d2d2bed 100644 --- a/adapter_info_test.go +++ b/adapter_info_test.go @@ -7,6 +7,7 @@ import ( unpiTest "github.com/shimmeringbee/unpi/testing" "github.com/shimmeringbee/zigbee" "github.com/stretchr/testify/assert" + "golang.org/x/sync/semaphore" "testing" "time" ) @@ -18,6 +19,7 @@ func Test_GetAdapterIEEEAddress(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() unpiMock.On(SREQ, UTIL, UtilGetDeviceInfoRequestID).Return(Frame{ @@ -42,6 +44,7 @@ func Test_GetAdapterNetworkAddress(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() unpiMock.On(SREQ, UTIL, UtilGetDeviceInfoRequestID).Return(Frame{ diff --git a/adapter_initialise.go b/adapter_initialise.go index 494316f..78f352a 100644 --- a/adapter_initialise.go +++ b/adapter_initialise.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/shimmeringbee/retry" "github.com/shimmeringbee/zigbee" + "golang.org/x/sync/semaphore" "reflect" ) @@ -24,6 +25,12 @@ func (z *ZStack) Initialise(pctx context.Context, nc zigbee.NetworkConfiguration return err } + if version.IsV3() { + z.sem = semaphore.NewWeighted(16) + } else { + z.sem = semaphore.NewWeighted(2) + } + z.logger.LogInfo(ctx, "Verifying existing network configuration.") if valid, err := z.verifyAdapterNetworkConfig(ctx, version); err != nil { return err diff --git a/go.mod b/go.mod index f939e50..5c39e1d 100644 --- a/go.mod +++ b/go.mod @@ -9,5 +9,6 @@ require ( github.com/shimmeringbee/unpi v0.0.0-20210525151328-7ede275a1033 github.com/shimmeringbee/zigbee v0.0.0-20210427191220-76676a734066 github.com/stretchr/testify v1.7.0 + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index eead7a6..f12d9ce 100644 --- a/go.sum +++ b/go.sum @@ -48,6 +48,8 @@ github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/joining.go b/joining.go index ef28a09..4b7703c 100644 --- a/joining.go +++ b/joining.go @@ -19,6 +19,11 @@ func (z *ZStack) DenyJoin(ctx context.Context) error { } func (z *ZStack) sendJoin(ctx context.Context, address zigbee.NetworkAddress, timeout uint8, newState JoinState) error { + if err := z.sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire semaphore: %w", err) + } + defer z.sem.Release(1) + response := ZDOMgmtPermitJoinRequestReply{} if err := z.requestResponder.RequestResponse(ctx, ZDOMgmtPermitJoinRequest{ diff --git a/joining_test.go b/joining_test.go index 49afb63..d482f84 100644 --- a/joining_test.go +++ b/joining_test.go @@ -7,6 +7,7 @@ import ( unpiTest "github.com/shimmeringbee/unpi/testing" "github.com/shimmeringbee/zigbee" "github.com/stretchr/testify/assert" + "golang.org/x/sync/semaphore" "testing" "time" ) @@ -18,6 +19,7 @@ func Test_PermitJoin(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() c := unpiMock.On(SREQ, ZDO, ZDOMgmtPermitJoinRequestID).Return(Frame{ @@ -42,6 +44,7 @@ func Test_PermitJoin(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() zstack.NetworkProperties.NetworkAddress = zigbee.NetworkAddress(0x0102) @@ -67,6 +70,7 @@ func Test_PermitJoin(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() unpiMock.On(SREQ, ZDO, ZDOMgmtPermitJoinRequestID).Return(Frame{ @@ -90,6 +94,7 @@ func Test_DenyJoin(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() c := unpiMock.On(SREQ, ZDO, ZDOMgmtPermitJoinRequestID).Return(Frame{ @@ -115,6 +120,7 @@ func Test_DenyJoin(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() unpiMock.On(SREQ, ZDO, ZDOMgmtPermitJoinRequestID).Return(Frame{ diff --git a/network_manager.go b/network_manager.go index ff007cf..9d926b4 100644 --- a/network_manager.go +++ b/network_manager.go @@ -123,6 +123,12 @@ func (z *ZStack) requestLQITable(node zigbee.Node, startIndex uint8) { ctx, cancel := context.WithTimeout(context.Background(), DefaultZStackTimeout) defer cancel() + if err := z.sem.Acquire(ctx, 1); err != nil { + z.logger.LogError(ctx, "Failed to request LQI table, failed to acquire semaphore ", logwrap.Datum("IEEEAddress", node.IEEEAddress.String()), logwrap.Datum("NetworkAddress", node.NetworkAddress), logwrap.Err(err)) + return + } + defer z.sem.Release(1) + resp := ZdoMGMTLQIReqReply{} z.logger.LogDebug(ctx, "Requesting LQI table from device.", logwrap.Datum("IEEEAddress", node.IEEEAddress.String()), logwrap.Datum("NetworkAddress", node.NetworkAddress), logwrap.Datum("StartIndex", startIndex)) if err := z.requestResponder.RequestResponse(ctx, ZdoMGMTLQIReq{DestinationAddress: node.NetworkAddress, StartIndex: startIndex}, &resp); err != nil { diff --git a/network_manager_test.go b/network_manager_test.go index 3c7eacd..0e1ea61 100644 --- a/network_manager_test.go +++ b/network_manager_test.go @@ -7,6 +7,7 @@ import ( unpiTest "github.com/shimmeringbee/unpi/testing" "github.com/shimmeringbee/zigbee" "github.com/stretchr/testify/assert" + "golang.org/x/sync/semaphore" "testing" "time" ) @@ -15,6 +16,7 @@ func Test_NetworkManager(t *testing.T) { t.Run("issues a lqi poll request only for coordinators or routers", func(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() defer zstack.Stop() @@ -39,6 +41,7 @@ func Test_NetworkManager(t *testing.T) { t.Run("the coordinator is added to the node list as a coordinator", func(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() defer zstack.Stop() @@ -72,6 +75,7 @@ func Test_NetworkManager(t *testing.T) { t.Run("a node is added to the node table when an ZdoIEEEAddrRsp messages are received", func(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() defer unpiMock.AssertCalls(t) @@ -106,6 +110,7 @@ func Test_NetworkManager(t *testing.T) { t.Run("a node is added to the node table when an ZdoNWKAddrRsp messages are received", func(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() defer unpiMock.AssertCalls(t) @@ -143,6 +148,7 @@ func Test_NetworkManager(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() defer unpiMock.AssertCalls(t) @@ -209,6 +215,7 @@ func Test_NetworkManager(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() defer unpiMock.AssertCalls(t) @@ -257,6 +264,7 @@ func Test_NetworkManager(t *testing.T) { t.Run("a new router will be queried for network state", func(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() defer unpiMock.AssertCalls(t) @@ -311,6 +319,7 @@ func Test_NetworkManager(t *testing.T) { t.Run("nodes in lqi query are added to network manager", func(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() defer unpiMock.AssertCalls(t) @@ -373,6 +382,7 @@ func Test_NetworkManager(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() nt := NewNodeTable() zstack := New(unpiMock, nt) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() defer unpiMock.AssertCalls(t) @@ -416,6 +426,7 @@ func Test_NetworkManager(t *testing.T) { t.Run("nodes in lqi query are not added if Ext PANID does not match", func(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() defer unpiMock.AssertCalls(t) @@ -472,6 +483,7 @@ func Test_NetworkManager(t *testing.T) { t.Run("nodes in lqi query are not added if it has an invalid IEEE address", func(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) zstack.NetworkProperties.IEEEAddress = zigbee.IEEEAddress(1) defer unpiMock.Stop() @@ -533,6 +545,7 @@ func Test_NetworkManager(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() defer unpiMock.AssertCalls(t) diff --git a/node_address.go b/node_address.go index b5b3837..4bf6e3b 100644 --- a/node_address.go +++ b/node_address.go @@ -2,6 +2,7 @@ package zstack import ( "context" + "fmt" "github.com/shimmeringbee/logwrap" "github.com/shimmeringbee/zigbee" ) @@ -25,6 +26,11 @@ func (z *ZStack) ResolveNodeNWKAddress(ctx context.Context, address zigbee.IEEEA } func (z *ZStack) QueryNodeIEEEAddress(ctx context.Context, address zigbee.NetworkAddress) (zigbee.IEEEAddress, error) { + if err := z.sem.Acquire(ctx, 1); err != nil { + return zigbee.EmptyIEEEAddress, fmt.Errorf("failed to acquire semaphore: %w", err) + } + defer z.sem.Release(1) + request := ZdoIEEEAddrReq{ NetworkAddress: address, ReqType: 0x00, @@ -47,6 +53,11 @@ func (z *ZStack) QueryNodeIEEEAddress(ctx context.Context, address zigbee.Networ } func (z *ZStack) QueryNodeNWKAddress(ctx context.Context, address zigbee.IEEEAddress) (zigbee.NetworkAddress, error) { + if err := z.sem.Acquire(ctx, 1); err != nil { + return zigbee.NetworkAddress(0x0), fmt.Errorf("failed to acquire semaphore: %w", err) + } + defer z.sem.Release(1) + request := ZdoNWKAddrReq{ IEEEAddress: address, ReqType: 0x00, diff --git a/node_address_test.go b/node_address_test.go index 44a423f..5899298 100644 --- a/node_address_test.go +++ b/node_address_test.go @@ -7,6 +7,7 @@ import ( unpiTest "github.com/shimmeringbee/unpi/testing" "github.com/shimmeringbee/zigbee" "github.com/stretchr/testify/assert" + "golang.org/x/sync/semaphore" "testing" "time" ) @@ -19,6 +20,7 @@ func Test_ResolveNodeIEEEAddress(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() defer unpiMock.AssertCalls(t) zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() zstack.nodeTable.addOrUpdate(0x1122334455667788, 0xaabb) @@ -35,6 +37,7 @@ func Test_ResolveNodeIEEEAddress(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() defer unpiMock.AssertCalls(t) zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() call := unpiMock.On(SREQ, ZDO, ZdoIEEEAddrReqID).Return(Frame{ @@ -75,6 +78,7 @@ func Test_QueryNodeIEEEAddress(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() defer unpiMock.AssertCalls(t) zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() call := unpiMock.On(SREQ, ZDO, ZdoIEEEAddrReqID).Return(Frame{ @@ -115,6 +119,7 @@ func Test_ResolveNodeNWKAddress(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() defer unpiMock.AssertCalls(t) zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() zstack.nodeTable.addOrUpdate(0x1122334455667788, 0xaabb) @@ -131,6 +136,7 @@ func Test_ResolveNodeNWKAddress(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() defer unpiMock.AssertCalls(t) zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() call := unpiMock.On(SREQ, ZDO, ZdoNWKAddrReqID).Return(Frame{ @@ -171,6 +177,7 @@ func Test_QueryNodeNWKAddress(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() defer unpiMock.AssertCalls(t) zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() call := unpiMock.On(SREQ, ZDO, ZdoNWKAddrReqID).Return(Frame{ diff --git a/node_bind.go b/node_bind.go index 676df73..8274e8e 100644 --- a/node_bind.go +++ b/node_bind.go @@ -2,16 +2,21 @@ package zstack import ( "context" + "fmt" "github.com/shimmeringbee/zigbee" ) func (z *ZStack) BindNodeToController(ctx context.Context, nodeAddress zigbee.IEEEAddress, sourceEndpoint zigbee.Endpoint, destinationEndpoint zigbee.Endpoint, cluster zigbee.ClusterID) error { networkAddress, err := z.ResolveNodeNWKAddress(ctx, nodeAddress) - if err != nil { return nil } + if err := z.sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire semaphore: %w", err) + } + defer z.sem.Release(1) + request := ZdoBindReq{ TargetAddress: networkAddress, SourceAddress: nodeAddress, diff --git a/node_bind_test.go b/node_bind_test.go index 868fadf..87d4e58 100644 --- a/node_bind_test.go +++ b/node_bind_test.go @@ -7,6 +7,7 @@ import ( unpiTest "github.com/shimmeringbee/unpi/testing" "github.com/shimmeringbee/zigbee" "github.com/stretchr/testify/assert" + "golang.org/x/sync/semaphore" "testing" "time" ) @@ -18,6 +19,7 @@ func Test_BindToNode(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() call := unpiMock.On(SREQ, ZDO, ZdoBindReqReplyID).Return(Frame{ diff --git a/node_description.go b/node_description.go index 19ab6e0..90fe9b2 100644 --- a/node_description.go +++ b/node_description.go @@ -2,16 +2,21 @@ package zstack import ( "context" + "fmt" "github.com/shimmeringbee/zigbee" ) func (z *ZStack) QueryNodeDescription(ctx context.Context, ieeeAddress zigbee.IEEEAddress) (zigbee.NodeDescription, error) { nwkAddress, err := z.ResolveNodeNWKAddress(ctx, ieeeAddress) - if err != nil { return zigbee.NodeDescription{}, err } + if err := z.sem.Acquire(ctx, 1); err != nil { + return zigbee.NodeDescription{}, fmt.Errorf("failed to acquire semaphore: %w", err) + } + defer z.sem.Release(1) + request := ZdoNodeDescReq{ DestinationAddress: nwkAddress, OfInterestAddress: nwkAddress, diff --git a/node_description_test.go b/node_description_test.go index a864c67..24f64ee 100644 --- a/node_description_test.go +++ b/node_description_test.go @@ -7,6 +7,7 @@ import ( unpiTest "github.com/shimmeringbee/unpi/testing" "github.com/shimmeringbee/zigbee" "github.com/stretchr/testify/assert" + "golang.org/x/sync/semaphore" "testing" "time" ) @@ -18,6 +19,7 @@ func Test_QueryNodeDescription(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() unpiMock.On(SREQ, ZDO, ZdoNodeDescReqID).Return(Frame{ diff --git a/node_endpoint_description.go b/node_endpoint_description.go index 1f0eff9..fe412f2 100644 --- a/node_endpoint_description.go +++ b/node_endpoint_description.go @@ -2,16 +2,21 @@ package zstack import ( "context" + "fmt" "github.com/shimmeringbee/zigbee" ) func (z *ZStack) QueryNodeEndpointDescription(ctx context.Context, ieeeAddress zigbee.IEEEAddress, endpoint zigbee.Endpoint) (zigbee.EndpointDescription, error) { networkAddress, err := z.ResolveNodeNWKAddress(ctx, ieeeAddress) - if err != nil { return zigbee.EndpointDescription{}, err } + if err := z.sem.Acquire(ctx, 1); err != nil { + return zigbee.EndpointDescription{}, fmt.Errorf("failed to acquire semaphore: %w", err) + } + defer z.sem.Release(1) + request := ZdoSimpleDescReq{ DestinationAddress: networkAddress, OfInterestAddress: networkAddress, diff --git a/node_endpoint_description_test.go b/node_endpoint_description_test.go index 6defc6d..9f38c3f 100644 --- a/node_endpoint_description_test.go +++ b/node_endpoint_description_test.go @@ -7,6 +7,7 @@ import ( unpiTest "github.com/shimmeringbee/unpi/testing" "github.com/shimmeringbee/zigbee" "github.com/stretchr/testify/assert" + "golang.org/x/sync/semaphore" "testing" "time" ) @@ -18,6 +19,7 @@ func Test_QueryNodeEndpointDescription(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() unpiMock.On(SREQ, ZDO, ZdoSimpleDescReqID).Return(Frame{ diff --git a/node_endpoints.go b/node_endpoints.go index c6a0c64..ce94d02 100644 --- a/node_endpoints.go +++ b/node_endpoints.go @@ -2,16 +2,21 @@ package zstack import ( "context" + "fmt" "github.com/shimmeringbee/zigbee" ) func (z *ZStack) QueryNodeEndpoints(ctx context.Context, ieeeAddress zigbee.IEEEAddress) ([]zigbee.Endpoint, error) { networkAddress, err := z.ResolveNodeNWKAddress(ctx, ieeeAddress) - if err != nil { return []zigbee.Endpoint{}, err } + if err := z.sem.Acquire(ctx, 1); err != nil { + return nil, fmt.Errorf("failed to acquire semaphore: %w", err) + } + defer z.sem.Release(1) + request := ZdoActiveEpReq{ DestinationAddress: networkAddress, OfInterestAddress: networkAddress, diff --git a/node_endpoints_test.go b/node_endpoints_test.go index 96ddc42..6919827 100644 --- a/node_endpoints_test.go +++ b/node_endpoints_test.go @@ -7,6 +7,7 @@ import ( unpiTest "github.com/shimmeringbee/unpi/testing" "github.com/shimmeringbee/zigbee" "github.com/stretchr/testify/assert" + "golang.org/x/sync/semaphore" "testing" "time" ) @@ -18,6 +19,7 @@ func Test_QueryNodeEndpoints(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() unpiMock.On(SREQ, ZDO, ZdoActiveEpReqID).Return(Frame{ diff --git a/node_remove.go b/node_remove.go index 3c5d7b3..d04743f 100644 --- a/node_remove.go +++ b/node_remove.go @@ -8,11 +8,15 @@ import ( func (z *ZStack) RequestNodeLeave(ctx context.Context, nodeAddress zigbee.IEEEAddress) error { networkAddress, err := z.ResolveNodeNWKAddress(ctx, nodeAddress) - if err != nil { return nil } + if err := z.sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire semaphore: %w", err) + } + defer z.sem.Release(1) + request := ZdoMgmtLeaveReq{ NetworkAddress: networkAddress, IEEEAddress: nodeAddress, diff --git a/node_remove_test.go b/node_remove_test.go index a6cf48c..dbe2c0d 100644 --- a/node_remove_test.go +++ b/node_remove_test.go @@ -7,6 +7,7 @@ import ( unpiTest "github.com/shimmeringbee/unpi/testing" "github.com/shimmeringbee/zigbee" "github.com/stretchr/testify/assert" + "golang.org/x/sync/semaphore" "testing" "time" ) @@ -18,6 +19,7 @@ func TestZStack_RequestNodeLeave(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() call := unpiMock.On(SREQ, ZDO, ZdoMgmtLeaveReqReplyID).Return(Frame{ diff --git a/node_send_message.go b/node_send_message.go index bec56ad..07ef271 100644 --- a/node_send_message.go +++ b/node_send_message.go @@ -3,6 +3,7 @@ package zstack import ( "context" "errors" + "fmt" "github.com/shimmeringbee/logwrap" "github.com/shimmeringbee/zigbee" ) @@ -11,12 +12,16 @@ const DefaultRadius uint8 = 0x20 func (z *ZStack) SendApplicationMessageToNode(ctx context.Context, destinationAddress zigbee.IEEEAddress, message zigbee.ApplicationMessage, requireAck bool) error { network, err := z.ResolveNodeNWKAddress(ctx, destinationAddress) - if err != nil { z.logger.LogError(ctx, "Failed to send AfDataRequest (application message), failed to resolve IEEE Address to Network Adddress.", logwrap.Err(err), logwrap.Datum("IEEEAddress", destinationAddress.String())) return err } + if err := z.sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire semaphore: %w", err) + } + defer z.sem.Release(1) + var transactionId uint8 select { diff --git a/node_send_message_test.go b/node_send_message_test.go index c3fd11f..a6e9851 100644 --- a/node_send_message_test.go +++ b/node_send_message_test.go @@ -7,6 +7,7 @@ import ( unpiTest "github.com/shimmeringbee/unpi/testing" "github.com/shimmeringbee/zigbee" "github.com/stretchr/testify/assert" + "golang.org/x/sync/semaphore" "testing" "time" ) @@ -19,6 +20,7 @@ func Test_SendNodeMessage(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() defer unpiMock.AssertCalls(t) zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() zstack.nodeTable.addOrUpdate(zigbee.IEEEAddress(0x1122334455667788), zigbee.NetworkAddress(0x1000)) @@ -63,6 +65,7 @@ func Test_SendNodeMessage(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() defer unpiMock.AssertCalls(t) zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() zstack.nodeTable.addOrUpdate(zigbee.IEEEAddress(0x1122334455667788), zigbee.NetworkAddress(0x1000)) diff --git a/node_unbind.go b/node_unbind.go index 4648b61..d157d50 100644 --- a/node_unbind.go +++ b/node_unbind.go @@ -2,16 +2,21 @@ package zstack import ( "context" + "fmt" "github.com/shimmeringbee/zigbee" ) func (z *ZStack) UnbindNodeFromController(ctx context.Context, nodeAddress zigbee.IEEEAddress, sourceEndpoint zigbee.Endpoint, destinationEndpoint zigbee.Endpoint, cluster zigbee.ClusterID) error { networkAddress, err := z.ResolveNodeNWKAddress(ctx, nodeAddress) - if err != nil { return nil } + if err := z.sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire semaphore: %w", err) + } + defer z.sem.Release(1) + request := ZdoUnbindReq{ TargetAddress: networkAddress, SourceAddress: nodeAddress, diff --git a/node_unbind_test.go b/node_unbind_test.go index 0f46882..ae9e2fe 100644 --- a/node_unbind_test.go +++ b/node_unbind_test.go @@ -7,6 +7,7 @@ import ( unpiTest "github.com/shimmeringbee/unpi/testing" "github.com/shimmeringbee/zigbee" "github.com/stretchr/testify/assert" + "golang.org/x/sync/semaphore" "testing" "time" ) @@ -18,6 +19,7 @@ func Test_UnbindToNode(t *testing.T) { unpiMock := unpiTest.NewMockAdapter() zstack := New(unpiMock, NewNodeTable()) + zstack.sem = semaphore.NewWeighted(8) defer unpiMock.Stop() call := unpiMock.On(SREQ, ZDO, ZdoUnbindReqReplyID).Return(Frame{ diff --git a/zstack.go b/zstack.go index 91c7ead..c996a7a 100644 --- a/zstack.go +++ b/zstack.go @@ -7,6 +7,7 @@ import ( "github.com/shimmeringbee/unpi/broker" "github.com/shimmeringbee/unpi/library" "github.com/shimmeringbee/zigbee" + "golang.org/x/sync/semaphore" "io" "log" "os" @@ -42,6 +43,8 @@ type ZStack struct { nodeTable *NodeTable transactionIdStore chan uint8 + sem *semaphore.Weighted + logger logwrap.Logger }