Skip to content

Commit 166640c

Browse files
committed
rollout reuse native threadpool
1 parent 63a6e82 commit 166640c

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

python/mujoco/rollout.cc

+19-8
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,
262262
}
263263

264264
// C-style threaded version of _unsafe_rollout
265+
static ThreadPool* pool = nullptr;
265266
void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData*>& d,
266267
int nroll, int nstep, unsigned int control_spec,
267268
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
@@ -272,24 +273,34 @@ void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData
272273
int njobs = nfulljobs;
273274
if (chunk_remainder > 0) njobs++;
274275

275-
ThreadPool pool = ThreadPool(nthread);
276+
if (pool == nullptr) {
277+
pool = new ThreadPool(nthread);
278+
}
279+
else if (pool->NumThreads() != nthread) {
280+
delete pool; // TODO make sure pool is shutdown correctly
281+
pool = new ThreadPool(nthread);
282+
} else {
283+
pool->ResetCount();
284+
}
285+
276286
for (int j = 0; j < nfulljobs; j++) {
277-
auto task = [=, &m, &d, &pool](void) {
278-
_unsafe_rollout(m, d[pool.WorkerId()], j*chunk_size, (j+1)*chunk_size,
287+
auto task = [=, &m, &d](void) {
288+
int id = pool->WorkerId();
289+
_unsafe_rollout(m, d[id], j*chunk_size, (j+1)*chunk_size,
279290
nstep, control_spec, state0, warmstart0, control, state, sensordata);
280291
};
281-
pool.Schedule(task);
292+
pool->Schedule(task);
282293
}
283294

284295
if (chunk_remainder > 0) {
285-
auto task = [=, &m, &d, &pool](void) {
286-
_unsafe_rollout(m, d[pool.WorkerId()], nfulljobs*chunk_size, nfulljobs*chunk_size+chunk_remainder,
296+
auto task = [=, &m, &d](void) {
297+
_unsafe_rollout(m, d[pool->WorkerId()], nfulljobs*chunk_size, nfulljobs*chunk_size+chunk_remainder,
287298
nstep, control_spec, state0, warmstart0, control, state, sensordata);
288299
};
289-
pool.Schedule(task);
300+
pool->Schedule(task);
290301
}
291302

292-
pool.WaitCount(njobs);
303+
pool->WaitCount(njobs);
293304
}
294305

295306
// NOLINTEND(whitespace/line_length)

python/mujoco/rollout_benchmark.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def run(self, model_list, initial_state, nstep):
6464

6565
def benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml'):
6666
nthread = 24
67-
nroll = [int(1e0), int(1e1), int(1e2)]
67+
nroll = [int(1e0), int(1e1), int(1e2), int(2e2)]
6868
nstep = [int(1e0), int(1e1), int(1e2), int(2e2)]
6969

7070
print('making structures')
@@ -84,7 +84,7 @@ def benchmark_rollout(model_file='../../test/benchmark/testdata/humanoid200.xml'
8484
for nstep_i in nstep:
8585
nt_res = timeit.timeit(lambda: rollout.rollout(m_list[:nroll_i], d_list, initial_state[:nroll_i], nstep=nstep_i), number=10)
8686
pt_res = timeit.timeit(lambda: pt.run(m_list[:nroll_i], initial_state[:nroll_i], nstep_i), number=10)
87-
print('{:03d} {:03d} {:0.3f} {:0.3f} {:0.3f}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res))
87+
print('nroll: {:04d} nstep: {:04d} nt: {:0.3f} pt: {:0.3f} nt/pt: {:0.3f}'.format(nroll_i, nstep_i, nt_res, pt_res, nt_res / pt_res))
8888

8989
# Generate plots
9090

0 commit comments

Comments
 (0)