Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Compile Time Args to Make Them Programmatically Accessible #18379

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

sagarwalTT
Copy link
Contributor

Ticket

#16040

Checklist

@sagarwalTT sagarwalTT linked an issue Feb 26, 2025 that may be closed by this pull request
4 tasks
Copy link
Contributor

@SeanNijjar SeanNijjar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool looks good! Just a few minor things (would be good to get some input from others on the embedded comments).

constexpr auto kernel_compile_time_args = make_array<std::uint32_t>();
#endif

#define get_compile_time_arg_val(arg_idx) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to think about this a little. Ideally we want to static_assert if arg_idx >= kernel_compile_time_args.size(). Maybe could be (but then lots of template code):

template <size_t i>
constexpr uint32_t get_ct_arg() {
  static_assert(i < kernel_compile_time_args.size());
  return kernel_compile_time_args[i];
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather it cause a substitution failure than a static_assert personally. I also think that rather than storing the compile-time arguments in a global constexpr std::array, the access should be done more directly. This array is what causes the kernel OOM issues when there are too many compile-time arguments (~400 based on the other thread I saw here), but there are ways to circumvent that limitation completely.

Copy link
Contributor

@patrickroberts patrickroberts Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, the substitution failure could be implemented like this:

template <size_t Idx>
constexpr std::enable_if_t<(Idx < kernel_compile_time_args.size()), uint32_t> get_ct_arg() {
  return kernel_compile_time_args[Idx];
}

but what I would suggest for replacing the std::array is a bit more complicated than what I can fit in this thread, though it could be used in much the same way without taking up memory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickroberts would you (after applying your suggestion with enable_if), this is a good enough starting point or should we evaluate this other approach you are thinking of?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can ensure existing ops (actually used in models, not contrived ones with 400 ct args) don't have increased binary sizes, I am okay with this as a starting point

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if this version of GCC has __builtin_is_constant_evaluated(), I would suggest using that to enforce get_ct_arg is only called at compile-time, since it's compiled with C++17 and we can't use std::is_constant_evaluated() or if consteval.

@@ -219,7 +231,8 @@ FORCE_INLINE T get_common_arg_val(int arg_idx) {
* | arg_idx | The index of the argument | uint32_t | 0 to 31 | True |
*/
// clang-format on
#define get_compile_time_arg_val(arg_idx) KERNEL_COMPILE_TIME_ARG_##arg_idx
#define get_compile_time_arg_val(arg_idx) \
((arg_idx) < kernel_compile_time_args.size() ? kernel_compile_time_args[(arg_idx)] : 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can it be commonized with the impl in compile_time_args?

}
defines += "-DKERNEL_COMPILE_TIME_ARGS=";
for (uint32_t i = 0; i < values.size(); i++) {
defines += to_string(values[i]) + ",";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider stringstream?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be cool to have some TMP in a test too :)

CoreCoord core = {0, 0};
Program program;

const uint32_t num_compile_time_args = 400;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try even bigger? :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially, I tried higher values, but I kept on hitting the mem limit. I think the maximum number of compile time args that a kernel can have is ~450.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay sounds good. Can you elabourate a bit on the error? Additionally, can you please add this to the PR description in addition the the limit you saw?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing the number of args will depend on the binary size, so for a bigger kernel it will be lower.

.noc = NOC::RISCV_0_default,
.compile_args = compile_time_args,
.defines = defines});
this->RunProgram(device, program);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the test verify the compile args successfully went through?
A simple way would be to pass a L1 address into the kernel. The kernel reads the compile args and writes them to L1. Then from the host side, read the device L1 to check if those args are written there to verify them. That's how runtime_args_kernel.cpp does it

