25
25
#include < pybind11/pybind11.h>
26
26
#include < pybind11/stl.h>
27
27
28
+ #include < condition_variable>
29
+ #include < cstdint>
30
+ #include < functional>
31
+ #include < iostream>
32
+ #include < mutex>
33
+ #include < queue>
34
+ #include < thread>
35
+ #include < vector>
36
+ #include < absl/base/attributes.h>
37
+
28
38
namespace mujoco ::python {
29
39
30
40
namespace {
31
41
32
42
namespace py = ::pybind11;
33
43
44
+ // Copied from https://github.com/google-deepmind/mujoco_mpc/blob/main/mjpc/threadpool.h
45
+ // ThreadPool class
46
+ class ThreadPool {
47
+ public:
48
+ // constructor
49
+ explicit ThreadPool (int num_threads);
50
+ // destructor
51
+ ~ThreadPool ();
52
+ int NumThreads () const { return threads_.size (); }
53
+ // returns an ID between 0 and NumThreads() - 1. must be called within
54
+ // worker thread (returns -1 if not).
55
+ static int WorkerId () { return worker_id_; }
56
+ // ----- methods ----- //
57
+ // set task for threadpool
58
+ void Schedule (std::function<void ()> task);
59
+ // return number of tasks completed
60
+ std::uint64_t GetCount () { return ctr_; }
61
+ // reset count to zero
62
+ void ResetCount () { ctr_ = 0 ; }
63
+ // wait for count, then return
64
+ void WaitCount (int value) {
65
+ std::unique_lock<std::mutex> lock (m_);
66
+ cv_ext_.wait (lock, [&]() { return this ->GetCount () >= value; });
67
+ }
68
+ private:
69
+ // ----- methods ----- //
70
+ // execute task with available thread
71
+ void WorkerThread (int i);
72
+ ABSL_CONST_INIT static thread_local int worker_id_;
73
+ // ----- members ----- //
74
+ std::vector<std::thread> threads_;
75
+ std::mutex m_;
76
+ std::condition_variable cv_in_;
77
+ std::condition_variable cv_ext_;
78
+ std::queue<std::function<void ()>> queue_;
79
+ std::uint64_t ctr_;
80
+ };
81
+
82
+ // Copied from https://github.com/google-deepmind/mujoco_mpc/blob/main/mjpc/threadpool.cc
83
+ ABSL_CONST_INIT thread_local int ThreadPool::worker_id_ = -1 ;
84
+ // ThreadPool constructor
85
+ ThreadPool::ThreadPool (int num_threads) : ctr_(0 ) {
86
+ for (int i = 0 ; i < num_threads; i++) {
87
+ threads_.push_back (std::thread (&ThreadPool::WorkerThread, this , i));
88
+ }
89
+ }
90
+ // ThreadPool destructor
91
+ ThreadPool::~ThreadPool () {
92
+ {
93
+ std::unique_lock<std::mutex> lock (m_);
94
+ for (int i = 0 ; i < threads_.size (); i++) {
95
+ queue_.push (nullptr );
96
+ }
97
+ cv_in_.notify_all ();
98
+ }
99
+ for (auto & thread : threads_) {
100
+ thread.join ();
101
+ }
102
+ }
103
+ // ThreadPool scheduler
104
+ void ThreadPool::Schedule (std::function<void ()> task) {
105
+ std::unique_lock<std::mutex> lock (m_);
106
+ queue_.push (std::move (task));
107
+ cv_in_.notify_one ();
108
+ }
109
+ // ThreadPool worker
110
+ void ThreadPool::WorkerThread (int i) {
111
+ worker_id_ = i;
112
+ while (true ) {
113
+ auto task = [&]() {
114
+ std::unique_lock<std::mutex> lock (m_);
115
+ cv_in_.wait (lock, [&]() { return !queue_.empty (); });
116
+ std::function<void ()> task = std::move (queue_.front ());
117
+ queue_.pop ();
118
+ cv_in_.notify_one ();
119
+ return task;
120
+ }();
121
+ if (task == nullptr ) {
122
+ {
123
+ std::unique_lock<std::mutex> lock (m_);
124
+ ++ctr_;
125
+ cv_ext_.notify_one ();
126
+ }
127
+ break ;
128
+ }
129
+ task ();
130
+ {
131
+ std::unique_lock<std::mutex> lock (m_);
132
+ ++ctr_;
133
+ cv_ext_.notify_one ();
134
+ }
135
+ }
136
+ }
137
+
34
138
// NOLINTBEGIN(whitespace/line_length)
35
139
36
140
const auto rollout_doc = R"(
@@ -54,7 +158,7 @@ Roll out open-loop trajectories from initial states, get resulting states and se
54
158
// C-style rollout function, assumes all arguments are valid
55
159
// all input fields of d are initialised, contents at call time do not matter
56
160
// after returning, d will contain the last step of the last rollout
57
- void _unsafe_rollout (std::vector<const mjModel*>& m, mjData* d, int nroll , int nstep, unsigned int control_spec,
161
+ void _unsafe_rollout (std::vector<const mjModel*>& m, mjData* d, int start_roll, int end_roll , int nstep, unsigned int control_spec,
58
162
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
59
163
mjtNum* state, mjtNum* sensordata) {
60
164
// sizes
@@ -75,7 +179,7 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n
75
179
}
76
180
77
181
// loop over rollouts
78
- for (int r = 0 ; r < nroll ; r++) {
182
+ for (int r = start_roll ; r < end_roll ; r++) {
79
183
// clear user inputs if unspecified
80
184
if (!(control_spec & mjSTATE_MOCAP_POS)) {
81
185
for (int i = 0 ; i < nbody; i++) {
@@ -158,6 +262,35 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n
158
262
}
159
263
}
160
264
265
+ // C-style threaded version of _unsafe_rollout
266
+ void _unsafe_rollout_threaded (std::vector<const mjModel*>& m, std::vector<mjData*>& d,
267
+ int nroll, int nstep, unsigned int control_spec,
268
+ const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
269
+ mjtNum* state, mjtNum* sensordata,
270
+ int nthread, int chunk_size) {
271
+ int njobs = nroll / chunk_size;
272
+ int chunk_remainder = nroll % chunk_size;
273
+
274
+ ThreadPool pool = ThreadPool (nthread);
275
+ for (int j = 0 ; j < njobs; j++) {
276
+ auto task = [=, &m, &d, &pool](void ) {
277
+ _unsafe_rollout (m, d[pool.WorkerId ()], j*chunk_size, (j+1 )*chunk_size,
278
+ nstep, control_spec, state0, warmstart0, control, state, sensordata);
279
+ };
280
+ pool.Schedule (task);
281
+ }
282
+
283
+ if (chunk_remainder > 0 ) {
284
+ auto task = [=, &m, &d, &pool](void ) {
285
+ _unsafe_rollout (m, d[pool.WorkerId ()], njobs*chunk_size, njobs*chunk_size+chunk_remainder,
286
+ nstep, control_spec, state0, warmstart0, control, state, sensordata);
287
+ };
288
+ pool.Schedule (task);
289
+ }
290
+
291
+ pool.WaitCount (nroll);
292
+ }
293
+
161
294
// NOLINTEND(whitespace/line_length)
162
295
163
296
// check size of optional argument to rollout(), return raw pointer
@@ -190,7 +323,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
190
323
// get subsequent states and corresponding sensor values
191
324
pymodule.def (
192
325
" rollout" ,
193
- [](py::list m, MjDataWrapper& d,
326
+ [](py::list m, py::list d,
194
327
int nstep, unsigned int control_spec,
195
328
const PyCArray state0,
196
329
std::optional<const PyCArray> warmstart0,
@@ -204,7 +337,12 @@ PYBIND11_MODULE(_rollout, pymodule) {
204
337
for (int r = 0 ; r < nroll; r++) {
205
338
model_ptrs[r] = m[r].cast <const MjModelWrapper*>()->get ();
206
339
}
207
- raw::MjData* data = d.get ();
340
+
341
+ int nthread = py::len (d);
342
+ std::vector<raw::MjData*> data_ptrs (nthread);
343
+ for (int t = 0 ; t < nthread; t++) {
344
+ data_ptrs[t] = d[t].cast <MjDataWrapper*>()->get ();
345
+ }
208
346
209
347
// check that some steps need to be taken, return if not
210
348
if (nstep < 1 ) {
@@ -230,9 +368,18 @@ PYBIND11_MODULE(_rollout, pymodule) {
230
368
py::gil_scoped_release no_gil;
231
369
232
370
// call unsafe rollout function
233
- InterceptMjErrors (_unsafe_rollout)(
234
- model_ptrs, data, nroll, nstep, control_spec, state0_ptr,
235
- warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
371
+ if (nthread > 0 ) {
372
+ int chunk_size = std::max (1 , nroll/(10 * nthread));
373
+ InterceptMjErrors (_unsafe_rollout_threaded)(
374
+ model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr,
375
+ warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr,
376
+ nthread, chunk_size);
377
+ }
378
+ else {
379
+ InterceptMjErrors (_unsafe_rollout)(
380
+ model_ptrs, data_ptrs[0 ], 0 , nroll, nstep, control_spec, state0_ptr,
381
+ warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
382
+ }
236
383
}
237
384
},
238
385
py::arg (" model" ),
0 commit comments