-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tutorial: PyTorch Export to StableHLO (#2237)
View rich text GH rendering here: [pytorch-export.ipynb](https://github.com/openxla/stablehlo/blob/60adde95c9337ce031c91f3f3fb7e5b405f4cef0/docs/tutorials/pytorch-export.ipynb). Includes "Open in Colab" URL fix for JAX tutorial.
- Loading branch information
Showing
3 changed files
with
350 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,346 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# # Tutorial: Exporting StableHLO from PyTorch\n", | ||
"\n", | ||
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/pytorch-export.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/pytorch-export.ipynb)\n", | ||
"\n", | ||
"_Intro to the [`torch_xla.stablehlo`](https://github.com/pytorch/xla/blob/main/docs/stablehlo.md) module._\n", | ||
"\n", | ||
"## Tutorial Setup\n", | ||
"\n", | ||
"### Install required dependencies\n", | ||
"\n", | ||
"We'll be using `torch` and `torchvision` to get a `resnet18` model, and `torch_xla` to export to StableHLO." | ||
], | ||
"metadata": { | ||
"id": "42_6VTj7l_xS" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "-GTAijyd3rOf" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install torch_xla torch torchvision" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## Export PyTorch model to StableHLO\n", | ||
"\n", | ||
"The general set of steps for exporting a PyTorch model to StableHLO is:\n", | ||
"1. Export using PyTorch's `torch.export` API.\n", | ||
"2. Convert exported FX Graph to StableHLO using `torch_xla.stablehlo` APIs.\n", | ||
"\n", | ||
"### Export model to FX graph using `torch.export`\n", | ||
"\n", | ||
"This step uses entirely vanilla PyTorch APIs to export a `resnet18` model from `torchvision`. Sample inputs are required for graph tracing, we use a `tensor<4x3x224x224xf32>` in this case." | ||
], | ||
"metadata": { | ||
"id": "V_AtPpV30Bt8" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"import torch\n", | ||
"import torchvision\n", | ||
"from torch.export import export\n", | ||
"\n", | ||
"resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n", | ||
"sample_input = (torch.randn(4, 3, 224, 224), )\n", | ||
"exported = export(resnet18, sample_input)" | ||
], | ||
"metadata": { | ||
"id": "GhIpxnx5fuxy" | ||
}, | ||
"execution_count": 7, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"### Export FX Graph to StableHLO using TorchXLA\n", | ||
"\n", | ||
"Once we have an exported FX graph, we can convert to StableHLO using the `torch_xla.stablehlo` module. In this case we'll use `exported_program_to_stablehlo`." | ||
], | ||
"metadata": { | ||
"id": "zuMAr3WO1PBk" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"from torch_xla.stablehlo import exported_program_to_stablehlo\n", | ||
"\n", | ||
"stablehlo_program = exported_program_to_stablehlo(exported)\n", | ||
"print(stablehlo_program.get_stablehlo_text('forward')[0:4000],\"\\n...\")" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "xiP7psIQgUc-", | ||
"outputId": "0de7d5bc-01a1-4e96-b2ce-19faa9965459" | ||
}, | ||
"execution_count": 8, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"module @IrToHlo.508 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {\n", | ||
" func.func @main(%arg0: tensor<1000xf32>, %arg1: tensor<1000x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<512xf32>, %arg4: tensor<512xf32>, %arg5: tensor<512xf32>, %arg6: tensor<512x256x1x1xf32>, %arg7: tensor<256xf32>, %arg8: tensor<256xf32>, %arg9: tensor<256xf32>, %arg10: tensor<256xf32>, %arg11: tensor<256x128x1x1xf32>, %arg12: tensor<128xf32>, %arg13: tensor<128xf32>, %arg14: tensor<128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x64x1x1xf32>, %arg17: tensor<64xf32>, %arg18: tensor<64xf32>, %arg19: tensor<64xf32>, %arg20: tensor<64xf32>, %arg21: tensor<64x3x7x7xf32>, %arg22: tensor<4x3x224x224xf32>, %arg23: tensor<64xf32>, %arg24: tensor<64xf32>, %arg25: tensor<64xf32>, %arg26: tensor<64xf32>, %arg27: tensor<64x64x3x3xf32>, %arg28: tensor<64xf32>, %arg29: tensor<64xf32>, %arg30: tensor<64xf32>, %arg31: tensor<64xf32>, %arg32: tensor<64x64x3x3xf32>, %arg33: tensor<64xf32>, %arg34: tensor<64xf32>, %arg35: tensor<64xf32>, %arg36: tensor<64xf32>, %arg37: tensor<64x64x3x3xf32>, %arg38: tensor<64xf32>, %arg39: tensor<64xf32>, %arg40: tensor<64xf32>, %arg41: tensor<64xf32>, %arg42: tensor<64x64x3x3xf32>, %arg43: tensor<128xf32>, %arg44: tensor<128xf32>, %arg45: tensor<128xf32>, %arg46: tensor<128xf32>, %arg47: tensor<128x128x3x3xf32>, %arg48: tensor<128xf32>, %arg49: tensor<128xf32>, %arg50: tensor<128xf32>, %arg51: tensor<128xf32>, %arg52: tensor<128x64x3x3xf32>, %arg53: tensor<128xf32>, %arg54: tensor<128xf32>, %arg55: tensor<128xf32>, %arg56: tensor<128xf32>, %arg57: tensor<128x128x3x3xf32>, %arg58: tensor<128xf32>, %arg59: tensor<128xf32>, %arg60: tensor<128xf32>, %arg61: tensor<128xf32>, %arg62: tensor<128x128x3x3xf32>, %arg63: tensor<256xf32>, %arg64: tensor<256xf32>, %arg65: tensor<256xf32>, %arg66: tensor<256xf32>, %arg67: tensor<256x256x3x3xf32>, %arg68: tensor<256xf32>, %arg69: tensor<256xf32>, %arg70: tensor<256xf32>, %arg71: tensor<256xf32>, %arg72: tensor<256x128x3x3xf32>, %arg73: tensor<256xf32>, %arg74: tensor<256xf32>, %arg75: tensor<256xf32>, %arg76: tensor<256xf32>, %arg77: tensor<256x256x3x3xf32>, %arg78: tensor<256xf32>, %arg79: tensor<256xf32>, %arg80: tensor<256xf32>, %arg81: tensor<256xf32>, %arg82: tensor<256x256x3x3xf32>, %arg83: tensor<512xf32>, %arg84: tensor<512xf32>, %arg85: tensor<512xf32>, %arg86: tensor<512xf32>, %arg87: tensor<512x512x3x3xf32>, %arg88: tensor<512xf32>, %arg89: tensor<512xf32>, %arg90: tensor<512xf32>, %arg91: tensor<512xf32>, %arg92: tensor<512x256x3x3xf32>, %arg93: tensor<512xf32>, %arg94: tensor<512xf32>, %arg95: tensor<512xf32>, %arg96: tensor<512xf32>, %arg97: tensor<512x512x3x3xf32>, %arg98: tensor<512xf32>, %arg99: tensor<512xf32>, %arg100: tensor<512xf32>, %arg101: tensor<512xf32>, %arg102: tensor<512x512x3x3xf32>) -> tensor<4x1000xf32> {\n", | ||
" %0 = stablehlo.constant dense<0.0204081628> : tensor<4x512xf32>\n", | ||
" %1 = stablehlo.constant dense<0.000000e+00> : tensor<4x512x7x7xf32>\n", | ||
" %2 = stablehlo.constant dense<0.000000e+00> : tensor<4x256x14x14xf32>\n", | ||
" %3 = stablehlo.constant dense<0.000000e+00> : tensor<4x128x28x28xf32>\n", | ||
" %4 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x56x56xf32>\n", | ||
" %5 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x112x112xf32>\n", | ||
" %6 = stablehlo.constant dense<0xFF800000> : tensor<f32>\n", | ||
" %7 = stablehlo.constant dense<0.000000e+00> : tensor<f32>\n", | ||
" %8 = stablehlo.convolution(%arg22, %arg21) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 2], pad = [[3, 3], [3, 3]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<4x3x224x224xf32>, tensor<64x3x7x7xf32>) -> tensor<4x64x112x112xf32>\n", | ||
" %output, %batch_mean, %batch_var = \"stablehlo.batch_norm_training\"(%8, %arg20, %arg19) {epsilon = 9.99999974E-6 : f3 \n", | ||
"...\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"### Export with dynamic batch dimension\n", | ||
"\n", | ||
"_This is a new feature and will work after 2.3 release cut, or if using `torch_xla` nightly. Once PyTorch/XLA 2.3 is released, this will be converted into a running example. Using the nightly `torch` and `torch_xla` will likely lead to notebook failures in the meantime._\n", | ||
"\n", | ||
"Dynamic batch dimensions can be specified as a part of the inital `torch.export` step. The FX Graph's symint information is used to export to dynamic StableHLO.\n", | ||
"\n", | ||
"In this example, we specify that dim 0 of the sample input is dynamic, which propagates shape using a `tensor<?x3x224x224xf32>`.\n", | ||
"\n", | ||
"```python\n", | ||
"from torch.export import Dim\n", | ||
"\n", | ||
"# Create a dynamic batch size, for the first dimension of the input\n", | ||
"batch = Dim(\"batch\", max=15)\n", | ||
"dynamic_shapes = ({0: batch},)\n", | ||
"dynamic_export = export(resnet18, sample_input, dynamic_shapes=dynamic_shapes)\n", | ||
"dynamic_stablehlo = exported_program_to_stablehlo(dynamic_export)\n", | ||
"print(dynamic_stablehlo.get_stablehlo_text('forward')[0:5000],\"\\n...\")\n", | ||
"```" | ||
], | ||
"metadata": { | ||
"id": "2Ujt2OjtpERw" | ||
} | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"### Saving and reloading StableHLO\n", | ||
"\n", | ||
"The `StableHLOGraphModule` has methods to `save` and `load` StableHLO artifacts. This stores StableHLO portable bytecode artifacts which have full backward compatiblity guarantees." | ||
], | ||
"metadata": { | ||
"id": "ySTEYXzG1fU6" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"from torch_xla.stablehlo import StableHLOGraphModule\n", | ||
"\n", | ||
"# Save to tmp\n", | ||
"stablehlo_program.save('/tmp/stablehlo_dir')\n", | ||
"!ls /tmp/stablehlo_dir\n", | ||
"!ls /tmp/stablehlo_dir/functions\n", | ||
"\n", | ||
"# Reload and execute - Stable serialization, forward / backward compatible.\n", | ||
"reloaded = StableHLOGraphModule.load('/tmp/stablehlo_dir')\n", | ||
"print(reloaded(sample_input[0]))" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "jeVTUs7jh8lk", | ||
"outputId": "4c248260-7442-495c-932b-a618e9eb2c67" | ||
}, | ||
"execution_count": 9, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"constants data functions\n", | ||
"forward.bytecode forward.meta\tforward.mlir\n", | ||
"tensor([[-0.7431, -2.5955, -0.0718, ..., -1.4230, 1.5928, -0.7693],\n", | ||
" [ 0.6199, -1.5941, -0.9018, ..., 0.2452, 0.6159, 2.4765],\n", | ||
" [-3.0291, 2.2174, -2.2809, ..., -0.9081, 1.8253, 2.2141],\n", | ||
" [ 0.9318, -0.0566, 0.8561, ..., -0.1650, 0.7882, -0.2697]],\n", | ||
" device='xla:0')\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## Export to SavedModel\n", | ||
"\n", | ||
"It is common to want to export a StableHLO model to SavedModel for interop with existing compilation pipelines, existing TF tooling, or serving via [TF Serving](https://github.com/tensorflow/serving).\n", | ||
"\n", | ||
"PyTorch/XLA makes it easy to pack StableHLO into a SavedModel, which can be loaded back and executed." | ||
], | ||
"metadata": { | ||
"id": "NsJyDxxnjd4B" | ||
} | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"### Install latest TF\n", | ||
"\n", | ||
"SavedModel definition lives in TF, so we need to install the dependency. We recommend using `tensorflow-cpu` or `tf-nightly`." | ||
], | ||
"metadata": { | ||
"id": "gsUYjXg75mHM" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"!pip install tensorflow-cpu" | ||
], | ||
"metadata": { | ||
"collapsed": true, | ||
"id": "d-y5rLcQjqbk" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"### Export to SavedModel using `torch_xla.tf_saved_model_integration`\n", | ||
"\n", | ||
"PyTorch/XLA provides a simple API for exporting StableHLO in a SavedModel `save_torch_module_as_tf_saved_model`. This uses the `torch.export` and `torch_xla.stablehlo.exported_program_to_stablehlo` functions under the hood.\n", | ||
"\n", | ||
"The input to the API is a PyTorch model, we'll use the same `resnet18` from the previous examples." | ||
], | ||
"metadata": { | ||
"id": "pty18Wc-5sGb" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model\n", | ||
"import tensorflow as tf\n", | ||
"\n", | ||
"save_torch_module_as_tf_saved_model(\n", | ||
" resnet18, # original pytorch torch.nn.Module\n", | ||
" sample_input, # sample inputs used to trace\n", | ||
" '/tmp/resnet_tf' # directory for tf.saved_model\n", | ||
")\n", | ||
"\n", | ||
"!ls /tmp/resnet_tf/" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "1g24aShhjG6e", | ||
"outputId": "382d14f5-7b7b-4297-8af4-0999b7ba4f3f" | ||
}, | ||
"execution_count": 10, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"assets\tfingerprint.pb\tsaved_model.pb\tvariables\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"### Reload and call the SavedModel\n", | ||
"\n", | ||
"Now we can load that SavedModel and compile using our `sample_input` from a previous example.\n", | ||
"\n", | ||
"_Note: the restored model does *not* require PyTorch or PyTorch/XLA to run, just XLA._" | ||
], | ||
"metadata": { | ||
"id": "shmmhYP76eX2" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"loaded_m = tf.saved_model.load('/tmp/resnet_tf')\n", | ||
"print(loaded_m.f(tf.constant(sample_input[0].numpy())))" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "2TZMQyJj6fHy", | ||
"outputId": "c2f97232-a623-4146-d369-ef75c2033136" | ||
}, | ||
"execution_count": 12, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"[<tf.Tensor: shape=(4, 1000), dtype=float32, numpy=\n", | ||
"array([[-0.74313045, -2.595488 , -0.0718156 , ..., -1.4230162 ,\n", | ||
" 1.5928375 , -0.7693139 ],\n", | ||
" [ 0.61994374, -1.594082 , -0.901797 , ..., 0.2451565 ,\n", | ||
" 0.6159245 , 2.4764667 ],\n", | ||
" [-3.029084 , 2.2174084 , -2.2808676 , ..., -0.90810233,\n", | ||
" 1.8252805 , 2.214109 ],\n", | ||
" [ 0.93176216, -0.0566061 , 0.8560745 , ..., -0.16496754,\n", | ||
" 0.7881946 , -0.26973075]], dtype=float32)>]\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Common Troubleshooting\n", | ||
"\n", | ||
"Most issues in PyTorch to StableHLO require a GH ticket as a next step. Teams are generally quick to help resolve issues.\n", | ||
"\n", | ||
"- Issues in `torch.export`: These need to be resolved in upstream PyTorch.\n", | ||
"- Issues in `torch_xla.stablehlo`: Open a ticket on pytorch/xla repo." | ||
], | ||
"metadata": { | ||
"id": "fSk7HFVk8CqR" | ||
} | ||
} | ||
] | ||
} |