|
16 | 16 |
|
17 | 17 | import concurrent.futures
|
18 | 18 | import threading
|
| 19 | +import time |
19 | 20 | import timeit
|
20 | 21 |
|
21 | 22 | import mujoco
|
@@ -81,12 +82,33 @@ def benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml'
|
81 | 82 |
|
82 | 83 | print('running benchmark')
|
83 | 84 | for nroll_i in nroll:
|
| 85 | + print('roll', nroll_i) |
84 | 86 | for nstep_i in nstep:
|
85 |
| - nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], nstep=nstep_i), number=10) |
86 |
| - pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=10) |
87 |
| - print('nroll: {:04d} nstep: {:04d} nt: {:0.3f} pt: {:0.3f} nt/pt: {:0.3f}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res)) |
| 87 | + number = int((1*nroll[-1] * nstep[-1]) / nroll_i / nstep_i) |
| 88 | + number = max(20, number) |
88 | 89 |
|
89 |
| - # Generate plots |
| 90 | + nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i), number=number) |
| 91 | + pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=number) |
| 92 | + nt_res /= number |
| 93 | + pt_res /= number |
| 94 | + |
| 95 | + # times = [time.time()] |
| 96 | + # for i in range(number): |
| 97 | + # rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i) |
| 98 | + # times.append(time.time()) |
| 99 | + # dt = np.diff(times) |
| 100 | + # nt_res = np.mean(dt) |
| 101 | + |
| 102 | + # times = [time.time()] |
| 103 | + # for i in range(number): |
| 104 | + # pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i) |
| 105 | + # times.append(time.time()) |
| 106 | + # dt = np.diff(times) |
| 107 | + # pt_res = np.mean(dt) |
| 108 | + |
| 109 | + print('nroll: {:04d} nstep: {:04d} nt: {:0.3f} pt: {:0.3f} nt/pt: {:0.3f} {}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res, number)) |
| 110 | + |
| 111 | + # TODO generate plots |
90 | 112 |
|
91 | 113 | if __name__ == '__main__':
|
92 | 114 | benchmark_rollout()
|
0 commit comments