diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index ce5e86db0623..683af2abaa1e 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -6,12 +6,14 @@ on: - r[0-9]+.[0-9]+ paths-ignore: - 'experimental/**' + - 'torchax/**' push: branches: - master - r[0-9]+.[0-9]+ paths-ignore: - 'experimental/**' + - 'torchax/**' workflow_dispatch: concurrency: diff --git a/.github/workflows/build_upstream_image.yml b/.github/workflows/build_upstream_image.yml index 37992bc20f8e..3000f7efcdba 100644 --- a/.github/workflows/build_upstream_image.yml +++ b/.github/workflows/build_upstream_image.yml @@ -6,6 +6,7 @@ on: - r[0-9]+.[0-9]+ paths-ignore: - 'experimental/**' + - 'torchax/**' workflow_dispatch: jobs: build: diff --git a/.github/workflows/torch_xla2.yml b/.github/workflows/torch_xla2.yml index b6581323a3fa..9fce8c95e38c 100644 --- a/.github/workflows/torch_xla2.yml +++ b/.github/workflows/torch_xla2.yml @@ -4,13 +4,13 @@ on: - master - r[0-9]+.[0-9]+ paths: - - 'experimental/torch_xla2/**' + - 'torchax/**' push: branches: - master - r[0-9]+.[0-9]+ paths: - - 'experimental/torch_xla2/**' + - 'torchax/**' workflow_dispatch: concurrency: @@ -28,21 +28,36 @@ jobs: uses: actions/checkout@v4 with: sparse-checkout: | - experimental/torch_xla2 + torchax - name: Setup Python uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install shell: bash - working-directory: experimental/torch_xla2 + working-directory: torchax run: | pip install -r test-requirements.txt pip install -e .[cpu] pip install tensorflow-cpu # for TF integrations tests - name: Run tests - working-directory: experimental/torch_xla2 + working-directory: torchax shell: bash run: | - pytest test/ + export JAX_PLATFORMS=cpu + pytest test/test_conv.py + pytest test/test_unbounded_dynamism.py + pytest test/test_interop.py + pytest test/test_ops.py + pytest test/test_context.py + pytest test/test_train.py + pytest test/test_mutations.py + pytest test/test_tf_integration.py + pytest test/gemma/test_gemma.py + pytest test/llama/test_llama.py + pytest test/test_core_aten_ops.py + pytest test/test_functions.py + pytest test/test_libraries.py + pytest test/test_symbolic_shapes.py + pytest test/test_exports.py XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest -n 0 test_dist/ diff --git a/experimental/torch_xla2/format.sh b/experimental/torch_xla2/format.sh deleted file mode 100755 index 08efc04b3995..000000000000 --- a/experimental/torch_xla2/format.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env bash -set -ex - -yapf --recursive -i *.py test torch_xla2 \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/ops/__init__.py b/experimental/torch_xla2/torch_xla2/ops/__init__.py deleted file mode 100644 index 3ba99a250c21..000000000000 --- a/experimental/torch_xla2/torch_xla2/ops/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -def all_aten_jax_ops(): - # to load the ops - import torch_xla2.ops.jaten # type: ignore - import torch_xla2.ops.ops_registry # type: ignore - - return { - key: val.func - for key, val in torch_xla2.ops.ops_registry.all_aten_ops.items() - if val.is_jax_function - } diff --git a/experimental/torch_xla2/LICENSE b/torchax/LICENSE similarity index 100% rename from experimental/torch_xla2/LICENSE rename to torchax/LICENSE diff --git a/experimental/torch_xla2/README.md b/torchax/README.md similarity index 89% rename from experimental/torch_xla2/README.md rename to torchax/README.md index 774c9206f2ab..ff38ea5d75bb 100644 --- a/experimental/torch_xla2/README.md +++ b/torchax/README.md @@ -19,14 +19,14 @@ TorchXLA2 and torch-xla have different installation instructions, please follow the instructions below from scratch (fresh venv / conda environment.) -### 1. Installing `torch_xla2` +### 1. Installing `torchax` -The following instructions assume you are in the `torch_xla2` directory: +The following instructions assume you are in the `torchax` directory: ``` Fork the repository $ git clone https://github.com//xla.git -$ cd xla/experimental/torch_xla2 +$ cd xla/experimental/torchax ``` @@ -55,13 +55,13 @@ Note: `dev-requirements.txt` will install the CPU-only version of PyTorch. #### 1.1 Install this package -If you want to install torch_xla2 without the jax dependency and use the jax dependency from torch_xla: +If you want to install torchax without the jax dependency and use the jax dependency from torch_xla: ```bash pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html pip install -e . ``` -Otherwise, install `torch_xla2` from source for your platform: +Otherwise, install `torchax` from source for your platform: ```bash pip install -e .[cpu] pip install -e .[cuda] @@ -77,7 +77,7 @@ pytest test ## Run a model -Now let's execute a model under torch_xla2. We'll start with a simple 2-layer model +Now let's execute a model under torchax. We'll start with a simple 2-layer model it can be in theory any instance of `torch.nn.Module`. ```python @@ -110,15 +110,15 @@ print(m(inputs)) This model `m` contains 2 parts: the weights that is stored inside of the model and it's submodules (`nn.Linear`). -To execute this model with `torch_xla2`; we need construct and run the model +To execute this model with `torchax`; we need construct and run the model under an `environment` that captures pytorch ops and swaps them with TPU equivalent. To create this environment: use ```python -import torch_xla2 +import torchax -env = torch_xla2.default_env() +env = torchax.default_env() ``` Then, execute the instantiation of the model, as well as evaluation of model, using `env` as a context manager: @@ -128,14 +128,14 @@ with env: inputs = torch.randn(3, 3, 28, 28) m = MyModel() res = m(inputs) - print(type(res)) # outputs XLATensor2 + print(type(res)) # outputs Tensor ``` You can also enable the environment globally with ```python -import torch_xla2 +import torchax -torch_xla2.enable_globally() +torchax.enable_globally() ``` Then everything afterwards is run with XLA. @@ -209,7 +209,7 @@ def model_func(param, inputs): Now, we can apply `jax_jit` ```python -from torch_xla2.interop import jax_jit +from torchax.interop import jax_jit model_func_jitted = jax_jit(model_func) print(model_func_jitted(new_state_dict, inputs)) ``` diff --git a/experimental/torch_xla2/build_nightly.sh b/torchax/build_nightly.sh similarity index 79% rename from experimental/torch_xla2/build_nightly.sh rename to torchax/build_nightly.sh index 977303eaba7e..885a90c6d44d 100755 --- a/experimental/torch_xla2/build_nightly.sh +++ b/torchax/build_nightly.sh @@ -5,6 +5,6 @@ NIGHTLY_VERSION=$(date '+%Y%m%d%H%M') # Update the version to .devYYYYMMDDHHMM in __init__.py VERSION_UPDATE_PATTERN="s/^__version__\s*=\s*\"([^\"]+)\"/__version__ = \"\1.dev$NIGHTLY_VERSION\"/g;" -sed -r "$VERSION_UPDATE_PATTERN" torch_xla2/__init__.py --in-place +sed -r "$VERSION_UPDATE_PATTERN" torchax/__init__.py --in-place hatch build -t wheel diff --git a/experimental/torch_xla2/dev-requirements.txt b/torchax/dev-requirements.txt similarity index 100% rename from experimental/torch_xla2/dev-requirements.txt rename to torchax/dev-requirements.txt diff --git a/experimental/torch_xla2/docs/dispatch.png b/torchax/docs/dispatch.png similarity index 100% rename from experimental/torch_xla2/docs/dispatch.png rename to torchax/docs/dispatch.png diff --git a/experimental/torch_xla2/docs/fixing_op_info_test.md b/torchax/docs/fixing_op_info_test.md similarity index 86% rename from experimental/torch_xla2/docs/fixing_op_info_test.md rename to torchax/docs/fixing_op_info_test.md index 6e3ef7a6f89b..f5e8fe606c13 100644 --- a/experimental/torch_xla2/docs/fixing_op_info_test.md +++ b/torchax/docs/fixing_op_info_test.md @@ -23,11 +23,11 @@ Remove one line from the `skiplist` set. i.e. ```bash -(base) hanq-macbookpro:torch_xla2 hanq$ git diff -diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py +(base) hanq-macbookpro:torchax hanq$ git diff +diff --git a/experimental/torchax/test/test_ops.py b/experimental/torchax/test/test_ops.py index 72a39ae85..2a156cbce 100644 ---- a/experimental/torch_xla2/test/test_ops.py -+++ b/experimental/torch_xla2/test/test_ops.py +--- a/experimental/torchax/test/test_ops.py ++++ b/experimental/torchax/test/test_ops.py @@ -15,7 +15,6 @@ skiplist = { "_native_batch_norm_legit", "_segment_reduce", @@ -41,15 +41,15 @@ index 72a39ae85..2a156cbce 100644 ### Run test to see what failure For errors you might get after running test, there are two kind: - Target op failure - - error shows related to target op, such as `No lowering found for 'aten::addbmm'`, please follow instruction like [Fix Target op failure](https://github.com/pytorch/xla/blob/ManfeiBai-patch-99/experimental/torch_xla2/docs/fixing_op_info_test.md#fix-target-op-failure) + - error shows related to target op, such as `No lowering found for 'aten::addbmm'`, please follow instruction like [Fix Target op failure](https://github.com/pytorch/xla/blob/ManfeiBai-patch-99/experimental/torchax/docs/fixing_op_info_test.md#fix-target-op-failure) - Decomposed op failure - - no implementation found for target ops, but error is not `no lowering`, error shows target op has been implemented somewhere; for sitution like this, please follow instruction like [Fix Decomposed op failure](https://github.com/pytorch/xla/blob/ManfeiBai-patch-99/experimental/torch_xla2/docs/fixing_op_info_test.md#fix-other-op-failure) + - no implementation found for target ops, but error is not `no lowering`, error shows target op has been implemented somewhere; for sitution like this, please follow instruction like [Fix Decomposed op failure](https://github.com/pytorch/xla/blob/ManfeiBai-patch-99/experimental/torchax/docs/fixing_op_info_test.md#fix-other-op-failure) #### Fix Target op failure Error gotten: ``` -(base) hanq-macbookpro:torch_xla2 hanq$ python test/test_ops.py +(base) hanq-macbookpro:torchax hanq$ python test/test_ops.py ... E RuntimeError: ('No lowering found for\n\nTo execute this test, run the following from the base repo dir:\n python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'aten::addbmm') ``` @@ -59,7 +59,7 @@ From here we have 2 strategies for fixing this test: 1. Add an implementation to `aten::addbmm` operator using Jax ops. Or, 2. Add an implementation `aten::addbmm` operator using torch ops (this commonly known as "decompositions"). -Either way works for torch_xla2. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of +Either way works for torchax. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py) so other projects can benefit from it. @@ -68,10 +68,10 @@ For illustration purposes, let's implement this op in Jax. (NOTE: this doesn't stop us from upstreaming a decomposition later if we want) #### Fix Decomposed op failure -For situation that no target op(`trapezoid`) implemention found in `experimental/torch_xla2/torch_xla2/ops/jaten.py`, but error shows target op(`trapezoid`) has been implemented somewhere: +For situation that no target op(`trapezoid`) implemention found in `experimental/torchax/torchax/ops/jaten.py`, but error shows target op(`trapezoid`) has been implemented somewhere: ``` ====================================================================== -FAIL: test_reference_eager_trapezoid_cpu_int64 (__main__.TestOpInfoCPU) [torch_xla2_diff:0.001] +FAIL: test_reference_eager_trapezoid_cpu_int64 (__main__.TestOpInfoCPU) [torchax_diff:0.001] ---------------------------------------------------------------------- ... AssertionError: The values for attribute 'dtype' do not match: torch.float64 != torch.float32. @@ -80,9 +80,9 @@ Please try to fix it by following these steps: 1. confirm your target op `trapezoid` is decomposed by running this code to print each sub ops: ``` import torch - import torch_xla2 + import torchax - env = torch_xla2.default_env() + env = torchax.default_env() env.config.debug_print_each_op = True env.config.debug_accuracy_for_each_op = True @@ -90,7 +90,7 @@ Please try to fix it by following these steps: y = torch.tensor([1, 5, 10]) print(torch.trapezoid(y)) ``` - 2. (optional) Debug by modify [debug_accuracy()](https://github.com/pytorch/xla/blob/c26b19ebdefccd3a4300763e1085724d3d4cd3d0/experimental/torch_xla2/torch_xla2/tensor.py#L171C1-L194C14) to check `res`(from jax) and `expected_res`(from torch)'s value and dtype/type. + 2. (optional) Debug by modify [debug_accuracy()](https://github.com/pytorch/xla/blob/c26b19ebdefccd3a4300763e1085724d3d4cd3d0/experimental/torchax/torchax/tensor.py#L171C1-L194C14) to check `res`(from jax) and `expected_res`(from torch)'s value and dtype/type. 3. you might need to debug/modify/add implementation of sub ops(found in step1) to support `trapezoid` by using step 2, like: ``` @op(torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar) @@ -130,12 +130,12 @@ python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64 We now see this error: ``` -FAIL: test_reference_eager_addbmm_cpu_int64 (__main__.TestOpInfoCPU) [torch_xla2_diff:0.001] +FAIL: test_reference_eager_addbmm_cpu_int64 (__main__.TestOpInfoCPU) [torchax_diff:0.001] ---------------------------------------------------------------------- Traceback (most recent call last): - File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 654, in run_export_and_compare + File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/test/test_ops.py", line 654, in run_export_and_compare diff_output( - File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 617, in diff_output + File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/test/test_ops.py", line 617, in diff_output testcase.assertTrue( AssertionError: False is not true ``` @@ -144,7 +144,7 @@ This is telling me that our implementation did not produce the same result as the ops in PyTorch. To debug this, let's figure out what exact input caused this. -We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/test/test_ops.py#L644), right before the diff. Here we can +We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torchax/test/test_ops.py#L644), right before the diff. Here we can inspect values of `res` and `res2`, as well as the `sample_input`. The sample input we get is diff --git a/experimental/torch_xla2/docs/how_it_works.md b/torchax/docs/how_it_works.md similarity index 86% rename from experimental/torch_xla2/docs/how_it_works.md rename to torchax/docs/how_it_works.md index e4098ca00968..47a5fd102e51 100644 --- a/experimental/torch_xla2/docs/how_it_works.md +++ b/torchax/docs/how_it_works.md @@ -4,15 +4,15 @@ How it works ## Tensor subclass and eager mode -The class `XLATensor2` is a `torch.Tensor` subclass +The class `Tensor` is a `torch.Tensor` subclass that overrides `__torch_dispatch__`. It roughly looks like this (with some details removed): -The complete class impl is at [tensor.py](../torch_xla2/tensor.py). +The complete class impl is at [tensor.py](../torchax/tensor.py). ```python -class XLATensor2(torch.Tensor): +class Tensor(torch.Tensor): @staticmethod def __new__(cls, elem): @@ -33,21 +33,21 @@ class XLATensor2(torch.Tensor): @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # here assumes ALL tensors in args / kwargs are - # instances of XLATensor2 + # instances of Tensor args, kwargs = unwrap((args, kwargs)) jax_func = some_registry[func] res = jax_func(*args, **kwargs) return wrap(res) def wrap(tree): - # wrap jax.Array with XLATensor2 + # wrap jax.Array with Tensor return pytree.tree_map_only( - jax.Array, XLATensor2, tree) + jax.Array, Tensor, tree) def unwrap(tree): - # get jax.Array out ofXLATensor2 + # get jax.Array out ofTensor return pytree.tree_map_only( - XLATensor2, lambda x: x._elem, tree) + Tensor, lambda x: x._elem, tree) ``` In other words, assuming that we have a function @@ -56,7 +56,7 @@ but otherwise implement the same semantics as a `ATen` op; then, using this tensor we would be able to route the call to this jax function. -[_ops.py](../torch_xla2/_ops.py) files defines some of those ops. +[_ops.py](../torchax/_ops.py) files defines some of those ops. Let's take `aten::add` as example: @@ -120,7 +120,7 @@ def backend(fxgraph): The inner function `tojit` is a function that takes and returns `jax.Array`'s. So it's suitable to be jitted with `jax.jit`. -`f` is returned callable that takes `XLATensor2`; so can interop with +`f` is returned callable that takes `Tensor`; so can interop with other torch codes. ## nn.Modules and state management diff --git a/experimental/torch_xla2/docs/ops_registry.md b/torchax/docs/ops_registry.md similarity index 100% rename from experimental/torch_xla2/docs/ops_registry.md rename to torchax/docs/ops_registry.md diff --git a/experimental/torch_xla2/docs/support_a_new_model.md b/torchax/docs/support_a_new_model.md similarity index 82% rename from experimental/torch_xla2/docs/support_a_new_model.md rename to torchax/docs/support_a_new_model.md index 8578861e4cf9..9fd0db52aebb 100644 --- a/experimental/torch_xla2/docs/support_a_new_model.md +++ b/torchax/docs/support_a_new_model.md @@ -1,23 +1,23 @@ -# Run a model under torch_xla2 +# Run a model under torchax -Supporting a new model in torch_xla2 means -having this model run using torch_xla2 and succeeds. +Supporting a new model in torchax means +having this model run using torchax and succeeds. A model usually consists of executing a list of torch ops on a set of tensors (i.e. the parameters and inputs) and produce a new tensor(s). These ops should just work. However, there are cases that the model doesn't run on -torch_xla2, because: +torchax, because: 1. Some op it needs is not implemented. 2. Some op it needs is implemented incorrectly -3. There are some non-torch-op code that interacts with torch_xla2 in a non-friendly matter. +3. There are some non-torch-op code that interacts with torchax in a non-friendly matter. Here we present few steps to attempt to fix the related issues. Using dlrm model as example. -This assumes that you already installed torch_xla2 with `pip install -e .` locally. +This assumes that you already installed torchax with `pip install -e .` locally. Following the instructions in [README](../README.md) @@ -77,15 +77,15 @@ Traceback (most recent call last): return handle_torch_function( File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/overrides.py", line 1619, in handle_torch_function result = mode.__torch_function__(public_api, types, args, kwargs) - File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 215, in __torch_function__ + File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 215, in __torch_function__ return func(*args, **(kwargs or {})) File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/functional.py", line 2451, in embedding_bag ret, _, _, _ = torch.embedding_bag( - File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 230, in __torch_dispatch__ + File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 230, in __torch_dispatch__ return self.env.dispatch(func, types, args, kwargs) - File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 310, in dispatch + File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 310, in dispatch raise OperatorNotFound( -torch_xla2.tensor.OperatorNotFound: Operator with name aten::_embedding_bag has no lowering +torchax.tensor.OperatorNotFound: Operator with name aten::_embedding_bag has no lowering ``` Now let's implement this op. @@ -110,19 +110,19 @@ After finishing `embedding_bag` badly, I reached the next op return F.embedding_bag(input, self.weight, offsets, File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/functional.py", line 2451, in embedding_bag ret, _, _, _ = torch.embedding_bag( - File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 124, in __torch_dispatch__ + File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 124, in __torch_dispatch__ return func(*args, **(kwargs or {})) File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_ops.py", line 594, in __call__ return self_._op(*args, **kwargs) - File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 212, in __torch_function__ + File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 212, in __torch_function__ return func(*args, **(kwargs or {})) File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_ops.py", line 594, in __call__ return self_._op(*args, **kwargs) - File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 227, in __torch_dispatch__ + File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 227, in __torch_dispatch__ return self.env.dispatch(func, types, args, kwargs) - File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 308, in dispatch + File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 308, in dispatch raise OperatorNotFound( -torch_xla2.tensor.OperatorNotFound: Operator with name aten::_embedding_bag_forward_only has no lowering +torchax.tensor.OperatorNotFound: Operator with name aten::_embedding_bag_forward_only has no lowering ``` Turns out, that is the same operator. so adding the @op(torch.ops.aten._embedding_bag_forward_only) diff --git a/experimental/torch_xla2/docs/torch_dispatch/README.md b/torchax/docs/torch_dispatch/README.md similarity index 100% rename from experimental/torch_xla2/docs/torch_dispatch/README.md rename to torchax/docs/torch_dispatch/README.md diff --git a/experimental/torch_xla2/docs/torch_dispatch/example.py b/torchax/docs/torch_dispatch/example.py similarity index 100% rename from experimental/torch_xla2/docs/torch_dispatch/example.py rename to torchax/docs/torch_dispatch/example.py diff --git a/experimental/torch_xla2/docs/torch_dispatch/run_env.py b/torchax/docs/torch_dispatch/run_env.py similarity index 79% rename from experimental/torch_xla2/docs/torch_dispatch/run_env.py rename to torchax/docs/torch_dispatch/run_env.py index 7ee27658af6e..1e257b93122a 100644 --- a/experimental/torch_xla2/docs/torch_dispatch/run_env.py +++ b/torchax/docs/torch_dispatch/run_env.py @@ -1,7 +1,7 @@ import torch -import torch_xla2 +import torchax -env = torch_xla2.default_env() +env = torchax.default_env() env.config.debug_print_each_op = True env.config.debug_accuracy_for_each_op = True diff --git a/experimental/torch_xla2/docs/torch_xla2_dynamo.md b/torchax/docs/torch_xla2_dynamo.md similarity index 95% rename from experimental/torch_xla2/docs/torch_xla2_dynamo.md rename to torchax/docs/torch_xla2_dynamo.md index d3994b4bc350..14053e9e6d7f 100644 --- a/experimental/torch_xla2/docs/torch_xla2_dynamo.md +++ b/torchax/docs/torch_xla2_dynamo.md @@ -2,13 +2,13 @@ ## Goal -Have a dynamo backend backend by torch_xla2. +Have a dynamo backend backend by torchax. The users should be able to do the following: ```python m = model ... -m_compiled = torch.compile(m, backend='torch_xla2_compile') # backend name TBD +m_compiled = torch.compile(m, backend='torchax_compile') # backend name TBD result = m_compiled(*inputs) ``` @@ -24,7 +24,7 @@ For every `call_function` node; we lookup the corresponding implementation of said ATen op in a dictionary for it's corresponding implementation in Jax, and we just call it. -This is illustrated here: https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/torch_xla2/export.py#L23 +This is illustrated here: https://github.com/pytorch/xla/blob/master/experimental/torchax/torchax/export.py#L23 Now, the challenge is for dynamo to be able to 1. produce the graph; and 2. n not incur any data copies in this process. @@ -33,11 +33,11 @@ not incur any data copies in this process. Consider this following pseudocode: ```python -class XLATensor2: +class Tensor: _data: jax.Array def __torch_dispatch__(...): # do stuff with _data, get new data - return XLATensor2(new_data) + return Tensor(new_data) def dynamo_backend(fx, sample): compiled = compile fx into graph that manipulate jax.Array. @@ -96,7 +96,7 @@ there are no Python logic in dispatching so dynamo cannot trace through. ## Cons: Now need to deal with C++ builds. In particular, `torch` becomes a source dependency instead of a pip dependency; meaning, again we need to start -building torch first then build torch_xla2. This might be mitigated if +building torch first then build torchax. This might be mitigated if that subclass can be upstreamed. diff --git a/experimental/torch_xla2/docs/understand_jax_jit/jax_jit.py b/torchax/docs/understand_jax_jit/jax_jit.py similarity index 100% rename from experimental/torch_xla2/docs/understand_jax_jit/jax_jit.py rename to torchax/docs/understand_jax_jit/jax_jit.py diff --git a/experimental/torch_xla2/docs/understand_jax_jit/torch_module.py b/torchax/docs/understand_jax_jit/torch_module.py similarity index 78% rename from experimental/torch_xla2/docs/understand_jax_jit/torch_module.py rename to torchax/docs/understand_jax_jit/torch_module.py index 409a6059efdb..478e555d1eb4 100644 --- a/experimental/torch_xla2/docs/understand_jax_jit/torch_module.py +++ b/torchax/docs/understand_jax_jit/torch_module.py @@ -34,9 +34,9 @@ def forward(self, X): print('---- example 2 -----') -import torch_xla2 +import torchax -env = torch_xla2.default_env() +env = torchax.default_env() with env: m2 = Linear() @@ -48,20 +48,20 @@ def forward(self, X): print('---- example 3 -----') # where is the jax jit? -# m2 is a callable that takes in XLATensor2 and returns XLATensor2 -# m2: (XLATensor2 -> XLATensor2) +# m2 is a callable that takes in Tensor and returns Tensor +# m2: (Tensor -> Tensor) -# suppose t2j (XLATensor2 -> jax.Array) "unwraps the XLATensor" -# suppose j2t (jax.Array -> XLATensor2) "wraps the XLATensor" -from torch_xla2 import tensor +# suppose t2j (Tensor -> jax.Array) "unwraps the XLATensor" +# suppose j2t (jax.Array -> Tensor) "wraps the XLATensor" +from torchax import tensor import jax -def t2j(torch_tensor: tensor.XLATensor2) -> jax.Array: +def t2j(torch_tensor: tensor.Tensor) -> jax.Array: return torch_tensor._elem -def j2t(jax_array: jax.Array) -> tensor.XLATensor2: - return tensor.XLATensor2(jax_array, env) +def j2t(jax_array: jax.Array) -> tensor.Tensor: + return tensor.Tensor(jax_array, env) # # further notice t2j(j2t(x)) == x; j2t(t2j(x)) == x @@ -75,7 +75,7 @@ def jax_m(X: jax.Array): jax_x = jnp.ones((10, 1000)) print(jax_m(jax_x)) -## Let f: XLATensor2 -> XLATensor2 +## Let f: Tensor -> Tensor ## There is a function g: jax.Array -> jax.Array; ## g = x |-> j2t (f (t2j(x))). OR, ## g = j2t . f . t2j (. denotes function composition) @@ -110,14 +110,14 @@ def jax_m_functional(states, X): # ## interop module # print('---- exmaple 4 ----') -# import torch_xla2.interop +# import torchax.interop # def m_functional(states, x): # return torch.func.functional_call(m2, states, x) # with jax.checking_leaks(): -# print(torch_xla2.interop.jax_jit(m_functional)(m2.state_dict(), x)) +# print(torchax.interop.jax_jit(m_functional)(m2.state_dict(), x)) # # Experiment if time: diff --git a/experimental/torch_xla2/examples/README.md b/torchax/examples/README.md similarity index 98% rename from experimental/torch_xla2/examples/README.md rename to torchax/examples/README.md index 0e22d28c5315..092ce967df94 100644 --- a/experimental/torch_xla2/examples/README.md +++ b/torchax/examples/README.md @@ -2,7 +2,7 @@ This readme will have a subsection for every example *.py file. -Please follow the instructions in [README.md](../README.md) to install torch_xla2, +Please follow the instructions in [README.md](../README.md) to install torchax, then install requirements for all of the examples with ```bash @@ -23,7 +23,7 @@ Example: ```python state_dict = pytree.tree_map_only(torch.Tensor, - torch_xla2.tensor.move_to_device, state_dict) + torchax.tensor.move_to_device, state_dict) ``` This fragment moves the state_dict to XLA devices; then the state_dict is passed diff --git a/experimental/torch_xla2/examples/__init__.py b/torchax/examples/__init__.py similarity index 100% rename from experimental/torch_xla2/examples/__init__.py rename to torchax/examples/__init__.py diff --git a/experimental/torch_xla2/examples/_diffusion.py b/torchax/examples/_diffusion.py similarity index 86% rename from experimental/torch_xla2/examples/_diffusion.py rename to torchax/examples/_diffusion.py index 5eae15edf255..e7bb67b556a8 100644 --- a/experimental/torch_xla2/examples/_diffusion.py +++ b/torchax/examples/_diffusion.py @@ -6,9 +6,9 @@ from torch.utils import _pytree as pytree -import torch_xla2 -import torch_xla2.functions -from torch_xla2.extra import torch_view, jax_view +import torchax +import torchax.functions +from torchax.extra import torch_view, jax_view import jax import torch.func @@ -19,15 +19,15 @@ class CompiledModule: def __init__(self, model): weights = model.state_dict() weights.update(model.named_parameters()) - self._weights = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.move_to_device, weights) + self._weights = pytree.tree_map_only(torch.Tensor, torchax.tensor.move_to_device, weights) self._model = model self._func_jitted_torch = None #torch_view(func_mod_jitted) def _maybe_move_tensor(self, tensor): - if isinstance(tensor, torch.Tensor) and not isinstance(tensor, torch_xla2.tensor.XLATensor2): - return torch_xla2.tensor.move_to_device(tensor) + if isinstance(tensor, torch.Tensor) and not isinstance(tensor, torchax.tensor.Tensor): + return torchax.tensor.move_to_device(tensor) return tensor def _make_jitted(self, args, kwargs): @@ -41,12 +41,12 @@ def _make_jitted(self, args, kwargs): static_argnames.append(k) def f(weights, *args, **kwargs): - weights, args, kwargs = torch_xla2.tensor.wrap((weights, args, kwargs)) - with torch_xla2.functions.XLAFunctionMode(), torch_xla2.tensor.XLADispatchMode(): + weights, args, kwargs = torchax.tensor.wrap((weights, args, kwargs)) + with torchax.functions.XLAFunctionMode(), torchax.tensor.XLADispatchMode(): res = torch.func.functional_call(self._model, weights, args, kwargs) if isinstance(res, tuple) and len(res) == 1: res = res[0] - return torch_xla2.tensor.unwrap(res) + return torchax.tensor.unwrap(res) fjit = jax.jit(f, static_argnames=tuple(static_argnames)) return torch_view(fjit) diff --git a/experimental/torch_xla2/examples/_grad_of_attention.py b/torchax/examples/_grad_of_attention.py similarity index 86% rename from experimental/torch_xla2/examples/_grad_of_attention.py rename to torchax/examples/_grad_of_attention.py index af5f422fa2ab..83d252cbc13c 100644 --- a/experimental/torch_xla2/examples/_grad_of_attention.py +++ b/torchax/examples/_grad_of_attention.py @@ -2,11 +2,11 @@ import jax from jax.experimental.pallas.ops.tpu import flash_attention -import torch_xla2 +import torchax from jax.experimental import mesh_utils -from torch_xla2.ops.jtorch import _tpu_flash_attention +from torchax.ops.jtorch import _tpu_flash_attention -env = torch_xla2.default_env() +env = torchax.default_env() jax.config.update('jax_enable_x64', False) env._mesh = jax.sharding.Mesh( mesh_utils.create_device_mesh((4, )), @@ -38,7 +38,7 @@ def forward(self, x): return self.a(x) m = M() -from torch_xla2.interop import JittableModule +from torchax.interop import JittableModule mjit = JittableModule(m) @@ -62,10 +62,10 @@ def crossent(x, y): k = jnp.ones(shape, dtype='bfloat16') -env = torch_xla2.default_env() +env = torchax.default_env() weights = env.t2j_iso(env.to_xla(mjit.params)) -from torch_xla2.interop import jax_view +from torchax.interop import jax_view #print(jax.jit(graded).lower(q, v, k).as_text()) print(jax.jit(jax.grad(jax_view(f))).lower( diff --git a/experimental/torch_xla2/examples/basic_training.py b/torchax/examples/basic_training.py similarity index 97% rename from experimental/torch_xla2/examples/basic_training.py rename to torchax/examples/basic_training.py index fb814fcf9788..9a9ebf6b20e0 100644 --- a/experimental/torch_xla2/examples/basic_training.py +++ b/torchax/examples/basic_training.py @@ -15,8 +15,8 @@ #from datetime import datetime # NOTE: add these lines to make it run on TPUs! -import torch_xla2 -torch_xla2.enable_globally() +import torchax +torchax.enable_globally() transform = transforms.Compose( [transforms.ToTensor(), @@ -51,8 +51,8 @@ def matplotlib_imshow(img, one_channel=False): plt.imshow(npimg, cmap="Greys") else: plt.imshow(np.transpose(npimg, (1, 2, 0))) -#torch_xla2.env.config.debug_print_each_op = True -#torch_xla2.env.config.debug_mixed_tensor = True +#torchax.env.config.debug_print_each_op = True +#torchax.env.config.debug_mixed_tensor = True dataiter = iter(training_loader) images, labels = next(dataiter) diff --git a/experimental/torch_xla2/examples/basic_training_jax.py b/torchax/examples/basic_training_jax.py similarity index 97% rename from experimental/torch_xla2/examples/basic_training_jax.py rename to torchax/examples/basic_training_jax.py index 5ca14398fd2e..c5429d8943c8 100644 --- a/experimental/torch_xla2/examples/basic_training_jax.py +++ b/torchax/examples/basic_training_jax.py @@ -4,13 +4,13 @@ """ import functools -from torch_xla2 import train, interop +from torchax import train, interop import torch from torch.utils import _pytree as pytree import torchvision import torchvision.transforms as transforms -import torch_xla2 -import torch_xla2.interop +import torchax +import torchax.interop import jax import optax import numpy as np @@ -19,7 +19,7 @@ from torch.utils.tensorboard import SummaryWriter from datetime import datetime -env = torch_xla2.enable_globally() +env = torchax.enable_globally() transform = transforms.Compose( diff --git a/experimental/torch_xla2/examples/eager_mode.py b/torchax/examples/eager_mode.py similarity index 87% rename from experimental/torch_xla2/examples/eager_mode.py rename to torchax/examples/eager_mode.py index 16561ed4f641..adfb5581b8b6 100644 --- a/experimental/torch_xla2/examples/eager_mode.py +++ b/torchax/examples/eager_mode.py @@ -1,9 +1,9 @@ -import torch_xla2 +import torchax from torch import nn from torch.nn import functional as F import torch -xla_env = torch_xla2.enable_globally() +xla_env = torchax.enable_globally() class MyModel(nn.Module): @@ -29,7 +29,7 @@ def forward(self, x): print(m(inputs)) print('---=====') -m_compiled = torch_xla2.compile(m) +m_compiled = torchax.compile(m) print(m_compiled(inputs)) diff --git a/experimental/torch_xla2/examples/lightning_training.py b/torchax/examples/lightning_training.py similarity index 95% rename from experimental/torch_xla2/examples/lightning_training.py rename to torchax/examples/lightning_training.py index b09f00d94731..420089762fcf 100644 --- a/experimental/torch_xla2/examples/lightning_training.py +++ b/torchax/examples/lightning_training.py @@ -30,8 +30,8 @@ def configure_optimizers(self): # ==== above is the lightning example from # https://lightning.ai/pytorch-lightning -import torch_xla2 -from torch_xla2.interop import jax_view, torch_view +import torchax +from torchax.interop import jax_view, torch_view import jax import optax @@ -46,7 +46,7 @@ def torch_opt_to_jax_opt(self, torch_opt): def fit(self, lightning_mod, data_loader): - xla_env = torch_xla2.default_env() + xla_env = torchax.default_env() def lightning_mod_loss( weights: jax.Array, data: jax.Array, batch_id): diff --git a/experimental/torch_xla2/examples/mnist_tpu.ipynb b/torchax/examples/mnist_tpu.ipynb similarity index 91% rename from experimental/torch_xla2/examples/mnist_tpu.ipynb rename to torchax/examples/mnist_tpu.ipynb index ff41f276fcef..8ffcc9ec27d0 100644 --- a/experimental/torch_xla2/examples/mnist_tpu.ipynb +++ b/torchax/examples/mnist_tpu.ipynb @@ -12,9 +12,9 @@ }, "outputs": [], "source": [ - "# Uncomment and run these if you haven't already installed `torch_xla2`\n", + "# Uncomment and run these if you haven't already installed `torchax`\n", "#!pip uninstall -y tensorflow\n", - "#!pip install tpu-info 'torch_xla2[tpu] @ git+https://github.com/pytorch/xla.git#subdirectory=experimental/torch_xla2' -f https://storage.googleapis.com/libtpu-releases/index.html\n", + "#!pip install tpu-info 'torchax[tpu] @ git+https://github.com/pytorch/xla.git#subdirectory=experimental/torchax' -f https://storage.googleapis.com/libtpu-releases/index.html\n", "#!pip install torchvision" ] }, @@ -22,9 +22,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Distributed training with `torch_xla2`\n", + "# Distributed training with `torchax`\n", "\n", - "This Notebook demonstrates how to perform distributed training using `torch_xla2`, which allows you to run PyTorch models with JAX." + "This Notebook demonstrates how to perform distributed training using `torchax`, which allows you to run PyTorch models with JAX." ] }, { @@ -141,7 +141,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`torch_xla2` uses JAX as a backend, so we can use JAX to double-check the device count. Don't worry -- we won't have to directly use JAX to run the model." + "`torchax` uses JAX as a backend, so we can use JAX to double-check the device count. Don't worry -- we won't have to directly use JAX to run the model." ] }, { @@ -179,7 +179,7 @@ "source": [ "The device count above should match the output of `tpu-info` (4 devices in the case of a v4-8).\n", "\n", - "In this example, we'll use `torch_xla2`'s custom `DistributedDataParallel` implementation to replicate the model parameters across all available TPU devices and split input data between each core." + "In this example, we'll use `torchax`'s custom `DistributedDataParallel` implementation to replicate the model parameters across all available TPU devices and split input data between each core." ] }, { @@ -199,9 +199,9 @@ } ], "source": [ - "import torch_xla2\n", + "import torchax\n", "\n", - "ddp_model = torch_xla2.distributed.DistributedDataParallel(model)" + "ddp_model = torchax.distributed.DistributedDataParallel(model)" ] }, { @@ -304,7 +304,7 @@ { "data": { "text/plain": [ - "XLATensor2( [[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n", + "Tensor( [[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n", " 0.00225713]\n", " [ 0.02272127 0.02205281 0.00828168 ... -0.02310903 0.02183958\n", " 0.01084254]\n", @@ -419,7 +419,7 @@ "source": [ "## Putting it all together\n", "\n", - "`torch_xla2` allows us to seamlessly shard and replicate tensors across devices, while still maintaining a singular view of that tensor through PyTorch. With some minor changes, we can adapt the conventional PyTorch training loop to use the TPU.\n", + "`torchax` allows us to seamlessly shard and replicate tensors across devices, while still maintaining a singular view of that tensor through PyTorch. With some minor changes, we can adapt the conventional PyTorch training loop to use the TPU.\n", "\n", "Note that we do not have to spawn any child processes. Although each parameter and input is represented by one tensor, that tensor is already distributed across multiple devices." ] @@ -447,7 +447,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "JAX gets significantly better performance when compiled, normally through `jax.jit`. `torch_xla2`'s DDP implementation contains a utility `jit_step` that can be used to compile a training step. Note that for this to work, the training step must be separated out into a function. Otherwise, the actual contents are the same as they would be for eager CPU or GPU." + "JAX gets significantly better performance when compiled, normally through `jax.jit`. `torchax`'s DDP implementation contains a utility `jit_step` that can be used to compile a training step. Note that for this to work, the training step must be separated out into a function. Otherwise, the actual contents are the same as they would be for eager CPU or GPU." ] }, { @@ -613,9 +613,9 @@ "source": [ "## Conclusion\n", "\n", - "With some minor changes to your training loop, `torch_xla2` allows you to distribute a model across multiple devices and run a compiled version with JAX. All of the data you interact with directly is still a `torch` tensor, and JAX handles all of the distributed details in the background.\n", + "With some minor changes to your training loop, `torchax` allows you to distribute a model across multiple devices and run a compiled version with JAX. All of the data you interact with directly is still a `torch` tensor, and JAX handles all of the distributed details in the background.\n", "\n", - "`torch_xla2` (and especially training) is still under heavy development. To learn more about the project and its current status, see https://github.com/pytorch/xla/tree/master/experimental/torch_xla2" + "`torchax` (and especially training) is still under heavy development. To learn more about the project and its current status, see https://github.com/pytorch/xla/tree/master/experimental/torchax" ] } ], diff --git a/experimental/torch_xla2/examples/requirements.txt b/torchax/examples/requirements.txt similarity index 100% rename from experimental/torch_xla2/examples/requirements.txt rename to torchax/examples/requirements.txt diff --git a/experimental/torch_xla2/examples/torchbench_models/BERT_pytorch.py b/torchax/examples/torchbench_models/BERT_pytorch.py similarity index 92% rename from experimental/torch_xla2/examples/torchbench_models/BERT_pytorch.py rename to torchax/examples/torchbench_models/BERT_pytorch.py index fc0b4653d6c6..f3bada44d223 100644 --- a/experimental/torch_xla2/examples/torchbench_models/BERT_pytorch.py +++ b/torchax/examples/torchbench_models/BERT_pytorch.py @@ -1,7 +1,7 @@ import torch import time -import torch_xla2 -import torch_xla2.interop +import torchax +import torchax.interop import os import importlib import sys @@ -28,7 +28,7 @@ model, example = benchmark.get_module() -env = torch_xla2.default_env() +env = torchax.default_env() env.config.debug_print_each_op = False model = env.to_xla(model) example = env.to_xla(example) @@ -42,7 +42,7 @@ def func_call(state, example): return torch.func.functional_call(model, state, example, tie_weights=False) -jitted = torch_xla2.interop.jax_jit(func_call) +jitted = torchax.interop.jax_jit(func_call) start = time.perf_counter() print(func_call(model.state_dict(), example)) end = time.perf_counter() diff --git a/experimental/torch_xla2/examples/train_gpt/requirements.txt b/torchax/examples/train_gpt/requirements.txt similarity index 100% rename from experimental/torch_xla2/examples/train_gpt/requirements.txt rename to torchax/examples/train_gpt/requirements.txt diff --git a/experimental/torch_xla2/examples/train_gpt/train_ddp.py b/torchax/examples/train_gpt/train_ddp.py similarity index 93% rename from experimental/torch_xla2/examples/train_gpt/train_ddp.py rename to torchax/examples/train_gpt/train_ddp.py index 60a8432fc6cb..1538cd096504 100644 --- a/experimental/torch_xla2/examples/train_gpt/train_ddp.py +++ b/torchax/examples/train_gpt/train_ddp.py @@ -5,7 +5,7 @@ https://github.com/karpathy/nanoGPT Example command (single host): -torchrun --standalone xla/experimental/torch_xla2/examples/train_gpt/train_ddp.py +torchrun --standalone xla/experimental/torchax/examples/train_gpt/train_ddp.py Tested on a TPU v4-8 """ @@ -17,7 +17,7 @@ import torch.utils.data.distributed import torch.distributed as dist import torch.optim as optim -import torch_xla2 +import torchax from tqdm import tqdm from mingpt.model import GPT from datasets import load_dataset @@ -29,8 +29,8 @@ def _checkpoint(jax_model, path: pathlib.Path): torch.save( torch_pytree.tree_map_only( - torch_xla2.tensor.XLATensor2, - torch_xla2.tensor.XLATensor2.torch, + torchax.tensor.Tensor, + torchax.tensor.Tensor.torch, jax_model.state_dict(), ), path, @@ -66,7 +66,7 @@ def group_texts(exs): group_texts, batched=True, remove_columns=["text", "ids"], num_proc=16 ) dataset.shard(dist.get_world_size(), dist.get_rank()) - env = torch_xla2.default_env() + env = torchax.default_env() print(jax.device_count(), "devices") @@ -92,7 +92,7 @@ def create_model(): "checkpoints" ) / datetime.datetime.now().strftime("%Y%m%d_%H%M%S") checkpoint_subdir.mkdir(parents=True) - jax_model = torch_xla2.distributed.DistributedDataParallel( + jax_model = torchax.distributed.DistributedDataParallel( create_model(), env ) diff --git a/experimental/torch_xla2/examples/train_llama/README.md b/torchax/examples/train_llama/README.md similarity index 97% rename from experimental/torch_xla2/examples/train_llama/README.md rename to torchax/examples/train_llama/README.md index b390953b813a..fb2dbb2b5289 100644 --- a/experimental/torch_xla2/examples/train_llama/README.md +++ b/torchax/examples/train_llama/README.md @@ -107,7 +107,7 @@ class FSDPv2(torch.nn.Module): return self.shard(res) def shard(self, x): - return torch_xla2.interop.call_jax( + return torchax.interop.call_jax( jax.lax.with_sharding_constraint, x, self.sharding, @@ -137,8 +137,8 @@ def scaled_dot_product_attention( return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal, scale) ``` -this implementation is located in [jtorch.py](../../torch_xla2/ops/jtorch.py) in -torch_xla2. The model itself does not need to change to use TPU version of +this implementation is located in [jtorch.py](../../torchax/ops/jtorch.py) in +torchax. The model itself does not need to change to use TPU version of flash attention, because it's calling pytorch's `F.scaled_dot_product_attention`. ## Misc optimizations diff --git a/experimental/torch_xla2/examples/train_llama/__init__.py b/torchax/examples/train_llama/__init__.py similarity index 100% rename from experimental/torch_xla2/examples/train_llama/__init__.py rename to torchax/examples/train_llama/__init__.py diff --git a/experimental/torch_xla2/examples/train_llama/model.py b/torchax/examples/train_llama/model.py similarity index 100% rename from experimental/torch_xla2/examples/train_llama/model.py rename to torchax/examples/train_llama/model.py diff --git a/experimental/torch_xla2/examples/train_llama/train_llama_lightning.py b/torchax/examples/train_llama/train_llama_lightning.py similarity index 96% rename from experimental/torch_xla2/examples/train_llama/train_llama_lightning.py rename to torchax/examples/train_llama/train_llama_lightning.py index 1f75cd57a240..2264b1cbb3a9 100644 --- a/experimental/torch_xla2/examples/train_llama/train_llama_lightning.py +++ b/torchax/examples/train_llama/train_llama_lightning.py @@ -10,7 +10,7 @@ from jax.experimental import shard_map import torch.nn.functional -import torch_xla2.interop +import torchax.interop from . import utils from . import model as editted_model @@ -70,7 +70,7 @@ def one_layer(weights, args): self.gpt_orig.transformer.h[0], weights, args) - self.one_layer = torch_xla2.interop.jax_jit(one_layer) + self.one_layer = torchax.interop.jax_jit(one_layer) def forward(self, idx: torch.Tensor, input_pos=None) -> torch.Tensor: @@ -116,9 +116,9 @@ def one_layer(args, weights): if self.manual_all_gather: weights, cos, sin = jax.lax.all_gather((weights, cos, sin), 'fsdp', tiled=True) args = (x, cos, sin, mask, input_pos) - args, weights = torch_xla2.default_env().j2t_iso((args, weights)) + args, weights = torchax.default_env().j2t_iso((args, weights)) res = torch.func.functional_call(one_block, weights, args) - res = torch_xla2.default_env().t2j_iso(res) + res = torchax.default_env().t2j_iso(res) return (res, *orig_args[1:]), jnp.array([0]) if self.manual_all_gather: @@ -185,7 +185,7 @@ def forward_with_weights(self, weights, idx): args = (x, cos, sin, mask, None) #import pdb; pdb.set_trace() - x = torch_xla2.interop.call_jax( + x = torchax.interop.call_jax( self.compiled_block, weights['block'], args, @@ -208,7 +208,7 @@ def forward_with_weights(self, weights, idx): import logging -import torch_xla2 +import torchax # Modes: @@ -231,7 +231,7 @@ def main_one( print(f"Running with parameters {locals()}") utils.SEQLEN = seqlen utils.BATCH = batch_size - env = torch_xla2.default_env() + env = torchax.default_env() env.config.use_tpu_flash_attention = use_flash_attention cfg = config.Config.from_name("Meta-Llama-3-8B") cfg.n_layer = n_layers diff --git a/experimental/torch_xla2/examples/train_llama/utils.py b/torchax/examples/train_llama/utils.py similarity index 97% rename from experimental/torch_xla2/examples/train_llama/utils.py rename to torchax/examples/train_llama/utils.py index 77cf2f66b5e4..f2b7948bdaa1 100644 --- a/experimental/torch_xla2/examples/train_llama/utils.py +++ b/torchax/examples/train_llama/utils.py @@ -1,7 +1,7 @@ from typing import Tuple import time -import torch_xla2 -from torch_xla2.interop import jax_view, torch_view, JittableModule +import torchax +from torchax.interop import jax_view, torch_view, JittableModule import jax import optax import jax @@ -77,7 +77,7 @@ def forward(self, *args): return self.shard(res) def shard(self, x): - return torch_xla2.interop.call_jax( + return torchax.interop.call_jax( jax.lax.with_sharding_constraint, x, self.sharding, @@ -107,7 +107,7 @@ def torch_opt_to_jax_opt(self, torch_opt): return optax.adamw(0.01) def fit_model_fori(self, gpt_mod, data_loader): - xla_env = torch_xla2.default_env() + xla_env = torchax.default_env() jax.config.update('jax_enable_x64', False) xla_env._mesh = self.mesh xla_env.use_flash_attention = True @@ -130,7 +130,7 @@ def fit_model_fori(self, gpt_mod, data_loader): def loss(jax_params, data): data = jax.lax.with_sharding_constraint(data, self.x_sharding) # fsdpv2 x, y = data - res = torch_xla2.interop.call_torch( + res = torchax.interop.call_torch( gpt_mod.forward_with_weights, jax_params, x) res = jax.lax.with_sharding_constraint(res, self.x_sharding) return jnp.mean( @@ -194,7 +194,7 @@ def _shard_fsdp_style(self, state_dict, sharding=None): if sharding is None: sharding = self.x_sharding def move_one_tensor(x): - jval = torch_xla2.tensor.t2j(x) + jval = torchax.tensor.t2j(x) return sharded_device_put(jval, sharding) if isinstance(state_dict, torch.Tensor): @@ -206,7 +206,7 @@ def move_one_tensor(x): def fit(self, lightning_mod, data_loader): - xla_env = torch_xla2.default_env() + xla_env = torchax.default_env() jax.config.update('jax_enable_x64', False) xla_env._mesh = self.mesh xla_env.use_flash_attention = True diff --git a/experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile b/torchax/examples/train_llama_torchtitan/Dockerfile similarity index 97% rename from experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile rename to torchax/examples/train_llama_torchtitan/Dockerfile index 21a27cbeddbb..f1f5575f247f 100644 --- a/experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile +++ b/torchax/examples/train_llama_torchtitan/Dockerfile @@ -27,7 +27,7 @@ RUN pip install . WORKDIR / RUN git clone https://github.com/pytorch/xla.git -WORKDIR xla/experimental/torch_xla2 +WORKDIR xla/experimental/torchax RUN pip install -e . ENTRYPOINT ["python", "examples/train_llama_torchtitan/train_llama.py"] diff --git a/experimental/torch_xla2/examples/train_llama_torchtitan/README.md b/torchax/examples/train_llama_torchtitan/README.md similarity index 96% rename from experimental/torch_xla2/examples/train_llama_torchtitan/README.md rename to torchax/examples/train_llama_torchtitan/README.md index 4f5e6b7745a4..52192d1371b8 100644 --- a/experimental/torch_xla2/examples/train_llama_torchtitan/README.md +++ b/torchax/examples/train_llama_torchtitan/README.md @@ -27,7 +27,7 @@ pip install . cd ~ git clone https://github.com/pytorch/xla.git -cd xla/experimental/torch_xla2 +cd xla/experimental/torchax pip install -e . ``` @@ -39,7 +39,7 @@ NOTE: these flags are copied from https://github.com/AI-Hypercomputer/maxtext/bl Tested locally on v6e-8 doesnt seems to make a difference. ```bash -cd ~/xla/experimental/torch_xla2/examples/train_llama_torchtitan +cd ~/xla/experimental/torchax/examples/train_llama_torchtitan python train_llama.py --seqlen=8192 ``` @@ -60,10 +60,10 @@ from torch.utils import _pytree as pytree import splash_attn import helper -import torch_xla2 as tx -import torch_xla2.interop -import torch_xla2.train -from torch_xla2.interop import jax_view, torch_view, JittableModule +import torchax as tx +import torchax.interop +import torchax.train +from torchax.interop import jax_view, torch_view, JittableModule import jax import jax.numpy as jnp from jax.experimental import shard_map @@ -176,7 +176,7 @@ class Trainer: self.replicated = jax.sharding.NamedSharding(self.mesh, P()) def fit(self, model, loss_fn, data_loader): - xla_env = torch_xla2.default_env() + xla_env = torchax.default_env() jax.config.update('jax_enable_x64', False) xla_env._mesh = self.mesh xla_env.use_flash_attention = True @@ -192,7 +192,7 @@ class Trainer: jax_optimizer = optax.sgd(0.01) opt_state = torch_view(jax_optimizer.init(jax_view(jittable_mod.params))) - train_step = torch_xla2.train.make_train_step( + train_step = torchax.train.make_train_step( model_fn, loss_fn, jax_optimizer, remat_policy=jax.checkpoint_policies.offload_dot_with_no_batch_dims('device', 'pinned_host'), mark_fsdp_sharding_axis='fsdp') @@ -224,7 +224,7 @@ class Trainer: loss, jittable_mod.params, opt_state = train_step( jittable_mod.params, jittable_mod.buffers, opt_state, inputs, labels) # wait for iteration to finish to measure time - torch_xla2.interop.call_jax(jax.block_until_ready, (loss, jittable_mod.params)) + torchax.interop.call_jax(jax.block_until_ready, (loss, jittable_mod.params)) step_end = time.perf_counter() print(i, 'loss', loss, 'step latency: ', step_end - step_start) loop_time = step_end - step_start @@ -308,7 +308,7 @@ This is a helper to process names in sharding map ```python def create_sharded_weights(model, mesh, sharding_map): res = {} - env = torch_xla2.default_env() + env = torchax.default_env() for name, weight_meta in model.state_dict().items(): sharding_spec = sharding_map.get(_process_sharding_name(name)) if sharding_spec is None: @@ -319,7 +319,7 @@ def create_sharded_weights(model, mesh, sharding_map): weight_torch = torch.randn( weight_meta.shape, dtype=weight_meta.dtype) - weight_jax = torch_xla2.default_env().to_xla(weight_torch).jax() + weight_jax = torchax.default_env().to_xla(weight_torch).jax() #print(name, weight.shape, weight.dtype) res[name] = env.j2t_iso(jax.make_array_from_callback( weight_jax.shape, sharding, lambda a: weight_jax[a] @@ -352,8 +352,8 @@ def main( use_scan = True, tp_parallelism=1, ): - torch_xla2.enable_globally() - torch_xla2.enable_performance_mode() + torchax.enable_globally() + torchax.enable_performance_mode() #logging.getLogger("jax").setLevel(logging.DEBUG) print(f"Running with parameters {locals()}") @@ -374,7 +374,7 @@ split into fsdp x tp 2D array. Tensors will be sharded on those 2 axis Scan is implemented as the `TransformerWithScan` below. ```python - env = torch_xla2.default_env() + env = torchax.default_env() env.config.use_tpu_flash_attention = True env.config.shmap_flash_attention = True env._mesh = mesh # this is the mesh used by flash attention pallas kernel @@ -479,7 +479,7 @@ class TransfomerWithScan(torch.nn.Module): self.tok_embeddings = old_transformer.tok_embeddings self.norm = old_transformer.norm self.output = old_transformer.output - self.layers = torch_xla2.train.ScannedModule(list(old_transformer.layers.values()), checkpoint_policy) + self.layers = torchax.train.ScannedModule(list(old_transformer.layers.values()), checkpoint_policy) self.register_buffer('freqs_cis', old_transformer.freqs_cis) diff --git a/experimental/torch_xla2/examples/train_llama_torchtitan/__init__.py b/torchax/examples/train_llama_torchtitan/__init__.py similarity index 100% rename from experimental/torch_xla2/examples/train_llama_torchtitan/__init__.py rename to torchax/examples/train_llama_torchtitan/__init__.py diff --git a/experimental/torch_xla2/examples/train_llama_torchtitan/helper.py b/torchax/examples/train_llama_torchtitan/helper.py similarity index 97% rename from experimental/torch_xla2/examples/train_llama_torchtitan/helper.py rename to torchax/examples/train_llama_torchtitan/helper.py index e12785aa9931..ae5b99788817 100644 --- a/experimental/torch_xla2/examples/train_llama_torchtitan/helper.py +++ b/torchax/examples/train_llama_torchtitan/helper.py @@ -2,7 +2,7 @@ import jax from jax.tree_util import tree_map from jax.sharding import NamedSharding -from torch_xla2 import interop +from torchax import interop P = jax.sharding.PartitionSpec diff --git a/experimental/torch_xla2/examples/train_llama_torchtitan/splash_attn.py b/torchax/examples/train_llama_torchtitan/splash_attn.py similarity index 100% rename from experimental/torch_xla2/examples/train_llama_torchtitan/splash_attn.py rename to torchax/examples/train_llama_torchtitan/splash_attn.py diff --git a/experimental/torch_xla2/examples/train_llama_torchtitan/train_llama.py b/torchax/examples/train_llama_torchtitan/train_llama.py similarity index 94% rename from experimental/torch_xla2/examples/train_llama_torchtitan/train_llama.py rename to torchax/examples/train_llama_torchtitan/train_llama.py index 6a82e1fd83aa..530af71d652f 100644 --- a/experimental/torch_xla2/examples/train_llama_torchtitan/train_llama.py +++ b/torchax/examples/train_llama_torchtitan/train_llama.py @@ -10,10 +10,10 @@ import splash_attn import helper -import torch_xla2 as tx -import torch_xla2.interop -import torch_xla2.train -from torch_xla2.interop import jax_view, torch_view, JittableModule +import torchax as tx +import torchax.interop +import torchax.train +from torchax.interop import jax_view, torch_view, JittableModule import jax import jax.numpy as jnp from jax.experimental import shard_map @@ -105,7 +105,7 @@ def __init__(self, mesh): self.replicated = jax.sharding.NamedSharding(self.mesh, P()) def fit(self, model, loss_fn, data_loader): - xla_env = torch_xla2.default_env() + xla_env = torchax.default_env() jax.config.update('jax_enable_x64', False) xla_env._mesh = self.mesh xla_env.use_flash_attention = True @@ -124,9 +124,9 @@ def model_fn(weights, buffers, args): jax_optimizer = optax.sgd(0.01) opt_state = torch_view(jax_optimizer.init(jax_view(jittable_mod.params))) - #opt_state = torch_xla2.interop.call_jax(jax_optimizer.init, jittable_mod.params) + #opt_state = torchax.interop.call_jax(jax_optimizer.init, jittable_mod.params) - train_step = torch_xla2.train.make_train_step( + train_step = torchax.train.make_train_step( model_fn, loss_fn, jax_optimizer, remat_policy=jax.checkpoint_policies.offload_dot_with_no_batch_dims('device', 'pinned_host')) @@ -157,7 +157,7 @@ def model_fn(weights, buffers, args): loss, jittable_mod.params, opt_state = train_step( jittable_mod.params, jittable_mod.buffers, opt_state, inputs, labels) # wait for iteration to finish to measure time - torch_xla2.interop.call_jax(jax.block_until_ready, (loss, jittable_mod.params)) + torchax.interop.call_jax(jax.block_until_ready, (loss, jittable_mod.params)) step_end = time.perf_counter() print(i, 'loss', loss, 'step latency: ', step_end - step_start) loop_time = step_end - step_start @@ -192,7 +192,7 @@ def is_integer(t): def create_sharded_weights(model, mesh, sharding_map): res = {} - env = torch_xla2.default_env() + env = torchax.default_env() for name, weight_meta in model.state_dict().items(): sharding_spec = sharding_map.get(_process_sharding_name(name)) if sharding_spec is None: @@ -203,7 +203,7 @@ def create_sharded_weights(model, mesh, sharding_map): weight_torch = torch.randn( weight_meta.shape, dtype=weight_meta.dtype) - weight_jax = torch_xla2.default_env().to_xla(weight_torch).jax() + weight_jax = torchax.default_env().to_xla(weight_torch).jax() #print(name, weight.shape, weight.dtype) res[name] = env.j2t_iso(jax.make_array_from_callback( weight_jax.shape, sharding, lambda a: weight_jax[a] @@ -225,8 +225,8 @@ def main( use_scan = True, tp_parallelism=1, ): - torch_xla2.enable_globally() - torch_xla2.enable_performance_mode() + torchax.enable_globally() + torchax.enable_performance_mode() #logging.getLogger("jax").setLevel(logging.DEBUG) print(f"Running with parameters {locals()}") @@ -238,7 +238,7 @@ def main( else: sharding_map = sharding_map_original - env = torch_xla2.default_env() + env = torchax.default_env() env.config.use_tpu_flash_attention = True env.config.shmap_flash_attention = True env._mesh = mesh # this is the mesh used by flash attention pallas kernel @@ -316,7 +316,7 @@ def __init__(self, old_transformer, checkpoint_policy): self.tok_embeddings = old_transformer.tok_embeddings self.norm = old_transformer.norm self.output = old_transformer.output - self.layers = torch_xla2.train.ScannedModule(list(old_transformer.layers.values()), checkpoint_policy) + self.layers = torchax.train.ScannedModule(list(old_transformer.layers.values()), checkpoint_policy) self.register_buffer('freqs_cis', old_transformer.freqs_cis) diff --git a/torchax/format.sh b/torchax/format.sh new file mode 100755 index 000000000000..9b9663294ca4 --- /dev/null +++ b/torchax/format.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -ex + +yapf --recursive -i *.py test torchax \ No newline at end of file diff --git a/experimental/torch_xla2/pyproject.toml b/torchax/pyproject.toml similarity index 91% rename from experimental/torch_xla2/pyproject.toml rename to torchax/pyproject.toml index 676ced51dcdf..b32b110b4ed3 100644 --- a/experimental/torch_xla2/pyproject.toml +++ b/torchax/pyproject.toml @@ -3,7 +3,7 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -name = "torch_xla2" +name = "torchax" dependencies = [ "absl-py", "immutabledict", @@ -16,7 +16,7 @@ license = {file = "LICENSE"} dynamic = ["version"] [tool.hatch.version] -path = "torch_xla2/__init__.py" +path = "torchax/__init__.py" [project.optional-dependencies] cpu = ["jax[cpu]>=0.4.30", "jax[cpu]", "tensorflow-cpu"] @@ -26,7 +26,7 @@ cuda = ["jax[cpu]>=0.4.30", "jax[cuda12]", "tensorflow-cpu"] odml = ["jax[cpu]>=0.4.30", "jax[cpu]"] [tool.hatch.build.targets.wheel] -packages = ["torch_xla2"] +packages = ["torchax"] [tool.pytest.ini_options] addopts="-n auto" diff --git a/experimental/torch_xla2/test-requirements.txt b/torchax/test-requirements.txt similarity index 87% rename from experimental/torch_xla2/test-requirements.txt rename to torchax/test-requirements.txt index 3b0e1f2df61c..2819504c2543 100644 --- a/experimental/torch_xla2/test-requirements.txt +++ b/torchax/test-requirements.txt @@ -3,6 +3,7 @@ absl-py immutabledict pytest pytest-xdist +pytest-forked sentencepiece expecttest optax diff --git a/experimental/torch_xla2/test/BUILD b/torchax/test/BUILD similarity index 100% rename from experimental/torch_xla2/test/BUILD rename to torchax/test/BUILD diff --git a/experimental/torch_xla2/test/__init__.py b/torchax/test/__init__.py similarity index 100% rename from experimental/torch_xla2/test/__init__.py rename to torchax/test/__init__.py diff --git a/experimental/torch_xla2/test/gemma/__init__.py b/torchax/test/gemma/__init__.py similarity index 100% rename from experimental/torch_xla2/test/gemma/__init__.py rename to torchax/test/gemma/__init__.py diff --git a/experimental/torch_xla2/test/gemma/config.py b/torchax/test/gemma/config.py similarity index 100% rename from experimental/torch_xla2/test/gemma/config.py rename to torchax/test/gemma/config.py diff --git a/experimental/torch_xla2/test/gemma/model.py b/torchax/test/gemma/model.py similarity index 100% rename from experimental/torch_xla2/test/gemma/model.py rename to torchax/test/gemma/model.py diff --git a/experimental/torch_xla2/test/gemma/test_gemma.py b/torchax/test/gemma/test_gemma.py similarity index 95% rename from experimental/torch_xla2/test/gemma/test_gemma.py rename to torchax/test/gemma/test_gemma.py index 4d91bc6f9b0f..38e60f75a8f5 100644 --- a/experimental/torch_xla2/test/gemma/test_gemma.py +++ b/torchax/test/gemma/test_gemma.py @@ -1,6 +1,6 @@ import torch import unittest -import torch_xla2 +import torchax from torch.utils import _pytree as pytree from . import config from . import model as gemma @@ -72,9 +72,9 @@ def test_gemma(self): top_ks_tensor, ) - weights, jax_func = torch_xla2.extract_jax(model) + weights, jax_func = torchax.extract_jax(model) inputs_jax = pytree.tree_map_only( - torch.Tensor, torch_xla2.tensor.t2j, inputs) + torch.Tensor, torchax.tensor.t2j, inputs) import jax print(jax.jit(jax_func)(weights, inputs_jax)) diff --git a/experimental/torch_xla2/test/gemma/tokenizer.py b/torchax/test/gemma/tokenizer.py similarity index 100% rename from experimental/torch_xla2/test/gemma/tokenizer.py rename to torchax/test/gemma/tokenizer.py diff --git a/experimental/torch_xla2/test/llama/BUILD b/torchax/test/llama/BUILD similarity index 100% rename from experimental/torch_xla2/test/llama/BUILD rename to torchax/test/llama/BUILD diff --git a/experimental/torch_xla2/test/llama/__init__.py b/torchax/test/llama/__init__.py similarity index 100% rename from experimental/torch_xla2/test/llama/__init__.py rename to torchax/test/llama/__init__.py diff --git a/experimental/torch_xla2/test/llama/llama_model.py b/torchax/test/llama/llama_model.py similarity index 100% rename from experimental/torch_xla2/test/llama/llama_model.py rename to torchax/test/llama/llama_model.py diff --git a/experimental/torch_xla2/test/llama/model_exportable.py b/torchax/test/llama/model_exportable.py similarity index 100% rename from experimental/torch_xla2/test/llama/model_exportable.py rename to torchax/test/llama/model_exportable.py diff --git a/experimental/torch_xla2/test/llama/test_llama.py b/torchax/test/llama/test_llama.py similarity index 91% rename from experimental/torch_xla2/test/llama/test_llama.py rename to torchax/test/llama/test_llama.py index a47e8572186f..41edaafba630 100644 --- a/experimental/torch_xla2/test/llama/test_llama.py +++ b/torchax/test/llama/test_llama.py @@ -1,7 +1,7 @@ import torch -from torch_xla2 import tensor # pylint: disable=unused-import -import torch_xla2 -import torch_xla2.export +from torchax import tensor # pylint: disable=unused-import +import torchax +import torchax.export from .. import test_base from . import llama_model @@ -12,7 +12,7 @@ class LlamaTest(test_base.TestCase): def test_can_run(self): - with torch_xla2.default_env(): + with torchax.default_env(): sample_args = ( torch.randint(0, 32000, (1, 2048), device='jax:0'), torch.arange(0, 2048, device='jax:0'), @@ -88,7 +88,7 @@ def make_cache(args, batch_size): with torch.no_grad(): m_prefill = torch.export.export(m, sample_input_prefill) - weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill) + weights, mj_prefill = torchax.export.exported_program_to_jax(m_prefill) sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, sample_input_prefill) print('Prefill', mj_prefill(weights, sample_inputs)) @@ -103,7 +103,7 @@ def make_cache(args, batch_size): ) with torch.no_grad(): m_decode = torch.export.export(m, sample_input_decode) - weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode) + weights, mj_decode = torchax.export.exported_program_to_jax(m_decode) sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, sample_input_decode) print('Decode', mj_decode(weights, sample_inputs)) diff --git a/experimental/torch_xla2/test/moe/__init__.py b/torchax/test/moe/__init__.py similarity index 100% rename from experimental/torch_xla2/test/moe/__init__.py rename to torchax/test/moe/__init__.py diff --git a/experimental/torch_xla2/test/moe/model.py b/torchax/test/moe/model.py similarity index 100% rename from experimental/torch_xla2/test/moe/model.py rename to torchax/test/moe/model.py diff --git a/experimental/torch_xla2/test/moe/moe_test.py b/torchax/test/moe/moe_test.py similarity index 90% rename from experimental/torch_xla2/test/moe/moe_test.py rename to torchax/test/moe/moe_test.py index f8d4a22e3f2c..ecb259585fe1 100644 --- a/experimental/torch_xla2/test/moe/moe_test.py +++ b/torchax/test/moe/moe_test.py @@ -1,5 +1,5 @@ -import torch_xla2 -import torch_xla2.interop +import torchax +import torchax.interop import torch import unittest import jax @@ -46,12 +46,12 @@ def test_moe_layer(self): x = torch.randn((seqlen, model_args.dim)) res = moe_layer(x) - env = torch_xla2.default_env() + env = torchax.default_env() model_xla = env.to_xla(moe_layer) x_xla = env.to_xla(x) with jax.default_matmul_precision('float32'): res_xla = model_xla(x_xla) - res2 = torch_xla2.tensor.j2t(res_xla._elem) + res2 = torchax.tensor.j2t(res_xla._elem) print('max diff', torch.max((res - res2).abs())) self.assertTrue( @@ -62,7 +62,7 @@ def test_moe_layer(self): def f(weights, x): return torch.func.functional_call(moe_layer, weights, (x, )) - fjitted = torch_xla2.interop.jax_jit(f) + fjitted = torchax.interop.jax_jit(f) weights_xla = env.to_xla(moe_layer.state_dict()) print(fjitted(weights_xla, x_xla)) diff --git a/experimental/torch_xla2/test/test_base.py b/torchax/test/test_base.py similarity index 88% rename from experimental/torch_xla2/test/test_base.py rename to torchax/test/test_base.py index d8b409306b76..f155b78f67a9 100644 --- a/experimental/torch_xla2/test/test_base.py +++ b/torchax/test/test_base.py @@ -2,7 +2,7 @@ import torch from torch.utils import _pytree as pytree -from torch_xla2 import tensor +from torchax import tensor TestCase = unittest.TestCase main = unittest.main @@ -36,12 +36,12 @@ def run_function_and_compare(testcase, ignore_indices=False): with testcase.subTest("torch_eval"): res = func(*args, **kwargs) - with testcase.subTest("torch_xla2_eval"): + with testcase.subTest("torchax_eval"): args2, kwargs2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device, (args, kwargs)) res2 = func(*args2, **kwargs2) - res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) - with testcase.subTest("torch_xla2_diff:" + str(atol)): + res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2) + with testcase.subTest("torchax_diff:" + str(atol)): if ignore_indices and isinstance(res, tuple) and len(res) == 2: diff_output( testcase, @@ -52,4 +52,4 @@ def run_function_and_compare(testcase, equal_nan=equal_nan) else: diff_output( - testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan) + testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan) \ No newline at end of file diff --git a/experimental/torch_xla2/test/test_context.py b/torchax/test/test_context.py similarity index 63% rename from experimental/torch_xla2/test/test_context.py rename to torchax/test/test_context.py index 5255f415ee1e..318795736727 100644 --- a/experimental/torch_xla2/test/test_context.py +++ b/torchax/test/test_context.py @@ -1,9 +1,9 @@ import unittest import torch -import torch_xla2 -from torch_xla2 import tensor -import torch_xla2.interop +import torchax +from torchax import tensor +import torchax.interop xla_env = tensor.Environment() @@ -20,9 +20,9 @@ def tearDown(self): def test_mode_context_manager(self): with xla_env: x = torch.full((3, 3), -1) - self.assertIsInstance(x, tensor.XLATensor2) + self.assertIsInstance(x, tensor.Tensor) y = x.abs() - self.assertIsInstance(y, tensor.XLATensor2) + self.assertIsInstance(y, tensor.Tensor) @staticmethod @xla_env @@ -34,32 +34,32 @@ def _test_mode_decorator(): def test_mode_decorator(self): x, y = self._test_mode_decorator() - self.assertIsInstance(x, tensor.XLATensor2) - self.assertIsInstance(y, tensor.XLATensor2) + self.assertIsInstance(x, tensor.Tensor) + self.assertIsInstance(y, tensor.Tensor) def test_same_manual_seed(self): with xla_env: torch.manual_seed(1234) x = torch.randn((3, 3)) - self.assertIsInstance(x, tensor.XLATensor2) + self.assertIsInstance(x, tensor.Tensor) torch.manual_seed(1234) y = torch.randn((3, 3)) - self.assertIsInstance(y, tensor.XLATensor2) + self.assertIsInstance(y, tensor.Tensor) - self.assertTrue(torch.equal(torch_xla2.tensor.j2t(x._elem), torch_xla2.tensor.j2t(y._elem))) + self.assertTrue(torch.equal(torchax.tensor.j2t(x._elem), torchax.tensor.j2t(y._elem))) def test_different_manual_seed(self): with xla_env: torch.manual_seed(1234) x = torch.randn((3, 3)) - self.assertIsInstance(x, tensor.XLATensor2) + self.assertIsInstance(x, tensor.Tensor) torch.manual_seed(12345) y = torch.randn((3, 3)) - self.assertIsInstance(y, tensor.XLATensor2) + self.assertIsInstance(y, tensor.Tensor) - self.assertFalse(torch.equal(torch_xla2.tensor.j2t(x._elem), torch_xla2.tensor.j2t(y._elem))) + self.assertFalse(torch.equal(torchax.tensor.j2t(x._elem), torchax.tensor.j2t(y._elem))) def test_jit_with_rng(self): @xla_env @@ -68,14 +68,14 @@ def random_op(): y = torch.randn(3, 3) return x @ y - random_jit = torch_xla2.interop.jax_jit(random_op) - self.assertIsInstance(random_jit(), tensor.XLATensor2) + random_jit = torchax.interop.jax_jit(random_op) + self.assertIsInstance(random_jit(), tensor.Tensor) # Result always expected to be the same for a jitted function because seeds # are baked in torch.testing.assert_close( - torch_xla2.tensor.j2t(random_jit()._elem), - torch_xla2.tensor.j2t(random_jit()._elem), + torchax.tensor.j2t(random_jit()._elem), + torchax.tensor.j2t(random_jit()._elem), atol=0, rtol=0) @@ -86,7 +86,7 @@ def test_generator_seed(self): # Values will be different, but still check device, layout, dtype, etc torch.testing.assert_close( - torch_xla2.tensor.j2t(x._elem), torch_xla2.tensor.j2t(y._elem)) + torchax.tensor.j2t(x._elem), torchax.tensor.j2t(y._elem)) def test_buffer(self): @@ -101,13 +101,13 @@ def __init__(self): # Test context manager. with xla_env: m = M() - self.assertIsInstance(m.c, tensor.XLATensor2) - self.assertIsInstance(m.c2, tensor.XLATensor2) + self.assertIsInstance(m.c, tensor.Tensor) + self.assertIsInstance(m.c2, tensor.Tensor) # Test `to_xla`. m = M() m = xla_env.to_xla(m) - self.assertIsInstance(m.c, tensor.XLATensor2) - self.assertIsInstance(m.c2, tensor.XLATensor2) + self.assertIsInstance(m.c, tensor.Tensor) + self.assertIsInstance(m.c2, tensor.Tensor) if __name__ == "__main__": diff --git a/experimental/torch_xla2/test/test_conv.py b/torchax/test/test_conv.py similarity index 85% rename from experimental/torch_xla2/test/test_conv.py rename to torchax/test/test_conv.py index 56b0ea83c5c2..087b973239a5 100644 --- a/experimental/torch_xla2/test/test_conv.py +++ b/torchax/test/test_conv.py @@ -1,6 +1,6 @@ import torch from torch import nn -import torch_xla2 +import torchax from . import test_base @@ -59,20 +59,20 @@ def test_conv1(self): arg = torch.randn((20, 1, 50)) res = m(arg) - jax_weights, jax_func = torch_xla2.extract_jax(m) - arg = torch_xla2.tensor.t2j(arg) + jax_weights, jax_func = torchax.extract_jax(m) + arg = torchax.tensor.t2j(arg) res2 = jax_func(jax_weights, (arg,)) - res2_torch = torch_xla2.tensor.j2t(res2) + res2_torch = torchax.tensor.j2t(res2) self.assertTrue(torch.allclose(res, res2_torch)) def test_conv2(self): m = CustomConv2() arg = torch.randn((20, 4, 50, 100)) res = m(arg) - jax_weights, jax_func = torch_xla2.extract_jax(m) - arg = torch_xla2.tensor.t2j(arg) + jax_weights, jax_func = torchax.extract_jax(m) + arg = torchax.tensor.t2j(arg) res2 = jax_func(jax_weights, (arg,)) - res2_torch = torch_xla2.tensor.j2t(res2) + res2_torch = torchax.tensor.j2t(res2) self.assertTrue(torch.allclose(res, res2_torch, atol=1e-4, rtol=1e-4)) diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/torchax/test/test_core_aten_ops.py similarity index 99% rename from experimental/torch_xla2/test/test_core_aten_ops.py rename to torchax/test/test_core_aten_ops.py index a5806ddb266c..ca0c5c15a9f4 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/torchax/test/test_core_aten_ops.py @@ -2,7 +2,7 @@ import unittest import torch -from torch_xla2 import tensor +from torchax import tensor from . import test_base from torch.utils import _pytree as pytree @@ -35,13 +35,13 @@ def run_export_and_compare(testcase, with testcase.subTest("torch_eval"): res = func(*args, **kwargs) - with testcase.subTest("torch_xla2_eval"): + with testcase.subTest("torchax_eval"): args2, kwargs2 = testcase.env.to_xla((args, kwargs)) with testcase.env: res2 = func(*args2, **kwargs2) - res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) + res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2) # import pdb; pdb.set_trace() - with testcase.subTest("torch_xla2_diff:" + str(atol)): + with testcase.subTest("torchax_diff:" + str(atol)): if ignore_indices and isinstance(res, tuple) and len(res) == 2: diff_output( testcase, @@ -3812,7 +3812,7 @@ def test_aten__softmax_1(self): def _compare_sorted_result(self, args): res = torch.ops.aten.sort(*args) - with self.subTest("torch_xla2_eval"): + with self.subTest("torchax_eval"): args2 = self.env.to_xla(args) with self.env: res2 = torch.ops.aten.sort(*args2) diff --git a/experimental/torch_xla2/test/test_exports.py b/torchax/test/test_exports.py similarity index 86% rename from experimental/torch_xla2/test/test_exports.py rename to torchax/test/test_exports.py index 1123a289db01..b24258426e38 100644 --- a/experimental/torch_xla2/test/test_exports.py +++ b/torchax/test/test_exports.py @@ -3,10 +3,10 @@ import torch.nn.functional as F import jax import jax.export -import torch_xla2 -import torch_xla2.export -from torch_xla2 import tensor -from torch_xla2.ops import mappings +import torchax +import torchax.export +from torchax import tensor +from torchax.ops import mappings class Interpolate(torch.nn.Module): @@ -34,7 +34,7 @@ class ExportTest(unittest.TestCase): def setUp(self): torch.manual_seed(0) - torch_xla2.enable_accuracy_mode() + torchax.enable_accuracy_mode() def test_interpolate(self): @@ -45,14 +45,14 @@ def test_interpolate(self): with torch.no_grad(): exported = torch.export.export(model, arg) - weights, func = torch_xla2.export.exported_program_to_jax(exported) + weights, func = torchax.export.exported_program_to_jax(exported) argj = tensor.t2j(arg[0]) ans2 = jax.jit(func)(weights, (argj,))[0] ans2 = tensor.j2t(ans2) self.assertTrue(torch.allclose(ans, ans2, atol=1e-3)) # Convert to StableHLO - weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported) + weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) module_str = str(stablehlo.mlir_module()) self.assertIn("func.func public @main", module_str) self.assertIn("func.func private @clip(%arg0: tensor<500xf32>", module_str) @@ -68,14 +68,14 @@ def test_constant(self): with torch.no_grad(): exported = torch.export.export(model, arg) - weights, func = torch_xla2.export.exported_program_to_jax(exported) + weights, func = torchax.export.exported_program_to_jax(exported) argj = tensor.t2j(arg[0]) ans2 = jax.jit(func)(weights, (argj,))[0] ans2 = tensor.j2t(ans2) self.assertTrue(torch.allclose(ans, ans2, atol=1e-5)) # Convert to StableHLO - weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported) + weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) module_str = str(stablehlo.mlir_module()) self.assertIn("func.func public @main", module_str) self.assertIn("stablehlo.divide", module_str) @@ -89,7 +89,7 @@ def test_interpolate_dynamic(self): with torch.no_grad(): exported = torch.export.export(model, arg, dynamic_shapes=dynamic_shapes) - weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported) + weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) module_str = str(stablehlo.mlir_module()) # Look for dynamic shape artifacts @@ -139,7 +139,7 @@ def test_export_dtypes(self): arg = (torch.randn(10).to(torch_dtype),) with torch.no_grad(): exported = torch.export.export(model, arg) - weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported) + weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) module_str = str(stablehlo.mlir_module()) self.assertIn(DTYPE_TO_MLIR_STR[torch_dtype], module_str) diff --git a/experimental/torch_xla2/test/test_functions.py b/torchax/test/test_functions.py similarity index 82% rename from experimental/torch_xla2/test/test_functions.py rename to torchax/test/test_functions.py index aab34bd1472f..c5e068aad4d8 100644 --- a/experimental/torch_xla2/test/test_functions.py +++ b/torchax/test/test_functions.py @@ -2,16 +2,16 @@ from absl.testing import absltest from absl.testing import parameterized import torch -import torch_xla2 -import torch_xla2.tensor +import torchax +import torchax.tensor class TestTorchFunctions(parameterized.TestCase): def setUp(self): - self.env = torch_xla2.tensor.Environment() + self.env = torchax.tensor.Environment() self.env.config.use_torch_native_for_cpu_tensor = False - torch_xla2.enable_accuracy_mode() + torchax.enable_accuracy_mode() @parameterized.named_parameters( ('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])), @@ -26,9 +26,9 @@ def test_tensor_constructor(self, func: Callable[[], torch.Tensor]): with self.env: actual = func() - self.assertIsInstance(actual, torch_xla2.tensor.XLATensor2) + self.assertIsInstance(actual, torchax.tensor.Tensor) - torch.testing.assert_close(torch_xla2.tensor.j2t(actual._elem), expected) + torch.testing.assert_close(torchax.tensor.j2t(actual._elem), expected) def test_dont_capture_conversion(self): t = torch.tensor([1,2,3]) diff --git a/experimental/torch_xla2/test/test_interop.py b/torchax/test/test_interop.py similarity index 97% rename from experimental/torch_xla2/test/test_interop.py rename to torchax/test/test_interop.py index 285341969057..32427602bc4f 100644 --- a/experimental/torch_xla2/test/test_interop.py +++ b/torchax/test/test_interop.py @@ -1,6 +1,6 @@ import torch import unittest -from torch_xla2 import interop +from torchax import interop class M1(torch.nn.Module): diff --git a/experimental/torch_xla2/test/test_libraries.py b/torchax/test/test_libraries.py similarity index 89% rename from experimental/torch_xla2/test/test_libraries.py rename to torchax/test/test_libraries.py index 5432e5cdd56f..db2e3bbfe69b 100644 --- a/experimental/torch_xla2/test/test_libraries.py +++ b/torchax/test/test_libraries.py @@ -2,10 +2,10 @@ import torch import torch.nn.functional as F from torch.library import Library, impl, impl_abstract -import torch_xla2 -import torch_xla2.export -from torch_xla2.ops import jaten -from torch_xla2.ops import jlibrary +import torchax +import torchax.export +from torchax.ops import jaten +from torchax.ops import jlibrary # Create a `mylib` library which has a basic SDPA op. @@ -54,7 +54,7 @@ class LibraryTest(unittest.TestCase): def setUp(self): torch.manual_seed(0) - torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False + torchax.default_env().config.use_torch_native_for_cpu_tensor = False def test_basic_sdpa_library(self): @@ -70,7 +70,7 @@ def forward(self, q,k,v): args = (arg, arg, arg, ) exported = torch.export.export(model, args=args) - weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported) + weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) module_str = str(stablehlo.mlir_module()) ## TODO Update this machinery from producing function calls to producing diff --git a/experimental/torch_xla2/test/test_mutations.py b/torchax/test/test_mutations.py similarity index 93% rename from experimental/torch_xla2/test/test_mutations.py rename to torchax/test/test_mutations.py index f5385a445c67..ab23623a7cf3 100644 --- a/experimental/torch_xla2/test/test_mutations.py +++ b/torchax/test/test_mutations.py @@ -1,5 +1,5 @@ import unittest -import torch_xla2 +import torchax import torch from torch.testing._internal.common_utils import TestCase @@ -7,7 +7,7 @@ class TestMutations(TestCase): def setUp(self): - self.env = torch_xla2.tensor.Environment() + self.env = torchax.tensor.Environment() def test_add(self): with self.env: diff --git a/experimental/torch_xla2/test/test_ops.py b/torchax/test/test_ops.py similarity index 94% rename from experimental/torch_xla2/test/test_ops.py rename to torchax/test/test_ops.py index e3b686f68ad0..9e7f0ba22f3b 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/torchax/test/test_ops.py @@ -6,8 +6,8 @@ from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, ops) from torch.utils import _pytree as pytree -from torch_xla2 import tensor -import torch_xla2 +from torchax import tensor +import torchax skiplist = { @@ -56,7 +56,7 @@ "__rpow__", # NOTE: cannot fix because torch test case has undefined behavior # such as 0 to negative power. "to_sparse", # We are not supporting sparse tensors yet. - "nn.functional.rrelu", # pure torch result match torch_xla2 test result, only OpInfo mismatch: https://gist.github.com/ManfeiBai/1a449b15f4e946bfcaa3e5ef86da20f4 + "nn.functional.rrelu", # pure torch result match torchax test result, only OpInfo mismatch: https://gist.github.com/ManfeiBai/1a449b15f4e946bfcaa3e5ef86da20f4 } # These inputs are themselves views @@ -139,13 +139,13 @@ def run_export_and_compare(testcase, with testcase.subTest("torch_eval"): res = func(sample_input.input, *sample_input.args, **sample_input.kwargs) - with testcase.subTest("torch_xla2_eval"): + with testcase.subTest("torchax_eval"): input2, args2, kwargs2 = testcase.env.to_xla(( sample_input.input, sample_input.args, sample_input.kwargs)) with testcase.env: res2 = func(input2, *args2, **kwargs2) - res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) - with testcase.subTest("torch_xla2_diff:" + str(atol)): + res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2) + with testcase.subTest("torchax_diff:" + str(atol)): if ignore_indices and isinstance(res, tuple) and len(res) == 2: diff_output( testcase, @@ -181,8 +181,8 @@ def setUpClass(cls): print('op_db size: ', len(op_db), 'testing: ', len(ops_to_test)) def setUp(self): - self.env = torch_xla2.default_env() - torch_xla2.enable_accuracy_mode() + self.env = torchax.default_env() + torchax.enable_accuracy_mode() #self.env.config.debug_accuracy_for_each_op = True torch.manual_seed(0) self.old_var = self.env.config.use_torch_native_for_cpu_tensor diff --git a/experimental/torch_xla2/test/test_symbolic_shapes.py b/torchax/test/test_symbolic_shapes.py similarity index 90% rename from experimental/torch_xla2/test/test_symbolic_shapes.py rename to torchax/test/test_symbolic_shapes.py index ef2c7e9f25c6..89f40f9f7326 100644 --- a/experimental/torch_xla2/test/test_symbolic_shapes.py +++ b/torchax/test/test_symbolic_shapes.py @@ -1,6 +1,6 @@ import torch -import torch_xla2 -import torch_xla2.export +import torchax +import torchax.export from . import test_base class AddOne(torch.nn.Module): @@ -40,7 +40,7 @@ def test_constraints_min_max(self): with torch.no_grad(): exported = torch.export.export(model, args=args, dynamic_shapes=dynamic_shapes) - weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported) + weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) module_str = str(stablehlo.mlir_module()) self.assertRegex(module_str, r"stablehlo.constant.*3") @@ -62,7 +62,7 @@ def test_constraints_multiply(self): with torch.no_grad(): exported = torch.export.export(model, args=args, dynamic_shapes=dynamic_shapes) - weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported) + weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) module_str = str(stablehlo.mlir_module()) self.assertRegex(module_str, r"stablehlo.constant.*10") @@ -84,7 +84,7 @@ def test_constraint_indirection(self): with torch.no_grad(): exported = torch.export.export(model, args=args, dynamic_shapes=dynamic_shapes) - weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported) + weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) module_str = str(stablehlo.mlir_module()) self.assertRegex(module_str, r"shape_assertion.*s[0-9]+ <= 10") diff --git a/experimental/torch_xla2/test/test_tf_integration.py b/torchax/test/test_tf_integration.py similarity index 92% rename from experimental/torch_xla2/test/test_tf_integration.py rename to torchax/test/test_tf_integration.py index 4562ba8cb0c8..f1f457954387 100644 --- a/experimental/torch_xla2/test/test_tf_integration.py +++ b/torchax/test/test_tf_integration.py @@ -4,9 +4,9 @@ import tensorflow as tf import torch import torch.nn.functional as F -import torch_xla2 +import torchax -from torch_xla2 import tf_integration +from torchax import tf_integration from . import test_base @@ -26,7 +26,7 @@ class TfIntegrationTest(test_base.TestCase): def setUp(self): torch.manual_seed(0) - torch_xla2.enable_accuracy_mode() + torchax.enable_accuracy_mode() def test_interpolate(self): """Simple model roundtripped through TF savedmodel""" diff --git a/experimental/torch_xla2/test/test_train.py b/torchax/test/test_train.py similarity index 88% rename from experimental/torch_xla2/test/test_train.py rename to torchax/test/test_train.py index 54b66679faa9..b6909a1ee7f1 100644 --- a/experimental/torch_xla2/test/test_train.py +++ b/torchax/test/test_train.py @@ -1,8 +1,8 @@ import unittest import torch -import torch_xla2 as tx -import torch_xla2.export -import torch_xla2.train +import torchax as tx +import torchax.export +import torchax.train from torch.testing._internal.common_utils import TestCase @@ -10,7 +10,7 @@ class TrainTest(unittest.TestCase): def setUp(self): torch.manual_seed(0) - torch_xla2.enable_accuracy_mode() + torchax.enable_accuracy_mode() def test_scan_module(self): x = torch.arange(300).reshape(3, 100).to(torch.float32) @@ -29,7 +29,7 @@ def test_scan_module(self): layers ) - with torch_xla2.default_env(): + with torchax.default_env(): x = x.to('jax') model.to('jax') result2 = model(x) @@ -37,7 +37,7 @@ def test_scan_module(self): def test_train_step_can_run(self): import optax - with torch_xla2.default_env(): + with torchax.default_env(): model = torch.nn.Linear(100, 100) model.to('jax') weights = model.state_dict() diff --git a/experimental/torch_xla2/test/test_unbounded_dynamism.py b/torchax/test/test_unbounded_dynamism.py similarity index 98% rename from experimental/torch_xla2/test/test_unbounded_dynamism.py rename to torchax/test/test_unbounded_dynamism.py index 675b52657383..e35bc2d05def 100644 --- a/experimental/torch_xla2/test/test_unbounded_dynamism.py +++ b/torchax/test/test_unbounded_dynamism.py @@ -4,12 +4,12 @@ import torch from torch.export import Dim, export -from torch_xla2.export import exported_program_to_stablehlo as exp2shlo -import torch_xla2 +from torchax.export import exported_program_to_stablehlo as exp2shlo +import torchax ## This file is copied from `xla/test/stablehlo/test_unbounded_dynamism.py` -## To test that torch_xla2 has identical behavior. -## The only differences in this test files are that torch_xla2 export preserves +## To test that torchax has identical behavior. +## The only differences in this test files are that torchax export preserves ## argument order more often than torch_xla export. ## ## This broke ~5 tests, for example: test_bmm_dynamic_out_dim @@ -20,7 +20,7 @@ ## dynamic_shapes = ((None, {2: Dim("dim")}),) ## ... ## torch_xla_regex = r'%arg.: tensor<8x256x\?xf32>.*%arg.: tensor<8x128x256xf32>.*->.*tensor<8x128x\?xf32>' -## torch_xla2_regex = r'%arg.: tensor<8x128x256xf32>.*%arg.: tensor<8x256x\?xf32>.*->.*tensor<8x128x\?xf32>' +## torchax_regex = r'%arg.: tensor<8x128x256xf32>.*%arg.: tensor<8x256x\?xf32>.*->.*tensor<8x128x\?xf32>' # Shim to run tests class ExportAdapter(): @@ -45,9 +45,9 @@ def forward(self, *args): class UnboundedDynamismExportTest(unittest.TestCase): def setUp(self): - self.env = torch_xla2.default_env() + self.env = torchax.default_env() self.env.config.use_torch_native_for_cpu_tensor = False - torch_xla2.enable_accuracy_mode() + torchax.enable_accuracy_mode() def tearDown(self): self.env.config.use_torch_native_for_cpu_tensor = True diff --git a/experimental/torch_xla2/test_dist/README.md b/torchax/test_dist/README.md similarity index 100% rename from experimental/torch_xla2/test_dist/README.md rename to torchax/test_dist/README.md diff --git a/experimental/torch_xla2/test_dist/__init__.py b/torchax/test_dist/__init__.py similarity index 100% rename from experimental/torch_xla2/test_dist/__init__.py rename to torchax/test_dist/__init__.py diff --git a/experimental/torch_xla2/test_dist/test_distributed.py b/torchax/test_dist/test_distributed.py similarity index 89% rename from experimental/torch_xla2/test_dist/test_distributed.py rename to torchax/test_dist/test_distributed.py index 7875b96dcb62..e600f211667b 100644 --- a/experimental/torch_xla2/test_dist/test_distributed.py +++ b/torchax/test_dist/test_distributed.py @@ -6,8 +6,8 @@ import torch import torch.distributed._functional_collectives import torch.distributed as dist -import torch_xla2 -import torch_xla2.distributed +import torchax +import torchax.distributed # Dummy group name to use with functional collectives. Ignored by # implementations. @@ -42,13 +42,13 @@ def process_group(): def test_all_gather_tensor(multi_cpu, process_group): device_count = multi_cpu - def f(index: torch_xla2.tensor.XLATensor2): - with torch_xla2.default_env(): + def f(index: torchax.tensor.Tensor): + with torchax.default_env(): output = torch.zeros_like(index).expand(device_count) dist.all_gather_into_tensor(output, index) return output - res = torch_xla2.distributed.spawn(f) + res = torchax.distributed.spawn(f) expected_tensors = [[0, 1, 2, 3] for _ in range(device_count)] np.testing.assert_equal([r.numpy() for r in res], expected_tensors) @@ -58,12 +58,12 @@ def test_all_gather_tensor_func(multi_cpu, process_group): device_count = multi_cpu group_ranks = process_group - def f(index: torch_xla2.tensor.XLATensor2): + def f(index: torchax.tensor.Tensor): return torch.distributed._functional_collectives.all_gather_tensor( index, 0, group_ranks ) - res = torch_xla2.distributed.spawn(f) + res = torchax.distributed.spawn(f) expected_tensors = [[0, 1, 2, 3] for _ in range(device_count)] np.testing.assert_equal([r.numpy() for r in res], expected_tensors) @@ -85,7 +85,7 @@ def f(index): dist.all_reduce(index, op) return index - res = torch_xla2.distributed.spawn(f) + res = torchax.distributed.spawn(f) expected_tensors = [expected for _ in range(device_count)] np.testing.assert_equal(res.numpy(), expected_tensors) @@ -108,7 +108,7 @@ def f(index): index, op, GROUP_NAME ) - res = torch_xla2.distributed.spawn(f) + res = torchax.distributed.spawn(f) expected_tensors = [expected for _ in range(device_count)] np.testing.assert_equal(res.numpy(), expected_tensors) @@ -128,7 +128,7 @@ def f(index): dist.broadcast(index, rank) return index - res = torch_xla2.distributed.spawn(f) + res = torchax.distributed.spawn(f) expected_tensors = [expected for _ in range(device_count)] np.testing.assert_equal(res.numpy(), expected_tensors) @@ -149,7 +149,7 @@ def f(index): index, rank, GROUP_NAME ) - res = torch_xla2.distributed.spawn(f) + res = torchax.distributed.spawn(f) expected_tensors = [expected for _ in range(device_count)] np.testing.assert_equal(res.numpy(), expected_tensors) diff --git a/experimental/torch_xla2/torch_xla2/CONTRIBUTING.md b/torchax/torchax/CONTRIBUTING.md similarity index 100% rename from experimental/torch_xla2/torch_xla2/CONTRIBUTING.md rename to torchax/torchax/CONTRIBUTING.md diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/torchax/torchax/__init__.py similarity index 92% rename from experimental/torch_xla2/torch_xla2/__init__.py rename to torchax/torchax/__init__.py index f36a0737c000..0e6e085719df 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/torchax/torchax/__init__.py @@ -5,8 +5,8 @@ import os import torch from torch.utils import _pytree as pytree -from torch_xla2 import tensor -from torch_xla2 import distributed # noqa: F401 +from torchax import tensor +from torchax import distributed # noqa: F401 __version__ = "0.0.1" VERSION = __version__ @@ -20,7 +20,7 @@ from jax._src import xla_bridge os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') -# torch_xla2:oss-begin +# torchax:oss-begin old_pjrt_options = jax.config.jax_pjrt_client_create_options try: jax.config.update( @@ -36,7 +36,7 @@ ) xla_bridge._clear_backends() jax.devices() # open PJRT to see if it opens -# torch_xla2:oss-end +# torchax:oss-end env = None def default_env(): @@ -91,8 +91,8 @@ def disable_temporarily(): unsupported_dtype=unsupported_dtype) import jax -import torch_xla2.device_module -torch._register_device_module('jax', torch_xla2.device_module) +import torchax.device_module +torch._register_device_module('jax', torchax.device_module) @@ -121,7 +121,7 @@ class CompileOptions: def compile(fn, options: Optional[CompileOptions] = None): options = options or CompileOptions() if options.mode == 'jax': - from torch_xla2 import interop + from torchax import interop if isinstance(fn, torch.nn.Module): module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs) for n in options.methods_to_compile: diff --git a/experimental/torch_xla2/torch_xla2/config.py b/torchax/torchax/config.py similarity index 100% rename from experimental/torch_xla2/torch_xla2/config.py rename to torchax/torchax/config.py diff --git a/experimental/torch_xla2/torch_xla2/decompositions.py b/torchax/torchax/decompositions.py similarity index 100% rename from experimental/torch_xla2/torch_xla2/decompositions.py rename to torchax/torchax/decompositions.py diff --git a/experimental/torch_xla2/torch_xla2/device_module.py b/torchax/torchax/device_module.py similarity index 100% rename from experimental/torch_xla2/torch_xla2/device_module.py rename to torchax/torchax/device_module.py diff --git a/experimental/torch_xla2/torch_xla2/distributed.py b/torchax/torchax/distributed.py similarity index 92% rename from experimental/torch_xla2/torch_xla2/distributed.py rename to torchax/torchax/distributed.py index dacc4d697654..6bf3e4eaa0cf 100644 --- a/experimental/torch_xla2/torch_xla2/distributed.py +++ b/torchax/torchax/distributed.py @@ -21,12 +21,12 @@ import torch.distributed._functional_collectives from torch._C._distributed_c10d import ProcessGroup # type: ignore import torch.distributed -import torch_xla2 +import torchax from jax.sharding import NamedSharding from jax.sharding import Mesh, PartitionSpec as P from jax.experimental import mesh_utils import torch.utils._pytree as torch_pytree -from torch_xla2 import interop +from torchax import interop class ProcessGroupJax(ProcessGroup): @@ -63,8 +63,8 @@ def _allgather_base( input: torch.Tensor, opts=..., ) -> dist.Work: - assert isinstance(input, torch_xla2.tensor.XLATensor2) - assert isinstance(output, torch_xla2.tensor.XLATensor2) + assert isinstance(input, torchax.tensor.Tensor) + assert isinstance(output, torchax.tensor.Tensor) torch.distributed._functional_collectives.all_gather_tensor_inplace( output, input, group=self ) @@ -76,7 +76,7 @@ def allreduce( opts: dist.AllreduceOptions = ..., ) -> dist.Work: assert len(tensors) == 1 - assert isinstance(tensors[0], torch_xla2.tensor.XLATensor2) + assert isinstance(tensors[0], torchax.tensor.Tensor) torch.distributed._functional_collectives.all_reduce_inplace( tensors[0], torch.distributed._functional_collectives.REDUCE_OP_TO_STR[ @@ -93,7 +93,7 @@ def broadcast( opts: dist.BroadcastOptions = ..., ) -> dist.Work: assert len(tensors) == 1 - assert isinstance(tensors[0], torch_xla2.tensor.XLATensor2) + assert isinstance(tensors[0], torchax.tensor.Tensor) tensors[0].copy_( torch.distributed._functional_collectives.broadcast( tensors[0], opts.rootRank, group=self @@ -132,13 +132,13 @@ def jax_rendezvous_handler( dist.register_rendezvous_handler("jax", jax_rendezvous_handler) -def spawn(f, args=(), env: Optional[torch_xla2.tensor.Environment] = None): +def spawn(f, args=(), env: Optional[torchax.tensor.Environment] = None): """Wrap `f` in a JAX `pmap` with the axis name `torch_dist` defined. `f` is expected to take the replica index as a positional argument, similar to `torch.multiprocessing.spawn`. Note: `spawn` does not actually create parallel processes. """ - env = env or torch_xla2.default_env() + env = env or torchax.default_env() def jax_wrapper(index, jax_args): index, args = env.j2t_iso([index, jax_args]) @@ -166,7 +166,7 @@ class DistributedDataParallel(torch.nn.Module): Example usage: ``` - jax_model = torch_xla2.distributed.DistributedDataParallel(create_model()) + jax_model = torchax.distributed.DistributedDataParallel(create_model()) for data, dataloader: jax_data = jax_model.shard_input(data) jax_output = jax_model(jax_data) @@ -175,14 +175,14 @@ class DistributedDataParallel(torch.nn.Module): def __init__( self, module: torch.nn.Module, - env: Optional[torch_xla2.tensor.Environment] = None, + env: Optional[torchax.tensor.Environment] = None, **kwargs, ): if kwargs: logging.warning(f"Unsupported kwargs {kwargs}") super().__init__() - self._env = env or torch_xla2.default_env() + self._env = env or torchax.default_env() self._mesh = Mesh( mesh_utils.create_device_mesh((jax.device_count(),)), axis_names=("batch",), diff --git a/experimental/torch_xla2/torch_xla2/environment.py b/torchax/torchax/environment.py similarity index 100% rename from experimental/torch_xla2/torch_xla2/environment.py rename to torchax/torchax/environment.py diff --git a/experimental/torch_xla2/torch_xla2/export.py b/torchax/torchax/export.py similarity index 98% rename from experimental/torch_xla2/torch_xla2/export.py rename to torchax/torchax/export.py index c77bcea672b3..7f3b9a9b5fe1 100644 --- a/experimental/torch_xla2/torch_xla2/export.py +++ b/torchax/torchax/export.py @@ -4,9 +4,9 @@ from typing import Any, Dict, Tuple import torch from torch.utils import _pytree as pytree -from torch_xla2 import tensor -from torch_xla2.ops import ops_registry -from torch_xla2 import decompositions +from torchax import tensor +from torchax.ops import ops_registry +from torchax import decompositions import jax import jax.export import sympy @@ -20,8 +20,8 @@ class JaxInterpreter(torch.fx.Interpreter): def __init__(self, graph_module): super().__init__(graph_module) - import torch_xla2.ops.jaten - import torch_xla2.ops.jtorch + import torchax.ops.jaten + import torchax.ops.jtorch def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: if not isinstance(target, diff --git a/experimental/torch_xla2/torch_xla2/interop.py b/torchax/torchax/interop.py similarity index 95% rename from experimental/torch_xla2/torch_xla2/interop.py rename to torchax/torchax/interop.py index 83c7b8e749b6..a8c7ea5fa8cc 100644 --- a/experimental/torch_xla2/torch_xla2/interop.py +++ b/torchax/torchax/interop.py @@ -6,10 +6,10 @@ import jax.numpy as jnp from jax import tree_util as pytree from jax.experimental.shard_map import shard_map -from torch_xla2 import tensor -import torch_xla2 +from torchax import tensor +import torchax -from torch_xla2.types import JaxValue, TorchValue, JaxCallable, TorchCallable +from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable def extract_all_buffers(m: torch.nn.Module): @@ -136,7 +136,7 @@ def _torch_view(t: JaxValue) -> TorchValue: # view it as-if it's a torch land object if isinstance(t, jax.Array): # TODO - return tensor.XLATensor2(t, torch_xla2.default_env()) + return tensor.Tensor(t, torchax.default_env()) if isinstance(t, type(jnp.int32)): return tensor.t2j_type(t) if callable(t): # t is a JaxCallable @@ -151,7 +151,7 @@ def _jax_view(t: TorchValue) -> JaxValue: # t is an object from torch land # view it as-if it's a jax land object if isinstance(t, torch.Tensor): - assert isinstance(t, tensor.XLATensor2), type(t) + assert isinstance(t, tensor.Tensor), type(t) return t.jax() if isinstance(t, type(torch.int32)): return tensor.t2j_dtype(t) @@ -175,7 +175,7 @@ def call_jax(jax_func: JaxCallable, def call_torch(torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue) -> JaxValue: args, kwargs = torch_view((args, kwargs)) - with torch_xla2.default_env(): + with torchax.default_env(): res: TorchValue = torch_func(*args, **kwargs) return jax_view(res) diff --git a/torchax/torchax/ops/__init__.py b/torchax/torchax/ops/__init__.py new file mode 100644 index 000000000000..71c1b137132f --- /dev/null +++ b/torchax/torchax/ops/__init__.py @@ -0,0 +1,10 @@ +def all_aten_jax_ops(): + # to load the ops + import torchax.ops.jaten # type: ignore + import torchax.ops.ops_registry # type: ignore + + return { + key: val.func + for key, val in torchax.ops.ops_registry.all_aten_ops.items() + if val.is_jax_function + } diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/torchax/torchax/ops/jaten.py similarity index 99% rename from experimental/torch_xla2/torch_xla2/ops/jaten.py rename to torchax/torchax/ops/jaten.py index 0fc250f4c7a9..1dfc33963e30 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -11,13 +11,13 @@ import numpy as np import torch import torch.distributed._functional_collectives -from torch_xla2.ops import ops_registry -from torch_xla2.ops import op_base, mappings -from torch_xla2 import interop -from torch_xla2.ops import jax_reimplement +from torchax.ops import ops_registry +from torchax.ops import op_base, mappings +from torchax import interop +from torchax.ops import jax_reimplement # Keys are OpOverload, value is a callable that takes -# XLATensor2 +# Tensor all_ops = {} # list all Aten ops from pytorch that does mutation diff --git a/experimental/torch_xla2/torch_xla2/ops/jax_reimplement.py b/torchax/torchax/ops/jax_reimplement.py similarity index 100% rename from experimental/torch_xla2/torch_xla2/ops/jax_reimplement.py rename to torchax/torchax/ops/jax_reimplement.py diff --git a/experimental/torch_xla2/torch_xla2/ops/jc10d.py b/torchax/torchax/ops/jc10d.py similarity index 96% rename from experimental/torch_xla2/torch_xla2/ops/jc10d.py rename to torchax/torchax/ops/jc10d.py index 82514003c775..a3d6a1207a3c 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jc10d.py +++ b/torchax/torchax/ops/jc10d.py @@ -2,7 +2,7 @@ import jax import jax.numpy as jnp -from torch_xla2.ops import ops_registry +from torchax.ops import ops_registry def op(*aten, **kwargs): diff --git a/experimental/torch_xla2/torch_xla2/ops/jlibrary.py b/torchax/torchax/ops/jlibrary.py similarity index 96% rename from experimental/torch_xla2/torch_xla2/ops/jlibrary.py rename to torchax/torchax/ops/jlibrary.py index 97230d610ef4..bbaf82f12655 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jlibrary.py +++ b/torchax/torchax/ops/jlibrary.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn -import torch_xla2 -from torch_xla2.ops import jaten +import torchax +from torchax.ops import jaten import jax import functools @@ -68,6 +68,6 @@ def forward(self, *args): # outside of the handler, we would build the jaxpr representation of the # module once during registration, potentially missing op registrations that # come after. I.e. may miss nested abstractions if we build jaxpr AoT. - state, jfn = torch_xla2.extract_jax(ImplWrapper()) + state, jfn = torchax.extract_jax(ImplWrapper()) jaxpr_impl = lambda *args: jfn(state, tuple([*args])) return _jit_composite_impl(composite_name, jaxpr_impl, **jit_args)(*args) diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/torchax/torchax/ops/jtorch.py similarity index 98% rename from experimental/torch_xla2/torch_xla2/ops/jtorch.py rename to torchax/torchax/ops/jtorch.py index 4d541cd04d13..7bb254df863c 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -11,9 +11,9 @@ from jax.experimental.shard_map import shard_map import torch -from torch_xla2.ops.ops_registry import register_torch_function_op -from torch_xla2.ops import op_base, mappings, jaten -import torch_xla2.tensor +from torchax.ops.ops_registry import register_torch_function_op +from torchax.ops import op_base, mappings, jaten +import torchax.tensor def register_function(torch_func, **kwargs): @@ -29,7 +29,7 @@ def _as_tensor(data, dtype=None, device=None, env=None): jax_res = jnp.asarray(data) else: jax_res = _tensor(data, dtype=dtype) - return torch_xla2.tensor.XLATensor2(jax_res, env) + return torchax.tensor.Tensor(jax_res, env) @register_function(torch.tensor) diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorchvision_nms.py b/torchax/torchax/ops/jtorchvision_nms.py similarity index 100% rename from experimental/torch_xla2/torch_xla2/ops/jtorchvision_nms.py rename to torchax/torchax/ops/jtorchvision_nms.py diff --git a/experimental/torch_xla2/torch_xla2/ops/mappings.py b/torchax/torchax/ops/mappings.py similarity index 100% rename from experimental/torch_xla2/torch_xla2/ops/mappings.py rename to torchax/torchax/ops/mappings.py diff --git a/experimental/torch_xla2/torch_xla2/ops/op_base.py b/torchax/torchax/ops/op_base.py similarity index 97% rename from experimental/torch_xla2/torch_xla2/ops/op_base.py rename to torchax/torchax/ops/op_base.py index f81ab2487ed2..c3326b7dd175 100644 --- a/experimental/torch_xla2/torch_xla2/ops/op_base.py +++ b/torchax/torchax/ops/op_base.py @@ -3,8 +3,8 @@ import jax.numpy as jnp import numpy as np import torch -from torch_xla2.ops import mappings -from torch_xla2 import types +from torchax.ops import mappings +from torchax import types import sys from typing import Callable, Optional, ParamSpec, Concatenate diff --git a/experimental/torch_xla2/torch_xla2/ops/ops_registry.py b/torchax/torchax/ops/ops_registry.py similarity index 95% rename from experimental/torch_xla2/torch_xla2/ops/ops_registry.py rename to torchax/torchax/ops/ops_registry.py index 495fcce1fb9e..24af4ee03a51 100644 --- a/experimental/torch_xla2/torch_xla2/ops/ops_registry.py +++ b/torchax/torchax/ops/ops_registry.py @@ -1,6 +1,6 @@ import dataclasses import logging -from torch_xla2.types import JaxCallable, TorchCallable +from torchax.types import JaxCallable, TorchCallable from typing import Union, Dict diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/torchax/torchax/tensor.py similarity index 90% rename from experimental/torch_xla2/torch_xla2/tensor.py rename to torchax/torchax/tensor.py index 7d348e561b05..f8f8b4cb7314 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/torchax/torchax/tensor.py @@ -11,8 +11,8 @@ import torch.utils._python_dispatch as torch_dispatch import torch.utils._pytree as torch_pytree -from torch_xla2 import config -from torch_xla2.ops import mappings, ops_registry +from torchax import config +from torchax.ops import mappings, ops_registry class OperatorNotFound(Exception): @@ -20,15 +20,15 @@ class OperatorNotFound(Exception): def wrap(jaxarray): - return torch_pytree.tree_map_only(jnp.ndarray, XLATensor2, jaxarray) + return torch_pytree.tree_map_only(jnp.ndarray, Tensor, jaxarray) def unwrap(torchtensors): - return torch_pytree.tree_map_only(XLATensor2, lambda x: x._elem, torchtensors) + return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors) def t2j(t): - if isinstance(t, XLATensor2): + if isinstance(t, Tensor): return t._elem return mappings.t2j(t) @@ -56,7 +56,7 @@ def log_nested(env, message): log_nested.level = 0 -class XLATensor2(torch.Tensor): +class Tensor(torch.Tensor): @staticmethod def __new__(cls, elem, env): @@ -81,7 +81,7 @@ def __init__(self, elem: jax.Array, env: 'Environment'): self._env = env def __str__(self): - return "XLATensor2({} {})".format(str(type(self._elem)), str(self._elem)) + return "Tensor({} {})".format(str(type(self._elem)), str(self._elem)) __repr__ = __str__ @@ -102,7 +102,7 @@ def flatten(self, start_dim=0, end_dim=-1): new_shape = ( self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim:]) new_elem = jnp.reshape(self._elem, new_shape) - return XLATensor2(new_elem, self._env) + return Tensor(new_elem, self._env) # return torch.reshape(self, new_shape) def __setitem__(self, key, val): @@ -119,7 +119,7 @@ def type_as(self, other): def __torch_dispatch__(cls, func, types, args=(), kwargs=None): env = None for arg in torch_pytree.arg_tree_leaves(*args, **kwargs): - if isinstance(arg, XLATensor2): + if isinstance(arg, Tensor): env = arg._env break @@ -127,7 +127,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return func(*args, **(kwargs or {})) def detach(self): - return XLATensor2(jax.lax.stop_gradient(self.jax()), self._env) + return Tensor(jax.lax.stop_gradient(self.jax()), self._env) def numpy(self) -> numpy.ndarray: import numpy as np @@ -162,12 +162,13 @@ def apply_jax(self, jax_function, *args, **kwargs): def apply_jax_(self, jax_function, *args, **kwargs): self._elem = jax_function(self._elem, *args, **kwargs) + return self def tolist(self): return self._elem.tolist() def shard_(self, sharding): - self.apply_(jax.lax.with_sharding_constraint, sharding) + self.apply_jax_(jax.lax.with_sharding_constraint, sharding) def debug_accuracy(func, args, kwargs, current_output): @@ -307,12 +308,12 @@ def get_as_jax_device(self, device: Any): def load_ops(self): - from torch_xla2.ops import jaten, jtorch, jc10d, jtorchvision_nms + from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms self._ops.update(ops_registry.all_aten_ops) self._ops.update(ops_registry.all_torch_functions) decomps = torch._decomp.core_aten_decompositions() - from torch_xla2.decompositions import EXTRA_DECOMP + from torchax.decompositions import EXTRA_DECOMP decomps.update(EXTRA_DECOMP) for k, v in decomps.items(): if k not in self._ops: @@ -325,7 +326,7 @@ def load_ops(self): ) def _to_copy(self, the_tensor, new_dtype, new_device): - if isinstance(the_tensor, XLATensor2): + if isinstance(the_tensor, Tensor): arr = the_tensor.jax() if new_dtype is not None and new_dtype != arr.dtype: arr = arr.astype(mappings.t2j_dtype(new_dtype)) @@ -335,7 +336,7 @@ def _to_copy(self, the_tensor, new_dtype, new_device): arr = jax.device_put(arr, jax_device) else: # converting to a non-jax device: let torch native handle it - torch_tensor = j2t(arr) if isinstance(the_tensor, XLATensor2) else arr + torch_tensor = j2t(arr) if isinstance(the_tensor, Tensor) else arr with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): return torch_tensor.to(new_device) else: @@ -350,7 +351,7 @@ def _to_copy(self, the_tensor, new_dtype, new_device): with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): return the_tensor.to(new_device) - return XLATensor2(arr, self) + return Tensor(arr, self) def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None): @@ -372,7 +373,7 @@ def _handle_tensor_constructor(self, func, args, kwargs): op = self._ops.get(func) res = op.func(*args, **kwargs) if isinstance(res, jax.Array): - res = XLATensor2(res, self) + res = Tensor(res, self) return res def _torch_Tensor_to(self, args, kwargs): @@ -403,11 +404,11 @@ def dispatch(self, func, types, args, kwargs): if func in (torch.Tensor.to, torch.ops.aten.lift_fresh.default ,torch.ops.aten._to_copy, torch.ops.aten._to_copy.default): return self._torch_Tensor_to(args, kwargs) - # If the func doesn't act on XLATensor2, and is not a tensor constructor, + # If the func doesn't act on Tensor, and is not a tensor constructor, # We should skip and let torch handle it. tensor_args = [t for t in torch_pytree.tree_flatten(args)[0] if isinstance(t, torch.Tensor)] - if tensor_args and all(not isinstance(t, XLATensor2) for t in tensor_args): + if tensor_args and all(not isinstance(t, Tensor) for t in tensor_args): return func(*args, **kwargs) with jax.named_scope(_name_of_func(func)): @@ -466,10 +467,10 @@ def _move_one_value(self, val): if isinstance(val, torch.nn.Module): with self: return val.to('jax') - if isinstance(val, XLATensor2): + if isinstance(val, Tensor): return val if isinstance(val, torch.Tensor): - return XLATensor2(t2j(val), self) + return Tensor(t2j(val), self) return val def to_xla(self, torchvalues): @@ -483,13 +484,13 @@ def t2j_iso(self, torchtensors): def to_jax(x): if isinstance(x, torch.distributed._functional_collectives.AsyncCollectiveTensor): x = x.wait() - assert isinstance(x, XLATensor2), f'Expect a XLATensor2 but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor' + assert isinstance(x, Tensor), f'Expect a Tensor but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor' return x.jax() return torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors) def j2t_iso(self, jaxarray): return torch_pytree.tree_map_only( - jnp.ndarray, lambda x: XLATensor2(x, self), jaxarray) + jnp.ndarray, lambda x: Tensor(x, self), jaxarray) def j2t_copy(self, args): pass diff --git a/experimental/torch_xla2/torch_xla2/tf_integration.py b/torchax/torchax/tf_integration.py similarity index 99% rename from experimental/torch_xla2/torch_xla2/tf_integration.py rename to torchax/torchax/tf_integration.py index e65c92527942..c9842089bfcf 100644 --- a/experimental/torch_xla2/torch_xla2/tf_integration.py +++ b/torchax/torchax/tf_integration.py @@ -5,7 +5,7 @@ from jax.experimental import jax2tf import tensorflow as tf import torch -from torch_xla2 import export +from torchax import export def exported_program_to_tf_function(ep, enable_xla=True): diff --git a/experimental/torch_xla2/torch_xla2/train.py b/torchax/torchax/train.py similarity index 96% rename from experimental/torch_xla2/torch_xla2/train.py rename to torchax/torchax/train.py index 4b6b40cdb0cb..bedf82a75c09 100644 --- a/experimental/torch_xla2/torch_xla2/train.py +++ b/torchax/torchax/train.py @@ -2,9 +2,9 @@ import functools import torch import jax -import torch_xla2 -from torch_xla2 import interop -from torch_xla2.interop import torch_view, jax_view +import torchax +from torchax import interop +from torchax.interop import torch_view, jax_view import optax @@ -31,7 +31,7 @@ def make_train_step(model_fn, remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how to do gradient checkpointing. If None, then it means checkpoint everything. """ - env = torch_xla2.default_env() + env = torchax.default_env() def loss(weights, buffers, args, label): # inputs are XLATensor with env, jax.named_scope('compute_loss'): res = model_fn(weights, buffers, args) diff --git a/experimental/torch_xla2/torch_xla2/types.py b/torchax/torchax/types.py similarity index 100% rename from experimental/torch_xla2/torch_xla2/types.py rename to torchax/torchax/types.py