diff --git a/.github/workflows/buildAndTestCMake.yml b/.github/workflows/buildAndTestCMake.yml index 2cdb0725c4..26cc7ab3da 100644 --- a/.github/workflows/buildAndTestCMake.yml +++ b/.github/workflows/buildAndTestCMake.yml @@ -93,6 +93,7 @@ jobs: - name: Build and Test StableHLO (with Python bindings) shell: bash run: | + pip install tensorflow-cpu ./build_tools/github_actions/ci_build_cmake.sh "$LLVM_BUILD_DIR" "$STABLEHLO_BUILD_DIR" env: CMAKE_BUILD_TYPE: Release diff --git a/README.md b/README.md index 5295b8d766..48f17e6bfa 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,8 @@ If you'd like to build the Python bindings, you'll need to install a few additional dependencies. ```sh -pip install install -r ./llvm-project/mlir/python/requirements.txt +pip install -r ./llvm-project/mlir/python/requirements.txt +pip install tensorflow-cpu # to convert stablehlo to tf-save-model ``` If you've built MLIR & StableHLO using the script above, the Python bindings diff --git a/stablehlo/integrations/python/CMakeLists.txt b/stablehlo/integrations/python/CMakeLists.txt index 142bef1127..506cce9454 100644 --- a/stablehlo/integrations/python/CMakeLists.txt +++ b/stablehlo/integrations/python/CMakeLists.txt @@ -46,6 +46,12 @@ declare_mlir_dialect_python_bindings( SOURCES dialects/stablehlo.py DIALECT_NAME stablehlo) +declare_mlir_python_sources(StablehloToSavedModelPythonSources + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/savedmodel" + SOURCES + stablehlo_to_tf_saved_model.py +) + declare_mlir_python_sources(VhloPythonSources) declare_mlir_python_sources(VhloPythonSources.Dialects ADD_TO_PARENT VhloPythonSources @@ -141,7 +147,14 @@ add_mlir_python_modules(StablehloUnifiedPythonModules VhloPythonExtensions COMMON_CAPI_LINK_LIBS StablehloUnifiedPythonCAPI - ) +) + +add_mlir_python_modules(StablehloToSavedModelAPI + ROOT_PREFIX "${STABLEHLO_BINARY_DIR}/python_packages/stablehlo/savedmodel" + INSTALL_PREFIX "python_packages/stablehlo/savedmodel" + DECLARED_SOURCES + StablehloToSavedModelPythonSources +) ################################################################################ # Tests diff --git a/stablehlo/integrations/python/savedmodel/README.md b/stablehlo/integrations/python/savedmodel/README.md new file mode 100644 index 0000000000..17d6bc6120 --- /dev/null +++ b/stablehlo/integrations/python/savedmodel/README.md @@ -0,0 +1,65 @@ +# Stablehlo to TF Saved Model + +`stablehlo_to_tf_saved_model.py` provides the following API to convert a +stablehlo program to TF saved model. + +```python +save_stablehlo_as_tf_saved_model(module: mlir.ir.Module, + saved_model_dir: os.PathLike, + input_locations: list = [], + state_dict: dict = {}, +) +``` +where + - `module`: An StableHLO module. + - `saved_model_dir`: Path to save TF saved-model artifacts. + - `input_locations`: Type of each input arguments: either it could be a + parameter with a name associated with it or a positional argument. The + parameters are generally the weights of a model with pre-trained constant + values. + - `state_dict`: Mapping of input parameters with constants. + + +For example, in order to export a simple +[torch.nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) +model to TF saved model using the above API, we need + +* `module` + +```mlir + module @linearmodule attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { + + func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = stablehlo.transpose %arg1, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[2,2]{0,1}"} : (tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = stablehlo.dot_general %arg2, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %2 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<2xf32>) -> tensor<2x2xf32> + %3 = stablehlo.add %1, %2 : tensor<2x2xf32> + return %3 : tensor<2x2xf32>\n + } +} +``` + +* `input_locations` + +```python +input_locations = [ + InputLocation.parameter(name='linear_layer.bias'), # bias parameter + InputLocation.parameter(name='linear_layer.weight'), # weight parameter + InputLocation.input_arg(position=0), # positional input argument +] +``` + +* `state_dict` + +``` +state_dict = { + 'linear_layer.weight': np.array( + [[0.19075723, -0.13815854], [0.46516803, 0.12362058]], dtype='float32' + ), + 'linear_layer.bias': np.array([-0.37076423, 0.03301], dtype='float32'), +} +``` + +Note that the API depends on the python bindings for +* StableHLO: Please refer to [README.md](https://github.com/openxla/stablehlo?tab=readme-ov-file#python). +* TensorFlow: Needs `pip install tensorflow-cpu`. diff --git a/stablehlo/integrations/python/savedmodel/stablehlo_to_tf_saved_model.py b/stablehlo/integrations/python/savedmodel/stablehlo_to_tf_saved_model.py new file mode 100644 index 0000000000..9cdd231c8c --- /dev/null +++ b/stablehlo/integrations/python/savedmodel/stablehlo_to_tf_saved_model.py @@ -0,0 +1,245 @@ +# Copyright 2024 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import dataclasses +from dataclasses import dataclass +import enum +import itertools +import logging +import os +from typing import Any, Dict, List +import mlir.dialects.stablehlo as stablehlo + +try: + import tensorflow as tf + from tensorflow.compiler.tf2xla.python import xla as tfxla +except ImportError: + logging.error( + 'This module is need tensorflow with xla support.\n' + 'Please install tensorflow with `pip install tf-nightly`.\n' + ) + raise + + +# Class to specifiy the input or output signature of a stablehlo function. +@dataclass +class VariableSignature: # either argument or parameters + shape: List[int] + dtype: str + dynamic_dims: List[int] = dataclasses.field(default_factory=list) + + +# Classes to specify the input type (parameter, argument) of a function. +class VariableType(enum.Enum): + INPUT_ARG = 'input_arg' + PARAMETER = 'parameter' + + +@dataclass +class InputLocation: + type_: VariableType + position: int = -1 + name: str = '' + + @classmethod + def parameter(cls, name: str): + return cls(type_=VariableType.PARAMETER, name=name) + + @classmethod + def input_arg(cls, position: int): + return cls(type_=VariableType.INPUT_ARG, position=position) + + +# Class to specify stablehlo input specification. +@dataclass +class StableHLOFuncSpec: + # stablehlo input signature + input_signature: List[VariableSignature] + # stablehlo output signature + output_signature: List[VariableSignature] + # annotations on stablehlo arguments as constants or variables + input_locations: List[InputLocation] + # serialized stablehlo format + bytecode: bytes + # map from constant arguments to constant values + state_dict: Dict[str, Any] + + +class StableHLOToTFSavedModel: + + def __init__(self, spec: StableHLOFuncSpec): + self.stablehlo_type_to_tf_type = { + 'i8': 'int8', + 'i16': 'i32', + 'i32': 'int32', + 'i64': 'int64', + 'f16': 'float16', + 'f32': 'float32', + 'f64': 'float64', + } + self.stablehlo_program = spec + + # Logic to convert stablehlo program to tf saved model + + def _get_shape_with_dynamic(self, signature: VariableSignature): + shape = copy.copy(signature.shape) + for i in signature.dynamic_dims: + shape[i] = None + return shape + + stablehlo_type_to_tf_type = { + 'i8': 'int8', + 'i16': 'i32', + 'i32': 'int32', + 'i64': 'int64', + 'f16': 'float16', + 'f32': 'float32', + 'f64': 'float64', + } + + def _extract_call_parameters(self, args): + call_args = [] + for loc in self.stablehlo_program.input_locations: + if str(loc.type_) == str(VariableType.PARAMETER): + call_args.append(self.stablehlo_program.state_dict[loc.name]) + else: + call_args.append(args[loc.position]) + return call_args + + def _wrap_as_tf_func(self): + def inner(*args): + Touts = [ + self.stablehlo_type_to_tf_type[sig.dtype] + for sig in self.stablehlo_program.output_signature + ] + Souts = [ + self._get_shape_with_dynamic(sig) + for sig in self.stablehlo_program.output_signature + ] + call_args = self._extract_call_parameters(args) + m = tfxla.call_module( + tuple(call_args), + version=5, + Tout=Touts, # dtype information + Sout=Souts, # Shape information + function_list=[], + module=self.stablehlo_program.bytecode, + ) + return m + + return inner + + def _make_tf_function(self): + return self._wrap_as_tf_func() + + def _make_input_signatures(self) -> List[tf.TensorSpec]: + input_pos_to_spec = { + loc.position: spec + for loc, spec in itertools.chain( + zip( + self.stablehlo_program.input_locations, + self.stablehlo_program.input_signature, + ), + [], + ) + if str(loc.type_) == str(VariableType.INPUT_ARG) + } + for i in range(len(input_pos_to_spec)): + spec = input_pos_to_spec[i] + shape = self._get_shape_with_dynamic(spec) + yield tf.TensorSpec( + shape=shape, + dtype=getattr( + tf, + self.stablehlo_type_to_tf_type[spec.dtype] + if spec.dtype in self.stablehlo_type_to_tf_type + else spec.dtype, + ), + name=f'args_{i}', + ) + + def to_tf_saved_model( + self, + path: os.PathLike, + serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + function_alias: str = '', + ) -> None: + tfm = tf.Module() + + self.stablehlo_program.state_dict = { + k: tf.Variable(v, trainable=False, name=k) + for k, v in self.stablehlo_program.state_dict.items() + } + + input_signatures = list(self._make_input_signatures()) + + tfm.f = tf.function( + self._make_tf_function(), input_signature=input_signatures + ) + tfm._variables = list(self.stablehlo_program.state_dict.values()) + signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)} + save_options = tf.saved_model.SaveOptions( + function_aliases={ + function_alias: tfm.f, + } + ) + tf.saved_model.save( + tfm, + path, + signatures=signatures, + options=save_options, + ) + + +# Top level API for stablehlo to tf saved model + + +def save_stablehlo_as_tf_saved_model( + module, + saved_model_dir: os.PathLike, + input_locations: list = [], + state_dict: dict = {}, +): + target = stablehlo.get_current_version() + input_signatures = [ + VariableSignature( + shape=input.shape, + dtype=str(input.element_type), + dynamic_dims=[], + ) + for input in module.body.operations[0].type.inputs + ] + output_signature = [ + VariableSignature( + shape=result.shape, + dtype=str(result.element_type), + dynamic_dims=[], + ) + for result in module.body.operations[0].type.results + ] + + if input_locations == []: + for i in range(len(module.body.operations[0].type.inputs)): + input_locations.append(InputLocation.input_arg(position=i)) + + shlo_spec = StableHLOFuncSpec( + input_signature=input_signatures, + output_signature=output_signature, + input_locations=input_locations, + state_dict=state_dict, + bytecode=stablehlo.serialize_portable_artifact(module, target), + ) + + StableHLOToTFSavedModel(shlo_spec).to_tf_saved_model(saved_model_dir) diff --git a/stablehlo/integrations/python/tests/CMakeLists.txt b/stablehlo/integrations/python/tests/CMakeLists.txt index 51cc62dd37..3c1737a470 100644 --- a/stablehlo/integrations/python/tests/CMakeLists.txt +++ b/stablehlo/integrations/python/tests/CMakeLists.txt @@ -20,6 +20,7 @@ add_custom_target(${test_name} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS StablehloUnifiedPythonModules + StablehloToSavedModelAPI ) add_dependencies(check-stablehlo-python ${test_name}) endfunction() @@ -27,6 +28,7 @@ endfunction() add_stablehlo_python_test(stablehlo-python-chlo chlo.py) add_stablehlo_python_test(stablehlo-python-smoketest smoketest.py) add_stablehlo_python_test(stablehlo-python-stablehlo stablehlo.py) +add_stablehlo_python_test(stablehlo-python-stablehlo-to-saved-model stablehlo_to_tf_saved_model_test.py) add_stablehlo_python_test(stablehlo-python-vhlo vhlo.py) add_dependencies(check-stablehlo-quick check-stablehlo-python) diff --git a/stablehlo/integrations/python/tests/stablehlo_to_tf_saved_model_test.py b/stablehlo/integrations/python/tests/stablehlo_to_tf_saved_model_test.py new file mode 100644 index 0000000000..36d21734cf --- /dev/null +++ b/stablehlo/integrations/python/tests/stablehlo_to_tf_saved_model_test.py @@ -0,0 +1,67 @@ +# Copyright 2024 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import mlir.dialects.stablehlo as stablehlo +import mlir.ir as ir +import numpy as np +from savedmodel.stablehlo_to_tf_saved_model import InputLocation, save_stablehlo_as_tf_saved_model +import tensorflow as tf +from tensorflow.python.tools import saved_model_utils + +# Convert a stablehlo program, expressing a nn.Linear layer with constant values +# for weight and bias, to saved model. + +mlir_module_string = """ + module @linearmodule attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { + + func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = stablehlo.transpose %arg1, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[2,2]{0,1}"} : (tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = stablehlo.dot_general %arg2, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %2 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<2xf32>) -> tensor<2x2xf32> + %3 = stablehlo.add %1, %2 : tensor<2x2xf32> + return %3 : tensor<2x2xf32>\n + } +} +""" + +ctx = ir.Context() +stablehlo.register_dialect(ctx) +module = ir.Module.parse(mlir_module_string, ctx) + +input_locations = [ + InputLocation.parameter(name='linear_layer.bias'), + InputLocation.parameter(name='linear_layer.weight'), + InputLocation.input_arg(position=0), +] +state_dict = { + 'linear_layer.weight': np.array( + [[0.19075723, -0.13815854], [0.46516803, 0.12362058]], dtype='float32' + ), + 'linear_layer.bias': np.array([-0.37076423, 0.03301], dtype='float32'), +} + + +saved_model_dir = tempfile.mkdtemp() +save_stablehlo_as_tf_saved_model( + module, + saved_model_dir=saved_model_dir, + input_locations=input_locations, + state_dict=state_dict, +) + +saved_model = saved_model_utils.read_saved_model(saved_model_dir) +assert saved_model != None +print(f'StableHLO convertion to TF Saved Model seems to work!')