|
15 | 15 | """benchmarks for rollout function."""
|
16 | 16 |
|
17 | 17 | import concurrent.futures
|
| 18 | +import os |
18 | 19 | import threading
|
19 | 20 | import time
|
20 | 21 | import timeit
|
@@ -63,52 +64,67 @@ def run(self, model_list, initial_state, nstep):
|
63 | 64 | for future in concurrent.futures.as_completed(futures):
|
64 | 65 | future.result()
|
65 | 66 |
|
66 |
| -def benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml'): |
67 |
| - nthread = 24 |
68 |
| - nroll = [int(1e0), int(1e1), int(1e2), int(2e2)] |
69 |
| - nstep = [int(1e0), int(1e1), int(1e2), int(2e2)] |
| 67 | +def benchmark_rollout(model_file, nthread=os.cpu_count()): |
| 68 | + print('\n', model_file) |
| 69 | + bench_steps = int(1e4) # Run approximately bench_steps per thread |
| 70 | + |
| 71 | + # A grid search |
| 72 | + nroll = [int(1e0), int(1e1), int(1e2), int(1e3)] |
| 73 | + nstep = [int(1e0), int(1e1), int(1e2), int(1e3)] |
| 74 | + nnroll, nnstep = np.meshgrid(nroll, nstep) |
| 75 | + nroll_nstep_grid = np.stack((nnroll.flatten(), nnstep.flatten()), axis=1) |
| 76 | + |
| 77 | + # Typical nroll/nstep for sysid, rl, mpc respectively |
| 78 | + nroll = [50, 3000, 100] |
| 79 | + nstep = [1000, 1, 50] |
| 80 | + nroll_nstep_app = np.stack((nroll, nstep), axis=1) |
| 81 | + |
| 82 | + nroll_nstep = np.vstack((nroll_nstep_grid, nroll_nstep_app)) |
70 | 83 |
|
71 |
| - print('making structures') |
72 | 84 | m = mujoco.MjModel.from_xml_path(model_file)
|
73 |
| - m_list = [m]*nroll[-1] # models do not need to be copied |
| 85 | + print('nv:', m.nv) |
| 86 | + |
| 87 | + m_list = [m]*np.max(nroll) # models do not need to be copied |
74 | 88 | d_list = [mujoco.MjData(m) for i in range(nthread)]
|
75 | 89 |
|
76 | 90 | initial_state = np.zeros((mujoco.mj_stateSize(m, mujoco.mjtState.mjSTATE_FULLPHYSICS),))
|
77 | 91 | mujoco.mj_getState(m, d_list[0], initial_state, mujoco.mjtState.mjSTATE_FULLPHYSICS)
|
78 |
| - initial_state = np.tile(initial_state, (nroll[-1], 1)) |
| 92 | + initial_state = np.tile(initial_state, (np.max(nroll), 1)) |
79 | 93 |
|
80 |
| - print('initializing thread pools') |
81 | 94 | pt = PythonThreading(m_list[0], len(d_list))
|
82 | 95 |
|
83 |
| - print('running benchmark') |
84 |
| - for nroll_i in nroll: |
85 |
| - print('roll', nroll_i) |
86 |
| - for nstep_i in nstep: |
87 |
| - number = int((1*nroll[-1] * nstep[-1]) / nroll_i / nstep_i) |
88 |
| - number = max(20, number) |
89 |
| - |
90 |
| - nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i), number=number) |
91 |
| - pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=number) |
92 |
| - nt_res /= number |
93 |
| - pt_res /= number |
94 |
| - |
95 |
| - # times = [time.time()] |
96 |
| - # for i in range(number): |
97 |
| - # rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i) |
98 |
| - # times.append(time.time()) |
99 |
| - # dt = np.diff(times) |
100 |
| - # nt_res = np.mean(dt) |
101 |
| - |
102 |
| - # times = [time.time()] |
103 |
| - # for i in range(number): |
104 |
| - # pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i) |
105 |
| - # times.append(time.time()) |
106 |
| - # dt = np.diff(times) |
107 |
| - # pt_res = np.mean(dt) |
108 |
| - |
109 |
| - print('nroll: {:04d} nstep: {:04d} nt: {:0.3f} pt: {:0.3f} nt/pt: {:0.3f} {}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res, number)) |
110 |
| - |
111 |
| - # TODO generate plots |
| 96 | + for i in range(nroll_nstep.shape[0]): |
| 97 | + nroll_i = int(nroll_nstep[i, 0]) |
| 98 | + nstep_i = int(nroll_nstep[i, 1]) |
| 99 | + |
| 100 | + nbench = max(1, int(np.round(min(nthread, nroll_i) * bench_steps / nstep_i / nroll_i))) |
| 101 | + |
| 102 | + times = [time.time()] |
| 103 | + for i in range(nbench): |
| 104 | + rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i) |
| 105 | + times.append(time.time()) |
| 106 | + dt = np.diff(times) |
| 107 | + nt_stats = (np.min(dt), np.max(dt), np.mean(dt), np.std(dt)) |
| 108 | + |
| 109 | + times = [time.time()] |
| 110 | + for i in range(nbench): |
| 111 | + pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i) |
| 112 | + times.append(time.time()) |
| 113 | + dt = np.diff(times) |
| 114 | + pt_stats = (np.min(dt), np.max(dt), np.mean(dt), np.std(dt)) |
| 115 | + |
| 116 | + print('nbench: {:06d} nroll: {:04d} nstep: {:04d} ' |
| 117 | + 'nt_min: {:0.4f} nt_max: {:0.4f} nt_mean: {:0.4f} nt_std: {:0.4f} ' |
| 118 | + 'pt_min: {:0.4f} pt_max: {:0.4f} pt_mean: {:0.4f} pt_std: {:0.4f} ' |
| 119 | + 'nt/pt min: {:0.3f} nt/pt mean: {:0.3f}'.format( |
| 120 | + nbench, nroll_i, nstep_i, |
| 121 | + *nt_stats, *pt_stats, |
| 122 | + nt_stats[0] / pt_stats[0], nt_stats[2] / pt_stats[2])) |
112 | 123 |
|
113 | 124 | if __name__ == '__main__':
|
114 |
| - benchmark_rollout() |
| 125 | + benchmark_rollout(model_file='../../../dm_control/dm_control/suite/hopper.xml') |
| 126 | + benchmark_rollout(model_file='../../../mujoco_menagerie/unitree_go2/scene.xml') |
| 127 | + benchmark_rollout(model_file='../../model/humanoid/humanoid.xml') |
| 128 | + benchmark_rollout(model_file='../../model/humanoid/humanoid100.xml') |
| 129 | + # benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml') |
| 130 | + # benchmark_rollout(model_file='../../model/humanoid/100_humanoids.xml') |
0 commit comments