@@ -711,18 +711,20 @@ The ``mujoco`` package contains two sub-modules: ``mujoco.rollout`` and ``mujoco
711
711
rollout
712
712
-------
713
713
714
- ``mujoco.rollout `` shows how to add additional C/C++ functionality, exposed as a Python module via pybind11. It is
715
- implemented in `rollout.cc <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.cc >`__
714
+ ``mujoco.rollout `` and `` mujoco.Rollout `` shows how to add additional C/C++ functionality, exposed as a Python module
715
+ via pybind11. It is implemented in `rollout.cc <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.cc >`__
716
716
and wrapped in `rollout.py <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout.py >`__. The module
717
717
performs a common functionality where tight loops implemented outside of Python are beneficial: rolling out a trajectory
718
718
(i.e., calling :ref: `mj_step ` in a loop), given an intial state and sequence of controls, and returning subsequent
719
- states and sensor values. The basic usage form is
719
+ states and sensor values. The rollouts are run in parallel with an internally managed thread pool if multiple MjData instances
720
+ (one per thread) are passed as an argument. The basic usage form is
720
721
721
722
.. code-block :: python
722
723
723
724
state, sensordata = rollout.rollout(model, data, initial_state, control)
724
725
725
726
``model `` is either a single instance of MjModel or a sequence of compatible MjModel of length ``nroll ``.
727
+ ``data `` is either a single instance of MjData or a sequence of compatible MjData of length ``nthread ``.
726
728
``initial_state `` is an ``nroll x nstate `` array, with ``nroll `` initial states of size ``nstate ``, where
727
729
``nstate = mj_stateSize(model, mjtState.mjSTATE_FULLPHYSICS) `` is the size of the
728
730
:ref: `full physics state<geFullPhysics> `. ``control `` is a ``nroll x nstep x ncontrol `` array of controls. Controls are
@@ -732,13 +734,39 @@ specified by passing an optional ``control_spec`` bitflag.
732
734
If a rollout diverges, the current state and sensor values are used to fill the remainder of the trajectory.
733
735
Therefore, non-increasing time values can be used to detect diverged rollouts.
734
736
735
- The ``rollout `` function is designed to be completely stateless, so all inputs of the stepping pipeline are set and any
737
+ The ``rollout `` function is designed to be computationally stateless, so all inputs of the stepping pipeline are set and any
736
738
values already present in the given ``MjData `` instance will have no effect on the output.
737
739
738
- Since the Global Interpreter Lock can be released, this function can be efficiently threaded using Python threads. See
739
- the ``test_threading `` function in
740
+ By default ``rollout.rollout `` creates a new thread pool every call if ``len(data) > 1 ``. To reuse the thread pool
741
+ over multiple calls use the ``persistent_pool `` argument. ``rollout.rollout `` is not thread safe when using
742
+ a persistent pool. The basic usage form is
743
+
744
+ .. code-block :: python
745
+
746
+ state, sensordata = rollout.rollout(model, data, initial_state, persistent_pool = True )
747
+
748
+ The pool is shutdown on interpreter shutdown or by a call to ``rollout.shutdown_persistent_pool ``.
749
+
750
+ To use multiple thread pools from multiple threads, use ``Rollout `` objects. The basic usage form is
751
+
752
+ .. code-block :: python
753
+
754
+ # Pool shutdown upon exiting block
755
+ with rollout.Rollout(nthread) as rollout_:
756
+ rollout_.rollout(model, data, initial_state)
757
+
758
+ or
759
+
760
+ .. code-block :: python
761
+
762
+ # pool shutdown on object deletion, interpreter shutdown, or call to rollout_.shutdown_pool
763
+ rollout_ = rollout.Rollout(nthread)
764
+ rollout_.rollout(model, data, initial_state)
765
+
766
+ Since the Global Interpreter Lock is released, this function can also be threaded using Python threads. However, this
767
+ is less efficient than using native threads. See the ``test_threading `` function in
740
768
`rollout_test.py <https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/rollout_test.py >`__ for an example
741
- of threaded operation (and more generally for usage examples).
769
+ of threaded operation (and for more general usage examples).
742
770
743
771
.. _PyMinimize :
744
772
0 commit comments