Skip to content

Commit

Permalink
Add StableHLO to SavedModel Tutorial (#2250)
Browse files Browse the repository at this point in the history
  • Loading branch information
GleasonK authored Apr 25, 2024
1 parent 0b7ecf3 commit f4d1000
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/_toc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ toc:
path: /stablehlo/tutorials/jax-export
- title: PyTorch Export to StableHLO
path: /stablehlo/tutorials/pytorch-export
- title: StableHLO Embed in SavedModel
path: /stablehlo/tutorials/savedmodel-embed
- title: Contributing
section:
- title: Governance
Expand Down
256 changes: 256 additions & 0 deletions docs/tutorials/savedmodel-embed.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "-g5l3RLOGXcc"
},
"source": [
"# Tutorial: Embedding StableHLO in SavedModel\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][savedmodel-tutorial-colab]\n",
"[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][savedmodel-tutorial-kaggle]\n",
"\n",
"_The [`stablehlo.savedmodel`][savedmodel-module] module._\n",
"\n",
"This tutorial will detail how to embed arbitrary StableHLO in a SavedModel. Note that most frameworks have specific APIs for emitting SavedModels, see other StableHLO tutorials for instructions on using these.\n",
"\n",
"## Tutorial Setup\n",
"\n",
"### Install required dependencies\n",
"\n",
"We'll be using the `stablehlo` nightly wheel to get StableHLO's Python APIs, and `tensorflow` for the [SavedModel][savedmodel-tf] dependency.\n",
"\n",
"[savedmodel-tf]: https://www.tensorflow.org/guide/saved_model\n",
"[savedmodel-module]: https://github.com/openxla/stablehlo/tree/main/stablehlo/integrations/python/stablehlo/savedmodel\n",
"[savedmodel-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/savedmodel-embed.ipynb\n",
"[savedmodel-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/savedmodel-embed.ipynb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B18oDm74-YoZ"
},
"outputs": [],
"source": [
"!pip install stablehlo -f https://github.com/openxla/stablehlo/releases/expanded_assets/dev-wheels\n",
"!pip install tensorflow-cpu"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eUjF3nxXGyuP"
},
"source": [
"## Embed StableHLO model in SavedModel\n",
"\n",
"In this section we'll take a very basic StableHLO module, and demonstrate some of the APIs to embed it in a SavedModel. In practice this StableHLO module can come from a debug dump, an export from a framework, or even converted from HLO.\n",
"\n",
"### Define a StableHLO `add` module\n",
"\n",
"For this tutorial we'll use a simple `add` model with two input arguments `arg0` and `bias`. When packaging in SavedModel, `bias` will be a constant that is stored in the SavedModel, while `arg0` is provided when calling the model."
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"id": "uBygn_kU-kek"
},
"outputs": [],
"source": [
"MODULE_STRING = \"\"\"\n",
"func.func @main(%arg0: tensor<1xf32>, %bias: tensor<1xf32>) -> tensor<1xf32> {\n",
" %0 = stablehlo.add %arg0, %bias: tensor<1xf32>\n",
" return %0 : tensor<1xf32>\n",
"}\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0KEsCtO7IZY6"
},
"source": [
"### Parse to a StableHLO MLIR Module\n",
"\n",
"Once we have a StableHLO file / dump of interest, we can parse it back to an MLIR module using `ir.Module.parse`.\n",
"\n",
"Note that all dialects in the module must be registered, otherwise `parse` will fail."
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gTkLk_0E_HPz",
"outputId": "6b66c9c0-3ca8-4f07-bf68-cf72f7fa235c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"module {\n",
" func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {\n",
" %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32>\n",
" return %0 : tensor<1xf32>\n",
" }\n",
"}\n",
"\n"
]
}
],
"source": [
"import mlir.ir as ir\n",
"import mlir.dialects.stablehlo as stablehlo\n",
"\n",
"with ir.Context() as ctx:\n",
" stablehlo.register_dialect(ctx)\n",
" module = ir.Module.parse(MODULE_STRING)\n",
"\n",
"print(module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P-gvoHbtItkg"
},
"source": [
"### Embed in SavedModel using `stablehlo_to_tf_saved_model`\n",
"\n",
"StableHLO's Python wheel includes a `savedmodel` module to help with packaging StableHLO in SavedModels.\n",
"\n",
"Packing in SavedModel requires a few details:\n",
"\n",
"**`input_locations`** specify where inputs to a model live, in the saved model (`InputLocation.parameter`) or passed in as input arguments during invocation (`InputLocation.input_arg`).\n",
"\n",
"**`state_dict`** can be used to specify values for the `parameter` arguments that live in the SavedModel. These are linked by `name`.\n",
"\n",
"In this example, we'll specify that the second input argument is a value with name `module.bias` which is stored in the SavedModel with the value `2`."
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"id": "sMZUE-vj_Wg5"
},
"outputs": [],
"source": [
"from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import InputLocation\n",
"import numpy as np\n",
"\n",
"input_locations = [\n",
" InputLocation.input_arg(position=0), # Parameter, non-constant\n",
" InputLocation.parameter(name='module.bias'), # Constant data in SavedModel\n",
"]\n",
"state_dict = {\n",
" 'module.bias': np.array([2], dtype='float32'),\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eKArz9iCK4Pa"
},
"source": [
"Now we can use `stablehlo_to_tf_saved_model` to create the SavedModel in a path specified using the `saved_model_dir` argument."
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Q3shJyzHKlbo",
"outputId": "48383f47-b083-4800-d4a4-0f7fd980d6c4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"assets\tfingerprint.pb\tsaved_model.pb\tvariables\n"
]
}
],
"source": [
"from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import stablehlo_to_tf_saved_model\n",
"\n",
"stablehlo_to_tf_saved_model(\n",
" module,\n",
" saved_model_dir='/tmp/add_model',\n",
" input_locations=input_locations,\n",
" state_dict=state_dict,\n",
")\n",
"\n",
"!ls /tmp/add_model/"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MboZEPwmLIYl"
},
"source": [
"### Reload and call the SavedModel\n",
"\n",
"Now we can load that SavedModel and compile using a sample input.\n",
"\n",
"Here we'll just use a TF constant with the value `3`."
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "i4KLLZ7AGB6X",
"outputId": "10fda712-819a-4002-d4e7-3702087294fd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>]\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"\n",
"restored_model = tf.saved_model.load('/tmp/add_model')\n",
"print(restored_model.f(tf.constant([3], tf.float32)))"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

0 comments on commit f4d1000

Please sign in to comment.