Skip to content

Commit 7ad4fff

Browse files
committed
rollout add native threading test
1 parent 3d90b03 commit 7ad4fff

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

python/mujoco/rollout_test.py

+29
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,35 @@ def call_rollout(initial_state, control, state, sensordata):
519519
np.testing.assert_array_equal(state, py_state)
520520
np.testing.assert_array_equal(sensordata, py_sensordata)
521521

522+
def test_threading_native(self):
523+
model = mujoco.MjModel.from_xml_string(TEST_XML)
524+
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
525+
num_workers = 32
526+
nroll = 10000
527+
nstep = 5
528+
initial_state = np.random.randn(nroll, nstate)
529+
state = np.empty((nroll, nstep, nstate))
530+
sensordata = np.empty((nroll, nstep, model.nsensordata))
531+
control = np.random.randn(nroll, nstep, model.nu)
532+
533+
model_list = [model] * nroll
534+
data_list = [mujoco.MjData(model) for i in range(num_workers)]
535+
536+
rollout.rollout(
537+
model_list,
538+
data_list,
539+
initial_state,
540+
control,
541+
nstep=nstep,
542+
state=state,
543+
sensordata=sensordata,
544+
)
545+
546+
data = mujoco.MjData(model)
547+
py_state, py_sensordata = py_rollout(model, data, initial_state, control)
548+
np.testing.assert_array_equal(state, py_state)
549+
np.testing.assert_array_equal(sensordata, py_sensordata)
550+
522551
# ---------------------------- test advanced operation
523552

524553
def test_warmstart(self):

0 commit comments

Comments
 (0)