Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide ONNX converter #538

Open
jcitrin opened this issue Nov 19, 2024 · 18 comments
Open

Provide ONNX converter #538

jcitrin opened this issue Nov 19, 2024 · 18 comments

Comments

@jcitrin
Copy link
Collaborator

jcitrin commented Nov 19, 2024

Enables easier coupling of ML-surrogates trained with various libraries

@theo-brown
Copy link
Collaborator

theo-brown commented Feb 13, 2025

Guessing this would be using https://github.com/google/jaxonnxruntime? Can have a crack at it if that would be helpful. We've got another surrogate that we're interested in bringing in, and supporting ONNX might speed things up.

@jcitrin
Copy link
Collaborator Author

jcitrin commented Feb 13, 2025

Thanks! We were looking into it, not sure on where we're at right now. @ernoc , @sbodenstein , @Nush395 ?

@sbodenstein
Copy link
Collaborator

@theo-brown: I will help work on this. And yes, will use the jaxonnxruntime I think. Do you have a model with inputs I can test on? Can also share it privately.

@theo-brown
Copy link
Collaborator

theo-brown commented Feb 20, 2025

Awesome!

It might be easiest to start with the ONNX version of the TGLF neural net in tglf-surrogate, as the interface to/from TORAX structures is already there. I can add the ONNX version to that branch, or hop on a new one if that would be better.

Equally if the QLKNN model has an ONNX version we could start with that?

I can come up with the beginnings of a PR on Friday

@sbodenstein
Copy link
Collaborator

Just to make 100% sure: do you have a surrogate model trained in a different framework (PyTorch, Scikit-learn) and want to run it in JAX? (jaxonnxruntime only converts ONNX models to JAX) Or do you want to export JAX surrogate models to ONNX?

@theo-brown
Copy link
Collaborator

Ah yes I understand how what I said could be confusing.

I was thinking of beginning with a model that we have both a) a non-JAX version and b) an existing TORAX-connected JAX version.
We would then ONNX-ify the non-JAX version and compare the performance of TORAX with the ONNX and JAX-native versions.

For example, for the TGLF surrogate, Lorenzo originally built the whole thing in Pytorch. I made a JAX version of the architecture, which we could load the saved weights into, and wrote the wrapper to connect it to TORAX. We could easily ONNX-ify the Pytorch version and figure out how to load it in from ONNX into JAX, and reuse the same wrapper to interface with TORAX i/o.

@sbodenstein
Copy link
Collaborator

OK that makes sense! Just so we don't duplicate efforts: do you have the beginnings of a PR you mentioned? Or should I implement a PyTorch to JAX workflow via ONNX?

@theo-brown
Copy link
Collaborator

theo-brown commented Feb 24, 2025

I made less progress than expected because I discovered a bunch of connective stuff we hadn't yet done on the TORAX side with this surrogate, so I worked on that instead. Do go ahead and start if you've got time - I don't have the beginnings of a PR

@sbodenstein
Copy link
Collaborator

sbodenstein commented Feb 27, 2025

Note that there is an alternative path from PyTorch to JAX: https://pytorch.org/xla/master/features/stablehlo.html#using-extract-jax. One useful thing for us: does this conversion approach work with your model? And does the jaxonnxruntime approach work?

@theo-brown
Copy link
Collaborator

@sbodenstein could you drop me an email, theo.brown@ukaea.uk, to discuss this further?

@theo-brown
Copy link
Collaborator

@sbodenstein I haven't tried the PyTorch XLA backend. What's the motivation for me doing so?

The jaxonnxruntime methods works fine for my model, subject to one quirk that I don't quite understand, marked in the TODO below. This isn't TORAX-related per se, so I'll chase it up with the jaxonnx team.

from jaxonnxruntime.backend import Backend
import onnx
from jaxonnxruntime.config_class import config

# TODO: identify why this is needed. It's something to do with the input normalisation/output denormalisation
config.update("jaxort_only_allow_initializers_as_static_args", False)

onnx_model = onnx.load_model(ONNX_MODEL_PATH)

jax_model = Backend.prepare(onnx_model)
jax_model.run({"x": jnp.zeros((1, 15), dtype=jnp.float32)})

