Skip to content

Commit

Permalink
address feedback:1
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Apr 23, 2024
1 parent 9fa5bed commit 07af947
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 44 deletions.
11 changes: 7 additions & 4 deletions stablehlo/integrations/python/stablehlo/savedmodel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ where

* `module`: An StableHLO module.
* `saved_model_dir`: Path to save TF saved-model artifacts.
* `input_locations`: Type of each input arguments: either it could be a
* `target_version`: Serialization version of StableHLO. Default: current
stablehlo version.
* `input_locations`: List of input argument types: either it could be a
parameter with a name associated with it or a positional argument. The
parameters are generally the weights or biases of a model with pre-trained
constant values.
* `state_dict`: Mapping of named input parameters with constants.
constant values. Default: empty list.
* `state_dict`: Mapping of named input parameters with constants. Default:
empty list.

For example, to export a simple
[torch.nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)
Expand Down Expand Up @@ -68,7 +71,7 @@ state_dict = {
The above API depends on

* MLIR Python bindings: To express an MLIR module.
* TensorFlow: To save the TF saved model artifacts.
* TensorFlow: Only used to work with TF saved model artifacts.

## Testing

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,15 @@ class StableHLOToTFSavedModel:

def __init__(self, spec: StableHLOFuncSpec):
self.stablehlo_type_to_tf_type = {
'i1': 'bool',
'i8': 'int8',
'i16': 'i32',
'i32': 'int32',
'i64': 'int64',
'f16': 'float16',
'f32': 'float32',
'f64': 'float64',
'bf16': 'bfloat16',
}
self.stablehlo_program = spec

Expand All @@ -100,16 +102,6 @@ def _get_shape_with_dynamic(self, signature: VariableSignature):
shape[i] = None
return shape

stablehlo_type_to_tf_type = {
'i8': 'int8',
'i16': 'i32',
'i32': 'int32',
'i64': 'int64',
'f16': 'float16',
'f32': 'float32',
'f64': 'float64',
}

def _extract_call_parameters(self, args):
call_args = []
for loc in self.stablehlo_program.input_locations:
Expand All @@ -121,10 +113,14 @@ def _extract_call_parameters(self, args):

def _wrap_as_tf_func(self):
def inner(*args):
Touts = [
self.stablehlo_type_to_tf_type[sig.dtype]
for sig in self.stablehlo_program.output_signature
]
try:
Touts = [
self.stablehlo_type_to_tf_type[sig.dtype]
for sig in self.stablehlo_program.output_signature
]
except KeyError as e:
raise KeyError(f'TensorFlow type mapping not found: {e}') from None

Souts = [
self._get_shape_with_dynamic(sig)
for sig in self.stablehlo_program.output_signature
Expand Down Expand Up @@ -160,14 +156,16 @@ def _make_input_signatures(self) -> List[tf.TensorSpec]:
for i in range(len(input_pos_to_spec)):
spec = input_pos_to_spec[i]
shape = self._get_shape_with_dynamic(spec)
try:
dtype = getattr(tf, self.stablehlo_type_to_tf_type[spec.dtype])
except KeyError as e:
raise KeyError(
f'TensorFlow type mapping not found for {spec.dtype}: {e}'
) from None

yield tf.TensorSpec(
shape=shape,
dtype=getattr(
tf,
self.stablehlo_type_to_tf_type[spec.dtype]
if spec.dtype in self.stablehlo_type_to_tf_type
else spec.dtype,
),
dtype=dtype,
name=f'args_{i}',
)

Expand Down Expand Up @@ -210,10 +208,10 @@ def to_tf_saved_model(
def stablehlo_to_tf_saved_model(
module: ir.Module,
saved_model_dir: os.PathLike,
target_version: str = stablehlo.get_current_version(),
input_locations: list = [],
state_dict: dict = {},
):
target = stablehlo.get_current_version()
input_signatures = [
VariableSignature(
shape=input.shape,
Expand All @@ -240,7 +238,7 @@ def stablehlo_to_tf_saved_model(
output_signature=output_signature,
input_locations=input_locations,
state_dict=state_dict,
bytecode=stablehlo.serialize_portable_artifact(module, target),
bytecode=stablehlo.serialize_portable_artifact(module, target_version),
)

StableHLOToTFSavedModel(shlo_spec).to_tf_saved_model(saved_model_dir)
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,37 @@
import tempfile
import mlir.dialects.stablehlo as stablehlo
import mlir.ir as ir
import numpy as np
from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import InputLocation, stablehlo_to_tf_saved_model
import numpy as np
import tensorflow as tf
from tensorflow.python.tools import saved_model_utils

# Convert a stablehlo program, expressing a nn.Linear layer with constant values
# for weight and bias, to saved model.
# Convert a stablehlo program, expressing addition of an argument with constant
# values for weight and bias, to saved model.

mlir_module_string = """
MODULE_STRING = """
module @linearmodule attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = stablehlo.transpose %arg1, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[2,2]{0,1}"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = stablehlo.dot_general %arg2, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%2 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<2xf32>) -> tensor<2x2xf32>
%3 = stablehlo.add %1, %2 : tensor<2x2xf32>
return %3 : tensor<2x2xf32>\n
func.func @main(%bias: tensor<1xf32>, %weight: tensor<1xf32>, %arg0: tensor<1xf32>) -> tensor<1xf32> {
%0 = stablehlo.add %arg0, %weight: tensor<1xf32>
%1 = stablehlo.add %0, %bias : tensor<1xf32>
return %1 : tensor<1xf32>\n
}
}
"""

ctx = ir.Context()
stablehlo.register_dialect(ctx)
module = ir.Module.parse(mlir_module_string, ctx)
module = ir.Module.parse(MODULE_STRING, ctx)

input_locations = [
InputLocation.parameter(name='linear_layer.bias'),
InputLocation.parameter(name='linear_layer.weight'),
InputLocation.input_arg(position=0),
]
state_dict = {
'linear_layer.weight': np.array(
[[0.19075723, -0.13815854], [0.46516803, 0.12362058]], dtype='float32'
),
'linear_layer.bias': np.array([-0.37076423, 0.03301], dtype='float32'),
'linear_layer.weight': np.array([1], dtype='float32'),
'linear_layer.bias': np.array([2], dtype='float32'),
}


Expand All @@ -62,6 +58,6 @@
state_dict=state_dict,
)

saved_model = saved_model_utils.read_saved_model(saved_model_dir)
assert saved_model != None
print(f'StableHLO convertion to TF Saved Model seems to work!')
restored_model = tf.saved_model.load(saved_model_dir)
restored_result = restored_model.f(tf.constant([3], tf.float32))
assert np.allclose(restored_result[0], tf.constant([6], tf.float32))

0 comments on commit 07af947

Please sign in to comment.