Skip to content

Commit e4cb773

Browse files
committed
rollout don't register atexit handler for Rollout objects
1 parent ba78821 commit e4cb773

File tree

4 files changed

+22
-3
lines changed

4 files changed

+22
-3
lines changed

doc/python.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -759,9 +759,11 @@ or
759759

760760
.. code-block:: python
761761
762-
# pool shutdown on object deletion, interpreter shutdown, or call to rollout_.shutdown_pool
762+
# pool shutdown on object deletion or call to rollout_.shutdown_pool
763+
# to ensure clean shutdown of threads, call shutdown_pool before interpreter exit
763764
rollout_ = rollout.Rollout(nthread)
764765
rollout_.rollout(model, data, initial_state)
766+
rollout_.shutdown_pool()
765767
766768
Since the Global Interpreter Lock is released, this function can also be threaded using Python threads. However, this
767769
is less efficient than using native threads. See the ``test_threading`` function in

python/mujoco/rollout.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData
176176
int nroll, int nstep, unsigned int control_spec,
177177
const mjtNum* state0, const mjtNum* warmstart0,
178178
const mjtNum* control, mjtNum* state, mjtNum* sensordata,
179-
std::shared_ptr<ThreadPool> pool, int chunk_size) {
179+
std::shared_ptr<ThreadPool>& pool, int chunk_size) {
180180
int nfulljobs = nroll / chunk_size;
181181
int chunk_remainder = nroll % chunk_size;
182182
int njobs = (chunk_remainder > 0) ? nfulljobs + 1 : nfulljobs;

python/mujoco/rollout.py

-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def __init__(self, nthread: int = None):
3333
""" # fmt: skip
3434
self.nthread = 0 if nthread is None else nthread
3535
self.rollout_ = _rollout.Rollout(self.nthread)
36-
atexit.register(self.shutdown_pool)
3736

3837
def __enter__(self):
3938
return self

python/mujoco/rollout_test.py

+18
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,24 @@ def test_threading_native_persistent_object(self):
579579
np.testing.assert_array_equal(state, py_state)
580580
np.testing.assert_array_equal(sensordata, py_sensordata)
581581

582+
rollout_ = rollout.Rollout(num_workers)
583+
for i in range(2):
584+
rollout_.rollout(
585+
model_list,
586+
data_list,
587+
initial_state,
588+
control,
589+
nstep=nstep,
590+
state=state,
591+
sensordata=sensordata,
592+
)
593+
594+
data = mujoco.MjData(model)
595+
py_state, py_sensordata = py_rollout(model, data, initial_state, control)
596+
np.testing.assert_array_equal(state, py_state)
597+
np.testing.assert_array_equal(sensordata, py_sensordata)
598+
rollout_.shutdown_pool()
599+
582600
def test_threading_native_persistent_function(self):
583601
model = mujoco.MjModel.from_xml_string(TEST_XML)
584602
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)

0 commit comments

Comments
 (0)