-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a DeviceBoundThreadPool class conforming to the ThreadPool interface #18751
base: main
Are you sure you want to change the base?
Conversation
- Built on top of the NumaAwareExecutor class - Can be created using the create_device_bound_thread_pool API - Allows submitting tasks to specific threads (each thread is tied to a physical device, and can thus process tasks only for that device) - Is NUMA Aware: Threads spawned for a physical device will be pinned to NUMA nodes "closest" to the device
There are a lot of good implementations of the threadpool over all the internet. It doesn't make any sense to reimplement the wheel. You will need to support it. Btw for the boost thread here is an easy way to set thread affinities (maybe there is a better way I don't now).
|
"Existing Boost backed thread pool implementation does not guarantee an optimal distribution of work across devices |
} | ||
|
||
private: | ||
std::vector<std::unique_ptr<NumaAwareExecutor>> workers_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we have a vector of boost thread pools (backed by 1 thread each), and then follow Denys' suggestion on pinning the threads to cpu cores correspondingly?
@@ -12,10 +12,11 @@ namespace tt::tt_metal { | |||
class ThreadPool { | |||
public: | |||
virtual ~ThreadPool() = default; | |||
virtual void enqueue(std::function<void()>&& f) = 0; | |||
virtual void enqueue(std::function<void()>&& f, uint32_t thread_idx = 0) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's change thread_idx
to device_idx
, and also make it optional. I think it makes sense to support a general case, where we do some work that is not affiliated with a device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's see if we really need thread_idx
on the interface. Seems like we just need uniform distribution of work which can be done underneath the API by better assignment of work to thread
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I thought the point of this PR is to make sure dispatches related to particular devices go to particular threads? My comment here was to say I treat it as a "hint" not as a requirement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's see what the perf. results show but agree w/ your comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the idea Oleg. I think we can:
- Use the hint provided by the caller if given
- Maintain internal state and use that for picking a thread, if a hint is not provided
This way we can ensure an even distribution of work.
I'll update the arg to bedevice_idx
std::shared_ptr<ThreadPool> create_boost_thread_pool(int num_threads); | ||
std::shared_ptr<ThreadPool> create_device_bound_thread_pool(int num_threads, uint32_t logical_cpu_offset = 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding some tests would be great - we can just have a simple stress test, but if you compile it and run with --enable-tsan
it might show bugs.
As I understood you need to have a few different pools. One per each numa node? |
Thanks for the feedback and ideas everyone, I appreciate you going through the code. A small model was used for illustration purposes, but the same pattern applies across the board. Boost Thread-Pool Custom Thread-Pool with a Vector of Boost Thread Pools (1 Thread Each) Thread-Pool from this PR Does this matter for End to End Performance? Boost Thread-Pool
Variation: 17%, Min: 9364 tok/s, Max: 11269 tok/s Custom Thread-Pool with a Vector of Boost Thread Pools (1 Thread Each)
Variation: 15%, Min: 9498 tok/s, Max: 11137 tok/s Thread-Pool from this PR
Variation: 3%, Min: 11063 tok/s, Max: 11353 tok/s Why does this matter? How can we ensure that a custom implementation is not buggy? The implementation can then be mainlined along with the rest of our work. We will run all stress tests and model tests before merging to main. In the meantime, I can suggest 2 things:
I hope this answers any outstanding questions. Thanks again everyone :) |
@tt-asaigal could you show your code? how do you run tasks? If you claim that boost thread pool doesn't use all threads you can create a simple standalone example with this issue. |
@@ -478,7 +479,9 @@ void MeshCommandQueue::enqueue_write_shards( | |||
}); | |||
|
|||
for (std::size_t shard_idx = 0; shard_idx < shard_data_transfers.size(); shard_idx++) { | |||
dispatch_thread_pool_->enqueue([&dispatch_lambda, shard_idx]() { dispatch_lambda(shard_idx); }); | |||
dispatch_thread_pool_->enqueue( | |||
[&dispatch_lambda, shard_idx]() { dispatch_lambda(shard_idx); }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I understood: you create a dispatch lambda which is std::function, then pass it by reference as capture paramter t oanother lambda which will become andother std::function inside of the enqueue. Making dispatch_lambda just lambda might save you some time.
@cfjchu @omilyutin-tt ptal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing the std::function and making it a raw lambda is a good optimization, though it doesn't affect the performance numbers I posted above. The function gets allocated once and reused across enqueue
calls.
@@ -23,7 +187,7 @@ class BoostThreadPool : public ThreadPool { | |||
|
|||
~BoostThreadPool() noexcept override = default; | |||
|
|||
void enqueue(std::function<void()>&& f) override { | |||
void enqueue(std::function<void()>&& f, uint32_t thread_idx) override { | |||
std::packaged_task<void()> task(std::move(f)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can optimize this code and compare vs your solution.
Lets forget about managing futures and just call pool.join() in the wait().
For a experiment you can just remove all future related things for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pool.join()
kills the thread pool. It will be unusable once you call join. Identical semantics to thread.join()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then use pool.wait()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When looking at the docs, I learnt that both APIs behave the same way: https://live.boost.org/doc/libs/master/doc/html/boost_asio/reference/thread_pool/wait.html.
I have also tested this myself. Here is the branch, with your suggestion: asaigal/thread_pool_experiments
The following test will fail:
./build_Release/test/tt_metal/distributed/distributed_unit_tests_wormhole_b0 --gtest_filter="*ThreadPoolTest*"
If you remove the first wait, it will pass.
Boost does not provide a light-weight API to synchronize threads. We need this for our implementation, which is why we brought up our own class.
Hey @dmakoviichuk-tt implementations for each thread-pool are pushed to the following branches:
All tasks are pushed by the main thread. A standalone example using boost has been pushed to https://github.com/tenstorrent/tt-metal/tree/asaigal/thread_pool_experiments. You can try this out and verify the behaviour on your end as well. Here is the Falcon 7B test: Our integration branch for testing model performance is: jchu/ttnn-integration-with-mesh |
Ticket
Link to Github Issue
Problem description
What's changed
DeviceBoundThreadPool
class conforming to theThreadPool
interfaceNumaAwareExecutor
classcreate_device_bound_thread_pool
APIChecklist