diff --git a/client/client.go b/client/client.go index fe26b194..d1dc8d09 100644 --- a/client/client.go +++ b/client/client.go @@ -18,6 +18,7 @@ type Factory interface { // Client is the interface that defines the implementation of all the endpoints type Client interface { GetOffChainData(ctx context.Context, hash common.Hash) ([]byte, error) + ListOffChainData(ctx context.Context, hashes []common.Hash) (map[common.Hash][]byte, error) SignSequence(signedSequence types.SignedSequence) ([]byte, error) } @@ -84,3 +85,27 @@ func (c *client) GetOffChainData(ctx context.Context, hash common.Hash) ([]byte, return result, nil } + +// ListOffChainData returns data based on the given hashes +func (c *client) ListOffChainData(ctx context.Context, hashes []common.Hash) (map[common.Hash][]byte, error) { + response, err := rpc.JSONRPCCallWithContext(ctx, c.url, "sync_listOffChainData", hashes) + if err != nil { + return nil, err + } + + if response.Error != nil { + return nil, fmt.Errorf("%v %v", response.Error.Code, response.Error.Message) + } + + result := make(map[common.Hash]types.ArgBytes) + if err = json.Unmarshal(response.Result, &result); err != nil { + return nil, err + } + + preparedResult := make(map[common.Hash][]byte) + for key, val := range result { + preparedResult[key] = val + } + + return preparedResult, nil +} diff --git a/client/client_test.go b/client/client_test.go index 2eb73530..4198493e 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -169,3 +169,81 @@ func TestClient_GetOffChainData(t *testing.T) { }) } } + +func TestClient_ListOffChainData(t *testing.T) { + tests := []struct { + name string + hashes []common.Hash + result string + data map[common.Hash][]byte + statusCode int + err error + }{ + { + name: "successfully got offhcain data", + hashes: []common.Hash{common.BytesToHash([]byte("hash"))}, + result: fmt.Sprintf(`{"result":{"%s":"%s"}}`, + common.BytesToHash([]byte("hash")).Hex(), hex.EncodeToString([]byte("offchaindata"))), + data: map[common.Hash][]byte{ + common.BytesToHash([]byte("hash")): []byte("offchaindata"), + }, + }, + { + name: "error returned by server", + hashes: []common.Hash{common.BytesToHash([]byte("hash"))}, + result: `{"error":{"code":123,"message":"test error"}}`, + err: errors.New("123 test error"), + }, + { + name: "invalid offchain data returned by server", + hashes: []common.Hash{common.BytesToHash([]byte("hash"))}, + result: fmt.Sprintf(`{"result":{"%s":"invalid-signature"}}`, + common.BytesToHash([]byte("hash")).Hex()), + data: map[common.Hash][]byte{ + common.BytesToHash([]byte("hash")): nil, + }, + }, + { + name: "unsuccessful status code returned by server", + hashes: []common.Hash{common.BytesToHash([]byte("hash"))}, + statusCode: http.StatusUnauthorized, + err: errors.New("invalid status code, expected: 200, found: 401"), + }, + } + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res rpc.Request + require.NoError(t, json.NewDecoder(r.Body).Decode(&res)) + require.Equal(t, "sync_listOffChainData", res.Method) + + var params [][]common.Hash + require.NoError(t, json.Unmarshal(res.Params, ¶ms)) + require.Equal(t, tt.hashes, params[0]) + + if tt.statusCode > 0 { + w.WriteHeader(tt.statusCode) + } + + _, err := fmt.Fprint(w, tt.result) + require.NoError(t, err) + })) + defer svr.Close() + + c := &client{url: svr.URL} + + got, err := c.ListOffChainData(context.Background(), tt.hashes) + if tt.err != nil { + require.Error(t, err) + require.EqualError(t, tt.err, err.Error()) + } else { + require.NoError(t, err) + require.Equal(t, tt.data, got) + } + }) + } +} diff --git a/cmd/main.go b/cmd/main.go index db51f51c..077c14a2 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -142,11 +142,11 @@ func start(cliCtx *cli.Context) error { []rpc.Service{ { Name: sync.APISYNC, - Service: sync.NewSyncEndpoints(storage), + Service: sync.NewEndpoints(storage), }, { Name: datacom.APIDATACOM, - Service: datacom.NewDataComEndpoints( + Service: datacom.NewEndpoints( storage, pk, sequencerTracker, diff --git a/db/db.go b/db/db.go index cf392a9b..71263164 100644 --- a/db/db.go +++ b/db/db.go @@ -29,6 +29,7 @@ type DB interface { Exists(ctx context.Context, key common.Hash) bool GetOffChainData(ctx context.Context, key common.Hash, dbTx sqlx.QueryerContext) (types.ArgBytes, error) + ListOffChainData(ctx context.Context, keys []common.Hash, dbTx sqlx.QueryerContext) (map[common.Hash]types.ArgBytes, error) StoreOffChainData(ctx context.Context, od []types.OffChainData, dbTx sqlx.ExecerContext) error } @@ -218,6 +219,54 @@ func (db *pgDB) GetOffChainData(ctx context.Context, key common.Hash, dbTx sqlx. return common.FromHex(hexValue), nil } +// ListOffChainData returns values identified by the given keys +func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash, dbTx sqlx.QueryerContext) (map[common.Hash]types.ArgBytes, error) { + if len(keys) == 0 { + return nil, nil + } + + const listOffchainDataSQL = ` + SELECT key, value + FROM data_node.offchain_data + WHERE key IN (?); + ` + + preparedKeys := make([]string, len(keys)) + for i, key := range keys { + preparedKeys[i] = key.Hex() + } + + query, args, err := sqlx.In(listOffchainDataSQL, preparedKeys) + if err != nil { + return nil, err + } + + // sqlx.In returns queries with the `?` bindvar, we can rebind it for our backend + query = db.pg.Rebind(query) + + rows, err := db.querier(dbTx).QueryxContext(ctx, query, args...) + if err != nil { + return nil, err + } + + defer rows.Close() + + list := make(map[common.Hash]types.ArgBytes) + for rows.Next() { + data := struct { + Key string `db:"key"` + Value string `db:"value"` + }{} + if err = rows.StructScan(&data); err != nil { + return nil, err + } + + list[common.HexToHash(data.Key)] = common.FromHex(data.Value) + } + + return list, nil +} + func (db *pgDB) execer(dbTx sqlx.ExecerContext) sqlx.ExecerContext { if dbTx != nil { return dbTx diff --git a/db/db_test.go b/db/db_test.go index 8195b3aa..43700376 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -2,6 +2,7 @@ package db import ( "context" + "database/sql/driver" "errors" "testing" @@ -478,6 +479,125 @@ func Test_DB_GetOffChainData(t *testing.T) { } } +func Test_DB_ListOffChainData(t *testing.T) { + testTable := []struct { + name string + od []types.OffChainData + keys []common.Hash + expected map[common.Hash]types.ArgBytes + sql string + returnErr error + }{ + { + name: "successfully selected one value", + od: []types.OffChainData{{ + Key: common.HexToHash("key1"), + Value: []byte("value1"), + }}, + keys: []common.Hash{ + common.BytesToHash([]byte("key1")), + }, + expected: map[common.Hash]types.ArgBytes{ + common.BytesToHash([]byte("key1")): []byte("value1"), + }, + sql: `SELECT key, value FROM data_node\.offchain_data WHERE key IN \(\$1\)`, + }, + { + name: "successfully selected two values", + od: []types.OffChainData{{ + Key: common.HexToHash("key1"), + Value: []byte("value1"), + }, { + Key: common.HexToHash("key2"), + Value: []byte("value2"), + }}, + keys: []common.Hash{ + common.BytesToHash([]byte("key1")), + common.BytesToHash([]byte("key2")), + }, + expected: map[common.Hash]types.ArgBytes{ + common.BytesToHash([]byte("key1")): []byte("value1"), + common.BytesToHash([]byte("key2")): []byte("value2"), + }, + sql: `SELECT key, value FROM data_node\.offchain_data WHERE key IN \(\$1\, \$2\)`, + }, + { + name: "error returned", + od: []types.OffChainData{{ + Key: common.HexToHash("key1"), + Value: []byte("value1"), + }}, + keys: []common.Hash{ + common.BytesToHash([]byte("key1")), + }, + sql: `SELECT key, value FROM data_node\.offchain_data WHERE key IN \(\$1\)`, + returnErr: errors.New("test error"), + }, + { + name: "no rows", + od: []types.OffChainData{{ + Key: common.HexToHash("key1"), + Value: []byte("value1"), + }}, + keys: []common.Hash{ + common.BytesToHash([]byte("underfined")), + }, + sql: `SELECT key, value FROM data_node\.offchain_data WHERE key IN \(\$1\)`, + returnErr: ErrStateNotSynchronized, + }, + } + + for _, tt := range testTable { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + db, mock, err := sqlmock.New() + require.NoError(t, err) + + defer db.Close() + + wdb := sqlx.NewDb(db, "postgres") + + // Seed data + seedOffchainData(t, wdb, mock, tt.od) + + preparedKeys := make([]driver.Value, len(tt.keys)) + for i, key := range tt.keys { + preparedKeys[i] = key.Hex() + } + + expected := mock.ExpectQuery(tt.sql). + WithArgs(preparedKeys...) + + if tt.returnErr != nil { + expected.WillReturnError(tt.returnErr) + } else { + returnData := sqlmock.NewRows([]string{"key", "value"}) + + for key, val := range tt.expected { + returnData = returnData.AddRow(key.Hex(), common.Bytes2Hex(val)) + } + + expected.WillReturnRows(returnData) + } + + dbPG := New(wdb) + + data, err := dbPG.ListOffChainData(context.Background(), tt.keys, wdb) + if tt.returnErr != nil { + require.ErrorIs(t, err, tt.returnErr) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, data) + } + + require.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + func Test_DB_Exist(t *testing.T) { testTable := []struct { name string diff --git a/mocks/client.generated.go b/mocks/client.generated.go index 1bcf5872..24152c46 100644 --- a/mocks/client.generated.go +++ b/mocks/client.generated.go @@ -84,6 +84,65 @@ func (_c *Client_GetOffChainData_Call) RunAndReturn(run func(context.Context, co return _c } +// ListOffChainData provides a mock function with given fields: ctx, hashes +func (_m *Client) ListOffChainData(ctx context.Context, hashes []common.Hash) (map[common.Hash][]byte, error) { + ret := _m.Called(ctx, hashes) + + if len(ret) == 0 { + panic("no return value specified for ListOffChainData") + } + + var r0 map[common.Hash][]byte + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []common.Hash) (map[common.Hash][]byte, error)); ok { + return rf(ctx, hashes) + } + if rf, ok := ret.Get(0).(func(context.Context, []common.Hash) map[common.Hash][]byte); ok { + r0 = rf(ctx, hashes) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[common.Hash][]byte) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []common.Hash) error); ok { + r1 = rf(ctx, hashes) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Client_ListOffChainData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListOffChainData' +type Client_ListOffChainData_Call struct { + *mock.Call +} + +// ListOffChainData is a helper method to define mock.On call +// - ctx context.Context +// - hashes []common.Hash +func (_e *Client_Expecter) ListOffChainData(ctx interface{}, hashes interface{}) *Client_ListOffChainData_Call { + return &Client_ListOffChainData_Call{Call: _e.mock.On("ListOffChainData", ctx, hashes)} +} + +func (_c *Client_ListOffChainData_Call) Run(run func(ctx context.Context, hashes []common.Hash)) *Client_ListOffChainData_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]common.Hash)) + }) + return _c +} + +func (_c *Client_ListOffChainData_Call) Return(_a0 map[common.Hash][]byte, _a1 error) *Client_ListOffChainData_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Client_ListOffChainData_Call) RunAndReturn(run func(context.Context, []common.Hash) (map[common.Hash][]byte, error)) *Client_ListOffChainData_Call { + _c.Call.Return(run) + return _c +} + // SignSequence provides a mock function with given fields: signedSequence func (_m *Client) SignSequence(signedSequence types.SignedSequence) ([]byte, error) { ret := _m.Called(signedSequence) diff --git a/mocks/db.generated.go b/mocks/db.generated.go index 8c6924fb..9c27e216 100644 --- a/mocks/db.generated.go +++ b/mocks/db.generated.go @@ -357,6 +357,66 @@ func (_c *DB_GetUnresolvedBatchKeys_Call) RunAndReturn(run func(context.Context) return _c } +// ListOffChainData provides a mock function with given fields: ctx, keys, dbTx +func (_m *DB) ListOffChainData(ctx context.Context, keys []common.Hash, dbTx sqlx.QueryerContext) (map[common.Hash]types.ArgBytes, error) { + ret := _m.Called(ctx, keys, dbTx) + + if len(ret) == 0 { + panic("no return value specified for ListOffChainData") + } + + var r0 map[common.Hash]types.ArgBytes + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []common.Hash, sqlx.QueryerContext) (map[common.Hash]types.ArgBytes, error)); ok { + return rf(ctx, keys, dbTx) + } + if rf, ok := ret.Get(0).(func(context.Context, []common.Hash, sqlx.QueryerContext) map[common.Hash]types.ArgBytes); ok { + r0 = rf(ctx, keys, dbTx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[common.Hash]types.ArgBytes) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []common.Hash, sqlx.QueryerContext) error); ok { + r1 = rf(ctx, keys, dbTx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DB_ListOffChainData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListOffChainData' +type DB_ListOffChainData_Call struct { + *mock.Call +} + +// ListOffChainData is a helper method to define mock.On call +// - ctx context.Context +// - keys []common.Hash +// - dbTx sqlx.QueryerContext +func (_e *DB_Expecter) ListOffChainData(ctx interface{}, keys interface{}, dbTx interface{}) *DB_ListOffChainData_Call { + return &DB_ListOffChainData_Call{Call: _e.mock.On("ListOffChainData", ctx, keys, dbTx)} +} + +func (_c *DB_ListOffChainData_Call) Run(run func(ctx context.Context, keys []common.Hash, dbTx sqlx.QueryerContext)) *DB_ListOffChainData_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]common.Hash), args[2].(sqlx.QueryerContext)) + }) + return _c +} + +func (_c *DB_ListOffChainData_Call) Return(_a0 map[common.Hash]types.ArgBytes, _a1 error) *DB_ListOffChainData_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *DB_ListOffChainData_Call) RunAndReturn(run func(context.Context, []common.Hash, sqlx.QueryerContext) (map[common.Hash]types.ArgBytes, error)) *DB_ListOffChainData_Call { + _c.Call.Return(run) + return _c +} + // StoreLastProcessedBlock provides a mock function with given fields: ctx, task, block, dbTx func (_m *DB) StoreLastProcessedBlock(ctx context.Context, task string, block uint64, dbTx sqlx.ExecerContext) error { ret := _m.Called(ctx, task, block, dbTx) diff --git a/services/datacom/datacom.go b/services/datacom/datacom.go index b61d8d4c..871e17a2 100644 --- a/services/datacom/datacom.go +++ b/services/datacom/datacom.go @@ -14,29 +14,27 @@ import ( // APIDATACOM is the namespace of the datacom service const APIDATACOM = "datacom" -// DataComEndpoints contains implementations for the "datacom" RPC endpoints -type DataComEndpoints struct { +// Endpoints contains implementations for the "datacom" RPC endpoints +type Endpoints struct { db db.DB txMan rpc.DBTxManager privateKey *ecdsa.PrivateKey sequencerTracker *sequencer.Tracker } -// NewDataComEndpoints returns DataComEndpoints -func NewDataComEndpoints( - db db.DB, privateKey *ecdsa.PrivateKey, sequencerTracker *sequencer.Tracker, -) *DataComEndpoints { - return &DataComEndpoints{ +// NewEndpoints returns Endpoints +func NewEndpoints(db db.DB, pk *ecdsa.PrivateKey, st *sequencer.Tracker) *Endpoints { + return &Endpoints{ db: db, - privateKey: privateKey, - sequencerTracker: sequencerTracker, + privateKey: pk, + sequencerTracker: st, } } // SignSequence generates the accumulated input hash aka accInputHash of the sequence and sign it. // After storing the data that will be sent hashed to the contract, it returns the signature. // This endpoint is only accessible to the sequencer -func (d *DataComEndpoints) SignSequence(signedSequence types.SignedSequence) (interface{}, rpc.Error) { +func (d *Endpoints) SignSequence(signedSequence types.SignedSequence) (interface{}, rpc.Error) { // Verify that the request comes from the sequencer sender, err := signedSequence.Signer() if err != nil { diff --git a/services/datacom/datacom_test.go b/services/datacom/datacom_test.go index 45c1a7ef..281b1076 100644 --- a/services/datacom/datacom_test.go +++ b/services/datacom/datacom_test.go @@ -89,7 +89,7 @@ func TestDataCom_SignSequence(t *testing.T) { signer = cfg.signer } - dce := NewDataComEndpoints(dbMock, signer, sequencer) + dce := NewEndpoints(dbMock, signer, sequencer) sig, err := dce.SignSequence(*signedSequence) if cfg.expectedError != "" { diff --git a/services/sync/sync.go b/services/sync/sync.go index 8bed2474..11eae652 100644 --- a/services/sync/sync.go +++ b/services/sync/sync.go @@ -7,26 +7,27 @@ import ( "github.com/0xPolygon/cdk-data-availability/log" "github.com/0xPolygon/cdk-data-availability/rpc" "github.com/0xPolygon/cdk-data-availability/types" + "github.com/ethereum/go-ethereum/common" ) // APISYNC is the namespace of the sync service const APISYNC = "sync" -// SyncEndpoints contains implementations for the "zkevm" RPC endpoints -type SyncEndpoints struct { +// Endpoints contains implementations for the "zkevm" RPC endpoints +type Endpoints struct { db db.DB txMan rpc.DBTxManager } -// NewSyncEndpoints returns ZKEVMEndpoints -func NewSyncEndpoints(db db.DB) *SyncEndpoints { - return &SyncEndpoints{ +// NewEndpoints returns Endpoints +func NewEndpoints(db db.DB) *Endpoints { + return &Endpoints{ db: db, } } // GetOffChainData returns the image of the given hash -func (z *SyncEndpoints) GetOffChainData(hash types.ArgHash) (interface{}, rpc.Error) { +func (z *Endpoints) GetOffChainData(hash types.ArgHash) (interface{}, rpc.Error) { return z.txMan.NewDbTxScope(z.db, func(ctx context.Context, dbTx db.Tx) (interface{}, rpc.Error) { data, err := z.db.GetOffChainData(ctx, hash.Hash(), dbTx) if err != nil { @@ -37,3 +38,21 @@ func (z *SyncEndpoints) GetOffChainData(hash types.ArgHash) (interface{}, rpc.Er return data, nil }) } + +// ListOffChainData returns the list of images of the given hashes +func (z *Endpoints) ListOffChainData(hashes []types.ArgHash) (interface{}, rpc.Error) { + keys := make([]common.Hash, len(hashes)) + for i, hash := range hashes { + keys[i] = hash.Hash() + } + + return z.txMan.NewDbTxScope(z.db, func(ctx context.Context, dbTx db.Tx) (interface{}, rpc.Error) { + list, err := z.db.ListOffChainData(ctx, keys, dbTx) + if err != nil { + log.Errorf("failed to list the requested data from the DB: %v", err) + return "0x0", rpc.NewRPCError(rpc.DefaultErrorCode, "failed to list the requested data") + } + + return list, nil + }) +} diff --git a/services/sync/sync_test.go b/services/sync/sync_test.go index 2b928715..a207fd4f 100644 --- a/services/sync/sync_test.go +++ b/services/sync/sync_test.go @@ -7,6 +7,7 @@ import ( "github.com/0xPolygon/cdk-data-availability/mocks" "github.com/0xPolygon/cdk-data-availability/types" + "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/require" ) @@ -66,7 +67,7 @@ func TestSyncEndpoints_GetOffChainData(t *testing.T) { defer txMock.AssertExpectations(t) defer dbMock.AssertExpectations(t) - z := &SyncEndpoints{db: dbMock} + z := &Endpoints{db: dbMock} got, err := z.GetOffChainData(tt.hash) if tt.err != nil { @@ -79,3 +80,84 @@ func TestSyncEndpoints_GetOffChainData(t *testing.T) { }) } } + +func TestSyncEndpoints_ListOffChainData(t *testing.T) { + tests := []struct { + name string + hashes []types.ArgHash + data interface{} + dbErr error + txErr error + err error + }{ + { + name: "successfully got offchain data", + hashes: []types.ArgHash{}, + data: map[common.Hash]types.ArgBytes{ + common.BytesToHash(nil): types.ArgBytes("offchaindata"), + }, + }, + { + name: "db returns error", + hashes: []types.ArgHash{}, + data: map[common.Hash]types.ArgBytes{ + common.BytesToHash(nil): types.ArgBytes("offchaindata"), + }, + dbErr: errors.New("test error"), + err: errors.New("failed to list the requested data"), + }, + { + name: "tx returns error", + hashes: []types.ArgHash{}, + data: map[common.Hash]types.ArgBytes{ + common.BytesToHash(nil): types.ArgBytes("offchaindata"), + }, + txErr: errors.New("test error"), + err: errors.New("failed to connect to the state"), + }, + } + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + txMock := mocks.NewTx(t) + + dbMock := mocks.NewDB(t) + dbMock.On("BeginStateTransaction", context.Background()). + Return(txMock, tt.txErr) + if tt.txErr == nil { + keys := make([]common.Hash, len(tt.hashes)) + for i, hash := range tt.hashes { + keys[i] = hash.Hash() + } + + dbMock.On("ListOffChainData", context.Background(), keys, txMock). + Return(tt.data, tt.dbErr) + + if tt.err != nil { + txMock.On("Rollback"). + Return(nil) + } else { + txMock.On("Commit"). + Return(nil) + } + } + + defer txMock.AssertExpectations(t) + defer dbMock.AssertExpectations(t) + + z := &Endpoints{db: dbMock} + + got, err := z.ListOffChainData(tt.hashes) + if tt.err != nil { + require.Error(t, err) + require.EqualError(t, tt.err, err.Error()) + } else { + require.NoError(t, err) + require.Equal(t, tt.data, got) + } + }) + } +} diff --git a/test/e2e/datacommittee_test.go b/test/e2e/datacommittee_test.go index f6984c97..3ec85adc 100644 --- a/test/e2e/datacommittee_test.go +++ b/test/e2e/datacommittee_test.go @@ -204,11 +204,12 @@ func TestDataCommittee(t *testing.T) { expectedKeys, err := getSequenceBatchesKeys(clientL1, iter.Event) require.NoError(t, err) for _, m := range membs { + offchainData, err := listOffchainDataKeys(m, expectedKeys) + require.NoError(t, err) + // Each member (including m0) should have all the keys for _, expected := range expectedKeys { - actual, err := getOffchainDataKeys(m, expected) - require.NoError(t, err) - require.Equal(t, expected, actual) + require.Equal(t, expected, offchainData[expected]) } } } @@ -239,16 +240,24 @@ func getSequenceBatchesKeys(clientL1 *ethclient.Client, event *polygonvalidium.P return keys, err } -func getOffchainDataKeys(m member, tx common.Hash) (common.Hash, error) { +func listOffchainDataKeys(m member, txes []common.Hash) (map[common.Hash]common.Hash, error) { testUrl := fmt.Sprintf("http://127.0.0.1:420%d", m.i) mc := newTestClient(testUrl, m.addr) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - data, err := mc.client.GetOffChainData(ctx, tx) + + data, err := mc.client.ListOffChainData(ctx, txes) if err != nil { - return common.Hash{}, err + return nil, err } - return crypto.Keccak256Hash(data), nil + + preparedData := make(map[common.Hash]common.Hash) + for hash, val := range data { + preparedData[hash] = crypto.Keccak256Hash(val) + } + + return preparedData, nil } type member struct { diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index e36a6010..82933bed 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -116,6 +116,7 @@ func (tc *testClient) signSequence(t *testing.T, expected *types.SignedSequence, actualAddr, err := expected.Signer() require.NoError(t, err) assert.Equal(t, tc.dacMemberAddr, actualAddr) + // Check that offchain data has been stored expectedOffchainData := expected.Sequence.OffChainData() for _, od := range expectedOffchainData { @@ -126,5 +127,17 @@ func (tc *testClient) signSequence(t *testing.T, expected *types.SignedSequence, require.NoError(t, err) assert.Equal(t, od.Value, actualData) } + + hashes := make([]common.Hash, len(expectedOffchainData)) + for i, od := range expectedOffchainData { + hashes[i] = od.Key + } + + actualData, err := tc.client.ListOffChainData(context.Background(), hashes) + require.NoError(t, err) + + for _, od := range expectedOffchainData { + assert.Equal(t, od.Value, actualData[od.Key]) + } } }