Skip to content

Commit

Permalink
Stablehlo to TF saved-model
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Apr 8, 2024
1 parent 74ce90f commit 2e61e14
Show file tree
Hide file tree
Showing 7 changed files with 398 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/buildAndTestCMake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ jobs:
- name: Build and Test StableHLO (with Python bindings)
shell: bash
run: |
pip install tensorflow-cpu
./build_tools/github_actions/ci_build_cmake.sh "$LLVM_BUILD_DIR" "$STABLEHLO_BUILD_DIR"
env:
CMAKE_BUILD_TYPE: Release
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ If you'd like to build the Python bindings, you'll need to install a few
additional dependencies.

```sh
pip install install -r ./llvm-project/mlir/python/requirements.txt
pip install -r ./llvm-project/mlir/python/requirements.txt
pip install tensorflow-cpu # to convert stablehlo to tf-save-model
```

If you've built MLIR & StableHLO using the script above, the Python bindings
Expand Down
15 changes: 14 additions & 1 deletion stablehlo/integrations/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ declare_mlir_dialect_python_bindings(
SOURCES dialects/stablehlo.py
DIALECT_NAME stablehlo)

declare_mlir_python_sources(StablehloToSavedModelPythonSources
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/savedmodel"
SOURCES
stablehlo_to_tf_saved_model.py
)

declare_mlir_python_sources(VhloPythonSources)
declare_mlir_python_sources(VhloPythonSources.Dialects
ADD_TO_PARENT VhloPythonSources
Expand Down Expand Up @@ -141,7 +147,14 @@ add_mlir_python_modules(StablehloUnifiedPythonModules
VhloPythonExtensions
COMMON_CAPI_LINK_LIBS
StablehloUnifiedPythonCAPI
)
)

add_mlir_python_modules(StablehloToSavedModelAPI
ROOT_PREFIX "${STABLEHLO_BINARY_DIR}/python_packages/stablehlo/savedmodel"
INSTALL_PREFIX "python_packages/stablehlo/savedmodel"
DECLARED_SOURCES
StablehloToSavedModelPythonSources
)

################################################################################
# Tests
Expand Down
67 changes: 67 additions & 0 deletions stablehlo/integrations/python/savedmodel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Stablehlo to TF Saved Model

`stablehlo_to_tf_saved_model.py` provides the following API to convert a
stablehlo program to TF saved model.

```python
stablehlo_to_tf_saved_model(module: mlir.ir.Module,
saved_model_dir: os.PathLike,
input_locations: list = [],
state_dict: dict = {},
)
```

where

* `module`: An StableHLO module.
* `saved_model_dir`: Path to save TF saved-model artifacts.
* `input_locations`: Type of each input arguments: either it could be a
parameter with a name associated with it or a positional argument. The
parameters are generally the weights of a model with pre-trained constant
values.
* `state_dict`: Mapping of input parameters with constants.

For example, in order to export a simple
[torch.nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)
model to TF saved model using the above API, we need

* `module`

```mlir
module @linearmodule attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = stablehlo.transpose %arg1, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[2,2]{0,1}"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = stablehlo.dot_general %arg2, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%2 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<2xf32>) -> tensor<2x2xf32>
%3 = stablehlo.add %1, %2 : tensor<2x2xf32>
return %3 : tensor<2x2xf32>\n
}
}
```

* `input_locations`

```python
input_locations = [
InputLocation.parameter(name='linear_layer.bias'), # bias parameter
InputLocation.parameter(name='linear_layer.weight'), # weight parameter
InputLocation.input_arg(position=0), # positional input argument
]
```

* `state_dict`

```python
state_dict = {
'linear_layer.weight': np.array(
[[0.19075723, -0.13815854], [0.46516803, 0.12362058]], dtype='float32'
),
'linear_layer.bias': np.array([-0.37076423, 0.03301], dtype='float32'),
}
```

Note that the API depends on the python bindings for

