|
| 1 | +# Copyright 2024 DeepMind Technologies Limited |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""benchmarks for rollout function.""" |
| 16 | + |
| 17 | +import concurrent.futures |
| 18 | +import threading |
| 19 | +import timeit |
| 20 | + |
| 21 | +import mujoco |
| 22 | +from mujoco import rollout |
| 23 | +import numpy as np |
| 24 | + |
| 25 | +class PythonThreading: |
| 26 | + def __init__(self, m_example, num_workers): |
| 27 | + self.m_example = m_example |
| 28 | + self.num_workers = num_workers |
| 29 | + self.thread_local = threading.local() |
| 30 | + self.executor = concurrent.futures.ThreadPoolExecutor( |
| 31 | + max_workers=self.num_workers, initializer=self.thread_initializer) |
| 32 | + |
| 33 | + def thread_initializer(self): |
| 34 | + self.thread_local.data = mujoco.MjData(self.m_example) |
| 35 | + |
| 36 | + def call_rollout(self, model_list, initial_state, nstep): |
| 37 | + rollout.rollout(model_list, [self.thread_local.data], initial_state, |
| 38 | + skip_checks=True, |
| 39 | + nstep=nstep) |
| 40 | + |
| 41 | + def run(self, model_list, initial_state, nstep): |
| 42 | + nroll = len(model_list) |
| 43 | + chunk_size = max(1, nroll // (10 * self.num_workers)) |
| 44 | + nfulljobs = nroll // chunk_size; |
| 45 | + chunk_remainder = nroll % chunk_size; |
| 46 | + njobs = nfulljobs |
| 47 | + if (chunk_remainder > 0): njobs += 1 |
| 48 | + |
| 49 | + chunks = [] # a list of tuples, one per worker |
| 50 | + for i in range(nfulljobs): |
| 51 | + chunks.append((model_list[i*chunk_size:(i+1)*chunk_size], |
| 52 | + initial_state[i*chunk_size:(i+1)*chunk_size], |
| 53 | + nstep)) |
| 54 | + if chunk_remainder > 0: |
| 55 | + chunks.append((model_list[nfulljobs*chunk_size:], |
| 56 | + initial_state[nfulljobs*chunk_size:], |
| 57 | + nstep)) |
| 58 | + |
| 59 | + futures = [] |
| 60 | + for chunk in chunks: |
| 61 | + futures.append(self.executor.submit(self.call_rollout, *chunk)) |
| 62 | + for future in concurrent.futures.as_completed(futures): |
| 63 | + future.result() |
| 64 | + |
| 65 | +def benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml'): |
| 66 | + nthread = 24 |
| 67 | + nroll = [int(1e0), int(1e1), int(1e2)] |
| 68 | + nstep = [int(1e0), int(1e1), int(1e2), int(2e2)] |
| 69 | + |
| 70 | + print('making structures') |
| 71 | + m = mujoco.MjModel.from_xml_path(model_file) |
| 72 | + m_list = [m]*nroll[-1] # models do not need to be copied |
| 73 | + d_list = [mujoco.MjData(m) for i in range(nthread)] |
| 74 | + |
| 75 | + initial_state = np.zeros((mujoco.mj_stateSize(m, mujoco.mjtState.mjSTATE_FULLPHYSICS),)) |
| 76 | + mujoco.mj_getState(m, d_list[0], initial_state, mujoco.mjtState.mjSTATE_FULLPHYSICS) |
| 77 | + initial_state = np.tile(initial_state, (nroll[-1], 1)) |
| 78 | + |
| 79 | + print('initializing thread pools') |
| 80 | + pt = PythonThreading(m_list[0], len(d_list)) |
| 81 | + |
| 82 | + print('running benchmark') |
| 83 | + for nroll_i in nroll: |
| 84 | + for nstep_i in nstep: |
| 85 | + nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], nstep=nstep_i), number=10) |
| 86 | + pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=10) |
| 87 | + print('{:03d} {:03d} {:0.3f} {:0.3f} {:0.3f}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res)) |
| 88 | + |
| 89 | + # Generate plots |
| 90 | + |
| 91 | +if __name__ == '__main__': |
| 92 | + benchmark_rollout() |
0 commit comments