Skip to content

Commit

Permalink
Stablehlo to TF saved-model
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Apr 6, 2024
1 parent 69ef2ac commit 582fc75
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/buildAndTestCMake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ jobs:
- name: Build and Test StableHLO (with Python bindings)
shell: bash
run: |
pip install tensorflow-cpu
./build_tools/github_actions/ci_build_cmake.sh "$LLVM_BUILD_DIR" "$STABLEHLO_BUILD_DIR"
env:
CMAKE_BUILD_TYPE: Release
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ If you'd like to build the Python bindings, you'll need to install a few
additional dependencies.

```sh
pip install install -r ./llvm-project/mlir/python/requirements.txt
pip install -r ./llvm-project/mlir/python/requirements.txt

pip install pip install tensorflow-cpu # needed mainly for
# stablehlo to tf-save-model connversion
```

If you've built MLIR & StableHLO using the script above, the Python bindings
Expand Down
15 changes: 14 additions & 1 deletion stablehlo/integrations/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ declare_mlir_dialect_python_bindings(
SOURCES dialects/stablehlo.py
DIALECT_NAME stablehlo)

declare_mlir_python_sources(StablehloToSavedModelPythonSources
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/savedmodel"
SOURCES
stablehlo_to_tf_saved_model.py
)

declare_mlir_python_sources(VhloPythonSources)
declare_mlir_python_sources(VhloPythonSources.Dialects
ADD_TO_PARENT VhloPythonSources
Expand Down Expand Up @@ -141,7 +147,14 @@ add_mlir_python_modules(StablehloUnifiedPythonModules
VhloPythonExtensions
COMMON_CAPI_LINK_LIBS
StablehloUnifiedPythonCAPI
)
)

add_mlir_python_modules(StablehloToSavedModelAPI
ROOT_PREFIX "${STABLEHLO_BINARY_DIR}/python_packages/stablehlo/savedmodel"
INSTALL_PREFIX "python_packages/stablehlo/savedmodel"
DECLARED_SOURCES
StablehloToSavedModelPythonSources
)

################################################################################
# Tests
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Copyright 2024 The StableHLO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import dataclasses
from dataclasses import dataclass
import enum
import itertools
import logging
import os
from typing import Any, Dict, List
import mlir.dialects.stablehlo as stablehlo

try:
import tensorflow as tf
from tensorflow.compiler.tf2xla.python import xla as tfxla
except ImportError:
logging.error(
'This module is need tensorflow with xla support.\n'
'Please install tensorflow with `pip install tf-nightly`.\n'
)
raise


# Class to specifiy the input or output signature of a stablehlo function.
@dataclass
class VariableSignature: # either argument or parameters
shape: List[int]
dtype: str
dynamic_dims: List[int] = dataclasses.field(default_factory=list)


# Classes to specify the input type (parameter, argument) of a function.
class VariableType(enum.Enum):
INPUT_ARG = 'input_arg'
PARAMETER = 'parameter'


@dataclass
class InputLocation:
type_: VariableType
position: int = -1
name: str = ''

@classmethod
def parameter(cls, name: str):
return cls(type_=VariableType.PARAMETER, name=name)

@classmethod
def input_arg(cls, position: int):
return cls(type_=VariableType.INPUT_ARG, position=position)


# Class to specify stablehlo input specification.
@dataclass
class StableHLOFuncSpec:
# stablehlo input signature
input_signature: List[VariableSignature]
# stablehlo output signature
output_signature: List[VariableSignature]
# annotations on stablehlo arguments as constants or variables
input_locations: List[InputLocation]
# serialized stablehlo format
bytecode: bytes
# map from constant arguments to constant values
state_dict: Dict[str, Any]


class StableHLOToTFSavedModel:

def __init__(self, spec: StableHLOFuncSpec):
self.stablehlo_type_to_tf_type = {
'i8': 'int8',
'i16': 'i32',
'i32': 'int32',
'i64': 'int64',
'f16': 'float16',
'f32': 'float32',
'f64': 'float64',
}
self.stablehlo_program = spec

# Logic to convert stablehlo program to tf saved model

def _get_shape_with_dynamic(self, signature: VariableSignature):
shape = copy.copy(signature.shape)
for i in signature.dynamic_dims:
shape[i] = None
return shape

stablehlo_type_to_tf_type = {
'i8': 'int8',
'i16': 'i32',
'i32': 'int32',
'i64': 'int64',
'f16': 'float16',
'f32': 'float32',
'f64': 'float64',
}

def _extract_call_parameters(self, args):
call_args = []
for loc in self.stablehlo_program.input_locations:
if str(loc.type_) == str(VariableType.PARAMETER):
call_args.append(self.stablehlo_program.state_dict[loc.name])
else:
call_args.append(args[loc.position])
return call_args

def _wrap_as_tf_func(self):
def inner(*args):
Touts = [
self.stablehlo_type_to_tf_type[sig.dtype]
for sig in self.stablehlo_program.output_signature
]
Souts = [
self._get_shape_with_dynamic(sig)
for sig in self.stablehlo_program.output_signature
]
call_args = self._extract_call_parameters(args)
m = tfxla.call_module(
tuple(call_args),
version=5,
Tout=Touts, # dtype information
Sout=Souts, # Shape information
function_list=[],
module=self.stablehlo_program.bytecode,
)
return m

return inner

def _make_tf_function(self):
return self._wrap_as_tf_func()

def _make_input_signatures(self) -> List[tf.TensorSpec]:
input_pos_to_spec = {
loc.position: spec
for loc, spec in itertools.chain(
zip(
self.stablehlo_program.input_locations,
self.stablehlo_program.input_signature,
),
[],
)
if str(loc.type_) == str(VariableType.INPUT_ARG)
}
for i in range(len(input_pos_to_spec)):
spec = input_pos_to_spec[i]
shape = self._get_shape_with_dynamic(spec)
yield tf.TensorSpec(
shape=shape,
dtype=getattr(
tf,
self.stablehlo_type_to_tf_type[spec.dtype]
if spec.dtype in self.stablehlo_type_to_tf_type
else spec.dtype,
),
name=f'args_{i}',
)

def to_tf_saved_model(
self,
path: os.PathLike,
serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
function_alias: str = '',
) -> None:
tfm = tf.Module()

self.stablehlo_program.state_dict = {
k: tf.Variable(v, trainable=False, name=k)
for k, v in self.stablehlo_program.state_dict.items()
}

input_signatures = list(self._make_input_signatures())

tfm.f = tf.function(
self._make_tf_function(), input_signature=input_signatures
)
tfm._variables = list(self.stablehlo_program.state_dict.values())
signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)}
save_options = tf.saved_model.SaveOptions(
function_aliases={
function_alias: tfm.f,
}
)
tf.saved_model.save(
tfm,
path,
signatures=signatures,
options=save_options,
)


# Top level API for stablehlo to tf saved model


def save_stablehlo_as_tf_saved_model(
module,
saved_model_dir: os.PathLike,
input_locations: list = [],
state_dict: dict = {},
):
target = stablehlo.get_current_version()
input_signatures = [
VariableSignature(
shape=input.shape,
dtype=str(input.element_type),
dynamic_dims=[],
)
for input in module.body.operations[0].type.inputs
]
output_signature = [
VariableSignature(
shape=result.shape,
dtype=str(result.element_type),
dynamic_dims=[],
)
for result in module.body.operations[0].type.results
]

if input_locations == []:
for i in range(len(module.body.operations[0].type.inputs)):
input_locations.append(InputLocation.input_arg(position=i))

shlo_spec = StableHLOFuncSpec(
input_signature=input_signatures,
output_signature=output_signature,
input_locations=input_locations,
state_dict=state_dict,
bytecode=stablehlo.serialize_portable_artifact(module, target),
)

StableHLOToTFSavedModel(shlo_spec).to_tf_saved_model(saved_model_dir)
2 changes: 2 additions & 0 deletions stablehlo/integrations/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ add_custom_target(${test_name}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
DEPENDS
StablehloUnifiedPythonModules
StablehloToSavedModelAPI
)
add_dependencies(check-stablehlo-python ${test_name})
endfunction()

add_stablehlo_python_test(stablehlo-python-chlo chlo.py)
add_stablehlo_python_test(stablehlo-python-smoketest smoketest.py)
add_stablehlo_python_test(stablehlo-python-stablehlo stablehlo.py)
add_stablehlo_python_test(stablehlo-python-stablehlo-to-saved-model stablehlo_to_tf_saved_model_test.py)
add_stablehlo_python_test(stablehlo-python-vhlo vhlo.py)

add_dependencies(check-stablehlo-quick check-stablehlo-python)
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2024 The StableHLO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import mlir.dialects.stablehlo as stablehlo
import mlir.ir as ir
import numpy as np
from savedmodel.stablehlo_to_tf_saved_model import InputLocation, save_stablehlo_as_tf_saved_model
import tensorflow as tf
from tensorflow.python.tools import saved_model_utils

# Convert a stablehlo program, expressing a nn.Linear layer with constant values
# for weight and bias, to saved model.

mlir_module_string = """
module @IrToHlo.12 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = stablehlo.transpose %arg1, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[2,2]{0,1}"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = stablehlo.dot_general %arg2, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%2 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<2xf32>) -> tensor<2x2xf32>
%3 = stablehlo.add %1, %2 : tensor<2x2xf32>
return %3 : tensor<2x2xf32>\n
}
}
"""

ctx = ir.Context()
stablehlo.register_dialect(ctx)
module = ir.Module.parse(mlir_module_string, ctx)

input_locations = [
InputLocation.parameter(name='linear_layer.bias'),
InputLocation.parameter(name='linear_layer.weight'),
InputLocation.input_arg(position=0),
]
state_dict = {
'linear_layer.weight': np.array(
[[0.19075723, -0.13815854], [0.46516803, 0.12362058]], dtype='float32'
),
'linear_layer.bias': np.array([-0.37076423, 0.03301], dtype='float32'),
}


saved_model_dir = tempfile.mkdtemp()
save_stablehlo_as_tf_saved_model(
module,
saved_model_dir=saved_model_dir,
input_locations=input_locations,
state_dict=state_dict,
)

saved_model = saved_model_utils.read_saved_model(saved_model_dir)
assert saved_model != None
print(f'StableHLO convertion to TF Saved Model seems to work!')

0 comments on commit 582fc75

Please sign in to comment.