@@ -50,8 +50,8 @@ Roll out open-loop trajectories from initial states, get resulting states and se
50
50
output arguments (optional):
51
51
state (nroll x nstep x nstate) nroll nstep states
52
52
sensordata (nroll x nstep x nsendordata) nroll trajectories of nstep sensordata vectors
53
- chunk_divisor integer, determines threadpool chunk size according to
54
- chunk_size = max(1, nroll / (nthread * chunk_divisor )
53
+ chunk_size integer, determines threadpool chunk size. If unspecified
54
+ chunk_size = max(1, nroll / (nthread * 10 )
55
55
)" ;
56
56
57
57
// C-style rollout function, assumes all arguments are valid
@@ -241,7 +241,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
241
241
std::optional<const PyCArray> control,
242
242
std::optional<const PyCArray> state,
243
243
std::optional<const PyCArray> sensordata,
244
- int chunk_divisor
244
+ std::optional< int > chunk_size
245
245
) {
246
246
// get raw pointers
247
247
int nroll = state0.shape (0 );
@@ -281,11 +281,17 @@ PYBIND11_MODULE(_rollout, pymodule) {
281
281
282
282
// call unsafe rollout function
283
283
if (nthread > 1 && nroll > 1 ) {
284
- int chunk_size = std::max (1 , nroll / (chunk_divisor * nthread));
284
+ int chunk_size_final = 1 ;
285
+ if (!chunk_size.has_value ()) {
286
+ chunk_size_final = std::max (1 , nroll / (10 * nthread));
287
+ }
288
+ else {
289
+ chunk_size_final = *chunk_size;
290
+ }
285
291
InterceptMjErrors (_unsafe_rollout_threaded)(
286
292
model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr,
287
293
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr,
288
- nthread, chunk_size );
294
+ nthread, chunk_size_final );
289
295
}
290
296
else {
291
297
InterceptMjErrors (_unsafe_rollout)(
@@ -303,7 +309,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
303
309
py::arg (" control" ) = py::none (),
304
310
py::arg (" state" ) = py::none (),
305
311
py::arg (" sensordata" ) = py::none (),
306
- py::arg (" chunk_divisor " ) = 10 ,
312
+ py::arg (" chunk_size " ) = py::none () ,
307
313
py::doc (rollout_doc)
308
314
);
309
315
}
0 commit comments