-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add StableHLO to SavedModel Tutorial (#2250)
[Preview](https://github.com/GleasonK/stablehlo/blob/d8894397b73d19ee93b9509024c02f705b631718/docs/tutorials/savedmodel-embed.ipynb) Add a tutorial to demonstrate the new StableHLO to SavedModel APIs.
- Loading branch information
Showing
2 changed files
with
258 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "-g5l3RLOGXcc" | ||
}, | ||
"source": [ | ||
"# Tutorial: Embedding StableHLO in SavedModel\n", | ||
"\n", | ||
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][savedmodel-tutorial-colab]\n", | ||
"[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][savedmodel-tutorial-kaggle]\n", | ||
"\n", | ||
"_The [`stablehlo.savedmodel`][savedmodel-module] module._\n", | ||
"\n", | ||
"This tutorial will detail how to embed arbitrary StableHLO in a SavedModel. Note that most frameworks have specific APIs for emitting SavedModels, see other StableHLO tutorials for instructions on using these.\n", | ||
"\n", | ||
"## Tutorial Setup\n", | ||
"\n", | ||
"### Install required dependencies\n", | ||
"\n", | ||
"We'll be using the `stablehlo` nightly wheel to get StableHLO's Python APIs, and `tensorflow` for the [SavedModel][savedmodel-tf] dependency.\n", | ||
"\n", | ||
"[savedmodel-tf]: https://www.tensorflow.org/guide/saved_model\n", | ||
"[savedmodel-module]: https://github.com/openxla/stablehlo/tree/main/stablehlo/integrations/python/stablehlo/savedmodel\n", | ||
"[savedmodel-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/savedmodel-embed.ipynb\n", | ||
"[savedmodel-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/savedmodel-embed.ipynb" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "B18oDm74-YoZ" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install stablehlo -f https://github.com/openxla/stablehlo/releases/expanded_assets/dev-wheels\n", | ||
"!pip install tensorflow-cpu" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "eUjF3nxXGyuP" | ||
}, | ||
"source": [ | ||
"## Embed StableHLO model in SavedModel\n", | ||
"\n", | ||
"In this section we'll take a very basic StableHLO module, and demonstrate some of the APIs to embed it in a SavedModel. In practice this StableHLO module can come from a debug dump, an export from a framework, or even converted from HLO.\n", | ||
"\n", | ||
"### Define a StableHLO `add` module\n", | ||
"\n", | ||
"For this tutorial we'll use a simple `add` model with two input arguments `arg0` and `bias`. When packaging in SavedModel, `bias` will be a constant that is stored in the SavedModel, while `arg0` is provided when calling the model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 48, | ||
"metadata": { | ||
"id": "uBygn_kU-kek" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"MODULE_STRING = \"\"\"\n", | ||
"func.func @main(%arg0: tensor<1xf32>, %bias: tensor<1xf32>) -> tensor<1xf32> {\n", | ||
" %0 = stablehlo.add %arg0, %bias: tensor<1xf32>\n", | ||
" return %0 : tensor<1xf32>\n", | ||
"}\n", | ||
"\"\"\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "0KEsCtO7IZY6" | ||
}, | ||
"source": [ | ||
"### Parse to a StableHLO MLIR Module\n", | ||
"\n", | ||
"Once we have a StableHLO file / dump of interest, we can parse it back to an MLIR module using `ir.Module.parse`.\n", | ||
"\n", | ||
"Note that all dialects in the module must be registered, otherwise `parse` will fail." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 54, | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "gTkLk_0E_HPz", | ||
"outputId": "6b66c9c0-3ca8-4f07-bf68-cf72f7fa235c" | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"module {\n", | ||
" func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {\n", | ||
" %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32>\n", | ||
" return %0 : tensor<1xf32>\n", | ||
" }\n", | ||
"}\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import mlir.ir as ir\n", | ||
"import mlir.dialects.stablehlo as stablehlo\n", | ||
"\n", | ||
"with ir.Context() as ctx:\n", | ||
" stablehlo.register_dialect(ctx)\n", | ||
" module = ir.Module.parse(MODULE_STRING)\n", | ||
"\n", | ||
"print(module)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "P-gvoHbtItkg" | ||
}, | ||
"source": [ | ||
"### Embed in SavedModel using `stablehlo_to_tf_saved_model`\n", | ||
"\n", | ||
"StableHLO's Python wheel includes a `savedmodel` module to help with packaging StableHLO in SavedModels.\n", | ||
"\n", | ||
"Packing in SavedModel requires a few details:\n", | ||
"\n", | ||
"**`input_locations`** specify where inputs to a model live, in the saved model (`InputLocation.parameter`) or passed in as input arguments during invocation (`InputLocation.input_arg`).\n", | ||
"\n", | ||
"**`state_dict`** can be used to specify values for the `parameter` arguments that live in the SavedModel. These are linked by `name`.\n", | ||
"\n", | ||
"In this example, we'll specify that the second input argument is a value with name `module.bias` which is stored in the SavedModel with the value `2`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 49, | ||
"metadata": { | ||
"id": "sMZUE-vj_Wg5" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import InputLocation\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"input_locations = [\n", | ||
" InputLocation.input_arg(position=0), # Parameter, non-constant\n", | ||
" InputLocation.parameter(name='module.bias'), # Constant data in SavedModel\n", | ||
"]\n", | ||
"state_dict = {\n", | ||
" 'module.bias': np.array([2], dtype='float32'),\n", | ||
"}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "eKArz9iCK4Pa" | ||
}, | ||
"source": [ | ||
"Now we can use `stablehlo_to_tf_saved_model` to create the SavedModel in a path specified using the `saved_model_dir` argument." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 52, | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "Q3shJyzHKlbo", | ||
"outputId": "48383f47-b083-4800-d4a4-0f7fd980d6c4" | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"assets\tfingerprint.pb\tsaved_model.pb\tvariables\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import stablehlo_to_tf_saved_model\n", | ||
"\n", | ||
"stablehlo_to_tf_saved_model(\n", | ||
" module,\n", | ||
" saved_model_dir='/tmp/add_model',\n", | ||
" input_locations=input_locations,\n", | ||
" state_dict=state_dict,\n", | ||
")\n", | ||
"\n", | ||
"!ls /tmp/add_model/" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "MboZEPwmLIYl" | ||
}, | ||
"source": [ | ||
"### Reload and call the SavedModel\n", | ||
"\n", | ||
"Now we can load that SavedModel and compile using a sample input.\n", | ||
"\n", | ||
"Here we'll just use a TF constant with the value `3`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 53, | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "i4KLLZ7AGB6X", | ||
"outputId": "10fda712-819a-4002-d4e7-3702087294fd" | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import tensorflow as tf\n", | ||
"\n", | ||
"restored_model = tf.saved_model.load('/tmp/add_model')\n", | ||
"print(restored_model.f(tf.constant([3], tf.float32)))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |