Skip to content

Commit 1fa1727

Browse files
committed
rollout better benchmarking and use skip_checks with native threading
1 parent 166640c commit 1fa1727

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

python/mujoco/rollout_benchmark.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import concurrent.futures
1818
import threading
19+
import time
1920
import timeit
2021

2122
import mujoco
@@ -81,12 +82,33 @@ def benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml'
8182

8283
print('running benchmark')
8384
for nroll_i in nroll:
85+
print('roll', nroll_i)
8486
for nstep_i in nstep:
85-
nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], nstep=nstep_i), number=10)
86-
pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=10)
87-
print('nroll: {:04d} nstep: {:04d} nt: {:0.3f} pt: {:0.3f} nt/pt: {:0.3f}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res))
87+
number = int((1*nroll[-1] * nstep[-1]) / nroll_i / nstep_i)
88+
number = max(20, number)
8889

89-
# Generate plots
90+
nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i), number=number)
91+
pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=number)
92+
nt_res /= number
93+
pt_res /= number
94+
95+
# times = [time.time()]
96+
# for i in range(number):
97+
# rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i)
98+
# times.append(time.time())
99+
# dt = np.diff(times)
100+
# nt_res = np.mean(dt)
101+
102+
# times = [time.time()]
103+
# for i in range(number):
104+
# pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i)
105+
# times.append(time.time())
106+
# dt = np.diff(times)
107+
# pt_res = np.mean(dt)
108+
109+
print('nroll: {:04d} nstep: {:04d} nt: {:0.3f} pt: {:0.3f} nt/pt: {:0.3f} {}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res, number))
110+
111+
# TODO generate plots
90112

91113
if __name__ == '__main__':
92114
benchmark_rollout()

0 commit comments

Comments
 (0)