Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Compile Time Args to Make Them Programmatically Accessible (TMP) #16040

Open
4 tasks
SeanNijjar opened this issue Dec 16, 2024 · 3 comments · May be fixed by #16043 or #18379
Open
4 tasks

Improve Compile Time Args to Make Them Programmatically Accessible (TMP) #16040

SeanNijjar opened this issue Dec 16, 2024 · 3 comments · May be fixed by #16043 or #18379
Assignees
Labels

Comments

@SeanNijjar
Copy link
Contributor

SeanNijjar commented Dec 16, 2024

This new approach will let us make compile time args programmatically accessible and usable in techniques like template metaprogramming, which greatly improves capabilities at one end of the spectrum, and at the other end, allows us to greatly improve code reuse, which is currently nearly impossible due to the requirement that compile time args require literals to access.

The new proposed approach will directly generate the compile time args into a(n) std::array.

Rather than passing compile time args individually like:

-DKERNEL_COMPILE_TIME_ARG_0=5 -DKERNEL_COMPILE_TIME_ARG_1=99 ... and so on

We instead will pass them together like so:

"-DKERNEL_COMPILE_TIME_ARGS=100,101,102,103,104,105,106,107,108"

This godbolt shows the minimal example:
https://godbolt.org/z/11Ka4WK93

#include <array>
#include <cstdint>

template <class T, class... Ts>
constexpr auto make_array(Ts... values) -> std::array<T, sizeof...(Ts)> {
    return {T(values)...};
}

inline constexpr auto kernel_compile_time_args{
    make_array<std::uint32_t>(KERNEL_COMPILE_TIME_ARGS),
};

Where the above will replace (or rather extend) the current implementation of get_compile_time_arg_val.

To maintain backwards compatibility, we keep get_compile_time_arg_val but replace its implementation with an access to the compile time args array, kernel_compile_time_args

