Skip to content

Commit 39a3755

Browse files
committed
rollout bench chunk_divisor
1 parent 06b90fe commit 39a3755

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""benchmarks for rollout function."""
16+
17+
import os
18+
import time
19+
20+
import mujoco
21+
from mujoco import rollout
22+
import numpy as np
23+
24+
def benchmark_rollout(model_file, nthread=os.cpu_count()):
25+
print('\n', model_file)
26+
bench_steps = int(1e4) # Run approximately bench_steps per thread
27+
28+
# A grid search
29+
nroll = [int(1e0), int(1e1), int(1e2), int(1e3)]
30+
nstep = [int(1e0), int(1e1), int(1e2), int(1e3)]
31+
nnroll, nnstep = np.meshgrid(nroll, nstep)
32+
nroll_nstep_grid = np.stack((nnroll.flatten(), nnstep.flatten()), axis=1)
33+
34+
# Typical nroll/nstep for sysid, rl, mpc respectively
35+
nroll = [50, 3000, 100]
36+
nstep = [1000, 1, 50]
37+
nroll_nstep_app = np.stack((nroll, nstep), axis=1)
38+
39+
nroll_nstep = np.vstack((nroll_nstep_grid, nroll_nstep_app))
40+
41+
chunk_divisors = [10, 1, 2, 4, 8, 16, 32, 64, 128] # First element is the nominal divisor
42+
43+
m = mujoco.MjModel.from_xml_path(model_file)
44+
print('nv:', m.nv)
45+
46+
m_list = [m]*np.max(nroll) # models do not need to be copied
47+
d_list = [mujoco.MjData(m) for i in range(nthread)]
48+
49+
initial_state = np.zeros((mujoco.mj_stateSize(m, mujoco.mjtState.mjSTATE_FULLPHYSICS),))
50+
mujoco.mj_getState(m, d_list[0], initial_state, mujoco.mjtState.mjSTATE_FULLPHYSICS)
51+
initial_state = np.tile(initial_state, (np.max(nroll), 1))
52+
53+
for i in range(nroll_nstep.shape[0]):
54+
nroll_i = int(nroll_nstep[i, 0])
55+
nstep_i = int(nroll_nstep[i, 1])
56+
57+
nbench = max(1, int(np.round(min(nthread, nroll_i) * bench_steps / nstep_i / nroll_i)))
58+
59+
chunk_divisors_stats = []
60+
for chunk_divisor in chunk_divisors:
61+
times = [time.time()]
62+
for i in range(nbench):
63+
rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i, chunk_divisor=chunk_divisor)
64+
times.append(time.time())
65+
dt = np.diff(times)
66+
chunk_divisors_stats.append((np.min(dt), np.max(dt), np.mean(dt), np.std(dt)))
67+
chunk_divisors_stats = np.array(chunk_divisors_stats)
68+
69+
slowest_chunk_divisor_i = np.argmax(chunk_divisors_stats[:, 2])
70+
slowest_chunk_divisor = chunk_divisors[slowest_chunk_divisor_i]
71+
72+
fastest_chunk_divisor_i = np.argmin(chunk_divisors_stats[:, 2])
73+
fastest_chunk_divisor = chunk_divisors[fastest_chunk_divisor_i]
74+
75+
print('nbench: {:06d} nroll: {:04d} nstep: {:04d} '
76+
'mean_nom {:0.4f} mean_slow: {:0.4f} mean_fast: {:0.4f} chunk_div_slow {:03d} chunk_div_fast {:03d} fast/slow {:0.3f} fast/nom {:0.3f}'.format(
77+
nbench, nroll_i, nstep_i,
78+
chunk_divisors_stats[0, 2], # nominal chunk divisor
79+
chunk_divisors_stats[slowest_chunk_divisor_i, 2],
80+
chunk_divisors_stats[fastest_chunk_divisor_i, 2],
81+
slowest_chunk_divisor, fastest_chunk_divisor,
82+
chunk_divisors_stats[fastest_chunk_divisor_i, 2] / chunk_divisors_stats[slowest_chunk_divisor_i, 2],
83+
chunk_divisors_stats[fastest_chunk_divisor_i, 2] / chunk_divisors_stats[0, 2]))
84+
85+
if __name__ == '__main__':
86+
print('============================================================')
87+
print('small to medium models')
88+
print('============================================================')
89+
90+
benchmark_rollout(model_file='../../../dm_control/dm_control/suite/hopper.xml')
91+
benchmark_rollout(model_file='../../../mujoco_menagerie/unitree_go2/scene.xml')
92+
benchmark_rollout(model_file='../../model/humanoid/humanoid.xml')
93+
94+
print()
95+
print('============================================================')
96+
print('very large models')
97+
print('============================================================')
98+
benchmark_rollout(model_file='../../model/cards/cards.xml')
99+
benchmark_rollout(model_file='../../model/humanoid/humanoid100.xml')
100+
benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml')
101+
# benchmark_rollout(model_file='../../model/humanoid/100_humanoids.xml')

0 commit comments

Comments
 (0)