|
20 | 20 | #include "errors.h"
|
21 | 21 | #include "raw.h"
|
22 | 22 | #include "structs.h"
|
| 23 | +#include "threadpool.h" |
23 | 24 | #include <pybind11/buffer_info.h>
|
24 | 25 | #include <pybind11/numpy.h>
|
25 | 26 | #include <pybind11/pybind11.h>
|
26 | 27 | #include <pybind11/stl.h>
|
27 | 28 |
|
28 |
| -#include <condition_variable> |
29 |
| -#include <cstdint> |
30 |
| -#include <functional> |
31 |
| -#include <mutex> |
32 |
| -#include <queue> |
33 |
| -#include <thread> |
34 |
| -#include <vector> |
35 |
| -#include <absl/base/attributes.h> |
36 |
| - |
37 | 29 | namespace mujoco::python {
|
38 | 30 |
|
39 | 31 | namespace {
|
40 | 32 |
|
41 | 33 | namespace py = ::pybind11;
|
42 | 34 |
|
43 |
| -// Copied from https://github.com/google-deepmind/mujoco_mpc/blob/main/mjpc/threadpool.h |
44 |
| -// ThreadPool class |
45 |
| -class ThreadPool { |
46 |
| - public: |
47 |
| - // constructor |
48 |
| - explicit ThreadPool(int num_threads); |
49 |
| - // destructor |
50 |
| - ~ThreadPool(); |
51 |
| - int NumThreads() const { return threads_.size(); } |
52 |
| - // returns an ID between 0 and NumThreads() - 1. must be called within |
53 |
| - // worker thread (returns -1 if not). |
54 |
| - static int WorkerId() { return worker_id_; } |
55 |
| - // ----- methods ----- // |
56 |
| - // set task for threadpool |
57 |
| - void Schedule(std::function<void()> task); |
58 |
| - // return number of tasks completed |
59 |
| - std::uint64_t GetCount() { return ctr_; } |
60 |
| - // reset count to zero |
61 |
| - void ResetCount() { ctr_ = 0; } |
62 |
| - // wait for count, then return |
63 |
| - void WaitCount(int value) { |
64 |
| - std::unique_lock<std::mutex> lock(m_); |
65 |
| - cv_ext_.wait(lock, [&]() { return this->GetCount() >= value; }); |
66 |
| - } |
67 |
| - private: |
68 |
| - // ----- methods ----- // |
69 |
| - // execute task with available thread |
70 |
| - void WorkerThread(int i); |
71 |
| - ABSL_CONST_INIT static thread_local int worker_id_; |
72 |
| - // ----- members ----- // |
73 |
| - std::vector<std::thread> threads_; |
74 |
| - std::mutex m_; |
75 |
| - std::condition_variable cv_in_; |
76 |
| - std::condition_variable cv_ext_; |
77 |
| - std::queue<std::function<void()>> queue_; |
78 |
| - std::uint64_t ctr_; |
79 |
| -}; |
80 |
| - |
81 |
| -// Copied from https://github.com/google-deepmind/mujoco_mpc/blob/main/mjpc/threadpool.cc |
82 |
| -ABSL_CONST_INIT thread_local int ThreadPool::worker_id_ = -1; |
83 |
| -// ThreadPool constructor |
84 |
| -ThreadPool::ThreadPool(int num_threads) : ctr_(0) { |
85 |
| - for (int i = 0; i < num_threads; i++) { |
86 |
| - threads_.push_back(std::thread(&ThreadPool::WorkerThread, this, i)); |
87 |
| - } |
88 |
| -} |
89 |
| -// ThreadPool destructor |
90 |
| -ThreadPool::~ThreadPool() { |
91 |
| - { |
92 |
| - std::unique_lock<std::mutex> lock(m_); |
93 |
| - for (int i = 0; i < threads_.size(); i++) { |
94 |
| - queue_.push(nullptr); |
95 |
| - } |
96 |
| - cv_in_.notify_all(); |
97 |
| - } |
98 |
| - for (auto& thread : threads_) { |
99 |
| - thread.join(); |
100 |
| - } |
101 |
| -} |
102 |
| -// ThreadPool scheduler |
103 |
| -void ThreadPool::Schedule(std::function<void()> task) { |
104 |
| - std::unique_lock<std::mutex> lock(m_); |
105 |
| - queue_.push(std::move(task)); |
106 |
| - cv_in_.notify_one(); |
107 |
| -} |
108 |
| -// ThreadPool worker |
109 |
| -void ThreadPool::WorkerThread(int i) { |
110 |
| - worker_id_ = i; |
111 |
| - while (true) { |
112 |
| - auto task = [&]() { |
113 |
| - std::unique_lock<std::mutex> lock(m_); |
114 |
| - cv_in_.wait(lock, [&]() { return !queue_.empty(); }); |
115 |
| - std::function<void()> task = std::move(queue_.front()); |
116 |
| - queue_.pop(); |
117 |
| - cv_in_.notify_one(); |
118 |
| - return task; |
119 |
| - }(); |
120 |
| - if (task == nullptr) { |
121 |
| - { |
122 |
| - std::unique_lock<std::mutex> lock(m_); |
123 |
| - ++ctr_; |
124 |
| - cv_ext_.notify_one(); |
125 |
| - } |
126 |
| - break; |
127 |
| - } |
128 |
| - task(); |
129 |
| - { |
130 |
| - std::unique_lock<std::mutex> lock(m_); |
131 |
| - ++ctr_; |
132 |
| - cv_ext_.notify_one(); |
133 |
| - } |
134 |
| - } |
135 |
| -} |
136 |
| - |
137 | 35 | // NOLINTBEGIN(whitespace/line_length)
|
138 | 36 |
|
139 | 37 | const auto rollout_doc = R"(
|
@@ -265,19 +163,18 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,
|
265 | 163 | static ThreadPool* pool = nullptr;
|
266 | 164 | void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData*>& d,
|
267 | 165 | int nroll, int nstep, unsigned int control_spec,
|
268 |
| - const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control, |
269 |
| - mjtNum* state, mjtNum* sensordata, |
| 166 | + const mjtNum* state0, const mjtNum* warmstart0, |
| 167 | + const mjtNum* control, mjtNum* state, mjtNum* sensordata, |
270 | 168 | int nthread, int chunk_size) {
|
271 | 169 | int nfulljobs = nroll / chunk_size;
|
272 | 170 | int chunk_remainder = nroll % chunk_size;
|
273 |
| - int njobs = nfulljobs; |
274 |
| - if (chunk_remainder > 0) njobs++; |
| 171 | + int njobs = (chunk_remainder > 0) ? nfulljobs + 1 : nfulljobs; |
275 | 172 |
|
276 | 173 | if (pool == nullptr) {
|
277 | 174 | pool = new ThreadPool(nthread);
|
278 | 175 | }
|
279 | 176 | else if (pool->NumThreads() != nthread) {
|
280 |
| - delete pool; // TODO make sure pool is shutdown correctly |
| 177 | + delete pool; |
281 | 178 | pool = new ThreadPool(nthread);
|
282 | 179 | } else {
|
283 | 180 | pool->ResetCount();
|
|
0 commit comments