A couple of gotchas I've noted for the PR:

  • As far as I can tell, exporting to ONNX requires fixing the batch dimension. For example, in PyTorch this model could take inputs of shape (b, 15), where b is a batch dimension (in this case, b is the number of radial points to evaluate a transport model at). However, when exported to ONNX this gets restricted to (1, 15). If there's appetite for an ONNX converter in TORAX, then the wrapper will need to vmap the model so it can be applied on the correct grid.
  • Dtypes are super important. This ONNX model will error if its inputs aren't f32!

That said, this is all with just playing around with a ONNX-saved model that I was given. There may be things we could change in the exporting step that sidestep these problems entirely.

Handily, the information about the input dims and dtypes are all readily available from the onnx_model, which means these should both be trivial. I imagine a function _wrap_onnx_model that loads the model using onnx, converts it with jaxonnxruntime, and returns a function f(jax.Array) -> jax.Array, with all the dimension checking, vectorisation, and type checking performed under the hood.
In this way, we're not pinning down the ONNX interface to being for a particular object (transport models, sources, etc), and the model call can just be treated like a standard jittable function.

@jcitrin not entirely sure where this would belong best. Are there any plans for a general surrogate_utils-type folder?

@jcitrin
Copy link
Collaborator Author

jcitrin commented Feb 28, 2025

@jcitrin not entirely sure where this would belong best. Are there any plans for a general surrogate_utils-type folder?

There's actually a new repository that @hamelphi just launched today 🚀
https://github.com/google-deepmind/fusion_transport_surrogates

See #784 WIP PR for getting the new QLKNN_7_11 model in

fusion_transport_surrogates is where the NN inference code lives. The idea is that the turbulence surrogates can be used in multiple frameworks so it makes sense that TORAX depends on this external library that others can depend on too.

So this sounds like the logical place for the ONNX interface and also the inference code and model descriptions in #477 ?

@theo-brown
Copy link
Collaborator

There's actually a new repository that @hamelphi just launched today 🚀

Yeah just saw this - great work! Looking forward to hearing more about it.

this sounds like the logical place for the ONNX interface and also the inference code and model descriptions in #477

Definitely makes sense to me! Next week I'll separate out the TGLFNN model and inference definitions from the TORAX interface, and make a new PR for the models in fusion_transport_surrogates. Having played around with ONNX this week I feel quite well placed to make a PR for that there too. @hamelphi I will probably seek your advice on what sort of interface you envision working best!

@theo-brown
Copy link
Collaborator

As far as I can tell, exporting to ONNX requires fixing the batch dimension

This is not true, which is good! It just requires a bit more care when exporting.
e.g. https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

@sbodenstein
Copy link
Collaborator

@theo-brown: the motivation is that this will likely be a better maintained path for importing Torch -> JAX. But happy to support the ONNX route. Do you think we should just document how to convert the models with the ONNX converter, or add a ONNX dependency to Torax?

@jcitrin
Copy link
Collaborator Author

jcitrin commented Mar 3, 2025

Shouldn't the ONNX dependency be in https://github.com/google-deepmind/fusion_transport_surrogates ? @hamelphi

@hamelphi
Copy link
Collaborator

hamelphi commented Mar 3, 2025

Yes, I think it would make sense to have the ONNX dependencies in fusion_transport_surrogates.

@theo-brown
Copy link
Collaborator

this will likely be a better maintained path for importing Torch -> JAX

Makes sense - jaxonnxruntime doesn't look particularly long term, better to go with something that's got a strong maintenance trajectory.

That said, supporting ONNX is useful generally, as it's used eg in coupling Julia models to JINTRAC at UKAEA.

Do you think we should just document how to convert the models with the ONNX converter, or add a ONNX dependency to Torax [or fusion-transport-surrogates]?

It may be that documentation and good examples are all that's required in the end. When I've got time to work on this (Thu/Fri) I'll turn what I've got into some docs, and then if there's loads of boilerplate that would probably be reused then I can add that to one of the codebases.

Shouldn't the ONNX dependency be in fusion-transport-surrogates?

In my opinion this depends slightly whether there is any current interest in coupling non transport surrogates.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants