Skip to content

Commit

Permalink
Stablehlo to TF saved-model
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Apr 5, 2024
1 parent 69ef2ac commit 84afaa7
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 1 deletion.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ If you'd like to build the Python bindings, you'll need to install a few
additional dependencies.

```sh
pip install install -r ./llvm-project/mlir/python/requirements.txt
pip install -r ./llvm-project/mlir/python/requirements.txt

pip install pip install tensorflow-cpu # needed mainly for
# stablehlo to tf-save-model connversion
```

If you've built MLIR & StableHLO using the script above, the Python bindings
Expand Down
7 changes: 7 additions & 0 deletions stablehlo/integrations/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ declare_mlir_dialect_python_bindings(
SOURCES dialects/stablehlo.py
DIALECT_NAME stablehlo)

declare_mlir_python_sources(StablehloToSavdedModelPythonSources
ADD_TO_PARENT StablehloPythonSources
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/savedmodel"
SOURCES
stablehlo_to_tf_saved_model.py
)

declare_mlir_python_sources(VhloPythonSources)
declare_mlir_python_sources(VhloPythonSources.Dialects
ADD_TO_PARENT VhloPythonSources
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright 2024 The StableHLO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import dataclasses
from dataclasses import dataclass
import enum
import itertools
import logging
import os
import shutil
from typing import Any, Dict, List, Mapping, Optional, Tuple
import mlir
import mlir.dialects.stablehlo as stablehlo
from mlir.ir import Context, InsertionPoint, Location, Module
import numpy as np

try:
import tensorflow as tf
from tensorflow.compiler.tf2xla.python import xla as tfxla
except ImportError:
logging.error(
'This module is need tensorflow with xla support.\n'
'Please install tensorflow with `pip install tf-nightly`.\n'
)
raise


# Stablehlo spec classes to express the stablehlo input specification
@dataclass
class VariableSignature: # either argument or parameters
shape: List[int]
dtype: str
dynamic_dims: List[int] = dataclasses.field(default_factory=list)


class VariableType(enum.Enum):
INPUT_ARG = 'input_arg'
PARAMETER = 'parameter'
CONSTANT = 'constant'


@dataclass
class InputLocation:
type_: VariableType
position: int = -1
name: str = ''

@classmethod
def parameter(cls, name: str):
return cls(type_=VariableType.PARAMETER, name=name)

@classmethod
def input_arg(cls, position: int):
return cls(type_=VariableType.INPUT_ARG, position=position)

@classmethod
def constant(cls, position):
return cls(type_=VariableType.CONSTANT, position=position)


@dataclass
class StableHLOFuncSpec:
# stablehlo input signature
input_signature: List[VariableSignature]
# stablehlo output signature
output_signature: List[VariableSignature]
# annotations on stablehlo arguments as constants or variables
input_locations: List[InputLocation]
# serialized stablehlo format
bytecode: bytes
# map from constant arguments to constant values
state_dict: Dict[str, Any]


# Logic to convert stablehlo program to tf saved model


def _get_shape_with_dynamic(signature: VariableSignature):
shape = copy.copy(signature.shape)
for i in signature.dynamic_dims:
shape[i] = None
return shape


def _extract_call_parameters(args, shlo_spec):
call_args = []
for loc in shlo_spec.input_locations:
if loc.type_ == VariableType.PARAMETER:
call_args.append(shlo_spec.state_dict[loc.name])
elif loc.type_ == VariableType.CONSTANT:
call_args.append(shlo_spec.additional_constants[loc.position])
else:
call_args.append(args[loc.position])
return call_args


def _wrap_as_tf_func(shlo_spec):
def inner(*args):
Touts = [sig.dtype for sig in shlo_spec.output_signature]
Souts = [_get_shape_with_dynamic(sig) for sig in shlo_spec.output_signature]
call_args = _extract_call_parameters(args, shlo_spec)
m = tfxla.call_module(
tuple(call_args),
version=5,
Tout=Touts, # dtype information
Sout=Souts, # Shape information
function_list=[],
module=shlo_spec.bytecode,
)
return m

return inner


def make_tf_function(shlo_spec: StableHLOFuncSpec):
return _wrap_as_tf_func(shlo_spec)


def _make_input_signatures(shlo_spec: StableHLOFuncSpec) -> List[tf.TensorSpec]:
input_pos_to_spec = {
loc.position: spec
for loc, spec in itertools.chain(
zip(shlo_spec.input_locations, shlo_spec.input_signature), []
)
if loc.type_ == VariableType.INPUT_ARG
}
primitive_type_to_tf_type = {'int': 'int32', 'float': 'float32'}
for i in range(len(input_pos_to_spec)):
spec = input_pos_to_spec[i]
shape = _get_shape_with_dynamic(spec)
yield tf.TensorSpec(
shape=shape,
dtype=getattr(
tf,
primitive_type_to_tf_type[spec.dtype]
if spec.dtype in primitive_type_to_tf_type
else spec.dtype,
),
name=f'args_{i}',
)


def save_stablehlo_graph_as_tf(
stablehlo_program: StableHLOFuncSpec,
path: os.PathLike,
serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
function_alias: str = '',
) -> None:
bundle = copy.deepcopy(stablehlo_program)
tfm = tf.Module()

bundle.state_dict = {
k: tf.Variable(v, trainable=False, name=k)
for k, v in bundle.state_dict.items()
}

input_signatures = list(_make_input_signatures(bundle))

tfm.f = tf.function(
make_tf_function(bundle), input_signature=input_signatures
)
tfm._variables = list(bundle.state_dict.values())
signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)}
save_options = tf.saved_model.SaveOptions(
function_aliases={
function_alias: tfm.f,
}
)
tf.saved_model.save(
tfm,
path,
signatures=signatures,
options=save_options,
)


# Top level API for stablehlo to tf saved model

import numpy as np


def save_stablehlo_as_tf_saved_model(
module,
saved_model_dir: os.PathLike,
input_locations: list = [],
state_dict: dict = {},
):
target = stablehlo.get_current_version()
stablehlo_type_to_torch_type = {'i32': 'int32', 'f32': 'float32'}
input_signatures = [
VariableSignature(
shape=input.shape,
dtype=stablehlo_type_to_torch_type[str(input.element_type)],
dynamic_dims=[],
)
for input in module.body.operations[0].type.inputs
]
output_signature = [
VariableSignature(
shape=result.shape,
dtype=stablehlo_type_to_torch_type[str(result.element_type)],
dynamic_dims=[],
)
for result in module.body.operations[0].type.results
]

if input_locations == []:
for i in range(len(module.body.operations[0].type.inputs)):
input_locations.append(InputLocation.input_arg(position=i))

shlo_spec = StableHLOFuncSpec(
input_signature=input_signatures,
output_signature=output_signature,
input_locations=input_locations,
state_dict=state_dict,
bytecode=stablehlo.serialize_portable_artifact(module, target),
)

save_stablehlo_graph_as_tf(shlo_spec, saved_model_dir)
1 change: 1 addition & 0 deletions stablehlo/integrations/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ endfunction()
add_stablehlo_python_test(stablehlo-python-chlo chlo.py)
add_stablehlo_python_test(stablehlo-python-smoketest smoketest.py)
add_stablehlo_python_test(stablehlo-python-stablehlo stablehlo.py)
add_stablehlo_python_test(stablehlo-python-stablehlo-to-saved-model stablehlo_to_tf_saved_model_test.py)
add_stablehlo_python_test(stablehlo-python-vhlo vhlo.py)

add_dependencies(check-stablehlo-quick check-stablehlo-python)
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 The StableHLO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mlir.dialects.stablehlo as stablehlo
import mlir.ir as ir
from mlir.stablehlo_to_tf_saved_model import InputLocation, save_stablehlo_as_tf_saved_model
import numpy as np
import tensorflow as tf


# Convert a stablehlo program, expressing a nn.Lienar layer with constant values
# for weight and bias, to saved model

mlir_module_string = """
module @IrToHlo.12 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = stablehlo.transpose %arg1, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[2,2]{0,1}"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = stablehlo.dot_general %arg2, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%2 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<2xf32>) -> tensor<2x2xf32>
%3 = stablehlo.add %1, %2 : tensor<2x2xf32>
return %3 : tensor<2x2xf32>\n
}
}
"""

ctx = ir.Context()
stablehlo.register_dialect(ctx)
module = ir.Module.parse(mlir_module_string, ctx)

input_locations = [
InputLocation.parameter(name='linear_layer.bias'),
InputLocation.parameter(name='linear_layer.weight'),
InputLocation.input_arg(position=0),
]
state_dict = {
'linear_layer.weight': np.array(
[[0.19075723, -0.13815854], [0.46516803, 0.12362058]], dtype='float32'
),
'linear_layer.bias': np.array([-0.37076423, 0.03301], dtype='float32'),
}


saved_model_dir = 'example_saved_model'
save_stablehlo_as_tf_saved_model(
module,
saved_model_dir=saved_model_dir,
input_locations=input_locations,
state_dict=state_dict,
)
tf_exported = tf.saved_model.load(saved_model_dir)
tf_mlir = tf.mlir.experimental.convert_saved_model(saved_model_dir, '')
print('\n TF Model\n', tf_mlir)

try:
os.remove(saved_model_dir)
print(f"File '{saved_model_dir}' deleted successfully")
except FileNotFoundError:
print(f"Error: File '{saved_model_dir}' not found")

0 comments on commit 84afaa7

Please sign in to comment.