for (uint32_t i = 0; i < NUM_COMPILE_TIME_ARGS; i++) {
if (kernel_compile_time_args[i] != i) {
ASSERT(0);
while (1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should verify this from the host side using the approach I described ~similar to runtime_args_kernel.cpp because if this doesn't work it'll hang the CI until timeout.

@bbradelTT
Copy link
Contributor

bbradelTT commented Feb 27, 2025

What would be the minimal example if instead of std::array (shown in the issue) we'd want to populate a struct? E.g. for two parameters something like

struct CoreParams {
   std::uint32_t per_core_M;
   std::uint32_t per_core_N;
};

Both on the program factory (CreateKernel) side and kernel side?

Also, for std::array, how would that look like on the program factory (CreateKernel) side (assuming the example is on the kernel side, which I may be mistaken about)?

@ayerofieiev-tt
Copy link
Member

This PR deserves a Pr description.

@SeanNijjar
Copy link
Contributor

This PR deserves a Pr description.

it's still only a draft, but yes!

@SeanNijjar
Copy link
Contributor

SeanNijjar commented Feb 27, 2025

What would be the minimal example if instead of std::array (shown in the issue) we'd want to populate a struct? E.g. for two parameters something like

struct CoreParams {
   std::uint32_t per_core_M;
   std::uint32_t per_core_N;
};

Both on the program factory (CreateKernel) side and kernel side?

Also, for std::array, how would that look like on the program factory (CreateKernel) side (assuming the example is on the kernel side, which I may be mistaken about)?

@bbradelTT - fundamentally, nothing has changed about compile time arg passing from the host side. So the std::array usage on the device side is just to put the args into a container.

Previously compile time args were implemented as individual macros:
e.g. if ct_args on host = {5, 99};

then the compile command for the kernel would be g++ <compile_args> -DKERNEL_COMPILE_TIME_ARG_0=5, -DKERNEL_COMPILE_TIME_ARG_1=99 and then on the kernel side, compile time args were implemented as macro expansions:

i.e. #define get_compile_time_arg_val(arg_idx) KERNEL_COMPILE_TIME_ARG_##arg_idx

which has the undesired consequence that all "calls" to get_compile_time_arg_val require a literal... literal to work (otherwise you get a compile error). You couldn't pass for example constexpr variables to get_compile_time_arg_val.

With this change, you can now use (constexpr) variables to fetch compile time args. This is the fundamental missing piece to let us write reusable template code libraries/utils that relies on compile time args.

Note that this PR does NOT add automatic or "free" object serialization (from host) and deserialization (from device) at compile time. However, all the required pieces are in place so that we could start to do this.

For example, we could now do something like this to pack/serialize

// on host - this code could already be written before this changes and similar type of code is use in some ops
void serialize_to_ct_args(CoreParams const& core_params, std::vector<uint32_t> &ct_args_out) {
  ct_args_out.push_back({core_params.per_core_M);
  ct_args_out.push_back({core_params.per_core_N);
}

then something like this to unpack/deserialize (very crude example but conveys the idea - ideally this could be formalized on the type itself):

// device scope
template <size_t ct_arg_idx>
constexpr CoreParams build_core_params_from_ct_args() {
  return CoreParams{get_compile_time_arg_val(ct_arg_idx), get_compile_time_arg_val(ct_arg_idx + 1)}; 
}
constexpr size_t core_params_ct_args_consumed() {
  return 2;
}

then in the device kernel code that wants this object (again, crude but could be prettified):

void kernel_main() {
  constexpr size_t ct_arg_idx = 0;
  constexpr CoreParams core_params = build_core_params_from_ct_args<ct_arg_idx>();

  constexpr size_t next_ct_arg_idx = core_params_ct_args_consumed() + ct_arg_idx;
  // unpack any other objects you want
  
  // main kernel code
  //...
}

This gives us some nice (simple) benefits like we can easily insert or reorder things in our CT args without having to potentially manually update the values of every CT arg. But more generally it means we can start to share these ct arg unpack functions.

For example, see the below example for how some really annoying and error prone code in some CCL kernel is simplified (and now reusable across CCL kernels!. The context is that I have two operand tensors to the kernel. They may be: {INTERLEAVED, INTERLEAVED}, {INTERLEAVED, SHARDED}, {SHARDED, INTERLEAVED}, {SHARDED, SHARDED}. In this case, interleaved consumes fewer args than sharded, so I need to have different ct_arg getters (facilitated by macro defines) for every possible combination.

#if defined(TENSOR1_SHARDED_MEM_LAYOUT)
#if defined(TENSOR0_SHARDED_MEM_LAYOUT)
constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = {
    true,
    get_compile_time_arg_val(17),
    get_compile_time_arg_val(18),
    get_compile_time_arg_val(19),
    get_compile_time_arg_val(20),
    get_compile_time_arg_val(21),
    get_compile_time_arg_val(22),
    get_compile_time_arg_val(23) != 0};
#else
// Then we are only consuming ct args for second operand and we resume from operand 10
constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = {
    true,
    get_compile_time_arg_val(10),
    get_compile_time_arg_val(11),
    get_compile_time_arg_val(12),
    get_compile_time_arg_val(13),
    get_compile_time_arg_val(14),
    get_compile_time_arg_val(15),
    get_compile_time_arg_val(16) != 0};
#endif

With the changes from this PR, we can replace the above code with the following:

// some CT arg capture above
constexpr operand_1_ct_arg_idx_start = last_ct_arg + 1; // (last_ct arg also constexpr)
constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = {
    true,
    get_compile_time_arg_val(operand_1_ct_arg_idx_start),
    get_compile_time_arg_val(operand_1_ct_arg_idx_start + 1),
    get_compile_time_arg_val(operand_1_ct_arg_idx_start + 2),
    get_compile_time_arg_val(operand_1_ct_arg_idx_start + 3),
    get_compile_time_arg_val(operand_1_ct_arg_idx_start + 4),
    get_compile_time_arg_val(operand_1_ct_arg_idx_start + 5),
    get_compile_time_arg_val(operand_1_ct_arg_idx_start + 6) != 0};

This could be wrapped in a function if we wanted and then this is pretty much fully composable and reusable.

Additionally, we can trivially start doing cool things like adding special value CT args at the start of compile time arg sequences representing an object.

For example

// on host - this code could already be written before this changes and similar type of code is use in some ops
void serialize_to_ct_args(CoreParams const& core_params, std::vector<uint32_t> &ct_args_out) {
  ct_args_out.push_back({0x00c0ffee);  // some special marker value
  ct_args_out.push_back({core_params.per_core_M);
  ct_args_out.push_back({core_params.per_core_N);
}

then something like this to unpack/deserialize (very crude example but conveys the idea - ideally this could be formalized on the type itself):

// device scope
template <size_t ct_arg_idx>
constexpr CoreParams build_core_params_from_ct_args() {
  static_assert(get_compile_time_arg_val(ct_arg_idx) == 0x00c0ffee, "CT arg misalignment. Didn't get expected key marker at start of CoreParams unpacking");  // make sure we got the special value
  return CoreParams{get_compile_time_arg_val(ct_arg_idx + 1), get_compile_time_arg_val(ct_arg_idx + 2)}; 
}
constexpr size_t core_params_ct_args_consumed() {
  return 3;
}

Anyways this was a super long response, and some of the examples were a little basic, but the point is we can now start to do some really useful things!

@sagarwalTT sagarwalTT force-pushed the sagarwal/compile_time_args branch 3 times, most recently from 5d43369 to 71af073 Compare February 27, 2025 20:02
Copy link
Contributor

@sjameelTT sjameelTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a description to the PR detailing the nature of the change. Great work!

CoreCoord core = {0, 0};
Program program;

const uint32_t num_compile_time_args = 400;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing the number of args will depend on the binary size, so for a bigger kernel it will be lower.

@sagarwalTT sagarwalTT force-pushed the sagarwal/compile_time_args branch 4 times, most recently from 950e791 to e100ec8 Compare March 3, 2025 15:16
@sagarwalTT sagarwalTT force-pushed the sagarwal/compile_time_args branch from e100ec8 to 4244b93 Compare March 6, 2025 22:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve Compile Time Args to Make Them Programmatically Accessible (TMP)
7 participants