Skip to content

Commit 63a6e82

Browse files
committed
prototype rollout benchmark
1 parent d7d0456 commit 63a6e82

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

python/mujoco/rollout_benchmark.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""benchmarks for rollout function."""
16+
17+
import concurrent.futures
18+
import threading
19+
import timeit
20+
21+
import mujoco
22+
from mujoco import rollout
23+
import numpy as np
24+
25+
class PythonThreading:
26+
def __init__(self, m_example, num_workers):
27+
self.m_example = m_example
28+
self.num_workers = num_workers
29+
self.thread_local = threading.local()
30+
self.executor = concurrent.futures.ThreadPoolExecutor(
31+
max_workers=self.num_workers, initializer=self.thread_initializer)
32+
33+
def thread_initializer(self):
34+
self.thread_local.data = mujoco.MjData(self.m_example)
35+
36+
def call_rollout(self, model_list, initial_state, nstep):
37+
rollout.rollout(model_list, [self.thread_local.data], initial_state,
38+
skip_checks=True,
39+
nstep=nstep)
40+
41+
def run(self, model_list, initial_state, nstep):
42+
nroll = len(model_list)
43+
chunk_size = max(1, nroll // (10 * self.num_workers))
44+
nfulljobs = nroll // chunk_size;
45+
chunk_remainder = nroll % chunk_size;
46+
njobs = nfulljobs
47+
if (chunk_remainder > 0): njobs += 1
48+
49+
chunks = [] # a list of tuples, one per worker
50+
for i in range(nfulljobs):
51+
chunks.append((model_list[i*chunk_size:(i+1)*chunk_size],
52+
initial_state[i*chunk_size:(i+1)*chunk_size],
53+
nstep))
54+
if chunk_remainder > 0:
55+
chunks.append((model_list[nfulljobs*chunk_size:],
56+
initial_state[nfulljobs*chunk_size:],
57+
nstep))
58+
59+
futures = []
60+
for chunk in chunks:
61+
futures.append(self.executor.submit(self.call_rollout, *chunk))
62+
for future in concurrent.futures.as_completed(futures):
63+
future.result()
64+
65+
def benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml'):
66+
nthread = 24
67+
nroll = [int(1e0), int(1e1), int(1e2)]
68+
nstep = [int(1e0), int(1e1), int(1e2), int(2e2)]
69+
70+
print('making structures')
71+
m = mujoco.MjModel.from_xml_path(model_file)
72+
m_list = [m]*nroll[-1] # models do not need to be copied
73+
d_list = [mujoco.MjData(m) for i in range(nthread)]
74+
75+
initial_state = np.zeros((mujoco.mj_stateSize(m, mujoco.mjtState.mjSTATE_FULLPHYSICS),))
76+
mujoco.mj_getState(m, d_list[0], initial_state, mujoco.mjtState.mjSTATE_FULLPHYSICS)
77+
initial_state = np.tile(initial_state, (nroll[-1], 1))
78+
79+
print('initializing thread pools')
80+
pt = PythonThreading(m_list[0], len(d_list))
81+
82+
print('running benchmark')
83+
for nroll_i in nroll:
84+
for nstep_i in nstep:
85+
nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], nstep=nstep_i), number=10)
86+
pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=10)
87+
print('{:03d} {:03d} {:0.3f} {:0.3f} {:0.3f}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res))
88+
89+
# Generate plots
90+
91+
if __name__ == '__main__':
92+
benchmark_rollout()

0 commit comments

Comments
 (0)