From f4d1000595d0186fdcaacd4034094dfcf30f2604 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Wed, 24 Apr 2024 22:47:57 -0700 Subject: [PATCH] Add StableHLO to SavedModel Tutorial (#2250) [Preview](https://github.com/GleasonK/stablehlo/blob/d8894397b73d19ee93b9509024c02f705b631718/docs/tutorials/savedmodel-embed.ipynb) Add a tutorial to demonstrate the new StableHLO to SavedModel APIs. --- docs/_toc.yaml | 2 + docs/tutorials/savedmodel-embed.ipynb | 256 ++++++++++++++++++++++++++ 2 files changed, 258 insertions(+) create mode 100644 docs/tutorials/savedmodel-embed.ipynb diff --git a/docs/_toc.yaml b/docs/_toc.yaml index e4a6d583ca..6836aaccbd 100644 --- a/docs/_toc.yaml +++ b/docs/_toc.yaml @@ -31,6 +31,8 @@ toc: path: /stablehlo/tutorials/jax-export - title: PyTorch Export to StableHLO path: /stablehlo/tutorials/pytorch-export + - title: StableHLO Embed in SavedModel + path: /stablehlo/tutorials/savedmodel-embed - title: Contributing section: - title: Governance diff --git a/docs/tutorials/savedmodel-embed.ipynb b/docs/tutorials/savedmodel-embed.ipynb new file mode 100644 index 0000000000..c58e1ef7b3 --- /dev/null +++ b/docs/tutorials/savedmodel-embed.ipynb @@ -0,0 +1,256 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "-g5l3RLOGXcc" + }, + "source": [ + "# Tutorial: Embedding StableHLO in SavedModel\n", + "\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][savedmodel-tutorial-colab]\n", + "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][savedmodel-tutorial-kaggle]\n", + "\n", + "_The [`stablehlo.savedmodel`][savedmodel-module] module._\n", + "\n", + "This tutorial will detail how to embed arbitrary StableHLO in a SavedModel. Note that most frameworks have specific APIs for emitting SavedModels, see other StableHLO tutorials for instructions on using these.\n", + "\n", + "## Tutorial Setup\n", + "\n", + "### Install required dependencies\n", + "\n", + "We'll be using the `stablehlo` nightly wheel to get StableHLO's Python APIs, and `tensorflow` for the [SavedModel][savedmodel-tf] dependency.\n", + "\n", + "[savedmodel-tf]: https://www.tensorflow.org/guide/saved_model\n", + "[savedmodel-module]: https://github.com/openxla/stablehlo/tree/main/stablehlo/integrations/python/stablehlo/savedmodel\n", + "[savedmodel-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/savedmodel-embed.ipynb\n", + "[savedmodel-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/savedmodel-embed.ipynb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "B18oDm74-YoZ" + }, + "outputs": [], + "source": [ + "!pip install stablehlo -f https://github.com/openxla/stablehlo/releases/expanded_assets/dev-wheels\n", + "!pip install tensorflow-cpu" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eUjF3nxXGyuP" + }, + "source": [ + "## Embed StableHLO model in SavedModel\n", + "\n", + "In this section we'll take a very basic StableHLO module, and demonstrate some of the APIs to embed it in a SavedModel. In practice this StableHLO module can come from a debug dump, an export from a framework, or even converted from HLO.\n", + "\n", + "### Define a StableHLO `add` module\n", + "\n", + "For this tutorial we'll use a simple `add` model with two input arguments `arg0` and `bias`. When packaging in SavedModel, `bias` will be a constant that is stored in the SavedModel, while `arg0` is provided when calling the model." + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": { + "id": "uBygn_kU-kek" + }, + "outputs": [], + "source": [ + "MODULE_STRING = \"\"\"\n", + "func.func @main(%arg0: tensor<1xf32>, %bias: tensor<1xf32>) -> tensor<1xf32> {\n", + " %0 = stablehlo.add %arg0, %bias: tensor<1xf32>\n", + " return %0 : tensor<1xf32>\n", + "}\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0KEsCtO7IZY6" + }, + "source": [ + "### Parse to a StableHLO MLIR Module\n", + "\n", + "Once we have a StableHLO file / dump of interest, we can parse it back to an MLIR module using `ir.Module.parse`.\n", + "\n", + "Note that all dialects in the module must be registered, otherwise `parse` will fail." + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gTkLk_0E_HPz", + "outputId": "6b66c9c0-3ca8-4f07-bf68-cf72f7fa235c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "module {\n", + " func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {\n", + " %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32>\n", + " return %0 : tensor<1xf32>\n", + " }\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "import mlir.ir as ir\n", + "import mlir.dialects.stablehlo as stablehlo\n", + "\n", + "with ir.Context() as ctx:\n", + " stablehlo.register_dialect(ctx)\n", + " module = ir.Module.parse(MODULE_STRING)\n", + "\n", + "print(module)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "P-gvoHbtItkg" + }, + "source": [ + "### Embed in SavedModel using `stablehlo_to_tf_saved_model`\n", + "\n", + "StableHLO's Python wheel includes a `savedmodel` module to help with packaging StableHLO in SavedModels.\n", + "\n", + "Packing in SavedModel requires a few details:\n", + "\n", + "**`input_locations`** specify where inputs to a model live, in the saved model (`InputLocation.parameter`) or passed in as input arguments during invocation (`InputLocation.input_arg`).\n", + "\n", + "**`state_dict`** can be used to specify values for the `parameter` arguments that live in the SavedModel. These are linked by `name`.\n", + "\n", + "In this example, we'll specify that the second input argument is a value with name `module.bias` which is stored in the SavedModel with the value `2`." + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": { + "id": "sMZUE-vj_Wg5" + }, + "outputs": [], + "source": [ + "from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import InputLocation\n", + "import numpy as np\n", + "\n", + "input_locations = [\n", + " InputLocation.input_arg(position=0), # Parameter, non-constant\n", + " InputLocation.parameter(name='module.bias'), # Constant data in SavedModel\n", + "]\n", + "state_dict = {\n", + " 'module.bias': np.array([2], dtype='float32'),\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eKArz9iCK4Pa" + }, + "source": [ + "Now we can use `stablehlo_to_tf_saved_model` to create the SavedModel in a path specified using the `saved_model_dir` argument." + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Q3shJyzHKlbo", + "outputId": "48383f47-b083-4800-d4a4-0f7fd980d6c4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "assets\tfingerprint.pb\tsaved_model.pb\tvariables\n" + ] + } + ], + "source": [ + "from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import stablehlo_to_tf_saved_model\n", + "\n", + "stablehlo_to_tf_saved_model(\n", + " module,\n", + " saved_model_dir='/tmp/add_model',\n", + " input_locations=input_locations,\n", + " state_dict=state_dict,\n", + ")\n", + "\n", + "!ls /tmp/add_model/" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MboZEPwmLIYl" + }, + "source": [ + "### Reload and call the SavedModel\n", + "\n", + "Now we can load that SavedModel and compile using a sample input.\n", + "\n", + "Here we'll just use a TF constant with the value `3`." + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "i4KLLZ7AGB6X", + "outputId": "10fda712-819a-4002-d4e7-3702087294fd" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[]\n" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "\n", + "restored_model = tf.saved_model.load('/tmp/add_model')\n", + "print(restored_model.f(tf.constant([3], tf.float32)))" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}