-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Provide ONNX converter #538
Comments
Guessing this would be using https://github.com/google/jaxonnxruntime? Can have a crack at it if that would be helpful. We've got another surrogate that we're interested in bringing in, and supporting ONNX might speed things up. |
Thanks! We were looking into it, not sure on where we're at right now. @ernoc , @sbodenstein , @Nush395 ? |
@theo-brown: I will help work on this. And yes, will use the jaxonnxruntime I think. Do you have a model with inputs I can test on? Can also share it privately. |
Awesome! It might be easiest to start with the ONNX version of the TGLF neural net in Equally if the QLKNN model has an ONNX version we could start with that? I can come up with the beginnings of a PR on Friday |
Just to make 100% sure: do you have a surrogate model trained in a different framework (PyTorch, Scikit-learn) and want to run it in JAX? (jaxonnxruntime only converts ONNX models to JAX) Or do you want to export JAX surrogate models to ONNX? |
Ah yes I understand how what I said could be confusing. I was thinking of beginning with a model that we have both a) a non-JAX version and b) an existing TORAX-connected JAX version. For example, for the TGLF surrogate, Lorenzo originally built the whole thing in Pytorch. I made a JAX version of the architecture, which we could load the saved weights into, and wrote the wrapper to connect it to TORAX. We could easily ONNX-ify the Pytorch version and figure out how to load it in from ONNX into JAX, and reuse the same wrapper to interface with TORAX i/o. |
OK that makes sense! Just so we don't duplicate efforts: do you have the beginnings of a PR you mentioned? Or should I implement a PyTorch to JAX workflow via ONNX? |
I made less progress than expected because I discovered a bunch of connective stuff we hadn't yet done on the TORAX side with this surrogate, so I worked on that instead. Do go ahead and start if you've got time - I don't have the beginnings of a PR |
Note that there is an alternative path from PyTorch to JAX: https://pytorch.org/xla/master/features/stablehlo.html#using-extract-jax. One useful thing for us: does this conversion approach work with your model? And does the jaxonnxruntime approach work? |
@sbodenstein could you drop me an email, theo.brown@ukaea.uk, to discuss this further? |
@sbodenstein I haven't tried the PyTorch XLA backend. What's the motivation for me doing so? The from jaxonnxruntime.backend import Backend
import onnx
from jaxonnxruntime.config_class import config
# TODO: identify why this is needed. It's something to do with the input normalisation/output denormalisation
config.update("jaxort_only_allow_initializers_as_static_args", False)
onnx_model = onnx.load_model(ONNX_MODEL_PATH)
jax_model = Backend.prepare(onnx_model)
jax_model.run({"x": jnp.zeros((1, 15), dtype=jnp.float32)}) A couple of gotchas I've noted for the PR:
That said, this is all with just playing around with a ONNX-saved model that I was given. There may be things we could change in the exporting step that sidestep these problems entirely. Handily, the information about the input dims and dtypes are all readily available from the @jcitrin not entirely sure where this would belong best. Are there any plans for a general |
There's actually a new repository that @hamelphi just launched today 🚀 See #784 WIP PR for getting the new QLKNN_7_11 model in fusion_transport_surrogates is where the NN inference code lives. The idea is that the turbulence surrogates can be used in multiple frameworks so it makes sense that TORAX depends on this external library that others can depend on too. So this sounds like the logical place for the ONNX interface and also the inference code and model descriptions in #477 ? |
Yeah just saw this - great work! Looking forward to hearing more about it.
Definitely makes sense to me! Next week I'll separate out the TGLFNN model and inference definitions from the TORAX interface, and make a new PR for the models in fusion_transport_surrogates. Having played around with ONNX this week I feel quite well placed to make a PR for that there too. @hamelphi I will probably seek your advice on what sort of interface you envision working best! |
This is not true, which is good! It just requires a bit more care when exporting. |
@theo-brown: the motivation is that this will likely be a better maintained path for importing Torch -> JAX. But happy to support the ONNX route. Do you think we should just document how to convert the models with the ONNX converter, or add a ONNX dependency to Torax? |
Shouldn't the ONNX dependency be in https://github.com/google-deepmind/fusion_transport_surrogates ? @hamelphi |
Yes, I think it would make sense to have the ONNX dependencies in fusion_transport_surrogates. |
Makes sense - jaxonnxruntime doesn't look particularly long term, better to go with something that's got a strong maintenance trajectory. That said, supporting ONNX is useful generally, as it's used eg in coupling Julia models to JINTRAC at UKAEA.
It may be that documentation and good examples are all that's required in the end. When I've got time to work on this (Thu/Fri) I'll turn what I've got into some docs, and then if there's loads of boilerplate that would probably be reused then I can add that to one of the codebases.
In my opinion this depends slightly whether there is any current interest in coupling non transport surrogates. |
Enables easier coupling of ML-surrogates trained with various libraries
The text was updated successfully, but these errors were encountered: