Skip to content

Commit 75603ee

Browse files
committed
rollout use threadpool as translation unit
1 parent efd8be1 commit 75603ee

File tree

4 files changed

+16
-119
lines changed

4 files changed

+16
-119
lines changed

python/mujoco/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ target_link_libraries(
383383
structs_header
384384
)
385385

386-
mujoco_pybind11_module(_rollout rollout.cc)
386+
mujoco_pybind11_module(_rollout rollout.cc threadpool.cc)
387387
target_link_libraries(_rollout PRIVATE functions_header mujoco raw)
388388

389389
mujoco_pybind11_module(

python/mujoco/rollout.cc

+5-108
Original file line numberDiff line numberDiff line change
@@ -20,120 +20,18 @@
2020
#include "errors.h"
2121
#include "raw.h"
2222
#include "structs.h"
23+
#include "threadpool.h"
2324
#include <pybind11/buffer_info.h>
2425
#include <pybind11/numpy.h>
2526
#include <pybind11/pybind11.h>
2627
#include <pybind11/stl.h>
2728

28-
#include <condition_variable>
29-
#include <cstdint>
30-
#include <functional>
31-
#include <mutex>
32-
#include <queue>
33-
#include <thread>
34-
#include <vector>
35-
#include <absl/base/attributes.h>
36-
3729
namespace mujoco::python {
3830

3931
namespace {
4032

4133
namespace py = ::pybind11;
4234

43-
// Copied from https://github.com/google-deepmind/mujoco_mpc/blob/main/mjpc/threadpool.h
44-
// ThreadPool class
45-
class ThreadPool {
46-
public:
47-
// constructor
48-
explicit ThreadPool(int num_threads);
49-
// destructor
50-
~ThreadPool();
51-
int NumThreads() const { return threads_.size(); }
52-
// returns an ID between 0 and NumThreads() - 1. must be called within
53-
// worker thread (returns -1 if not).
54-
static int WorkerId() { return worker_id_; }
55-
// ----- methods ----- //
56-
// set task for threadpool
57-
void Schedule(std::function<void()> task);
58-
// return number of tasks completed
59-
std::uint64_t GetCount() { return ctr_; }
60-
// reset count to zero
61-
void ResetCount() { ctr_ = 0; }
62-
// wait for count, then return
63-
void WaitCount(int value) {
64-
std::unique_lock<std::mutex> lock(m_);
65-
cv_ext_.wait(lock, [&]() { return this->GetCount() >= value; });
66-
}
67-
private:
68-
// ----- methods ----- //
69-
// execute task with available thread
70-
void WorkerThread(int i);
71-
ABSL_CONST_INIT static thread_local int worker_id_;
72-
// ----- members ----- //
73-
std::vector<std::thread> threads_;
74-
std::mutex m_;
75-
std::condition_variable cv_in_;
76-
std::condition_variable cv_ext_;
77-
std::queue<std::function<void()>> queue_;
78-
std::uint64_t ctr_;
79-
};
80-
81-
// Copied from https://github.com/google-deepmind/mujoco_mpc/blob/main/mjpc/threadpool.cc
82-
ABSL_CONST_INIT thread_local int ThreadPool::worker_id_ = -1;
83-
// ThreadPool constructor
84-
ThreadPool::ThreadPool(int num_threads) : ctr_(0) {
85-
for (int i = 0; i < num_threads; i++) {
86-
threads_.push_back(std::thread(&ThreadPool::WorkerThread, this, i));
87-
}
88-
}
89-
// ThreadPool destructor
90-
ThreadPool::~ThreadPool() {
91-
{
92-
std::unique_lock<std::mutex> lock(m_);
93-
for (int i = 0; i < threads_.size(); i++) {
94-
queue_.push(nullptr);
95-
}
96-
cv_in_.notify_all();
97-
}
98-
for (auto& thread : threads_) {
99-
thread.join();
100-
}
101-
}
102-
// ThreadPool scheduler
103-
void ThreadPool::Schedule(std::function<void()> task) {
104-
std::unique_lock<std::mutex> lock(m_);
105-
queue_.push(std::move(task));
106-
cv_in_.notify_one();
107-
}
108-
// ThreadPool worker
109-
void ThreadPool::WorkerThread(int i) {
110-
worker_id_ = i;
111-
while (true) {
112-
auto task = [&]() {
113-
std::unique_lock<std::mutex> lock(m_);
114-
cv_in_.wait(lock, [&]() { return !queue_.empty(); });
115-
std::function<void()> task = std::move(queue_.front());
116-
queue_.pop();
117-
cv_in_.notify_one();
118-
return task;
119-
}();
120-
if (task == nullptr) {
121-
{
122-
std::unique_lock<std::mutex> lock(m_);
123-
++ctr_;
124-
cv_ext_.notify_one();
125-
}
126-
break;
127-
}
128-
task();
129-
{
130-
std::unique_lock<std::mutex> lock(m_);
131-
++ctr_;
132-
cv_ext_.notify_one();
133-
}
134-
}
135-
}
136-
13735
// NOLINTBEGIN(whitespace/line_length)
13836

13937
const auto rollout_doc = R"(
@@ -265,19 +163,18 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,
265163
static ThreadPool* pool = nullptr;
266164
void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData*>& d,
267165
int nroll, int nstep, unsigned int control_spec,
268-
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
269-
mjtNum* state, mjtNum* sensordata,
166+
const mjtNum* state0, const mjtNum* warmstart0,
167+
const mjtNum* control, mjtNum* state, mjtNum* sensordata,
270168
int nthread, int chunk_size) {
271169
int nfulljobs = nroll / chunk_size;
272170
int chunk_remainder = nroll % chunk_size;
273-
int njobs = nfulljobs;
274-
if (chunk_remainder > 0) njobs++;
171+
int njobs = (chunk_remainder > 0) ? nfulljobs + 1 : nfulljobs;
275172

276173
if (pool == nullptr) {
277174
pool = new ThreadPool(nthread);
278175
}
279176
else if (pool->NumThreads() != nthread) {
280-
delete pool; // TODO make sure pool is shutdown correctly
177+
delete pool;
281178
pool = new ThreadPool(nthread);
282179
} else {
283180
pool->ResetCount();

python/mujoco/threadpool.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2022 DeepMind Technologies Limited
1+
// Copyright 2024 DeepMind Technologies Limited
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "mjpc/threadpool.h"
15+
#include "threadpool.h"
1616

1717
#include <condition_variable>
1818
#include <functional>
@@ -22,7 +22,7 @@
2222

2323
#include <absl/base/attributes.h>
2424

25-
namespace mjpc {
25+
namespace mujoco::python {
2626

2727
ABSL_CONST_INIT thread_local int ThreadPool::worker_id_ = -1;
2828

@@ -84,4 +84,4 @@ void ThreadPool::WorkerThread(int i) {
8484
}
8585
}
8686

87-
} // namespace mjpc
87+
} // namespace mujoco::python

python/mujoco/threadpool.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2022 DeepMind Technologies Limited
1+
// Copyright 2024 DeepMind Technologies Limited
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -12,8 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#ifndef MJPC_THREADPOOL_H_
16-
#define MJPC_THREADPOOL_H_
15+
#ifndef MUJOCO_PYTHON_THREADPOOL_H_
16+
#define MUJOCO_PYTHON_THREADPOOL_H_
1717

1818
#include <condition_variable>
1919
#include <cstdint>
@@ -26,7 +26,7 @@
2626

2727
#include <absl/base/attributes.h>
2828

29-
namespace mjpc {
29+
namespace mujoco::python {
3030

3131
// ThreadPool class
3232
class ThreadPool {
@@ -76,6 +76,6 @@ class ThreadPool {
7676
std::uint64_t ctr_;
7777
};
7878

79-
} // namespace mjpc
79+
} // namespace mujoco::python
8080

81-
#endif // MJPC_THREADPOOL_H_
81+
#endif // MUJOCO_PYTHON_THREADPOOL_H_

0 commit comments

Comments
 (0)