From 0fcf3a466ffc2733a0c56b8ff5b5fcd79f1707b8 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Thu, 25 Apr 2024 06:14:48 +0000 Subject: [PATCH 1/5] Update PyTorch example to support 2.3.0 --- docs/tutorials/pytorch-export.ipynb | 224 ++++++++++++++-------------- 1 file changed, 112 insertions(+), 112 deletions(-) diff --git a/docs/tutorials/pytorch-export.ipynb b/docs/tutorials/pytorch-export.ipynb index d2a4412977..7819bfeaf9 100644 --- a/docs/tutorials/pytorch-export.ipynb +++ b/docs/tutorials/pytorch-export.ipynb @@ -1,26 +1,15 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "markdown", + "metadata": { + "id": "42_6VTj7l_xS" + }, "source": [ "# Tutorial: Exporting StableHLO from PyTorch\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][pytorch-tutorial-colab]\n", - "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][pytorch-tutorial-kaggle]\n", + "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][pytorch-tutorial-kaggle]\n", "\n", "_Intro to the [`torch_xla.stablehlo`](https://github.com/pytorch/xla/blob/main/docs/stablehlo.md) module._\n", "\n", @@ -30,12 +19,9 @@ "\n", "We'll be using `torch` and `torchvision` to get a `resnet18` model, and `torch_xla` to export to StableHLO.\n", "\n", - "[pytorch-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/pytorch-export.ipynb\n", - "[pytorch-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/pytorch-export.ipynb" - ], - "metadata": { - "id": "42_6VTj7l_xS" - } + "[pytorch-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/pytorch-export.ipynb\n", + "[pytorch-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/pytorch-export.ipynb" + ] }, { "cell_type": "code", @@ -45,11 +31,14 @@ }, "outputs": [], "source": [ - "!pip install torch_xla torch torchvision" + "!pip install torch_xla==2.3.0 torch==2.3.0 torchvision==0.18.0" ] }, { "cell_type": "markdown", + "metadata": { + "id": "V_AtPpV30Bt8" + }, "source": [ "## Export PyTorch model to StableHLO\n", "\n", @@ -60,13 +49,15 @@ "### Export model to FX graph using `torch.export`\n", "\n", "This step uses entirely vanilla PyTorch APIs to export a `resnet18` model from `torchvision`. Sample inputs are required for graph tracing, we use a `tensor<4x3x224x224xf32>` in this case." - ], - "metadata": { - "id": "V_AtPpV30Bt8" - } + ] }, { "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "GhIpxnx5fuxy" + }, + "outputs": [], "source": [ "import torch\n", "import torchvision\n", @@ -75,32 +66,22 @@ "resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n", "sample_input = (torch.randn(4, 3, 224, 224), )\n", "exported = export(resnet18, sample_input)" - ], - "metadata": { - "id": "GhIpxnx5fuxy" - }, - "execution_count": 7, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "zuMAr3WO1PBk" + }, "source": [ "### Export FX Graph to StableHLO using TorchXLA\n", "\n", "Once we have an exported FX graph, we can convert to StableHLO using the `torch_xla.stablehlo` module. In this case we'll use `exported_program_to_stablehlo`." - ], - "metadata": { - "id": "zuMAr3WO1PBk" - } + ] }, { "cell_type": "code", - "source": [ - "from torch_xla.stablehlo import exported_program_to_stablehlo\n", - "\n", - "stablehlo_program = exported_program_to_stablehlo(exported)\n", - "print(stablehlo_program.get_stablehlo_text('forward')[0:4000],\"\\n...\")" - ], + "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -108,11 +89,10 @@ "id": "xiP7psIQgUc-", "outputId": "0de7d5bc-01a1-4e96-b2ce-19faa9965459" }, - "execution_count": 8, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "module @IrToHlo.508 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {\n", " func.func @main(%arg0: tensor<1000xf32>, %arg1: tensor<1000x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<512xf32>, %arg4: tensor<512xf32>, %arg5: tensor<512xf32>, %arg6: tensor<512x256x1x1xf32>, %arg7: tensor<256xf32>, %arg8: tensor<256xf32>, %arg9: tensor<256xf32>, %arg10: tensor<256xf32>, %arg11: tensor<256x128x1x1xf32>, %arg12: tensor<128xf32>, %arg13: tensor<128xf32>, %arg14: tensor<128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x64x1x1xf32>, %arg17: tensor<64xf32>, %arg18: tensor<64xf32>, %arg19: tensor<64xf32>, %arg20: tensor<64xf32>, %arg21: tensor<64x3x7x7xf32>, %arg22: tensor<4x3x224x224xf32>, %arg23: tensor<64xf32>, %arg24: tensor<64xf32>, %arg25: tensor<64xf32>, %arg26: tensor<64xf32>, %arg27: tensor<64x64x3x3xf32>, %arg28: tensor<64xf32>, %arg29: tensor<64xf32>, %arg30: tensor<64xf32>, %arg31: tensor<64xf32>, %arg32: tensor<64x64x3x3xf32>, %arg33: tensor<64xf32>, %arg34: tensor<64xf32>, %arg35: tensor<64xf32>, %arg36: tensor<64xf32>, %arg37: tensor<64x64x3x3xf32>, %arg38: tensor<64xf32>, %arg39: tensor<64xf32>, %arg40: tensor<64xf32>, %arg41: tensor<64xf32>, %arg42: tensor<64x64x3x3xf32>, %arg43: tensor<128xf32>, %arg44: tensor<128xf32>, %arg45: tensor<128xf32>, %arg46: tensor<128xf32>, %arg47: tensor<128x128x3x3xf32>, %arg48: tensor<128xf32>, %arg49: tensor<128xf32>, %arg50: tensor<128xf32>, %arg51: tensor<128xf32>, %arg52: tensor<128x64x3x3xf32>, %arg53: tensor<128xf32>, %arg54: tensor<128xf32>, %arg55: tensor<128xf32>, %arg56: tensor<128xf32>, %arg57: tensor<128x128x3x3xf32>, %arg58: tensor<128xf32>, %arg59: tensor<128xf32>, %arg60: tensor<128xf32>, %arg61: tensor<128xf32>, %arg62: tensor<128x128x3x3xf32>, %arg63: tensor<256xf32>, %arg64: tensor<256xf32>, %arg65: tensor<256xf32>, %arg66: tensor<256xf32>, %arg67: tensor<256x256x3x3xf32>, %arg68: tensor<256xf32>, %arg69: tensor<256xf32>, %arg70: tensor<256xf32>, %arg71: tensor<256xf32>, %arg72: tensor<256x128x3x3xf32>, %arg73: tensor<256xf32>, %arg74: tensor<256xf32>, %arg75: tensor<256xf32>, %arg76: tensor<256xf32>, %arg77: tensor<256x256x3x3xf32>, %arg78: tensor<256xf32>, %arg79: tensor<256xf32>, %arg80: tensor<256xf32>, %arg81: tensor<256xf32>, %arg82: tensor<256x256x3x3xf32>, %arg83: tensor<512xf32>, %arg84: tensor<512xf32>, %arg85: tensor<512xf32>, %arg86: tensor<512xf32>, %arg87: tensor<512x512x3x3xf32>, %arg88: tensor<512xf32>, %arg89: tensor<512xf32>, %arg90: tensor<512xf32>, %arg91: tensor<512xf32>, %arg92: tensor<512x256x3x3xf32>, %arg93: tensor<512xf32>, %arg94: tensor<512xf32>, %arg95: tensor<512xf32>, %arg96: tensor<512xf32>, %arg97: tensor<512x512x3x3xf32>, %arg98: tensor<512xf32>, %arg99: tensor<512xf32>, %arg100: tensor<512xf32>, %arg101: tensor<512xf32>, %arg102: tensor<512x512x3x3xf32>) -> tensor<4x1000xf32> {\n", @@ -129,10 +109,19 @@ "...\n" ] } + ], + "source": [ + "from torch_xla.stablehlo import exported_program_to_stablehlo\n", + "\n", + "stablehlo_program = exported_program_to_stablehlo(exported)\n", + "print(stablehlo_program.get_stablehlo_text('forward')[0:4000],\"\\n...\")" ] }, { "cell_type": "markdown", + "metadata": { + "id": "2Ujt2OjtpERw" + }, "source": [ "### Export with dynamic batch dimension\n", "\n", @@ -152,36 +141,22 @@ "dynamic_stablehlo = exported_program_to_stablehlo(dynamic_export)\n", "print(dynamic_stablehlo.get_stablehlo_text('forward')[0:5000],\"\\n...\")\n", "```" - ], - "metadata": { - "id": "2Ujt2OjtpERw" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "ySTEYXzG1fU6" + }, "source": [ "### Saving and reloading StableHLO\n", "\n", "The `StableHLOGraphModule` has methods to `save` and `load` StableHLO artifacts. This stores StableHLO portable bytecode artifacts which have full backward compatiblity guarantees." - ], - "metadata": { - "id": "ySTEYXzG1fU6" - } + ] }, { "cell_type": "code", - "source": [ - "from torch_xla.stablehlo import StableHLOGraphModule\n", - "\n", - "# Save to tmp\n", - "stablehlo_program.save('/tmp/stablehlo_dir')\n", - "!ls /tmp/stablehlo_dir\n", - "!ls /tmp/stablehlo_dir/functions\n", - "\n", - "# Reload and execute - Stable serialization, forward / backward compatible.\n", - "reloaded = StableHLOGraphModule.load('/tmp/stablehlo_dir')\n", - "print(reloaded(sample_input[0]))" - ], + "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -189,11 +164,10 @@ "id": "jeVTUs7jh8lk", "outputId": "4c248260-7442-495c-932b-a618e9eb2c67" }, - "execution_count": 9, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "constants data functions\n", "forward.bytecode forward.meta\tforward.mlir\n", @@ -204,71 +178,72 @@ " device='xla:0')\n" ] } + ], + "source": [ + "from torch_xla.stablehlo import StableHLOGraphModule\n", + "\n", + "# Save to tmp\n", + "stablehlo_program.save('/tmp/stablehlo_dir')\n", + "!ls /tmp/stablehlo_dir\n", + "!ls /tmp/stablehlo_dir/functions\n", + "\n", + "# Reload and execute - Stable serialization, forward / backward compatible.\n", + "reloaded = StableHLOGraphModule.load('/tmp/stablehlo_dir')\n", + "print(reloaded(sample_input[0]))" ] }, { "cell_type": "markdown", + "metadata": { + "id": "NsJyDxxnjd4B" + }, "source": [ "## Export to SavedModel\n", "\n", "It is common to want to export a StableHLO model to SavedModel for interop with existing compilation pipelines, existing TF tooling, or serving via [TF Serving](https://github.com/tensorflow/serving).\n", "\n", "PyTorch/XLA makes it easy to pack StableHLO into a SavedModel, which can be loaded back and executed." - ], - "metadata": { - "id": "NsJyDxxnjd4B" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "gsUYjXg75mHM" + }, "source": [ "### Install latest TF\n", "\n", "SavedModel definition lives in TF, so we need to install the dependency. We recommend using `tensorflow-cpu` or `tf-nightly`." - ], - "metadata": { - "id": "gsUYjXg75mHM" - } + ] }, { "cell_type": "code", - "source": [ - "!pip install tensorflow-cpu" - ], + "execution_count": null, "metadata": { "collapsed": true, "id": "d-y5rLcQjqbk" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "!pip install tensorflow-cpu" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "pty18Wc-5sGb" + }, "source": [ "### Export to SavedModel using `torch_xla.tf_saved_model_integration`\n", "\n", "PyTorch/XLA provides a simple API for exporting StableHLO in a SavedModel `save_torch_module_as_tf_saved_model`. This uses the `torch.export` and `torch_xla.stablehlo.exported_program_to_stablehlo` functions under the hood.\n", "\n", "The input to the API is a PyTorch model, we'll use the same `resnet18` from the previous examples." - ], - "metadata": { - "id": "pty18Wc-5sGb" - } + ] }, { "cell_type": "code", - "source": [ - "from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model\n", - "import tensorflow as tf\n", - "\n", - "save_torch_module_as_tf_saved_model(\n", - " resnet18, # original pytorch torch.nn.Module\n", - " sample_input, # sample inputs used to trace\n", - " '/tmp/resnet_tf' # directory for tf.saved_model\n", - ")\n", - "\n", - "!ls /tmp/resnet_tf/" - ], + "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -276,36 +251,44 @@ "id": "1g24aShhjG6e", "outputId": "382d14f5-7b7b-4297-8af4-0999b7ba4f3f" }, - "execution_count": 10, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "assets\tfingerprint.pb\tsaved_model.pb\tvariables\n" ] } + ], + "source": [ + "from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model\n", + "import tensorflow as tf\n", + "\n", + "save_torch_module_as_tf_saved_model(\n", + " resnet18, # original pytorch torch.nn.Module\n", + " sample_input, # sample inputs used to trace\n", + " '/tmp/resnet_tf' # directory for tf.saved_model\n", + ")\n", + "\n", + "!ls /tmp/resnet_tf/" ] }, { "cell_type": "markdown", + "metadata": { + "id": "shmmhYP76eX2" + }, "source": [ "### Reload and call the SavedModel\n", "\n", "Now we can load that SavedModel and compile using our `sample_input` from a previous example.\n", "\n", "_Note: the restored model does *not* require PyTorch or PyTorch/XLA to run, just XLA._" - ], - "metadata": { - "id": "shmmhYP76eX2" - } + ] }, { "cell_type": "code", - "source": [ - "loaded_m = tf.saved_model.load('/tmp/resnet_tf')\n", - "print(loaded_m.f(tf.constant(sample_input[0].numpy())))" - ], + "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -313,11 +296,10 @@ "id": "2TZMQyJj6fHy", "outputId": "c2f97232-a623-4146-d369-ef75c2033136" }, - "execution_count": 12, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "[]\n" ] } + ], + "source": [ + "loaded_m = tf.saved_model.load('/tmp/resnet_tf')\n", + "print(loaded_m.f(tf.constant(sample_input[0].numpy())))" ] }, { "cell_type": "markdown", + "metadata": { + "id": "fSk7HFVk8CqR" + }, "source": [ "# Common Troubleshooting\n", "\n", @@ -341,10 +330,21 @@ "\n", "- Issues in `torch.export`: These need to be resolved in upstream PyTorch.\n", "- Issues in `torch_xla.stablehlo`: Open a ticket on pytorch/xla repo." - ], - "metadata": { - "id": "fSk7HFVk8CqR" - } + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" } - ] + }, + "nbformat": 4, + "nbformat_minor": 0 } From 1bac1af2c81c5c645525efbbe1b761a1a19af04a Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Thu, 25 Apr 2024 06:18:43 +0000 Subject: [PATCH 2/5] Update 2.4 wording --- docs/tutorials/pytorch-export.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/pytorch-export.ipynb b/docs/tutorials/pytorch-export.ipynb index 7819bfeaf9..75eb269f7a 100644 --- a/docs/tutorials/pytorch-export.ipynb +++ b/docs/tutorials/pytorch-export.ipynb @@ -125,9 +125,9 @@ "source": [ "### Export with dynamic batch dimension\n", "\n", - "_This is a new feature and will work after 2.3 release cut, or if using `torch_xla` nightly. Once PyTorch/XLA 2.3 is released, this will be converted into a running example. Using the nightly `torch` and `torch_xla` will likely lead to notebook failures in the meantime._\n", + "_This is a new feature and will be included in pip installation after the 2.4 release, or if using `torch_xla` nightly. Once PyTorch/XLA 2.4 is released, this will be converted into a running example. Using the nightly `torch` and `torch_xla` will likely lead to notebook failures in the meantime._\n", "\n", - "Dynamic batch dimensions can be specified as a part of the inital `torch.export` step. The FX Graph's symint information is used to export to dynamic StableHLO.\n", + "Dynamic batch dimensions can be specified as a part of the initial `torch.export` step. The FX Graph's symint information is used to export to dynamic StableHLO.\n", "\n", "In this example, we specify that dim 0 of the sample input is dynamic, which propagates shape using a `tensor`.\n", "\n", From 485f532cfc873201f7abf7742e6580b4e3bcf6d4 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Thu, 25 Apr 2024 06:23:37 +0000 Subject: [PATCH 3/5] Also update TF package --- docs/tutorials/pytorch-export.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/pytorch-export.ipynb b/docs/tutorials/pytorch-export.ipynb index 75eb269f7a..eb97b7d2c9 100644 --- a/docs/tutorials/pytorch-export.ipynb +++ b/docs/tutorials/pytorch-export.ipynb @@ -225,7 +225,7 @@ }, "outputs": [], "source": [ - "!pip install tensorflow-cpu" + "!pip install -U tensorflow-cpu" ] }, { From 0aee86cbd7e4f71283e56ebf647432d205515bfe Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Thu, 25 Apr 2024 06:24:29 +0000 Subject: [PATCH 4/5] better import scoping --- docs/tutorials/pytorch-export.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/tutorials/pytorch-export.ipynb b/docs/tutorials/pytorch-export.ipynb index eb97b7d2c9..000aa90a10 100644 --- a/docs/tutorials/pytorch-export.ipynb +++ b/docs/tutorials/pytorch-export.ipynb @@ -262,7 +262,6 @@ ], "source": [ "from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model\n", - "import tensorflow as tf\n", "\n", "save_torch_module_as_tf_saved_model(\n", " resnet18, # original pytorch torch.nn.Module\n", @@ -314,6 +313,8 @@ } ], "source": [ + "import tensorflow as tf\n", + "\n", "loaded_m = tf.saved_model.load('/tmp/resnet_tf')\n", "print(loaded_m.f(tf.constant(sample_input[0].numpy())))" ] From 14dbecb45be2eee59d17e9c9adf1402acdf446dd Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Thu, 25 Apr 2024 06:27:17 +0000 Subject: [PATCH 5/5] Add common troubleshooting for PT/XLA --- docs/tutorials/pytorch-export.ipynb | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/tutorials/pytorch-export.ipynb b/docs/tutorials/pytorch-export.ipynb index 000aa90a10..266679d0c1 100644 --- a/docs/tutorials/pytorch-export.ipynb +++ b/docs/tutorials/pytorch-export.ipynb @@ -330,7 +330,9 @@ "Most issues in PyTorch to StableHLO require a GH ticket as a next step. Teams are generally quick to help resolve issues.\n", "\n", "- Issues in `torch.export`: These need to be resolved in upstream PyTorch.\n", - "- Issues in `torch_xla.stablehlo`: Open a ticket on pytorch/xla repo." + "- Issues in `torch_xla.stablehlo`: Open a ticket on pytorch/xla repo.\n", + "\n", + "The most common issue is that dependencies are out of sync. PyTorch/XLA and PyTorch must be on the same version. Version mismatch can result in import errors, as well as some runtime issues. After that, it is possible that a program fails to export due to a bug in the PyTorch/XLA bridge, for which a ticket with repro is helpful." ] } ],