@@ -262,6 +262,7 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,
262
262
}
263
263
264
264
// C-style threaded version of _unsafe_rollout
265
+ static ThreadPool* pool = nullptr ;
265
266
void _unsafe_rollout_threaded (std::vector<const mjModel*>& m, std::vector<mjData*>& d,
266
267
int nroll, int nstep, unsigned int control_spec,
267
268
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
@@ -272,24 +273,34 @@ void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData
272
273
int njobs = nfulljobs;
273
274
if (chunk_remainder > 0 ) njobs++;
274
275
275
- ThreadPool pool = ThreadPool (nthread);
276
+ if (pool == nullptr ) {
277
+ pool = new ThreadPool (nthread);
278
+ }
279
+ else if (pool->NumThreads () != nthread) {
280
+ delete pool; // TODO make sure pool is shutdown correctly
281
+ pool = new ThreadPool (nthread);
282
+ } else {
283
+ pool->ResetCount ();
284
+ }
285
+
276
286
for (int j = 0 ; j < nfulljobs; j++) {
277
- auto task = [=, &m, &d, &pool](void ) {
278
- _unsafe_rollout (m, d[pool.WorkerId ()], j*chunk_size, (j+1 )*chunk_size,
287
+ auto task = [=, &m, &d](void ) {
288
+ int id = pool->WorkerId ();
289
+ _unsafe_rollout (m, d[id], j*chunk_size, (j+1 )*chunk_size,
279
290
nstep, control_spec, state0, warmstart0, control, state, sensordata);
280
291
};
281
- pool. Schedule (task);
292
+ pool-> Schedule (task);
282
293
}
283
294
284
295
if (chunk_remainder > 0 ) {
285
- auto task = [=, &m, &d, &pool ](void ) {
286
- _unsafe_rollout (m, d[pool. WorkerId ()], nfulljobs*chunk_size, nfulljobs*chunk_size+chunk_remainder,
296
+ auto task = [=, &m, &d](void ) {
297
+ _unsafe_rollout (m, d[pool-> WorkerId ()], nfulljobs*chunk_size, nfulljobs*chunk_size+chunk_remainder,
287
298
nstep, control_spec, state0, warmstart0, control, state, sensordata);
288
299
};
289
- pool. Schedule (task);
300
+ pool-> Schedule (task);
290
301
}
291
302
292
- pool. WaitCount (njobs);
303
+ pool-> WaitCount (njobs);
293
304
}
294
305
295
306
// NOLINTEND(whitespace/line_length)
0 commit comments