-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
313 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
231 changes: 231 additions & 0 deletions
231
stablehlo/integrations/python/savedmodel/stablehlo_to_tf_saved_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
# Copyright 2024 The StableHLO Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import copy | ||
import dataclasses | ||
from dataclasses import dataclass | ||
import enum | ||
import itertools | ||
import logging | ||
import os | ||
import shutil | ||
from typing import Any, Dict, List, Mapping, Optional, Tuple | ||
import mlir | ||
import mlir.dialects.stablehlo as stablehlo | ||
from mlir.ir import Context, InsertionPoint, Location, Module | ||
import numpy as np | ||
|
||
try: | ||
import tensorflow as tf | ||
from tensorflow.compiler.tf2xla.python import xla as tfxla | ||
except ImportError: | ||
logging.error( | ||
'This module is need tensorflow with xla support.\n' | ||
'Please install tensorflow with `pip install tf-nightly`.\n' | ||
) | ||
raise | ||
|
||
|
||
# Stablehlo spec classes to express the stablehlo input specification | ||
@dataclass | ||
class VariableSignature: # either argument or parameters | ||
shape: List[int] | ||
dtype: str | ||
dynamic_dims: List[int] = dataclasses.field(default_factory=list) | ||
|
||
|
||
class VariableType(enum.Enum): | ||
INPUT_ARG = 'input_arg' | ||
PARAMETER = 'parameter' | ||
CONSTANT = 'constant' | ||
|
||
|
||
@dataclass | ||
class InputLocation: | ||
type_: VariableType | ||
position: int = -1 | ||
name: str = '' | ||
|
||
@classmethod | ||
def parameter(cls, name: str): | ||
return cls(type_=VariableType.PARAMETER, name=name) | ||
|
||
@classmethod | ||
def input_arg(cls, position: int): | ||
return cls(type_=VariableType.INPUT_ARG, position=position) | ||
|
||
@classmethod | ||
def constant(cls, position): | ||
return cls(type_=VariableType.CONSTANT, position=position) | ||
|
||
|
||
@dataclass | ||
class StableHLOFuncSpec: | ||
# stablehlo input signature | ||
input_signature: List[VariableSignature] | ||
# stablehlo output signature | ||
output_signature: List[VariableSignature] | ||
# annotations on stablehlo arguments as constants or variables | ||
input_locations: List[InputLocation] | ||
# serialized stablehlo format | ||
bytecode: bytes | ||
# map from constant arguments to constant values | ||
state_dict: Dict[str, Any] | ||
|
||
|
||
# Logic to convert stablehlo program to tf saved model | ||
|
||
|
||
def _get_shape_with_dynamic(signature: VariableSignature): | ||
shape = copy.copy(signature.shape) | ||
for i in signature.dynamic_dims: | ||
shape[i] = None | ||
return shape | ||
|
||
|
||
def _extract_call_parameters(args, shlo_spec): | ||
call_args = [] | ||
for loc in shlo_spec.input_locations: | ||
if loc.type_ == VariableType.PARAMETER: | ||
call_args.append(shlo_spec.state_dict[loc.name]) | ||
elif loc.type_ == VariableType.CONSTANT: | ||
call_args.append(shlo_spec.additional_constants[loc.position]) | ||
else: | ||
call_args.append(args[loc.position]) | ||
return call_args | ||
|
||
|
||
def _wrap_as_tf_func(shlo_spec): | ||
def inner(*args): | ||
Touts = [sig.dtype for sig in shlo_spec.output_signature] | ||
Souts = [_get_shape_with_dynamic(sig) for sig in shlo_spec.output_signature] | ||
call_args = _extract_call_parameters(args, shlo_spec) | ||
m = tfxla.call_module( | ||
tuple(call_args), | ||
version=5, | ||
Tout=Touts, # dtype information | ||
Sout=Souts, # Shape information | ||
function_list=[], | ||
module=shlo_spec.bytecode, | ||
) | ||
return m | ||
|
||
return inner | ||
|
||
|
||
def make_tf_function(shlo_spec: StableHLOFuncSpec): | ||
return _wrap_as_tf_func(shlo_spec) | ||
|
||
|
||
def _make_input_signatures(shlo_spec: StableHLOFuncSpec) -> List[tf.TensorSpec]: | ||
input_pos_to_spec = { | ||
loc.position: spec | ||
for loc, spec in itertools.chain( | ||
zip(shlo_spec.input_locations, shlo_spec.input_signature), [] | ||
) | ||
if loc.type_ == VariableType.INPUT_ARG | ||
} | ||
primitive_type_to_tf_type = {'int': 'int32', 'float': 'float32'} | ||
for i in range(len(input_pos_to_spec)): | ||
spec = input_pos_to_spec[i] | ||
shape = _get_shape_with_dynamic(spec) | ||
yield tf.TensorSpec( | ||
shape=shape, | ||
dtype=getattr( | ||
tf, | ||
primitive_type_to_tf_type[spec.dtype] | ||
if spec.dtype in primitive_type_to_tf_type | ||
else spec.dtype, | ||
), | ||
name=f'args_{i}', | ||
) | ||
|
||
|
||
def save_stablehlo_graph_as_tf( | ||
stablehlo_program: StableHLOFuncSpec, | ||
path: os.PathLike, | ||
serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, | ||
function_alias: str = '', | ||
) -> None: | ||
bundle = copy.deepcopy(stablehlo_program) | ||
tfm = tf.Module() | ||
|
||
bundle.state_dict = { | ||
k: tf.Variable(v, trainable=False, name=k) | ||
for k, v in bundle.state_dict.items() | ||
} | ||
|
||
input_signatures = list(_make_input_signatures(bundle)) | ||
|
||
tfm.f = tf.function( | ||
make_tf_function(bundle), input_signature=input_signatures | ||
) | ||
tfm._variables = list(bundle.state_dict.values()) | ||
signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)} | ||
save_options = tf.saved_model.SaveOptions( | ||
function_aliases={ | ||
function_alias: tfm.f, | ||
} | ||
) | ||
tf.saved_model.save( | ||
tfm, | ||
path, | ||
signatures=signatures, | ||
options=save_options, | ||
) | ||
|
||
|
||
# Top level API for stablehlo to tf saved model | ||
|
||
import numpy as np | ||
|
||
|
||
def save_stablehlo_as_tf_saved_model( | ||
module, | ||
saved_model_dir: os.PathLike, | ||
input_locations: list = [], | ||
state_dict: dict = {}, | ||
): | ||
target = stablehlo.get_current_version() | ||
stablehlo_type_to_torch_type = {'i32': 'int32', 'f32': 'float32'} | ||
input_signatures = [ | ||
VariableSignature( | ||
shape=input.shape, | ||
dtype=stablehlo_type_to_torch_type[str(input.element_type)], | ||
dynamic_dims=[], | ||
) | ||
for input in module.body.operations[0].type.inputs | ||
] | ||
output_signature = [ | ||
VariableSignature( | ||
shape=result.shape, | ||
dtype=stablehlo_type_to_torch_type[str(result.element_type)], | ||
dynamic_dims=[], | ||
) | ||
for result in module.body.operations[0].type.results | ||
] | ||
|
||
if input_locations == []: | ||
for i in range(len(module.body.operations[0].type.inputs)): | ||
input_locations.append(InputLocation.input_arg(position=i)) | ||
|
||
shlo_spec = StableHLOFuncSpec( | ||
input_signature=input_signatures, | ||
output_signature=output_signature, | ||
input_locations=input_locations, | ||
state_dict=state_dict, | ||
bytecode=stablehlo.serialize_portable_artifact(module, target), | ||
) | ||
|
||
save_stablehlo_graph_as_tf(shlo_spec, saved_model_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
70 changes: 70 additions & 0 deletions
70
stablehlo/integrations/python/tests/stablehlo_to_tf_saved_model_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright 2024 The StableHLO Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import mlir.dialects.stablehlo as stablehlo | ||
import mlir.ir as ir | ||
from mlir.stablehlo_to_tf_saved_model import InputLocation, save_stablehlo_as_tf_saved_model | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
|
||
# Convert a stablehlo program, expressing a nn.Lienar layer with constant values | ||
# for weight and bias, to saved model | ||
|
||
mlir_module_string = """ | ||
module @IrToHlo.12 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { | ||
func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> { | ||
%0 = stablehlo.transpose %arg1, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[2,2]{0,1}"} : (tensor<2x2xf32>) -> tensor<2x2xf32> | ||
%1 = stablehlo.dot_general %arg2, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> | ||
%2 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<2xf32>) -> tensor<2x2xf32> | ||
%3 = stablehlo.add %1, %2 : tensor<2x2xf32> | ||
return %3 : tensor<2x2xf32>\n | ||
} | ||
} | ||
""" | ||
|
||
ctx = ir.Context() | ||
stablehlo.register_dialect(ctx) | ||
module = ir.Module.parse(mlir_module_string, ctx) | ||
|
||
input_locations = [ | ||
InputLocation.parameter(name='linear_layer.bias'), | ||
InputLocation.parameter(name='linear_layer.weight'), | ||
InputLocation.input_arg(position=0), | ||
] | ||
state_dict = { | ||
'linear_layer.weight': np.array( | ||
[[0.19075723, -0.13815854], [0.46516803, 0.12362058]], dtype='float32' | ||
), | ||
'linear_layer.bias': np.array([-0.37076423, 0.03301], dtype='float32'), | ||
} | ||
|
||
|
||
saved_model_dir = 'example_saved_model' | ||
save_stablehlo_as_tf_saved_model( | ||
module, | ||
saved_model_dir=saved_model_dir, | ||
input_locations=input_locations, | ||
state_dict=state_dict, | ||
) | ||
tf_exported = tf.saved_model.load(saved_model_dir) | ||
tf_mlir = tf.mlir.experimental.convert_saved_model(saved_model_dir, '') | ||
print('\n TF Model\n', tf_mlir) | ||
|
||
try: | ||
os.remove(saved_model_dir) | ||
print(f"File '{saved_model_dir}' deleted successfully") | ||
except FileNotFoundError: | ||
print(f"Error: File '{saved_model_dir}' not found") |