Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
GleasonK committed Apr 18, 2024
1 parent 74ceb70 commit f85d9ff
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions docs/tutorials/jax-export.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
{
"cell_type": "markdown",
"source": [
"# Tutorial: JAX Export to StableHLO\n",
"# Tutorial: Exporting StableHLO from JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/examples/jax-export.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/examples/jax-export.ipynb)\n",
"\n",
Expand Down Expand Up @@ -46,12 +46,12 @@
{
"cell_type": "code",
"source": [
"#@title Define `prettyprint_stablehlo` to help with MLIR printing\n",
"#@title Define `get_stablehlo_asm` to help with MLIR printing\n",
"from jax._src.interpreters import mlir as jax_mlir\n",
"from jax._src.lib.mlir import ir\n",
"\n",
"# Prettyprint without large constants\n",
"def prettyprint_stablehlo(module_str):\n",
"# Returns prettyprint of StableHLO module without large constants\n",
"def get_stablehlo_asm(module_str):\n",
" with jax_mlir.make_ir_context():\n",
" stablehlo_module = ir.Module.parse(module_str, context=jax_mlir.make_ir_context())\n",
" return stablehlo_module.operation.get_asm(large_elements_limit=20)\n",
Expand All @@ -61,9 +61,10 @@
"logging.disable(logging.WARNING)"
],
"metadata": {
"id": "HqjeC_QSugYj"
"id": "HqjeC_QSugYj",
"cellView": "form"
},
"execution_count": 4,
"execution_count": 2,
"outputs": []
},
{
Expand Down Expand Up @@ -115,17 +116,17 @@
"# Create abstract input shapes:\n",
"inputs = (np.int32(1), np.int32(1),)\n",
"input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]\n",
"stablehlo_add_str = export.export(plus)(*input_shapes).mlir_module()\n",
"print(prettyprint_stablehlo(stablehlo_add_str))"
"stablehlo_add = export.export(plus)(*input_shapes).mlir_module()\n",
"print(get_stablehlo_asm(stablehlo_add))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "v-GN3vPbvFoa",
"outputId": "0d434553-6332-4e9e-fb1b-5cd07db5dab0"
"outputId": "b128c08f-3591-4142-fd67-561812bb3d4e"
},
"execution_count": 2,
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -177,15 +178,15 @@
"\n",
"# Export to StableHLO\n",
"stablehlo_resnet18_export = export.export(resnet18)(input_shape)\n",
"resnet18_stablehlo = prettyprint_stablehlo(stablehlo_resnet18_export.mlir_module())\n",
"resnet18_stablehlo = get_stablehlo_asm(stablehlo_resnet18_export.mlir_module())\n",
"print(resnet18_stablehlo[:600], \"\\n...\\n\", resnet18_stablehlo[-345:])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "53T7jO-v_6PC",
"outputId": "4567e611-f02d-4121-fd82-3fd66e6bec75"
"outputId": "0536386d-09d2-49a3-b51c-951c20f6e49b"
},
"execution_count": 5,
"outputs": [
Expand Down Expand Up @@ -234,15 +235,15 @@
"dyn_scope = export.SymbolicScope()\n",
"dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape(\"a,3,224,224\", scope=dyn_scope), np.float32)\n",
"dyn_resnet18_export = export.export(resnet18)(dyn_input_shape)\n",
"dyn_resnet18_stablehlo = prettyprint_stablehlo(dyn_resnet18_export.mlir_module())\n",
"dyn_resnet18_stablehlo = get_stablehlo_asm(dyn_resnet18_export.mlir_module())\n",
"print(dyn_resnet18_stablehlo[:1900], \"\\n...\\n\", dyn_resnet18_stablehlo[-1000:])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "sIkbtViEMJ3T",
"outputId": "a1da7467-4adb-44cf-f861-c482bbc7fc82"
"outputId": "165019c7-e771-43d6-c324-6bb4222798ec"
},
"execution_count": 6,
"outputs": [
Expand Down Expand Up @@ -294,7 +295,7 @@
"A few things to note in the exported StableHLO:\n",
"\n",
"1. The exported program now has `tensor<?x3x224x224xf32>`. These input types can be refined in many ways: SavedModel execution takes care of refinement which we'll see in the next example, but StableHLO also has APIs to [refine shapes](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L69-L99) and [canonicalize dynamic programs](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L18-L28) to static programs.\n",
"2. JAX will generated guards to ensure the values of `a` are valid, in this case `a > 1` is checked. These can be washed away at compile time once refined."
"2. JAX will generate guards to ensure the values of `a` are valid, in this case `a > 1` is checked. These can be washed away at compile time once refined."
],
"metadata": {
"id": "dRIu7xlSoDUK"
Expand Down Expand Up @@ -369,7 +370,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "GXkgtX7QEiWa",
"outputId": "930b87c6-d0cb-4363-ccce-5510543ee7d9",
"outputId": "9034836e-4ba1-4210-c2c1-316777d4ad89",
"collapsed": true
},
"execution_count": 7,
Expand Down Expand Up @@ -408,7 +409,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "9Az3dXXWrVDM",
"outputId": "c3f4fd83-8461-4386-8907-dd91942312bb"
"outputId": "cac6d3b1-126e-4e66-dc4d-f2a798ede463"
},
"execution_count": 8,
"outputs": [
Expand Down

0 comments on commit f85d9ff

Please sign in to comment.