Skip to content

Commit 36f3321

Browse files
committed
rollout prototype C++ threadpool for benchmarking
1 parent a793a34 commit 36f3321

File tree

1 file changed

+154
-7
lines changed

1 file changed

+154
-7
lines changed

python/mujoco/rollout.cc

+154-7
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,116 @@
2525
#include <pybind11/pybind11.h>
2626
#include <pybind11/stl.h>
2727

28+
#include <condition_variable>
29+
#include <cstdint>
30+
#include <functional>
31+
#include <iostream>
32+
#include <mutex>
33+
#include <queue>
34+
#include <thread>
35+
#include <vector>
36+
#include <absl/base/attributes.h>
37+
2838
namespace mujoco::python {
2939

3040
namespace {
3141

3242
namespace py = ::pybind11;
3343

44+
// Copied from https://github.com/google-deepmind/mujoco_mpc/blob/main/mjpc/threadpool.h
45+
// ThreadPool class
46+
class ThreadPool {
47+
public:
48+
// constructor
49+
explicit ThreadPool(int num_threads);
50+
// destructor
51+
~ThreadPool();
52+
int NumThreads() const { return threads_.size(); }
53+
// returns an ID between 0 and NumThreads() - 1. must be called within
54+
// worker thread (returns -1 if not).
55+
static int WorkerId() { return worker_id_; }
56+
// ----- methods ----- //
57+
// set task for threadpool
58+
void Schedule(std::function<void()> task);
59+
// return number of tasks completed
60+
std::uint64_t GetCount() { return ctr_; }
61+
// reset count to zero
62+
void ResetCount() { ctr_ = 0; }
63+
// wait for count, then return
64+
void WaitCount(int value) {
65+
std::unique_lock<std::mutex> lock(m_);
66+
cv_ext_.wait(lock, [&]() { return this->GetCount() >= value; });
67+
}
68+
private:
69+
// ----- methods ----- //
70+
// execute task with available thread
71+
void WorkerThread(int i);
72+
ABSL_CONST_INIT static thread_local int worker_id_;
73+
// ----- members ----- //
74+
std::vector<std::thread> threads_;
75+
std::mutex m_;
76+
std::condition_variable cv_in_;
77+
std::condition_variable cv_ext_;
78+
std::queue<std::function<void()>> queue_;
79+
std::uint64_t ctr_;
80+
};
81+
82+
// Copied from https://github.com/google-deepmind/mujoco_mpc/blob/main/mjpc/threadpool.cc
83+
ABSL_CONST_INIT thread_local int ThreadPool::worker_id_ = -1;
84+
// ThreadPool constructor
85+
ThreadPool::ThreadPool(int num_threads) : ctr_(0) {
86+
for (int i = 0; i < num_threads; i++) {
87+
threads_.push_back(std::thread(&ThreadPool::WorkerThread, this, i));
88+
}
89+
}
90+
// ThreadPool destructor
91+
ThreadPool::~ThreadPool() {
92+
{
93+
std::unique_lock<std::mutex> lock(m_);
94+
for (int i = 0; i < threads_.size(); i++) {
95+
queue_.push(nullptr);
96+
}
97+
cv_in_.notify_all();
98+
}
99+
for (auto& thread : threads_) {
100+
thread.join();
101+
}
102+
}
103+
// ThreadPool scheduler
104+
void ThreadPool::Schedule(std::function<void()> task) {
105+
std::unique_lock<std::mutex> lock(m_);
106+
queue_.push(std::move(task));
107+
cv_in_.notify_one();
108+
}
109+
// ThreadPool worker
110+
void ThreadPool::WorkerThread(int i) {
111+
worker_id_ = i;
112+
while (true) {
113+
auto task = [&]() {
114+
std::unique_lock<std::mutex> lock(m_);
115+
cv_in_.wait(lock, [&]() { return !queue_.empty(); });
116+
std::function<void()> task = std::move(queue_.front());
117+
queue_.pop();
118+
cv_in_.notify_one();
119+
return task;
120+
}();
121+
if (task == nullptr) {
122+
{
123+
std::unique_lock<std::mutex> lock(m_);
124+
++ctr_;
125+
cv_ext_.notify_one();
126+
}
127+
break;
128+
}
129+
task();
130+
{
131+
std::unique_lock<std::mutex> lock(m_);
132+
++ctr_;
133+
cv_ext_.notify_one();
134+
}
135+
}
136+
}
137+
34138
// NOLINTBEGIN(whitespace/line_length)
35139

36140
const auto rollout_doc = R"(
@@ -54,7 +158,7 @@ Roll out open-loop trajectories from initial states, get resulting states and se
54158
// C-style rollout function, assumes all arguments are valid
55159
// all input fields of d are initialised, contents at call time do not matter
56160
// after returning, d will contain the last step of the last rollout
57-
void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int nstep, unsigned int control_spec,
161+
void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll, int end_roll, int nstep, unsigned int control_spec,
58162
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
59163
mjtNum* state, mjtNum* sensordata) {
60164
// sizes
@@ -75,7 +179,7 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n
75179
}
76180

77181
// loop over rollouts
78-
for (int r = 0; r < nroll; r++) {
182+
for (int r = start_roll; r < end_roll; r++) {
79183
// clear user inputs if unspecified
80184
if (!(control_spec & mjSTATE_MOCAP_POS)) {
81185
for (int i = 0; i < nbody; i++) {
@@ -158,6 +262,35 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int nroll, int n
158262
}
159263
}
160264

265+
// C-style threaded version of _unsafe_rollout
266+
void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData*>& d,
267+
int nroll, int nstep, unsigned int control_spec,
268+
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
269+
mjtNum* state, mjtNum* sensordata,
270+
int nthread, int chunk_size) {
271+
int njobs = nroll / chunk_size;
272+
int chunk_remainder = nroll % chunk_size;
273+
274+
ThreadPool pool = ThreadPool(nthread);
275+
for (int j = 0; j < njobs; j++) {
276+
auto task = [=, &m, &d, &pool](void) {
277+
_unsafe_rollout(m, d[pool.WorkerId()], j*chunk_size, (j+1)*chunk_size,
278+
nstep, control_spec, state0, warmstart0, control, state, sensordata);
279+
};
280+
pool.Schedule(task);
281+
}
282+
283+
if (chunk_remainder > 0) {
284+
auto task = [=, &m, &d, &pool](void) {
285+
_unsafe_rollout(m, d[pool.WorkerId()], njobs*chunk_size, njobs*chunk_size+chunk_remainder,
286+
nstep, control_spec, state0, warmstart0, control, state, sensordata);
287+
};
288+
pool.Schedule(task);
289+
}
290+
291+
pool.WaitCount(nroll);
292+
}
293+
161294
// NOLINTEND(whitespace/line_length)
162295

163296
// check size of optional argument to rollout(), return raw pointer
@@ -190,7 +323,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
190323
// get subsequent states and corresponding sensor values
191324
pymodule.def(
192325
"rollout",
193-
[](py::list m, MjDataWrapper& d,
326+
[](py::list m, py::list d,
194327
int nstep, unsigned int control_spec,
195328
const PyCArray state0,
196329
std::optional<const PyCArray> warmstart0,
@@ -204,7 +337,12 @@ PYBIND11_MODULE(_rollout, pymodule) {
204337
for (int r = 0; r < nroll; r++) {
205338
model_ptrs[r] = m[r].cast<const MjModelWrapper*>()->get();
206339
}
207-
raw::MjData* data = d.get();
340+
341+
int nthread = py::len(d);
342+
std::vector<raw::MjData*> data_ptrs(nthread);
343+
for (int t = 0; t < nthread; t++) {
344+
data_ptrs[t] = d[t].cast<MjDataWrapper*>()->get();
345+
}
208346

209347
// check that some steps need to be taken, return if not
210348
if (nstep < 1) {
@@ -230,9 +368,18 @@ PYBIND11_MODULE(_rollout, pymodule) {
230368
py::gil_scoped_release no_gil;
231369

232370
// call unsafe rollout function
233-
InterceptMjErrors(_unsafe_rollout)(
234-
model_ptrs, data, nroll, nstep, control_spec, state0_ptr,
235-
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
371+
if (nthread > 0) {
372+
int chunk_size = std::max(1, nroll/(10 * nthread));
373+
InterceptMjErrors(_unsafe_rollout_threaded)(
374+
model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr,
375+
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr,
376+
nthread, chunk_size);
377+
}
378+
else {
379+
InterceptMjErrors(_unsafe_rollout)(
380+
model_ptrs, data_ptrs[0], 0, nroll, nstep, control_spec, state0_ptr,
381+
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
382+
}
236383
}
237384
},
238385
py::arg("model"),

0 commit comments

Comments
 (0)