Tasks for this issue:

  • Modify JIT compile commands to change compile time args to be passed as a single list rather than many individual defines
  • Possibly for backwards compatibility - also export the old defines. I see a few limited tests using them directly though they are tests and can just be updated to use the new APIs instead (they'll end up being cleaner actually)
  • Extend the device headers to modify get_compile_time_arg_val and expose kernel_compile_time_args
  • Validate scalability: Develop some tests that specify a large minimum number of compile time args (e.g. 1k, 10k) to ensure the approach is scalable (ideally we can pass atleast as many args as before
    - [ ] Update uses of get_compile_time_arg_val to instead use constexpr variables in some example files (e.g, CCL command interpreter) to simplify the code and to provide the diff as an example for other developers.

Note on tests:

It would be great if we could have a test that does something like the following as a basic demonstration of the new capabilities:

template <size_t index, size_t n_vals>
void ct_unrolled_loop(volatile uint32_t *buffer_ptr) {
  *buffer_ptr[index] = kernel_compile_time_args[index];
  if constexpr (index < n_vals - 1) {
    ct_unrolled_loop<index + 1, n_vals>(buffer_ptr);
  }
}

void kernel_main() {
  auto cb_id = kernel_compile_time_args[0];
  auto n_vals = kernel_compile_time_args[1];
  cb_reserve_back(cb_id, 1); // Assume page size == n_vals * sizeof(val);
  auto buffer_ptr = get_wr_ptr(cb_id);
  kernel_compile_time_args<0, n_vals>(buffer_ptr);
  
  
  noc_async_write(buffer_ptr, get_noc_addr(get_arg_val<uint32_t>(0),get_arg_val<uint32_t>(1), get_arg_val<uint32_t>(2)), n_vals * sizeof(uint32_t)); // assume val == uint32 
}

Adding Notes from Offline Discussion About Benefits

Here's an example of what we have to deal with today. Any time there's conditionality on how many CT args are captured for some specific purpose, that conditionality needs to propagate to all later CT arg captures (and any further conditionality is combinatorial with earlier conditionality).

In this example, I have two operand tensors to the kernel. They may be: {INTERLEAVED, INTERLEAVED}, {INTERLEAVED, SHARDED}, {SHARDED, INTERLEAVED}, {SHARDED, SHARDED}. In this case, interleaved consumes fewer args than sharded.

If operand 0 is interleaved and operand 1 is sharded, then operand 1's CT args are earlier in the list. If operand 0 is sharded, then operand 1's CT args are at a later index (notice both of these are for initializations of in1_sharded_addrgen_fields). I need to do this multiple places in the file - major annoyance for something so simple.

Also... no reusability here because each kernel has slightly different CT arg usage patterns.

#ifndef SINGLE_TENSOR
#if defined(TENSOR1_SHARDED_MEM_LAYOUT)
#if defined(TENSOR0_SHARDED_MEM_LAYOUT)
constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = {
    true,
    get_compile_time_arg_val(17),
    get_compile_time_arg_val(18),
    get_compile_time_arg_val(19),
    get_compile_time_arg_val(20),
    get_compile_time_arg_val(21),
    get_compile_time_arg_val(22),
    get_compile_time_arg_val(23) != 0};
#else
// Then we are only consuming ct args for second operand and we resume from operand 10
constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = {
    true,
    get_compile_time_arg_val(10),
    get_compile_time_arg_val(11),
    get_compile_time_arg_val(12),
    get_compile_time_arg_val(13),
    get_compile_time_arg_val(14),
    get_compile_time_arg_val(15),
    get_compile_time_arg_val(16) != 0};
#endif
#else
constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = {0, 0, 0, 0, 0, 0, 0, 0};
#endif
#endif

With this feature I can collapse the above to :

// some CT arg capture above
constexpr operand_1_ct_arg_idx_start = last_ct_arg + 1; // (last_ct arg also constexpr)
#ifndef SINGLE_TENSOR
constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = {
    true,
    get_compile_time_arg_val(operand_1_ct_arg_idx_start),
    get_compile_time_arg_val(operand_1_ct_arg_idx_start + 1),
    get_compile_time_arg_val(operand_1_ct_arg_idx_start + 2),
    get_compile_time_arg_val(operand_1_ct_arg_idx_start + 3),
    get_compile_time_arg_val(operand_1_ct_arg_idx_start + 4),
    get_compile_time_arg_val(operand_1_ct_arg_idx_start + 5),
    get_compile_time_arg_val(operand_1_ct_arg_idx_start + 6) != 0};
#else
constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = {0, 0, 0, 0, 0, 0, 0, 0};
#endif

And once I can do the above I can start to put this into a library for reuse across kernels:

template <size_t start_idx>
constexpr sharded_addrgen_fields build_sharded_addrgen_fields_from_ct_args() {
  return sharded_addrgen_fields{
    true,
    get_compile_time_arg_val(start_idx),
    get_compile_time_arg_val(start_idx + 1),
    get_compile_time_arg_val(start_idx + 2),
    get_compile_time_arg_val(start_idx + 3),
    get_compile_time_arg_val(start_idx + 4),
    get_compile_time_arg_val(start_idx + 5),
    get_compile_time_arg_val(start_idx + 6) != 0};
}

and I just call build_sharded_addrgen_fields_from_ct_args instead (edited)

Dealing with Arg alignment

Alignment is no different (worse) than before but in practice becomes much easier - why?

On host you can (and we do in CCLs) have "serialization" of certain types to CT or RT args. Without this feature, I need to manually implement "deserialization" of the type for every kernel that wants it. WITH this feature, I can implement a single "deserializer" for reuse everywhere. As long as those two are in agreement with each other (which we could write gtests for), we are golden.

We can go so far as to include tag or special values at the start/end of each complex type (one that requires multiple args to be together in some order) which can be detected at compile time if there is misalignment

Additionally, with the above we can start to do much more sophisticated code gen. I've built an "interpreter" of sorts for CCLs that take in "commands" or "instruction" and these are higher level operations ("forward this logical tensor slice from that CB to the fabric" type of thing) and with features like the above, I can transform a command stream directly into a kernel

@SeanNijjar
Copy link
Contributor Author

FYI @pgkeller - please let us know if this raises any concerns on your end or if you feel there is any additional set of work we should include here.

@SeanNijjar
Copy link
Contributor Author

SeanNijjar commented Dec 16, 2024

@SeanNijjar SeanNijjar changed the title Improve Compile Time Args to Make Them Programmatically Accesible (TMP) Improve Compile Time Args to Make Them Programmatically Accessible (TMP) Dec 16, 2024
@SeanNijjar SeanNijjar removed their assignment Feb 26, 2025
@sagarwalTT sagarwalTT linked a pull request Feb 26, 2025 that will close this issue
6 tasks
@bbradelTT
Copy link
Contributor

@SeanNijjar what would be the minimal example if instead of std::array we'd want to populate a struct? E.g. for two parameters something like

struct CoreParams {
   std::uint32_t per_core_M;
   std::uint32_t per_core_N;
};

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
4 participants