Skip to content

Commit e6aae74

Browse files
committed
rollout bench pool creation overhead
1 parent 169cf99 commit e6aae74

File tree

2 files changed

+180
-10
lines changed

2 files changed

+180
-10
lines changed

python/mujoco/rollout.cc

+30-10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include <chrono>
1516
#include <iostream>
1617
#include <optional>
1718
#include <sstream>
@@ -162,25 +163,33 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,
162163
}
163164

164165
// C-style threaded version of _unsafe_rollout
165-
static ThreadPool* pool = nullptr;
166+
// static ThreadPool* pool = nullptr;
166167
void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData*>& d,
167168
int nroll, int nstep, unsigned int control_spec,
168169
const mjtNum* state0, const mjtNum* warmstart0,
169170
const mjtNum* control, mjtNum* state, mjtNum* sensordata,
170171
int nthread, int chunk_size) {
172+
auto clock = std::chrono::high_resolution_clock();
173+
174+
auto start = clock.now();
175+
171176
int nfulljobs = nroll / chunk_size;
172177
int chunk_remainder = nroll % chunk_size;
173178
int njobs = (chunk_remainder > 0) ? nfulljobs + 1 : nfulljobs;
174179

175-
if (pool == nullptr) {
176-
pool = new ThreadPool(nthread);
177-
}
178-
else if (pool->NumThreads() != nthread) {
179-
delete pool;
180-
pool = new ThreadPool(nthread);
181-
} else {
182-
pool->ResetCount();
183-
}
180+
// if (pool == nullptr) {
181+
// pool = new ThreadPool(nthread);
182+
// }
183+
// else if (pool->NumThreads() != nthread) {
184+
// delete pool;
185+
// pool = new ThreadPool(nthread);
186+
// } else {
187+
// pool->ResetCount();
188+
// }
189+
190+
auto pool_create_start = clock.now();
191+
ThreadPool* pool = new ThreadPool(nthread);
192+
auto pool_create_end = clock.now();
184193

185194
for (int j = 0; j < nfulljobs; j++) {
186195
auto task = [=, &m, &d](void) {
@@ -200,6 +209,17 @@ void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData
200209
}
201210

202211
pool->WaitCount(njobs);
212+
213+
auto pool_delete_start = clock.now();
214+
delete pool;
215+
auto pool_delete_end = clock.now();
216+
217+
std::chrono::duration<double> total = pool_delete_end - start;
218+
std::chrono::duration<double> create_pool = pool_create_end - pool_create_start;
219+
std::chrono::duration<double> delete_pool = pool_delete_end - pool_delete_start;
220+
std::chrono::duration<double> total_pool = create_pool + delete_pool;
221+
222+
std::cout << "total: " << total.count() << " pool: " << total_pool.count() << " ratio " << total_pool.count() / total.count() << std::endl;
203223
}
204224

205225
// NOLINTEND(whitespace/line_length)

python/mujoco/rollout_benchmark.py

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""benchmarks for rollout function."""
16+
17+
import concurrent.futures
18+
import os
19+
import threading
20+
import time
21+
import timeit
22+
23+
import mujoco
24+
from mujoco import rollout
25+
import numpy as np
26+
27+
class PythonThreading:
28+
def __init__(self, m_example, num_workers):
29+
self.m_example = m_example
30+
self.num_workers = num_workers
31+
self.thread_local = threading.local()
32+
self.executor = concurrent.futures.ThreadPoolExecutor(
33+
max_workers=self.num_workers, initializer=self.thread_initializer)
34+
35+
def thread_initializer(self):
36+
self.thread_local.data = mujoco.MjData(self.m_example)
37+
38+
def call_rollout(self, model_list, initial_state, nstep):
39+
rollout.rollout(model_list, [self.thread_local.data], initial_state,
40+
skip_checks=True,
41+
nstep=nstep)
42+
43+
def run(self, model_list, initial_state, nstep):
44+
nroll = len(model_list)
45+
46+
# Divide jobs evenly across threads (as in test)
47+
# better for very wide rollouts
48+
# chunk_size = max(1, nroll // self.num_workers)
49+
50+
# Divide jobs across threads with a chunksize 1/10th of the even amount
51+
# new strategy, helps with load balancing
52+
chunk_size = max(1, nroll // (10 * self.num_workers))
53+
54+
nfulljobs = nroll // chunk_size;
55+
chunk_remainder = nroll % chunk_size;
56+
njobs = nfulljobs
57+
if (chunk_remainder > 0): njobs += 1
58+
59+
chunks = [] # a list of tuples, one per worker
60+
for i in range(nfulljobs):
61+
chunks.append((model_list[i*chunk_size:(i+1)*chunk_size],
62+
initial_state[i*chunk_size:(i+1)*chunk_size],
63+
nstep))
64+
if chunk_remainder > 0:
65+
chunks.append((model_list[nfulljobs*chunk_size:],
66+
initial_state[nfulljobs*chunk_size:],
67+
nstep))
68+
69+
futures = []
70+
for chunk in chunks:
71+
futures.append(self.executor.submit(self.call_rollout, *chunk))
72+
for future in concurrent.futures.as_completed(futures):
73+
future.result()
74+
75+
def benchmark_rollout(model_file, nthread=os.cpu_count()):
76+
print()
77+
print(model_file)
78+
bench_steps = int(1e4) # Run approximately bench_steps per thread
79+
80+
# A grid search
81+
nroll = [int(1e0), int(1e1), int(1e2), int(1e3)]
82+
nstep = [int(1e0), int(1e1), int(1e2), int(1e3)]
83+
nnroll, nnstep = np.meshgrid(nroll, nstep)
84+
nroll_nstep_grid = np.stack((nnroll.flatten(), nnstep.flatten()), axis=1)
85+
86+
# Typical nroll/nstep for sysid, rl, mpc respectively
87+
nroll = [50, 3000, 100]
88+
nstep = [1000, 1, 50]
89+
nroll_nstep_app = np.stack((nroll, nstep), axis=1)
90+
91+
# nroll_nstep = np.vstack((nroll_nstep_grid, nroll_nstep_app))
92+
nroll_nstep = nroll_nstep_app
93+
94+
m = mujoco.MjModel.from_xml_path(model_file)
95+
print('nv:', m.nv)
96+
97+
m_list = [m]*np.max(nroll) # models do not need to be copied
98+
d_list = [mujoco.MjData(m) for i in range(nthread)]
99+
100+
initial_state = np.zeros((mujoco.mj_stateSize(m, mujoco.mjtState.mjSTATE_FULLPHYSICS),))
101+
mujoco.mj_getState(m, d_list[0], initial_state, mujoco.mjtState.mjSTATE_FULLPHYSICS)
102+
initial_state = np.tile(initial_state, (np.max(nroll), 1))
103+
104+
pt = PythonThreading(m_list[0], len(d_list))
105+
106+
for i in range(nroll_nstep.shape[0]):
107+
nroll_i = int(nroll_nstep[i, 0])
108+
nstep_i = int(nroll_nstep[i, 1])
109+
110+
nbench = max(1, int(np.round(min(nthread, nroll_i) * bench_steps / nstep_i / nroll_i)))
111+
112+
times = [time.time()]
113+
for i in range(nbench):
114+
rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], skip_checks=True, nstep=nstep_i)
115+
times.append(time.time())
116+
dt = np.diff(times)
117+
nt_stats = (np.min(dt), np.max(dt), np.mean(dt), np.std(dt))
118+
119+
# times = [time.time()]
120+
# for i in range(nbench):
121+
# pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i)
122+
# times.append(time.time())
123+
# dt = np.diff(times)
124+
# pt_stats = (np.min(dt), np.max(dt), np.mean(dt), np.std(dt))
125+
126+
print('nbench: {:06d} nroll: {:04d} nstep: {:04d} '
127+
'nt_min: {:0.4f} nt_max: {:0.4f} nt_mean: {:0.4f} nt_std: {:0.4f} '.format(
128+
nbench, nroll_i, nstep_i,
129+
*nt_stats))#, *pt_stats,
130+
#nt_stats[0] / pt_stats[0], nt_stats[2] / pt_stats[2]))
131+
time.sleep(5)
132+
133+
if __name__ == '__main__':
134+
print('============================================================')
135+
print('small to medium models')
136+
print('============================================================')
137+
138+
benchmark_rollout(model_file='../../../dm_control/dm_control/suite/hopper.xml')
139+
benchmark_rollout(model_file='../../../mujoco_menagerie/unitree_go2/scene.xml')
140+
benchmark_rollout(model_file='../../model/humanoid/humanoid.xml')
141+
exit(0)
142+
143+
print()
144+
print('============================================================')
145+
print('very large models')
146+
print('============================================================')
147+
benchmark_rollout(model_file='../../model/cards/cards.xml')
148+
benchmark_rollout(model_file='../../model/humanoid/humanoid100.xml')
149+
benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml')
150+
# benchmark_rollout(model_file='../../model/humanoid/100_humanoids.xml')

0 commit comments

Comments
 (0)