Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve argument inference mechanism #39

Open
govereau opened this issue Feb 12, 2025 · 0 comments
Open

Improve argument inference mechanism #39

govereau opened this issue Feb 12, 2025 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@govereau
Copy link
Collaborator

We added a feature to infer the types of arguments, so that kernels could be compiled without needing to write a driver in python to supply the arguments. This is handy, but the mechanism is not great because it just uses small test tensors for everything, which is not going to be the correct choice in many cases.

Background

ML Kernels are compiled for specific inputs. That is, even though the code could be used for many different inputs, such as tensors of different sizes, before compilation specific inputs are selected, and the compilation is only valid for argument similar to the specific inputs. For example:

def kernel(a):
  nl.load(a) + 1

def test():
  a = np.ndarray((128,128))
  K = Parser(kernel)
  K.apply_args(a)
  K.to_klr()

In this example, in order to translate the kernel we have to supply a test argument a. This is inconvenient because the testing code is longer than the kernel, and you have to run the python test function to start the translation process. This second problem prevents us from having a traditional compiler workflow. Note, this is not a problem when situated inside of a framework (like Jax), it is only a problem for developers writing kernels outside of a framework, or developers of KLR itself, that just want to easily generate KLR from various example kernels.

We would like a solution that allows the developer to start using KLR without the test function above. Such a solution should make development easier, and also make setting up unit test frameworks easier.

Current Approach

The current approach simply fills in all missing argument with tensors of size 128x128 with element type float32. While this would work for the example above, it is not general enough.

Ideas

Use type annotations

If the user wrote a type annotation, e.g.:

def kernel(a : tensor[float32, (128,512)]): ...

then we do not need prototype arguments---everything the translation needs to know is already in the type annotation. The issue is with kernels that are meant to work on lots of different sizes of tensors. First, writing these type annotations is something that needs to be designed, and then deciding how to instantiate the type brings us back to the original problem.

Use files

For tensors, the user could generate some serialized numpy files with names that correspond to the arguments (e.g. a.npy). We could use these files automatically to instantiate the arguments.

Use command-line arguments

We could have a system of command line arguments for specifying how to instantiate the missing arguments, but as the number of parameters grows it seems easier to just write a python driver function.

@govereau govereau added the enhancement New feature or request label Feb 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants