Skip to content

Commit

Permalink
Tutorial: PyTorch Export to StableHLO (#2237)
Browse files Browse the repository at this point in the history
View rich text GH rendering here:
[pytorch-export.ipynb](https://github.com/openxla/stablehlo/blob/60adde95c9337ce031c91f3f3fb7e5b405f4cef0/docs/tutorials/pytorch-export.ipynb).

Includes "Open in Colab" URL fix for JAX tutorial.
  • Loading branch information
GleasonK authored Apr 18, 2024
1 parent 5217297 commit c8d96c0
Show file tree
Hide file tree
Showing 3 changed files with 350 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/_toc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ toc:
section:
- title: JAX Export to StableHLO
path: /stablehlo/tutorials/jax-export
- title: PyTorch Export to StableHLO
path: /stablehlo/tutorials/pytorch-export
- title: Contributing
section:
- title: Governance
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/jax-export.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"source": [
"# Tutorial: Exporting StableHLO from JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/examples/jax-export.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/examples/jax-export.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb)\n",
"\n",
"## Tutorial Setup\n",
"\n",
Expand Down Expand Up @@ -438,4 +438,4 @@
}
}
]
}
}
346 changes: 346 additions & 0 deletions docs/tutorials/pytorch-export.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,346 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# # Tutorial: Exporting StableHLO from PyTorch\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/pytorch-export.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/pytorch-export.ipynb)\n",
"\n",
"_Intro to the [`torch_xla.stablehlo`](https://github.com/pytorch/xla/blob/main/docs/stablehlo.md) module._\n",
"\n",
"## Tutorial Setup\n",
"\n",
"### Install required dependencies\n",
"\n",
"We'll be using `torch` and `torchvision` to get a `resnet18` model, and `torch_xla` to export to StableHLO."
],
"metadata": {
"id": "42_6VTj7l_xS"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-GTAijyd3rOf"
},
"outputs": [],
"source": [
"!pip install torch_xla torch torchvision"
]
},
{
"cell_type": "markdown",
"source": [
"## Export PyTorch model to StableHLO\n",
"\n",
"The general set of steps for exporting a PyTorch model to StableHLO is:\n",
"1. Export using PyTorch's `torch.export` API.\n",
"2. Convert exported FX Graph to StableHLO using `torch_xla.stablehlo` APIs.\n",
"\n",
"### Export model to FX graph using `torch.export`\n",
"\n",
"This step uses entirely vanilla PyTorch APIs to export a `resnet18` model from `torchvision`. Sample inputs are required for graph tracing, we use a `tensor<4x3x224x224xf32>` in this case."
],
"metadata": {
"id": "V_AtPpV30Bt8"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import torchvision\n",
"from torch.export import export\n",
"\n",
"resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n",
"sample_input = (torch.randn(4, 3, 224, 224), )\n",
"exported = export(resnet18, sample_input)"
],
"metadata": {
"id": "GhIpxnx5fuxy"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Export FX Graph to StableHLO using TorchXLA\n",
"\n",
"Once we have an exported FX graph, we can convert to StableHLO using the `torch_xla.stablehlo` module. In this case we'll use `exported_program_to_stablehlo`."
],
"metadata": {
"id": "zuMAr3WO1PBk"
}
},
{
"cell_type": "code",
"source": [
"from torch_xla.stablehlo import exported_program_to_stablehlo\n",
"\n",
"stablehlo_program = exported_program_to_stablehlo(exported)\n",
"print(stablehlo_program.get_stablehlo_text('forward')[0:4000],\"\\n...\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xiP7psIQgUc-",
"outputId": "0de7d5bc-01a1-4e96-b2ce-19faa9965459"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"module @IrToHlo.508 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {\n",
" func.func @main(%arg0: tensor<1000xf32>, %arg1: tensor<1000x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<512xf32>, %arg4: tensor<512xf32>, %arg5: tensor<512xf32>, %arg6: tensor<512x256x1x1xf32>, %arg7: tensor<256xf32>, %arg8: tensor<256xf32>, %arg9: tensor<256xf32>, %arg10: tensor<256xf32>, %arg11: tensor<256x128x1x1xf32>, %arg12: tensor<128xf32>, %arg13: tensor<128xf32>, %arg14: tensor<128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x64x1x1xf32>, %arg17: tensor<64xf32>, %arg18: tensor<64xf32>, %arg19: tensor<64xf32>, %arg20: tensor<64xf32>, %arg21: tensor<64x3x7x7xf32>, %arg22: tensor<4x3x224x224xf32>, %arg23: tensor<64xf32>, %arg24: tensor<64xf32>, %arg25: tensor<64xf32>, %arg26: tensor<64xf32>, %arg27: tensor<64x64x3x3xf32>, %arg28: tensor<64xf32>, %arg29: tensor<64xf32>, %arg30: tensor<64xf32>, %arg31: tensor<64xf32>, %arg32: tensor<64x64x3x3xf32>, %arg33: tensor<64xf32>, %arg34: tensor<64xf32>, %arg35: tensor<64xf32>, %arg36: tensor<64xf32>, %arg37: tensor<64x64x3x3xf32>, %arg38: tensor<64xf32>, %arg39: tensor<64xf32>, %arg40: tensor<64xf32>, %arg41: tensor<64xf32>, %arg42: tensor<64x64x3x3xf32>, %arg43: tensor<128xf32>, %arg44: tensor<128xf32>, %arg45: tensor<128xf32>, %arg46: tensor<128xf32>, %arg47: tensor<128x128x3x3xf32>, %arg48: tensor<128xf32>, %arg49: tensor<128xf32>, %arg50: tensor<128xf32>, %arg51: tensor<128xf32>, %arg52: tensor<128x64x3x3xf32>, %arg53: tensor<128xf32>, %arg54: tensor<128xf32>, %arg55: tensor<128xf32>, %arg56: tensor<128xf32>, %arg57: tensor<128x128x3x3xf32>, %arg58: tensor<128xf32>, %arg59: tensor<128xf32>, %arg60: tensor<128xf32>, %arg61: tensor<128xf32>, %arg62: tensor<128x128x3x3xf32>, %arg63: tensor<256xf32>, %arg64: tensor<256xf32>, %arg65: tensor<256xf32>, %arg66: tensor<256xf32>, %arg67: tensor<256x256x3x3xf32>, %arg68: tensor<256xf32>, %arg69: tensor<256xf32>, %arg70: tensor<256xf32>, %arg71: tensor<256xf32>, %arg72: tensor<256x128x3x3xf32>, %arg73: tensor<256xf32>, %arg74: tensor<256xf32>, %arg75: tensor<256xf32>, %arg76: tensor<256xf32>, %arg77: tensor<256x256x3x3xf32>, %arg78: tensor<256xf32>, %arg79: tensor<256xf32>, %arg80: tensor<256xf32>, %arg81: tensor<256xf32>, %arg82: tensor<256x256x3x3xf32>, %arg83: tensor<512xf32>, %arg84: tensor<512xf32>, %arg85: tensor<512xf32>, %arg86: tensor<512xf32>, %arg87: tensor<512x512x3x3xf32>, %arg88: tensor<512xf32>, %arg89: tensor<512xf32>, %arg90: tensor<512xf32>, %arg91: tensor<512xf32>, %arg92: tensor<512x256x3x3xf32>, %arg93: tensor<512xf32>, %arg94: tensor<512xf32>, %arg95: tensor<512xf32>, %arg96: tensor<512xf32>, %arg97: tensor<512x512x3x3xf32>, %arg98: tensor<512xf32>, %arg99: tensor<512xf32>, %arg100: tensor<512xf32>, %arg101: tensor<512xf32>, %arg102: tensor<512x512x3x3xf32>) -> tensor<4x1000xf32> {\n",
" %0 = stablehlo.constant dense<0.0204081628> : tensor<4x512xf32>\n",
" %1 = stablehlo.constant dense<0.000000e+00> : tensor<4x512x7x7xf32>\n",
" %2 = stablehlo.constant dense<0.000000e+00> : tensor<4x256x14x14xf32>\n",
" %3 = stablehlo.constant dense<0.000000e+00> : tensor<4x128x28x28xf32>\n",
" %4 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x56x56xf32>\n",
" %5 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x112x112xf32>\n",
" %6 = stablehlo.constant dense<0xFF800000> : tensor<f32>\n",
" %7 = stablehlo.constant dense<0.000000e+00> : tensor<f32>\n",
" %8 = stablehlo.convolution(%arg22, %arg21) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 2], pad = [[3, 3], [3, 3]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<4x3x224x224xf32>, tensor<64x3x7x7xf32>) -> tensor<4x64x112x112xf32>\n",
" %output, %batch_mean, %batch_var = \"stablehlo.batch_norm_training\"(%8, %arg20, %arg19) {epsilon = 9.99999974E-6 : f3 \n",
"...\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Export with dynamic batch dimension\n",
"\n",
"_This is a new feature and will work after 2.3 release cut, or if using `torch_xla` nightly. Once PyTorch/XLA 2.3 is released, this will be converted into a running example. Using the nightly `torch` and `torch_xla` will likely lead to notebook failures in the meantime._\n",
"\n",
"Dynamic batch dimensions can be specified as a part of the inital `torch.export` step. The FX Graph's symint information is used to export to dynamic StableHLO.\n",
"\n",
"In this example, we specify that dim 0 of the sample input is dynamic, which propagates shape using a `tensor<?x3x224x224xf32>`.\n",
"\n",
"```python\n",
"from torch.export import Dim\n",
"\n",
"# Create a dynamic batch size, for the first dimension of the input\n",
"batch = Dim(\"batch\", max=15)\n",
"dynamic_shapes = ({0: batch},)\n",
"dynamic_export = export(resnet18, sample_input, dynamic_shapes=dynamic_shapes)\n",
"dynamic_stablehlo = exported_program_to_stablehlo(dynamic_export)\n",
"print(dynamic_stablehlo.get_stablehlo_text('forward')[0:5000],\"\\n...\")\n",
"```"
],
"metadata": {
"id": "2Ujt2OjtpERw"
}
},
{
"cell_type": "markdown",
"source": [
"### Saving and reloading StableHLO\n",
"\n",
"The `StableHLOGraphModule` has methods to `save` and `load` StableHLO artifacts. This stores StableHLO portable bytecode artifacts which have full backward compatiblity guarantees."
],
"metadata": {
"id": "ySTEYXzG1fU6"
}
},
{
"cell_type": "code",
"source": [
"from torch_xla.stablehlo import StableHLOGraphModule\n",
"\n",
"# Save to tmp\n",
"stablehlo_program.save('/tmp/stablehlo_dir')\n",
"!ls /tmp/stablehlo_dir\n",
"!ls /tmp/stablehlo_dir/functions\n",
"\n",
"# Reload and execute - Stable serialization, forward / backward compatible.\n",
"reloaded = StableHLOGraphModule.load('/tmp/stablehlo_dir')\n",
"print(reloaded(sample_input[0]))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jeVTUs7jh8lk",
"outputId": "4c248260-7442-495c-932b-a618e9eb2c67"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"constants data functions\n",
"forward.bytecode forward.meta\tforward.mlir\n",
"tensor([[-0.7431, -2.5955, -0.0718, ..., -1.4230, 1.5928, -0.7693],\n",
" [ 0.6199, -1.5941, -0.9018, ..., 0.2452, 0.6159, 2.4765],\n",
" [-3.0291, 2.2174, -2.2809, ..., -0.9081, 1.8253, 2.2141],\n",
" [ 0.9318, -0.0566, 0.8561, ..., -0.1650, 0.7882, -0.2697]],\n",
" device='xla:0')\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Export to SavedModel\n",
"\n",
"It is common to want to export a StableHLO model to SavedModel for interop with existing compilation pipelines, existing TF tooling, or serving via [TF Serving](https://github.com/tensorflow/serving).\n",
"\n",
"PyTorch/XLA makes it easy to pack StableHLO into a SavedModel, which can be loaded back and executed."
],
"metadata": {
"id": "NsJyDxxnjd4B"
}
},
{
"cell_type": "markdown",
"source": [
"### Install latest TF\n",
"\n",
"SavedModel definition lives in TF, so we need to install the dependency. We recommend using `tensorflow-cpu` or `tf-nightly`."
],
"metadata": {
"id": "gsUYjXg75mHM"
}
},
{
"cell_type": "code",
"source": [
"!pip install tensorflow-cpu"
],
"metadata": {
"collapsed": true,
"id": "d-y5rLcQjqbk"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Export to SavedModel using `torch_xla.tf_saved_model_integration`\n",
"\n",
"PyTorch/XLA provides a simple API for exporting StableHLO in a SavedModel `save_torch_module_as_tf_saved_model`. This uses the `torch.export` and `torch_xla.stablehlo.exported_program_to_stablehlo` functions under the hood.\n",
"\n",
"The input to the API is a PyTorch model, we'll use the same `resnet18` from the previous examples."
],
"metadata": {
"id": "pty18Wc-5sGb"
}
},
{
"cell_type": "code",
"source": [
"from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model\n",
"import tensorflow as tf\n",
"\n",
"save_torch_module_as_tf_saved_model(\n",
" resnet18, # original pytorch torch.nn.Module\n",
" sample_input, # sample inputs used to trace\n",
" '/tmp/resnet_tf' # directory for tf.saved_model\n",
")\n",
"\n",
"!ls /tmp/resnet_tf/"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1g24aShhjG6e",
"outputId": "382d14f5-7b7b-4297-8af4-0999b7ba4f3f"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"assets\tfingerprint.pb\tsaved_model.pb\tvariables\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Reload and call the SavedModel\n",
"\n",
"Now we can load that SavedModel and compile using our `sample_input` from a previous example.\n",
"\n",
"_Note: the restored model does *not* require PyTorch or PyTorch/XLA to run, just XLA._"
],
"metadata": {
"id": "shmmhYP76eX2"
}
},
{
"cell_type": "code",
"source": [
"loaded_m = tf.saved_model.load('/tmp/resnet_tf')\n",
"print(loaded_m.f(tf.constant(sample_input[0].numpy())))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2TZMQyJj6fHy",
"outputId": "c2f97232-a623-4146-d369-ef75c2033136"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[<tf.Tensor: shape=(4, 1000), dtype=float32, numpy=\n",
"array([[-0.74313045, -2.595488 , -0.0718156 , ..., -1.4230162 ,\n",
" 1.5928375 , -0.7693139 ],\n",
" [ 0.61994374, -1.594082 , -0.901797 , ..., 0.2451565 ,\n",
" 0.6159245 , 2.4764667 ],\n",
" [-3.029084 , 2.2174084 , -2.2808676 , ..., -0.90810233,\n",
" 1.8252805 , 2.214109 ],\n",
" [ 0.93176216, -0.0566061 , 0.8560745 , ..., -0.16496754,\n",
" 0.7881946 , -0.26973075]], dtype=float32)>]\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Common Troubleshooting\n",
"\n",
"Most issues in PyTorch to StableHLO require a GH ticket as a next step. Teams are generally quick to help resolve issues.\n",
"\n",
"- Issues in `torch.export`: These need to be resolved in upstream PyTorch.\n",
"- Issues in `torch_xla.stablehlo`: Open a ticket on pytorch/xla repo."
],
"metadata": {
"id": "fSk7HFVk8CqR"
}
}
]
}

0 comments on commit c8d96c0

Please sign in to comment.