diff --git a/go.mod b/go.mod index 60c864d5..8c2db5c3 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/stretchr/testify v1.9.0 go.uber.org/mock v0.4.0 go.uber.org/zap v1.27.0 - google.golang.org/grpc v1.66.1 + google.golang.org/grpc v1.66.2 google.golang.org/protobuf v1.34.2 ) diff --git a/go.sum b/go.sum index 80a606f5..1fe265ad 100644 --- a/go.sum +++ b/go.sum @@ -1027,8 +1027,8 @@ google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8= google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= -google.golang.org/grpc v1.66.1 h1:hO5qAXR19+/Z44hmvIM4dQFMSYX9XcWsByfoxutBpAM= -google.golang.org/grpc v1.66.1/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= +google.golang.org/grpc v1.66.2 h1:3QdXkuq3Bkh7w+ywLdLvM56cmGvQHUMZpiCzt6Rqaoo= +google.golang.org/grpc v1.66.2/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/peers/app_request_network.go b/peers/app_request_network.go index f25a8518..df4e20e0 100644 --- a/peers/app_request_network.go +++ b/peers/app_request_network.go @@ -212,12 +212,12 @@ type ConnectedCanonicalValidators struct { ConnectedWeight uint64 TotalValidatorWeight uint64 ValidatorSet []*warp.Validator - nodeValidatorIndexMap map[ids.NodeID]int + NodeValidatorIndexMap map[ids.NodeID]int } // Returns the Warp Validator and its index in the canonical Validator ordering for a given nodeID func (c *ConnectedCanonicalValidators) GetValidator(nodeID ids.NodeID) (*warp.Validator, int) { - return c.ValidatorSet[c.nodeValidatorIndexMap[nodeID]], c.nodeValidatorIndexMap[nodeID] + return c.ValidatorSet[c.NodeValidatorIndexMap[nodeID]], c.NodeValidatorIndexMap[nodeID] } // ConnectToCanonicalValidators connects to the canonical validators of the given subnet and returns the connected @@ -258,7 +258,7 @@ func (n *appRequestNetwork) ConnectToCanonicalValidators(subnetID ids.ID) (*Conn ConnectedWeight: connectedWeight, TotalValidatorWeight: totalValidatorWeight, ValidatorSet: validatorSet, - nodeValidatorIndexMap: nodeValidatorIndexMap, + NodeValidatorIndexMap: nodeValidatorIndexMap, }, nil } diff --git a/signature-aggregator/aggregator/aggregator.go b/signature-aggregator/aggregator/aggregator.go index 5bb13b59..7d2b9b01 100644 --- a/signature-aggregator/aggregator/aggregator.go +++ b/signature-aggregator/aggregator/aggregator.go @@ -5,6 +5,7 @@ package aggregator import ( "bytes" + "encoding/hex" "errors" "fmt" "math/big" @@ -556,6 +557,7 @@ func (s *SignatureAggregator) isValidSignatureResponse( if !bls.Verify(pubKey, sig, unsignedMessage.Bytes()) { s.logger.Debug( "Failed verification for signature", + zap.String("pubKey", hex.EncodeToString(bls.PublicKeyToUncompressedBytes(pubKey))), ) return blsSignatureBuf{}, false } diff --git a/signature-aggregator/aggregator/aggregator_test.go b/signature-aggregator/aggregator/aggregator_test.go index 8876a646..03273145 100644 --- a/signature-aggregator/aggregator/aggregator_test.go +++ b/signature-aggregator/aggregator/aggregator_test.go @@ -1,13 +1,22 @@ package aggregator import ( + "bytes" + "context" + "os" "testing" "time" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/message" + "github.com/ava-labs/avalanchego/proto/pb/sdk" + "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/subnets" + "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/vms/platformvm/warp" "github.com/ava-labs/awm-relayer/peers" "github.com/ava-labs/awm-relayer/peers/mocks" @@ -15,10 +24,15 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "google.golang.org/protobuf/proto" ) -var sigAggMetrics *metrics.SignatureAggregatorMetrics -var messageCreator message.Creator +var ( + sigAggMetrics *metrics.SignatureAggregatorMetrics + messageCreator message.Creator +) func instantiateAggregator(t *testing.T) ( *SignatureAggregator, @@ -36,25 +50,89 @@ func instantiateAggregator(t *testing.T) ( constants.DefaultNetworkCompressionType, constants.DefaultNetworkMaximumInboundTimeout, ) - require.Equal(t, err, nil) + require.NoError(t, err) } aggregator, err := NewSignatureAggregator( mockNetwork, - logging.NoLog{}, + logging.NewLogger( + "aggregator_test", + logging.NewWrappedCore( + logging.Debug, + os.Stdout, + zapcore.NewConsoleEncoder( + zap.NewProductionEncoderConfig(), + ), + ), + ), 1024, sigAggMetrics, messageCreator, // Setting the etnaTime to a minute ago so that the post-etna code path is used in the test time.Now().Add(-1*time.Minute), ) - require.Equal(t, err, nil) + require.NoError(t, err) return aggregator, mockNetwork } +// Generate the validator values. +type validatorInfo struct { + nodeID ids.NodeID + blsSecretKey *bls.SecretKey + blsPublicKey *bls.PublicKey + blsPublicKeyBytes []byte +} + +func (v validatorInfo) Compare(o validatorInfo) int { + return bytes.Compare(v.blsPublicKeyBytes, o.blsPublicKeyBytes) +} + +func makeConnectedValidators(validatorCount int) (*peers.ConnectedCanonicalValidators, []*bls.SecretKey) { + validatorValues := make([]validatorInfo, validatorCount) + for i := 0; i < validatorCount; i++ { + secretKey, err := bls.NewSecretKey() + if err != nil { + panic(err) + } + pubKey := bls.PublicFromSecretKey(secretKey) + nodeID := ids.GenerateTestNodeID() + validatorValues[i] = validatorInfo{ + nodeID: nodeID, + blsSecretKey: secretKey, + blsPublicKey: pubKey, + blsPublicKeyBytes: bls.PublicKeyToUncompressedBytes(pubKey), + } + } + + // Sort the validators by public key to construct the NodeValidatorIndexMap + utils.Sort(validatorValues) + + // Placeholder for results + validatorSet := make([]*warp.Validator, validatorCount) + validatorSecretKeys := make([]*bls.SecretKey, validatorCount) + nodeValidatorIndexMap := make(map[ids.NodeID]int) + for i, validator := range validatorValues { + validatorSecretKeys[i] = validator.blsSecretKey + validatorSet[i] = &warp.Validator{ + PublicKey: validator.blsPublicKey, + PublicKeyBytes: validator.blsPublicKeyBytes, + Weight: 1, + NodeIDs: []ids.NodeID{validator.nodeID}, + } + nodeValidatorIndexMap[validator.nodeID] = i + } + + return &peers.ConnectedCanonicalValidators{ + ConnectedWeight: uint64(validatorCount), + TotalValidatorWeight: uint64(validatorCount), + ValidatorSet: validatorSet, + NodeValidatorIndexMap: nodeValidatorIndexMap, + }, validatorSecretKeys +} + func TestCreateSignedMessageFailsWithNoValidators(t *testing.T) { aggregator, mockNetwork := instantiateAggregator(t) msg, err := warp.NewUnsignedMessage(0, ids.Empty, []byte{}) - require.Equal(t, err, nil) + require.NoError(t, err) mockNetwork.EXPECT().GetSubnetID(ids.Empty).Return(ids.Empty, nil) mockNetwork.EXPECT().ConnectToCanonicalValidators(ids.Empty).Return( &peers.ConnectedCanonicalValidators{ @@ -71,7 +149,7 @@ func TestCreateSignedMessageFailsWithNoValidators(t *testing.T) { func TestCreateSignedMessageFailsWithoutSufficientConnectedStake(t *testing.T) { aggregator, mockNetwork := instantiateAggregator(t) msg, err := warp.NewUnsignedMessage(0, ids.Empty, []byte{}) - require.Equal(t, err, nil) + require.NoError(t, err) mockNetwork.EXPECT().GetSubnetID(ids.Empty).Return(ids.Empty, nil) mockNetwork.EXPECT().ConnectToCanonicalValidators(ids.Empty).Return( &peers.ConnectedCanonicalValidators{ @@ -88,3 +166,234 @@ func TestCreateSignedMessageFailsWithoutSufficientConnectedStake(t *testing.T) { "failed to connect to a threshold of stake", ) } + +func makeAppRequests( + chainID ids.ID, + requestID uint32, + connectedValidators *peers.ConnectedCanonicalValidators, +) []ids.RequestID { + var appRequests []ids.RequestID + for _, validator := range connectedValidators.ValidatorSet { + for _, nodeID := range validator.NodeIDs { + appRequests = append( + appRequests, + ids.RequestID{ + NodeID: nodeID, + ChainID: chainID, + RequestID: requestID, + Op: byte( + message.AppResponseOp, + ), + }, + ) + } + } + return appRequests +} + +func TestCreateSignedMessageRetriesAndFailsWithoutP2PResponses(t *testing.T) { + aggregator, mockNetwork := instantiateAggregator(t) + + var ( + connectedValidators, _ = makeConnectedValidators(2) + requestID = aggregator.currentRequestID.Load() + 1 + ) + + chainID := ids.GenerateTestID() + + msg, err := warp.NewUnsignedMessage(0, chainID, []byte{}) + require.NoError(t, err) + + subnetID := ids.GenerateTestID() + mockNetwork.EXPECT().GetSubnetID(chainID).Return( + subnetID, + nil, + ) + + mockNetwork.EXPECT().ConnectToCanonicalValidators(subnetID).Return( + connectedValidators, + nil, + ) + + appRequests := makeAppRequests(chainID, requestID, connectedValidators) + for _, appRequest := range appRequests { + mockNetwork.EXPECT().RegisterAppRequest(appRequest).Times( + maxRelayerQueryAttempts, + ) + } + + mockNetwork.EXPECT().RegisterRequestID( + requestID, + len(appRequests), + ).Return( + make(chan message.InboundMessage, len(appRequests)), + ).Times(maxRelayerQueryAttempts) + + var nodeIDs set.Set[ids.NodeID] + for _, appRequest := range appRequests { + nodeIDs.Add(appRequest.NodeID) + } + mockNetwork.EXPECT().Send( + gomock.Any(), + nodeIDs, + subnetID, + subnets.NoOpAllower, + ).Times(maxRelayerQueryAttempts) + + _, err = aggregator.CreateSignedMessage(msg, nil, subnetID, 80) + require.ErrorContains( + t, + err, + "failed to collect a threshold of signatures", + ) +} + +func TestCreateSignedMessageSucceeds(t *testing.T) { + var msg *warp.UnsignedMessage // to be signed + chainID := ids.GenerateTestID() + networkID := constants.UnitTestID + msg, err := warp.NewUnsignedMessage( + networkID, + chainID, + utils.RandomBytes(1234), + ) + require.NoError(t, err) + + // the signers: + var connectedValidators, validatorSecretKeys = makeConnectedValidators(5) + + // prime the aggregator: + + aggregator, mockNetwork := instantiateAggregator(t) + + subnetID := ids.GenerateTestID() + mockNetwork.EXPECT().GetSubnetID(chainID).Return( + subnetID, + nil, + ) + + mockNetwork.EXPECT().ConnectToCanonicalValidators(subnetID).Return( + connectedValidators, + nil, + ) + + // prime the signers' responses: + + var requestID = aggregator.currentRequestID.Load() + 1 + + appRequests := makeAppRequests(chainID, requestID, connectedValidators) + for _, appRequest := range appRequests { + mockNetwork.EXPECT().RegisterAppRequest(appRequest).Times(1) + } + + var nodeIDs set.Set[ids.NodeID] + responseChan := make(chan message.InboundMessage, len(appRequests)) + for _, appRequest := range appRequests { + nodeIDs.Add(appRequest.NodeID) + validatorSecretKey := validatorSecretKeys[connectedValidators.NodeValidatorIndexMap[appRequest.NodeID]] + responseBytes, err := proto.Marshal( + &sdk.SignatureResponse{ + Signature: bls.SignatureToBytes( + bls.Sign( + validatorSecretKey, + msg.Bytes(), + ), + ), + }, + ) + require.NoError(t, err) + responseChan <- message.InboundAppResponse( + chainID, + requestID, + responseBytes, + appRequest.NodeID, + ) + } + mockNetwork.EXPECT().RegisterRequestID( + requestID, + len(appRequests), + ).Return(responseChan).Times(1) + + mockNetwork.EXPECT().Send( + gomock.Any(), + nodeIDs, + subnetID, + subnets.NoOpAllower, + ).Times(1).Return(nodeIDs) + + // aggregate the signatures: + var quorumPercentage uint64 = 80 + signedMessage, err := aggregator.CreateSignedMessage( + msg, + nil, + subnetID, + quorumPercentage, + ) + require.NoError(t, err) + + // verify the aggregated signature: + pChainState := newPChainStateStub( + chainID, + subnetID, + 1, + connectedValidators, + ) + verifyErr := signedMessage.Signature.Verify( + context.Background(), + msg, + networkID, + pChainState, + pChainState.currentHeight, + quorumPercentage, + 100, + ) + require.NoError(t, verifyErr) +} + +type pChainStateStub struct { + subnetIDByChainID map[ids.ID]ids.ID + connectedCanonicalValidators *peers.ConnectedCanonicalValidators + currentHeight uint64 +} + +func newPChainStateStub( + chainID, subnetID ids.ID, + currentHeight uint64, + connectedValidators *peers.ConnectedCanonicalValidators, +) *pChainStateStub { + subnetIDByChainID := make(map[ids.ID]ids.ID) + subnetIDByChainID[chainID] = subnetID + return &pChainStateStub{ + subnetIDByChainID: subnetIDByChainID, + connectedCanonicalValidators: connectedValidators, + currentHeight: currentHeight, + } +} + +func (p pChainStateStub) GetSubnetID(ctx context.Context, chainID ids.ID) (ids.ID, error) { + return p.subnetIDByChainID[chainID], nil +} + +func (p pChainStateStub) GetMinimumHeight(context.Context) (uint64, error) { return 0, nil } + +func (p pChainStateStub) GetCurrentHeight(context.Context) (uint64, error) { + return p.currentHeight, nil +} + +func (p pChainStateStub) GetValidatorSet( + ctx context.Context, + height uint64, + subnetID ids.ID, +) (map[ids.NodeID]*validators.GetValidatorOutput, error) { + output := make(map[ids.NodeID]*validators.GetValidatorOutput) + for _, validator := range p.connectedCanonicalValidators.ValidatorSet { + for _, nodeID := range validator.NodeIDs { + output[nodeID] = &validators.GetValidatorOutput{ + NodeID: nodeID, + PublicKey: validator.PublicKey, + Weight: validator.Weight, + } + } + } + return output, nil +} diff --git a/signature-aggregator/aggregator/cache/cache.go b/signature-aggregator/aggregator/cache/cache.go index 475a6fdc..7363e07e 100644 --- a/signature-aggregator/aggregator/cache/cache.go +++ b/signature-aggregator/aggregator/cache/cache.go @@ -41,7 +41,11 @@ func (c *Cache) Get(msgID ids.ID) (map[PublicKeyBytes]SignatureBytes, bool) { cachedValue, isCached := c.signatures.Get(msgID) if isCached { - c.logger.Debug("cache hit", zap.Stringer("msgID", msgID)) + c.logger.Debug( + "cache hit", + zap.Stringer("msgID", msgID), + zap.Int("signatureCount", len(cachedValue)), + ) return cachedValue, true } else { c.logger.Debug("cache miss", zap.Stringer("msgID", msgID))