* StableHLO: Please refer to [README.md](https://github.com/openxla/stablehlo?tab=readme-ov-file#python).
* TensorFlow: Needs `pip install tensorflow-cpu`.
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Copyright 2024 The StableHLO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import dataclasses
from dataclasses import dataclass
import enum
import itertools
import logging
import os
from typing import Any, Dict, List
import mlir.dialects.stablehlo as stablehlo

try:
import tensorflow as tf
from tensorflow.compiler.tf2xla.python import xla as tfxla
except ImportError:
logging.error(
'This module is need tensorflow with xla support.\n'
'Please install tensorflow with `pip install tf-nightly`.\n'
)
raise


# Class to specifiy the input or output signature of a stablehlo function.
@dataclass
class VariableSignature: # either argument or parameters
shape: List[int]
dtype: str
dynamic_dims: List[int] = dataclasses.field(default_factory=list)


# Classes to specify the input type (parameter, argument) of a function.
class VariableType(enum.Enum):
INPUT_ARG = 'input_arg'
PARAMETER = 'parameter'


@dataclass
class InputLocation:
type_: VariableType
position: int = -1
name: str = ''

@classmethod
def parameter(cls, name: str):
return cls(type_=VariableType.PARAMETER, name=name)

@classmethod
def input_arg(cls, position: int):
return cls(type_=VariableType.INPUT_ARG, position=position)


# Class to specify stablehlo input specification.
@dataclass
class StableHLOFuncSpec:
# stablehlo input signature
input_signature: List[VariableSignature]
# stablehlo output signature
output_signature: List[VariableSignature]
# annotations on stablehlo arguments as constants or variables
input_locations: List[InputLocation]
# serialized stablehlo format
bytecode: bytes
# map from constant arguments to constant values
state_dict: Dict[str, Any]


class StableHLOToTFSavedModel:

def __init__(self, spec: StableHLOFuncSpec):
self.stablehlo_type_to_tf_type = {
'i8': 'int8',
'i16': 'i32',
'i32': 'int32',
'i64': 'int64',
'f16': 'float16',
'f32': 'float32',
'f64': 'float64',
}
self.stablehlo_program = spec

# Logic to convert stablehlo program to tf saved model

def _get_shape_with_dynamic(self, signature: VariableSignature):
shape = copy.copy(signature.shape)
for i in signature.dynamic_dims:
shape[i] = None
return shape

stablehlo_type_to_tf_type = {
'i8': 'int8',
'i16': 'i32',
'i32': 'int32',
'i64': 'int64',
'f16': 'float16',
'f32': 'float32',
'f64': 'float64',
}

def _extract_call_parameters(self, args):
call_args = []
for loc in self.stablehlo_program.input_locations:
if str(loc.type_) == str(VariableType.PARAMETER):
call_args.append(self.stablehlo_program.state_dict[loc.name])
else:
call_args.append(args[loc.position])
return call_args

def _wrap_as_tf_func(self):
def inner(*args):
Touts = [
self.stablehlo_type_to_tf_type[sig.dtype]
for sig in self.stablehlo_program.output_signature
]
Souts = [
self._get_shape_with_dynamic(sig)
for sig in self.stablehlo_program.output_signature
]
call_args = self._extract_call_parameters(args)
m = tfxla.call_module(
tuple(call_args),
version=5,
Tout=Touts, # dtype information
Sout=Souts, # Shape information
function_list=[],
module=self.stablehlo_program.bytecode,
)
return m

return inner

def _make_tf_function(self):
return self._wrap_as_tf_func()

def _make_input_signatures(self) -> List[tf.TensorSpec]:
input_pos_to_spec = {
loc.position: spec
for loc, spec in itertools.chain(
zip(
self.stablehlo_program.input_locations,
self.stablehlo_program.input_signature,
),
[],
)
if str(loc.type_) == str(VariableType.INPUT_ARG)
}
for i in range(len(input_pos_to_spec)):
spec = input_pos_to_spec[i]
shape = self._get_shape_with_dynamic(spec)
yield tf.TensorSpec(
shape=shape,
dtype=getattr(
tf,
self.stablehlo_type_to_tf_type[spec.dtype]
if spec.dtype in self.stablehlo_type_to_tf_type
else spec.dtype,
),
name=f'args_{i}',
)

def to_tf_saved_model(
self,
path: os.PathLike,
serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
function_alias: str = '',
) -> None:
tfm = tf.Module()

self.stablehlo_program.state_dict = {
k: tf.Variable(v, trainable=False, name=k)
for k, v in self.stablehlo_program.state_dict.items()
}

input_signatures = list(self._make_input_signatures())

tfm.f = tf.function(
self._make_tf_function(), input_signature=input_signatures
)
tfm._variables = list(self.stablehlo_program.state_dict.values())
signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)}
save_options = tf.saved_model.SaveOptions(
function_aliases={
function_alias: tfm.f,
}
)
tf.saved_model.save(
tfm,
path,
signatures=signatures,
options=save_options,
)


# Top level API for stablehlo to tf saved model


def stablehlo_to_tf_saved_model(
module,
saved_model_dir: os.PathLike,
input_locations: list = [],
state_dict: dict = {},
):
target = stablehlo.get_current_version()
input_signatures = [
VariableSignature(
shape=input.shape,
dtype=str(input.element_type),
dynamic_dims=[],
)
for input in module.body.operations[0].type.inputs
]
output_signature = [
VariableSignature(
shape=result.shape,
dtype=str(result.element_type),
dynamic_dims=[],
)
for result in module.body.operations[0].type.results
]

if input_locations == []:
for i in range(len(module.body.operations[0].type.inputs)):
input_locations.append(InputLocation.input_arg(position=i))

shlo_spec = StableHLOFuncSpec(
input_signature=input_signatures,
output_signature=output_signature,
input_locations=input_locations,
state_dict=state_dict,
bytecode=stablehlo.serialize_portable_artifact(module, target),
)

StableHLOToTFSavedModel(shlo_spec).to_tf_saved_model(saved_model_dir)
2 changes: 2 additions & 0 deletions stablehlo/integrations/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ add_custom_target(${test_name}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
DEPENDS
StablehloUnifiedPythonModules
StablehloToSavedModelAPI
)
add_dependencies(check-stablehlo-python ${test_name})
endfunction()

add_stablehlo_python_test(stablehlo-python-chlo chlo.py)
add_stablehlo_python_test(stablehlo-python-smoketest smoketest.py)
add_stablehlo_python_test(stablehlo-python-stablehlo stablehlo.py)
add_stablehlo_python_test(stablehlo-python-stablehlo-to-saved-model stablehlo_to_tf_saved_model_test.py)
add_stablehlo_python_test(stablehlo-python-vhlo vhlo.py)

add_dependencies(check-stablehlo-quick check-stablehlo-python)
Loading

0 comments on commit 2e61e14

Please sign in to comment.