You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In particular, the user need to know both to access JAX (e.g. TPU) features.
This RFC proposes that we should refactor the API such that the user only need to know the "jax" device in order to operate on the TPUs.
Motivation
The device is a well known concept in PyTorch: tutorials talk about tensor.cpu() and tensor.cuda(). It's commonly understood that if I create a tensor without the device argument, then that tensor lives on the default device (usually CPU). If I call .cuda(), then that tensor is moved to the GPU.
Since people primarily choose torchax to be able to use the TPU (or access the XLA GPU backend), it makes sense to present this functionality as a PyTorch device. By the principle of symmetry, it's natural to introduce jax counterparts for various cuda APIs, where applicable. This gives people a clear mental model of when they are or aren't using JAX.
We look at a few examples (all these assume we import torchax):
I can call torch.cuda.current_device() to get the index of the current CUDA device.
I can also call torch.jax.current_device() to get the index of the current JAX (XLA) device.
I can call torch.cuda.is_available() to check if CUDA support is available.
I can call torch.jax.is_available() to check if the JAX backend is available.
I can run torch.randn(1, 2, device='cuda') to generate a random number using the CUDA device.
❌ If I ran torch.randn(1, 2, device='jax'), that fails with a confusing dispatcher error: 1
I can run torch.set_default_device('cuda') to make all subsequent tensor live on the CUDA device.
❌ If I ran torch.set_default_device('jax') and then creates some tensor, that fails with another confusing error: 2
This RFC proposes that we should change torchax to close the behavior divergence such as the two above bullet points. In the limit, using eager torchax should feel identical to using some other backend of PyTorch.
Pitch
Always call enable_globally()
We're pretty close to closing the gaps above. If I run torchax.enable_globally() after importing torchax, then torch.randn(1, 2, device='jax') works, and the error after torch.set_default_device('jax') seems like a fixable bug. I propose we go one step further and just automatically call enable_globally() and we should also fix the default device behavior.
Always keep the torchax modes activated
Today the environment object is what activates the XLAFunctionMode and XLADispatchMode that intercept PyTorch operations. However, these modes are an implementation detail of how torchax supports the JAX device. It should be possible to always keep the XLAFunctionMode and XLADispatchMode activated in the mode stack, without changing the behavior of non-JAX tensors. This is akin to how PyTorch already keeps a few modes such as FuncTorchVmapMode and FuncTorchDynamicLayerFrontMode in the stack most of the time. For testing purposes, it could be useful to temporarily disable the XLAFunctionMode and XLADispatchMode, but that should be an internal API that users don't know about.
As a pressure test, we could probably try running some subset of PyTorch tests with XLA{Function,Dispatch}Mode in the mode stack, and make sure those don't fail. That's to ensure that even if the user import torchax, their CPU tensor behaviors don't change.
This suggests we need to decouple the XLAFunctionMode and XLADispatchMode from the environment. For example, perhaps those could be relocated to a torchax._internal.XLAModes context manager.
Configuration context managers
The environment object also holds certain configuration (e.g. optimize for performance or accuracy). As a user it's useful to change these settings sometimes. We can keep them in the environment and always provide a sensible default in the default environment. We could also support a stack of environments via context managers, where configurations at the top of the stack takes precedence. That's a useful way to locally change some config and have them revert to previous values when leaving the scope.
RNG seed
The environment object also holds a seed for the pseudo random number generator. That should probably change as part of solving #8636.
Alternatives
An alternative is to do nothing and stick to the status quo.
Another follow up is to see what torch changes do we need to remove frictions of using the JAX device. For example, today if I write tensor.jax() with the hope of moving the tensor to the JAX device, the Python type checker complains that jax() is not a known function on Tensor, unlike cuda.
Additional context
Anecdotally, some people had questions why they have to create an environment to use a torchax tensor and didn't understand the error message when the environment was missing.
The text was updated successfully, but these errors were encountered:
🚀 Feature
Today we have two related concepts in
torchax
:In particular, the user need to know both to access JAX (e.g. TPU) features.
This RFC proposes that we should refactor the API such that the user only need to know the "jax" device in order to operate on the TPUs.
Motivation
The device is a well known concept in PyTorch: tutorials talk about
tensor.cpu()
andtensor.cuda()
. It's commonly understood that if I create a tensor without the device argument, then that tensor lives on the default device (usually CPU). If I call.cuda()
, then that tensor is moved to the GPU.Since people primarily choose
torchax
to be able to use the TPU (or access the XLA GPU backend), it makes sense to present this functionality as a PyTorch device. By the principle of symmetry, it's natural to introducejax
counterparts for variouscuda
APIs, where applicable. This gives people a clear mental model of when they are or aren't using JAX.We look at a few examples (all these assume we
import torchax
):I can call
torch.cuda.current_device()
to get the index of the current CUDA device.torch.jax.current_device()
to get the index of the current JAX (XLA) device.I can call
torch.cuda.is_available()
to check if CUDA support is available.torch.jax.is_available()
to check if the JAX backend is available.I can run
torch.randn(1, 2, device='cuda')
to generate a random number using the CUDA device.torch.randn(1, 2, device='jax')
, that fails with a confusing dispatcher error: 1I can run
torch.set_default_device('cuda')
to make all subsequent tensor live on the CUDA device.torch.set_default_device('jax')
and then creates some tensor, that fails with another confusing error: 2This RFC proposes that we should change
torchax
to close the behavior divergence such as the two above bullet points. In the limit, using eager torchax should feel identical to using some other backend of PyTorch.Pitch
Always call
enable_globally()
We're pretty close to closing the gaps above. If I run
torchax.enable_globally()
after importingtorchax
, thentorch.randn(1, 2, device='jax')
works, and the error aftertorch.set_default_device('jax')
seems like a fixable bug. I propose we go one step further and just automatically callenable_globally()
and we should also fix the default device behavior.Always keep the torchax modes activated
Today the environment object is what activates the
XLAFunctionMode
andXLADispatchMode
that intercept PyTorch operations. However, these modes are an implementation detail of howtorchax
supports the JAX device. It should be possible to always keep theXLAFunctionMode
andXLADispatchMode
activated in the mode stack, without changing the behavior of non-JAX tensors. This is akin to how PyTorch already keeps a few modes such asFuncTorchVmapMode
andFuncTorchDynamicLayerFrontMode
in the stack most of the time. For testing purposes, it could be useful to temporarily disable theXLAFunctionMode
andXLADispatchMode
, but that should be an internal API that users don't know about.As a pressure test, we could probably try running some subset of PyTorch tests with
XLA{Function,Dispatch}Mode
in the mode stack, and make sure those don't fail. That's to ensure that even if the userimport torchax
, their CPU tensor behaviors don't change.This suggests we need to decouple the
XLAFunctionMode
andXLADispatchMode
from the environment. For example, perhaps those could be relocated to atorchax._internal.XLAModes
context manager.Configuration context managers
The environment object also holds certain configuration (e.g. optimize for performance or accuracy). As a user it's useful to change these settings sometimes. We can keep them in the environment and always provide a sensible default in the default environment. We could also support a stack of environments via context managers, where configurations at the top of the stack takes precedence. That's a useful way to locally change some config and have them revert to previous values when leaving the scope.
RNG seed
The environment object also holds a seed for the pseudo random number generator. That should probably change as part of solving #8636.
Alternatives
An alternative is to do nothing and stick to the status quo.
Another follow up is to see what
torch
changes do we need to remove frictions of using the JAX device. For example, today if I writetensor.jax()
with the hope of moving the tensor to the JAX device, the Python type checker complains thatjax()
is not a known function onTensor
, unlikecuda
.Additional context
Anecdotally, some people had questions why they have to create an environment to use a torchax tensor and didn't understand the error message when the environment was missing.
The text was updated successfully, but these errors were encountered: