-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Compile Time Args to Make Them Programmatically Accessible (TMP) #16040
Open
4 tasks
Labels
Comments
FYI @pgkeller - please let us know if this raises any concerns on your end or if you feel there is any additional set of work we should include here. |
6 tasks
6 tasks
6 tasks
6 tasks
@SeanNijjar what would be the minimal example if instead of std::array we'd want to populate a struct? E.g. for two parameters something like
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This new approach will let us make compile time args programmatically accessible and usable in techniques like template metaprogramming, which greatly improves capabilities at one end of the spectrum, and at the other end, allows us to greatly improve code reuse, which is currently nearly impossible due to the requirement that compile time args require literals to access.
The new proposed approach will directly generate the compile time args into a(n) std::array.
Rather than passing compile time args individually like:
We instead will pass them together like so:
This godbolt shows the minimal example:
https://godbolt.org/z/11Ka4WK93
Where the above will replace (or rather extend) the current implementation of
get_compile_time_arg_val
.To maintain backwards compatibility, we keep
get_compile_time_arg_val
but replace its implementation with an access to the compile time args array,kernel_compile_time_args
Tasks for this issue:
get_compile_time_arg_val
and exposekernel_compile_time_args
- [ ] Update uses ofget_compile_time_arg_val
to instead use constexpr variables in some example files (e.g, CCL command interpreter) to simplify the code and to provide the diff as an example for other developers.Note on tests:
It would be great if we could have a test that does something like the following as a basic demonstration of the new capabilities:
Adding Notes from Offline Discussion About Benefits
Here's an example of what we have to deal with today. Any time there's conditionality on how many CT args are captured for some specific purpose, that conditionality needs to propagate to all later CT arg captures (and any further conditionality is combinatorial with earlier conditionality).
In this example, I have two operand tensors to the kernel. They may be: {INTERLEAVED, INTERLEAVED}, {INTERLEAVED, SHARDED}, {SHARDED, INTERLEAVED}, {SHARDED, SHARDED}. In this case, interleaved consumes fewer args than sharded.
If operand 0 is interleaved and operand 1 is sharded, then operand 1's CT args are earlier in the list. If operand 0 is sharded, then operand 1's CT args are at a later index (notice both of these are for initializations of in1_sharded_addrgen_fields). I need to do this multiple places in the file - major annoyance for something so simple.
Also... no reusability here because each kernel has slightly different CT arg usage patterns.
With this feature I can collapse the above to :
And once I can do the above I can start to put this into a library for reuse across kernels:
and I just call build_sharded_addrgen_fields_from_ct_args instead (edited)
Dealing with Arg alignment
Alignment is no different (worse) than before but in practice becomes much easier - why?
On host you can (and we do in CCLs) have "serialization" of certain types to CT or RT args. Without this feature, I need to manually implement "deserialization" of the type for every kernel that wants it. WITH this feature, I can implement a single "deserializer" for reuse everywhere. As long as those two are in agreement with each other (which we could write gtests for), we are golden.
We can go so far as to include tag or special values at the start/end of each complex type (one that requires multiple args to be together in some order) which can be detected at compile time if there is misalignment
Additionally, with the above we can start to do much more sophisticated code gen. I've built an "interpreter" of sorts for CCLs that take in "commands" or "instruction" and these are higher level operations ("forward this logical tensor slice from that CB to the fabric" type of thing) and with features like the above, I can transform a command stream directly into a kernel
The text was updated successfully, but these errors were encountered: