-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Compile Time Args to Make Them Programmatically Accessible #18379
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool looks good! Just a few minor things (would be good to get some input from others on the embedded comments).
constexpr auto kernel_compile_time_args = make_array<std::uint32_t>(); | ||
#endif | ||
|
||
#define get_compile_time_arg_val(arg_idx) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to think about this a little. Ideally we want to static_assert
if arg_idx >= kernel_compile_time_args.size()
. Maybe could be (but then lots of template code):
template <size_t i>
constexpr uint32_t get_ct_arg() {
static_assert(i < kernel_compile_time_args.size());
return kernel_compile_time_args[i];
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather it cause a substitution failure than a static_assert
personally. I also think that rather than storing the compile-time arguments in a global constexpr std::array
, the access should be done more directly. This array is what causes the kernel OOM issues when there are too many compile-time arguments (~400 based on the other thread I saw here), but there are ways to circumvent that limitation completely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, the substitution failure could be implemented like this:
template <size_t Idx>
constexpr std::enable_if_t<(Idx < kernel_compile_time_args.size()), uint32_t> get_ct_arg() {
return kernel_compile_time_args[Idx];
}
but what I would suggest for replacing the std::array
is a bit more complicated than what I can fit in this thread, though it could be used in much the same way without taking up memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickroberts would you (after applying your suggestion with enable_if), this is a good enough starting point or should we evaluate this other approach you are thinking of?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you can ensure existing ops (actually used in models, not contrived ones with 400 ct args) don't have increased binary sizes, I am okay with this as a starting point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also if this version of GCC has __builtin_is_constant_evaluated()
, I would suggest using that to enforce get_ct_arg
is only called at compile-time, since it's compiled with C++17 and we can't use std::is_constant_evaluated()
or if consteval
.
tt_metal/hw/inc/dataflow_api.h
Outdated
@@ -219,7 +231,8 @@ FORCE_INLINE T get_common_arg_val(int arg_idx) { | |||
* | arg_idx | The index of the argument | uint32_t | 0 to 31 | True | | |||
*/ | |||
// clang-format on | |||
#define get_compile_time_arg_val(arg_idx) KERNEL_COMPILE_TIME_ARG_##arg_idx | |||
#define get_compile_time_arg_val(arg_idx) \ | |||
((arg_idx) < kernel_compile_time_args.size() ? kernel_compile_time_args[(arg_idx)] : 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can it be commonized with the impl in compile_time_args?
} | ||
defines += "-DKERNEL_COMPILE_TIME_ARGS="; | ||
for (uint32_t i = 0; i < values.size(); i++) { | ||
defines += to_string(values[i]) + ","; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider stringstream?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be cool to have some TMP in a test too :)
CoreCoord core = {0, 0}; | ||
Program program; | ||
|
||
const uint32_t num_compile_time_args = 400; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try even bigger? :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initially, I tried higher values, but I kept on hitting the mem limit. I think the maximum number of compile time args that a kernel can have is ~450.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay sounds good. Can you elabourate a bit on the error? Additionally, can you please add this to the PR description in addition the the limit you saw?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing the number of args will depend on the binary size, so for a bigger kernel it will be lower.
.noc = NOC::RISCV_0_default, | ||
.compile_args = compile_time_args, | ||
.defines = defines}); | ||
this->RunProgram(device, program); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does the test verify the compile args successfully went through?
A simple way would be to pass a L1 address into the kernel. The kernel reads the compile args and writes them to L1. Then from the host side, read the device L1 to check if those args are written there to verify them. That's how runtime_args_kernel.cpp
does it
tests/tt_metal/tt_metal/test_kernels/misc/compile_time_args_kernel.cpp
Outdated
Show resolved
Hide resolved
for (uint32_t i = 0; i < NUM_COMPILE_TIME_ARGS; i++) { | ||
if (kernel_compile_time_args[i] != i) { | ||
ASSERT(0); | ||
while (1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should verify this from the host side using the approach I described ~similar to runtime_args_kernel.cpp
because if this doesn't work it'll hang the CI until timeout.
What would be the minimal example if instead of std::array (shown in the issue) we'd want to populate a struct? E.g. for two parameters something like
Both on the program factory (CreateKernel) side and kernel side? Also, for std::array, how would that look like on the program factory (CreateKernel) side (assuming the example is on the kernel side, which I may be mistaken about)? |
This PR deserves a Pr description. |
it's still only a draft, but yes! |
@bbradelTT - fundamentally, nothing has changed about compile time arg passing from the host side. So the std::array usage on the device side is just to put the args into a container. Previously compile time args were implemented as individual macros: then the compile command for the kernel would be i.e. which has the undesired consequence that all "calls" to With this change, you can now use (constexpr) variables to fetch compile time args. This is the fundamental missing piece to let us write reusable template code libraries/utils that relies on compile time args. Note that this PR does NOT add automatic or "free" object serialization (from host) and deserialization (from device) at compile time. However, all the required pieces are in place so that we could start to do this. For example, we could now do something like this to pack/serialize
then something like this to unpack/deserialize (very crude example but conveys the idea - ideally this could be formalized on the type itself):
then in the device kernel code that wants this object (again, crude but could be prettified):
This gives us some nice (simple) benefits like we can easily insert or reorder things in our CT args without having to potentially manually update the values of every CT arg. But more generally it means we can start to share these ct arg unpack functions. For example, see the below example for how some really annoying and error prone code in some CCL kernel is simplified (and now reusable across CCL kernels!. The context is that I have two operand tensors to the kernel. They may be: {INTERLEAVED, INTERLEAVED}, {INTERLEAVED, SHARDED}, {SHARDED, INTERLEAVED}, {SHARDED, SHARDED}. In this case, interleaved consumes fewer args than sharded, so I need to have different ct_arg getters (facilitated by macro defines) for every possible combination.
With the changes from this PR, we can replace the above code with the following:
This could be wrapped in a function if we wanted and then this is pretty much fully composable and reusable. Additionally, we can trivially start doing cool things like adding special value CT args at the start of compile time arg sequences representing an object. For example
then something like this to unpack/deserialize (very crude example but conveys the idea - ideally this could be formalized on the type itself):
Anyways this was a super long response, and some of the examples were a little basic, but the point is we can now start to do some really useful things! |
5d43369
to
71af073
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a description to the PR detailing the nature of the change. Great work!
CoreCoord core = {0, 0}; | ||
Program program; | ||
|
||
const uint32_t num_compile_time_args = 400; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing the number of args will depend on the binary size, so for a bigger kernel it will be lower.
950e791
to
e100ec8
Compare
e100ec8
to
4244b93
Compare
Ticket
#16040
Checklist