@@ -519,6 +519,35 @@ def call_rollout(initial_state, control, state, sensordata):
519
519
np .testing .assert_array_equal (state , py_state )
520
520
np .testing .assert_array_equal (sensordata , py_sensordata )
521
521
522
+ def test_threading_native (self ):
523
+ model = mujoco .MjModel .from_xml_string (TEST_XML )
524
+ nstate = mujoco .mj_stateSize (model , mujoco .mjtState .mjSTATE_FULLPHYSICS )
525
+ num_workers = 32
526
+ nroll = 10000
527
+ nstep = 5
528
+ initial_state = np .random .randn (nroll , nstate )
529
+ state = np .empty ((nroll , nstep , nstate ))
530
+ sensordata = np .empty ((nroll , nstep , model .nsensordata ))
531
+ control = np .random .randn (nroll , nstep , model .nu )
532
+
533
+ model_list = [model ] * nroll
534
+ data_list = [mujoco .MjData (model ) for i in range (num_workers )]
535
+
536
+ rollout .rollout (
537
+ model_list ,
538
+ data_list ,
539
+ initial_state ,
540
+ control ,
541
+ nstep = nstep ,
542
+ state = state ,
543
+ sensordata = sensordata ,
544
+ )
545
+
546
+ data = mujoco .MjData (model )
547
+ py_state , py_sensordata = py_rollout (model , data , initial_state , control )
548
+ np .testing .assert_array_equal (state , py_state )
549
+ np .testing .assert_array_equal (sensordata , py_sensordata )
550
+
522
551
# ---------------------------- test advanced operation
523
552
524
553
def test_warmstart (self ):
0 commit comments