Skip to content

Commit ba78821

Browse files
committed
rollout update docs and changelog
1 parent 1e8bffa commit ba78821

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

doc/changelog.rst

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ Bug fixes
1212
Python bindings
1313
^^^^^^^^^^^^^^^
1414
- :ref:`rollout<PyRollout>` now features native multi-threading. If a sequence of MjData instances
15-
of length ``nthread`` is passed in, ``rollout`` will automatically create a persistent threadpool
16-
and parallelize the computation. Contribution by :github:user:`aftersomemath`.
15+
of length ``nthread`` is passed in, ``rollout`` will automatically create a thread pool and parallelize
16+
the computation. The thread pool can resused across calls, but then the function cannot be called simultaneously
17+
from multiple threads. To run multiple threaded rollouts simultaneously, use the new class ``Rollout`` which
18+
encapsulates the thread pool. Contribution by :github:user:`aftersomemath`.
1719

1820
Version 3.2.6 (Dec 2, 2024)
1921
---------------------------

doc/python.rst

+35-7
Original file line numberDiff line numberDiff line change
@@ -711,18 +711,20 @@ The ``mujoco`` package contains two sub-modules: ``mujoco.rollout`` and ``mujoco
711711
rollout
712712
-------
713713

714-
``mujoco.rollout`` shows how to add additional C/C++ functionality, exposed as a Python module via pybind11. It is
715-
implemented in `rollout.cc <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.cc>`__
714+
``mujoco.rollout`` and ``mujoco.Rollout`` shows how to add additional C/C++ functionality, exposed as a Python module
715+
via pybind11. It is implemented in `rollout.cc <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.cc>`__
716716
and wrapped in `rollout.py <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.py>`__. The module
717717
performs a common functionality where tight loops implemented outside of Python are beneficial: rolling out a trajectory
718718
(i.e., calling :ref:`mj_step` in a loop), given an intial state and sequence of controls, and returning subsequent
719-
states and sensor values. The basic usage form is
719+
states and sensor values. The rollouts are run in parallel with an internally managed thread pool if multiple MjData instances
720+
(one per thread) are passed as an argument. The basic usage form is
720721

721722
.. code-block:: python
722723
723724
state, sensordata = rollout.rollout(model, data, initial_state, control)
724725
725726
``model`` is either a single instance of MjModel or a sequence of compatible MjModel of length ``nroll``.
727+
``data`` is either a single instance of MjData or a sequence of compatible MjData of length ``nthread``.
726728
``initial_state`` is an ``nroll x nstate`` array, with ``nroll`` initial states of size ``nstate``, where
727729
``nstate = mj_stateSize(model, mjtState.mjSTATE_FULLPHYSICS)`` is the size of the
728730
:ref:`full physics state<geFullPhysics>`. ``control`` is a ``nroll x nstep x ncontrol`` array of controls. Controls are
@@ -732,13 +734,39 @@ specified by passing an optional ``control_spec`` bitflag.
732734
If a rollout diverges, the current state and sensor values are used to fill the remainder of the trajectory.
733735
Therefore, non-increasing time values can be used to detect diverged rollouts.
734736

735-
The ``rollout`` function is designed to be completely stateless, so all inputs of the stepping pipeline are set and any
737+
The ``rollout`` function is designed to be computationally stateless, so all inputs of the stepping pipeline are set and any
736738
values already present in the given ``MjData`` instance will have no effect on the output.
737739

738-
Since the Global Interpreter Lock can be released, this function can be efficiently threaded using Python threads. See
739-
the ``test_threading`` function in
740+
By default ``rollout.rollout`` creates a new thread pool every call if ``len(data) > 1``. To reuse the thread pool
741+
over multiple calls use the ``persistent_pool`` argument. ``rollout.rollout`` is not thread safe when using
742+
a persistent pool. The basic usage form is
743+
744+
.. code-block:: python
745+
746+
state, sensordata = rollout.rollout(model, data, initial_state, persistent_pool=True)
747+
748+
The pool is shutdown on interpreter shutdown or by a call to ``rollout.shutdown_persistent_pool``.
749+
750+
To use multiple thread pools from multiple threads, use ``Rollout`` objects. The basic usage form is
751+
752+
.. code-block:: python
753+
754+
# Pool shutdown upon exiting block
755+
with rollout.Rollout(nthread) as rollout_:
756+
rollout_.rollout(model, data, initial_state)
757+
758+
or
759+
760+
.. code-block:: python
761+
762+
# pool shutdown on object deletion, interpreter shutdown, or call to rollout_.shutdown_pool
763+
rollout_ = rollout.Rollout(nthread)
764+
rollout_.rollout(model, data, initial_state)
765+
766+
Since the Global Interpreter Lock is released, this function can also be threaded using Python threads. However, this
767+
is less efficient than using native threads. See the ``test_threading`` function in
740768
`rollout_test.py <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout_test.py>`__ for an example
741-
of threaded operation (and more generally for usage examples).
769+
of threaded operation (and for more general usage examples).
742770

743771
.. _PyMinimize:
744772

0 commit comments

Comments
 (0)