Skip to content

Commit 31226ce

Browse files
committed
rollout benchmark produces reasonable results
1 parent 1fa1727 commit 31226ce

File tree

2 files changed

+55
-39
lines changed

2 files changed

+55
-39
lines changed

python/mujoco/rollout.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
380380
py::gil_scoped_release no_gil;
381381

382382
// call unsafe rollout function
383-
if (nthread > 1) {
383+
if (nthread > 1 && nroll > 1) {
384384
int chunk_size = std::max(1, nroll / (10 * nthread));
385385
InterceptMjErrors(_unsafe_rollout_threaded)(
386386
model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr,

python/mujoco/rollout_benchmark.py

+54-38
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""benchmarks for rollout function."""
1616

1717
import concurrent.futures
18+
import os
1819
import threading
1920
import time
2021
import timeit
@@ -63,52 +64,67 @@ def run(self, model_list, initial_state, nstep):
6364
for future in concurrent.futures.as_completed(futures):
6465
future.result()
6566

66-
def benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml'):
67-
nthread = 24
68-
nroll = [int(1e0), int(1e1), int(1e2), int(2e2)]
69-
nstep = [int(1e0), int(1e1), int(1e2), int(2e2)]
67+
def benchmark_rollout(model_file, nthread=os.cpu_count()):
68+
print('\n', model_file)
69+
bench_steps = int(1e4) # Run approximately bench_steps per thread
70+
71+
# A grid search
72+
nroll = [int(1e0), int(1e1), int(1e2), int(1e3)]
73+
nstep = [int(1e0), int(1e1), int(1e2), int(1e3)]
74+
nnroll, nnstep = np.meshgrid(nroll, nstep)
75+
nroll_nstep_grid = np.stack((nnroll.flatten(), nnstep.flatten()), axis=1)
76+
77+
# Typical nroll/nstep for sysid, rl, mpc respectively
78+
nroll = [50, 3000, 100]
79+
nstep = [1000, 1, 50]
80+
nroll_nstep_app = np.stack((nroll, nstep), axis=1)
81+
82+
nroll_nstep = np.vstack((nroll_nstep_grid, nroll_nstep_app))
7083

71-
print('making structures')
7284
m = mujoco.MjModel.from_xml_path(model_file)
73-
m_list = [m]*nroll[-1] # models do not need to be copied
85+
print('nv:', m.nv)
86+
87+
m_list = [m]*np.max(nroll) # models do not need to be copied
7488
d_list = [mujoco.MjData(m) for i in range(nthread)]
7589

7690
initial_state = np.zeros((mujoco.mj_stateSize(m, mujoco.mjtState.mjSTATE_FULLPHYSICS),))
7791
mujoco.mj_getState(m, d_list[0], initial_state, mujoco.mjtState.mjSTATE_FULLPHYSICS)
78-
initial_state = np.tile(initial_state, (nroll[-1], 1))
92+
initial_state = np.tile(initial_state, (np.max(nroll), 1))
7993

80-
print('initializing thread pools')
8194
pt = PythonThreading(m_list[0], len(d_list))
8295

83-
print('running benchmark')
84-
for nroll_i in nroll:
85-
print('roll', nroll_i)
86-
for nstep_i in nstep:
87-
number = int((1*nroll[-1] * nstep[-1]) / nroll_i / nstep_i)
88-
number = max(20, number)
89-
90-
nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i), number=number)
91-
pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=number)
92-
nt_res /= number
93-
pt_res /= number
94-
95-
# times = [time.time()]
96-
# for i in range(number):
97-
# rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i)
98-
# times.append(time.time())
99-
# dt = np.diff(times)
100-
# nt_res = np.mean(dt)
101-
102-
# times = [time.time()]
103-
# for i in range(number):
104-
# pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i)
105-
# times.append(time.time())
106-
# dt = np.diff(times)
107-
# pt_res = np.mean(dt)
108-
109-
print('nroll: {:04d} nstep: {:04d} nt: {:0.3f} pt: {:0.3f} nt/pt: {:0.3f} {}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res, number))
110-
111-
# TODO generate plots
96+
for i in range(nroll_nstep.shape[0]):
97+
nroll_i = int(nroll_nstep[i, 0])
98+
nstep_i = int(nroll_nstep[i, 1])
99+
100+
nbench = max(1, int(np.round(min(nthread, nroll_i) * bench_steps / nstep_i / nroll_i)))
101+
102+
times = [time.time()]
103+
for i in range(nbench):
104+
rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i)
105+
times.append(time.time())
106+
dt = np.diff(times)
107+
nt_stats = (np.min(dt), np.max(dt), np.mean(dt), np.std(dt))
108+
109+
times = [time.time()]
110+
for i in range(nbench):
111+
pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i)
112+
times.append(time.time())
113+
dt = np.diff(times)
114+
pt_stats = (np.min(dt), np.max(dt), np.mean(dt), np.std(dt))
115+
116+
print('nbench: {:06d} nroll: {:04d} nstep: {:04d} '
117+
'nt_min: {:0.4f} nt_max: {:0.4f} nt_mean: {:0.4f} nt_std: {:0.4f} '
118+
'pt_min: {:0.4f} pt_max: {:0.4f} pt_mean: {:0.4f} pt_std: {:0.4f} '
119+
'nt/pt min: {:0.3f} nt/pt mean: {:0.3f}'.format(
120+
nbench, nroll_i, nstep_i,
121+
*nt_stats, *pt_stats,
122+
nt_stats[0] / pt_stats[0], nt_stats[2] / pt_stats[2]))
112123

113124
if __name__ == '__main__':
114-
benchmark_rollout()
125+
benchmark_rollout(model_file='../../../dm_control/dm_control/suite/hopper.xml')
126+
benchmark_rollout(model_file='../../../mujoco_menagerie/unitree_go2/scene.xml')
127+
benchmark_rollout(model_file='../../model/humanoid/humanoid.xml')
128+
benchmark_rollout(model_file='../../model/humanoid/humanoid100.xml')
129+
# benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml')
130+
# benchmark_rollout(model_file='../../model/humanoid/100_humanoids.xml')

0 commit comments

Comments
 (0)