Skip to content

Commit

Permalink
Rename torch_xla2 to torchax (#8599)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored Jan 25, 2025
1 parent 79e4e72 commit 8b24140
Show file tree
Hide file tree
Showing 102 changed files with 414 additions and 394 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ on:
- r[0-9]+.[0-9]+
paths-ignore:
- 'experimental/**'
- 'torchax/**'
push:
branches:
- master
- r[0-9]+.[0-9]+
paths-ignore:
- 'experimental/**'
- 'torchax/**'
workflow_dispatch:

concurrency:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build_upstream_image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- r[0-9]+.[0-9]+
paths-ignore:
- 'experimental/**'
- 'torchax/**'
workflow_dispatch:
jobs:
build:
Expand Down
27 changes: 21 additions & 6 deletions .github/workflows/torch_xla2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ on:
- master
- r[0-9]+.[0-9]+
paths:
- 'experimental/torch_xla2/**'
- 'torchax/**'
push:
branches:
- master
- r[0-9]+.[0-9]+
paths:
- 'experimental/torch_xla2/**'
- 'torchax/**'
workflow_dispatch:

concurrency:
Expand All @@ -28,21 +28,36 @@ jobs:
uses: actions/checkout@v4
with:
sparse-checkout: |
experimental/torch_xla2
torchax
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install
shell: bash
working-directory: experimental/torch_xla2
working-directory: torchax
run: |
pip install -r test-requirements.txt
pip install -e .[cpu]
pip install tensorflow-cpu # for TF integrations tests
- name: Run tests
working-directory: experimental/torch_xla2
working-directory: torchax
shell: bash
run: |
pytest test/
export JAX_PLATFORMS=cpu
pytest test/test_conv.py
pytest test/test_unbounded_dynamism.py
pytest test/test_interop.py
pytest test/test_ops.py
pytest test/test_context.py
pytest test/test_train.py
pytest test/test_mutations.py
pytest test/test_tf_integration.py
pytest test/gemma/test_gemma.py
pytest test/llama/test_llama.py
pytest test/test_core_aten_ops.py
pytest test/test_functions.py
pytest test/test_libraries.py
pytest test/test_symbolic_shapes.py
pytest test/test_exports.py
XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest -n 0 test_dist/
4 changes: 0 additions & 4 deletions experimental/torch_xla2/format.sh

This file was deleted.

10 changes: 0 additions & 10 deletions experimental/torch_xla2/torch_xla2/ops/__init__.py

This file was deleted.

File renamed without changes.
26 changes: 13 additions & 13 deletions experimental/torch_xla2/README.md → torchax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ TorchXLA2 and torch-xla have different installation instructions, please follow
the instructions below from scratch (fresh venv / conda environment.)


### 1. Installing `torch_xla2`
### 1. Installing `torchax`

The following instructions assume you are in the `torch_xla2` directory:
The following instructions assume you are in the `torchax` directory:

```
Fork the repository
$ git clone https://github.com/<github_username>/xla.git
$ cd xla/experimental/torch_xla2
$ cd xla/experimental/torchax
```


Expand Down Expand Up @@ -55,13 +55,13 @@ Note: `dev-requirements.txt` will install the CPU-only version of PyTorch.

#### 1.1 Install this package

If you want to install torch_xla2 without the jax dependency and use the jax dependency from torch_xla:
If you want to install torchax without the jax dependency and use the jax dependency from torch_xla:
```bash
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -e .
```

Otherwise, install `torch_xla2` from source for your platform:
Otherwise, install `torchax` from source for your platform:
```bash
pip install -e .[cpu]
pip install -e .[cuda]
Expand All @@ -77,7 +77,7 @@ pytest test

## Run a model

Now let's execute a model under torch_xla2. We'll start with a simple 2-layer model
Now let's execute a model under torchax. We'll start with a simple 2-layer model
it can be in theory any instance of `torch.nn.Module`.

```python
Expand Down Expand Up @@ -110,15 +110,15 @@ print(m(inputs))
This model `m` contains 2 parts: the weights that is stored inside of the model
and it's submodules (`nn.Linear`).

To execute this model with `torch_xla2`; we need construct and run the model
To execute this model with `torchax`; we need construct and run the model
under an `environment` that captures pytorch ops and swaps them with TPU equivalent.

To create this environment: use

```python
import torch_xla2
import torchax

env = torch_xla2.default_env()
env = torchax.default_env()
```
Then, execute the instantiation of the model, as well as evaluation of model,
using `env` as a context manager:
Expand All @@ -128,14 +128,14 @@ with env:
inputs = torch.randn(3, 3, 28, 28)
m = MyModel()
res = m(inputs)
print(type(res)) # outputs XLATensor2
print(type(res)) # outputs Tensor
```

You can also enable the environment globally with
```python
import torch_xla2
import torchax

torch_xla2.enable_globally()
torchax.enable_globally()
```

Then everything afterwards is run with XLA.
Expand Down Expand Up @@ -209,7 +209,7 @@ def model_func(param, inputs):
Now, we can apply `jax_jit`

```python
from torch_xla2.interop import jax_jit
from torchax.interop import jax_jit
model_func_jitted = jax_jit(model_func)
print(model_func_jitted(new_state_dict, inputs))
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ NIGHTLY_VERSION=$(date '+%Y%m%d%H%M')

# Update the version to <version>.devYYYYMMDDHHMM in __init__.py
VERSION_UPDATE_PATTERN="s/^__version__\s*=\s*\"([^\"]+)\"/__version__ = \"\1.dev$NIGHTLY_VERSION\"/g;"
sed -r "$VERSION_UPDATE_PATTERN" torch_xla2/__init__.py --in-place
sed -r "$VERSION_UPDATE_PATTERN" torchax/__init__.py --in-place

hatch build -t wheel
File renamed without changes.
File renamed without changes
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ Remove one line from the `skiplist` set.
i.e.

```bash
(base) hanq-macbookpro:torch_xla2 hanq$ git diff
diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py
(base) hanq-macbookpro:torchax hanq$ git diff
diff --git a/experimental/torchax/test/test_ops.py b/experimental/torchax/test/test_ops.py
index 72a39ae85..2a156cbce 100644
--- a/experimental/torch_xla2/test/test_ops.py
+++ b/experimental/torch_xla2/test/test_ops.py
--- a/experimental/torchax/test/test_ops.py
+++ b/experimental/torchax/test/test_ops.py
@@ -15,7 +15,6 @@ skiplist = {
"_native_batch_norm_legit",
"_segment_reduce",
Expand All @@ -41,15 +41,15 @@ index 72a39ae85..2a156cbce 100644
### Run test to see what failure
For errors you might get after running test, there are two kind:
- Target op failure
- error shows related to target op, such as `No lowering found for 'aten::addbmm'`, please follow instruction like [Fix Target op failure](https://github.com/pytorch/xla/blob/ManfeiBai-patch-99/experimental/torch_xla2/docs/fixing_op_info_test.md#fix-target-op-failure)
- error shows related to target op, such as `No lowering found for 'aten::addbmm'`, please follow instruction like [Fix Target op failure](https://github.com/pytorch/xla/blob/ManfeiBai-patch-99/experimental/torchax/docs/fixing_op_info_test.md#fix-target-op-failure)
- Decomposed op failure
- no implementation found for target ops, but error is not `no lowering`, error shows target op has been implemented somewhere; for sitution like this, please follow instruction like [Fix Decomposed op failure](https://github.com/pytorch/xla/blob/ManfeiBai-patch-99/experimental/torch_xla2/docs/fixing_op_info_test.md#fix-other-op-failure)
- no implementation found for target ops, but error is not `no lowering`, error shows target op has been implemented somewhere; for sitution like this, please follow instruction like [Fix Decomposed op failure](https://github.com/pytorch/xla/blob/ManfeiBai-patch-99/experimental/torchax/docs/fixing_op_info_test.md#fix-other-op-failure)
#### Fix Target op failure
Error gotten:
```
(base) hanq-macbookpro:torch_xla2 hanq$ python test/test_ops.py
(base) hanq-macbookpro:torchax hanq$ python test/test_ops.py
...
E RuntimeError: ('No lowering found for\n\nTo execute this test, run the following from the base repo dir:\n python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'aten::addbmm')
```
Expand All @@ -59,7 +59,7 @@ From here we have 2 strategies for fixing this test:
1. Add an implementation to `aten::addbmm` operator using Jax ops. Or,
2. Add an implementation `aten::addbmm` operator using torch ops (this commonly known as "decompositions").
Either way works for torch_xla2. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of
Either way works for torchax. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of
upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py)
so other projects can benefit from it.
Expand All @@ -68,10 +68,10 @@ For illustration purposes, let's implement this op in Jax.
(NOTE: this doesn't stop us from upstreaming a decomposition later if we want)
#### Fix Decomposed op failure
For situation that no target op(`trapezoid`) implemention found in `experimental/torch_xla2/torch_xla2/ops/jaten.py`, but error shows target op(`trapezoid`) has been implemented somewhere:
For situation that no target op(`trapezoid`) implemention found in `experimental/torchax/torchax/ops/jaten.py`, but error shows target op(`trapezoid`) has been implemented somewhere:
```
======================================================================
FAIL: test_reference_eager_trapezoid_cpu_int64 (__main__.TestOpInfoCPU) [torch_xla2_diff:0.001]
FAIL: test_reference_eager_trapezoid_cpu_int64 (__main__.TestOpInfoCPU) [torchax_diff:0.001]
----------------------------------------------------------------------
...
AssertionError: The values for attribute 'dtype' do not match: torch.float64 != torch.float32.
Expand All @@ -80,17 +80,17 @@ Please try to fix it by following these steps:
1. confirm your target op `trapezoid` is decomposed by running this code to print each sub ops:
```
import torch
import torch_xla2
import torchax
env = torch_xla2.default_env()
env = torchax.default_env()
env.config.debug_print_each_op = True
env.config.debug_accuracy_for_each_op = True
with env:
y = torch.tensor([1, 5, 10])
print(torch.trapezoid(y))
```
2. (optional) Debug by modify [debug_accuracy()](https://github.com/pytorch/xla/blob/c26b19ebdefccd3a4300763e1085724d3d4cd3d0/experimental/torch_xla2/torch_xla2/tensor.py#L171C1-L194C14) to check `res`(from jax) and `expected_res`(from torch)'s value and dtype/type.
2. (optional) Debug by modify [debug_accuracy()](https://github.com/pytorch/xla/blob/c26b19ebdefccd3a4300763e1085724d3d4cd3d0/experimental/torchax/torchax/tensor.py#L171C1-L194C14) to check `res`(from jax) and `expected_res`(from torch)'s value and dtype/type.
3. you might need to debug/modify/add implementation of sub ops(found in step1) to support `trapezoid` by using step 2, like:
```
@op(torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar)
Expand Down Expand Up @@ -130,12 +130,12 @@ python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64
We now see this error:
```
FAIL: test_reference_eager_addbmm_cpu_int64 (__main__.TestOpInfoCPU) [torch_xla2_diff:0.001]
FAIL: test_reference_eager_addbmm_cpu_int64 (__main__.TestOpInfoCPU) [torchax_diff:0.001]
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 654, in run_export_and_compare
File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/test/test_ops.py", line 654, in run_export_and_compare
diff_output(
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 617, in diff_output
File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/test/test_ops.py", line 617, in diff_output
testcase.assertTrue(
AssertionError: False is not true
```
Expand All @@ -144,7 +144,7 @@ This is telling me that our implementation did not produce
the same result as the ops in PyTorch.
To debug this, let's figure out what exact input caused this.
We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/test/test_ops.py#L644), right before the diff. Here we can
We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torchax/test/test_ops.py#L644), right before the diff. Here we can
inspect values of `res` and `res2`, as well as the `sample_input`.
The sample input we get is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ How it works

## Tensor subclass and eager mode

The class `XLATensor2` is a `torch.Tensor` subclass
The class `Tensor` is a `torch.Tensor` subclass
that overrides `__torch_dispatch__`.

It roughly looks like this (with some details removed):

The complete class impl is at [tensor.py](../torch_xla2/tensor.py).
The complete class impl is at [tensor.py](../torchax/tensor.py).

```python
class XLATensor2(torch.Tensor):
class Tensor(torch.Tensor):

@staticmethod
def __new__(cls, elem):
Expand All @@ -33,21 +33,21 @@ class XLATensor2(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# here assumes ALL tensors in args / kwargs are
# instances of XLATensor2
# instances of Tensor
args, kwargs = unwrap((args, kwargs))
jax_func = some_registry[func]
res = jax_func(*args, **kwargs)
return wrap(res)

def wrap(tree):
# wrap jax.Array with XLATensor2
# wrap jax.Array with Tensor
return pytree.tree_map_only(
jax.Array, XLATensor2, tree)
jax.Array, Tensor, tree)

def unwrap(tree):
# get jax.Array out ofXLATensor2
# get jax.Array out ofTensor
return pytree.tree_map_only(
XLATensor2, lambda x: x._elem, tree)
Tensor, lambda x: x._elem, tree)
```

In other words, assuming that we have a function
Expand All @@ -56,7 +56,7 @@ but otherwise implement the same semantics
as a `ATen` op; then, using this tensor we would
be able to route the call to this jax function.

[_ops.py](../torch_xla2/_ops.py) files defines some of those ops.
[_ops.py](../torchax/_ops.py) files defines some of those ops.

Let's take `aten::add` as example:

Expand Down Expand Up @@ -120,7 +120,7 @@ def backend(fxgraph):
The inner function `tojit` is a function that takes and returns
`jax.Array`'s. So it's suitable to be jitted with `jax.jit`.

`f` is returned callable that takes `XLATensor2`; so can interop with
`f` is returned callable that takes `Tensor`; so can interop with
other torch codes.

## nn.Modules and state management
Expand Down
File renamed without changes.
Loading

0 comments on commit 8b24140

Please sign in to comment.