diff --git a/.github/workflows/buildAndTestCMake.yml b/.github/workflows/buildAndTestCMake.yml index 2cdb0725c40..26cc7ab3dae 100644 --- a/.github/workflows/buildAndTestCMake.yml +++ b/.github/workflows/buildAndTestCMake.yml @@ -93,6 +93,7 @@ jobs: - name: Build and Test StableHLO (with Python bindings) shell: bash run: | + pip install tensorflow-cpu ./build_tools/github_actions/ci_build_cmake.sh "$LLVM_BUILD_DIR" "$STABLEHLO_BUILD_DIR" env: CMAKE_BUILD_TYPE: Release diff --git a/README.md b/README.md index 7c8ad8041e6..686bab63392 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,7 @@ If you'd like to build the Python bindings, you'll need to install a few additional dependencies. ```sh -pip install install -r ./llvm-project/mlir/python/requirements.txt +pip install -r ./llvm-project/mlir/python/requirements.txt ``` If you've built MLIR & StableHLO using the script above, the Python bindings @@ -165,6 +165,14 @@ We also make nightly wheels available on our GitHub Releases page. pip install stablehlo -f https://github.com/openxla/stablehlo/releases/expanded_assets/dev-wheels ``` +## StableHLO to TensorFLow SavedModel + +This repository offers tooling for the conversion of a StableHLO program, +including its metadata (representing trained weights and biases), into a +TensorFlow SavedModel. Please refer to +[README.md](https://github.com/openxla/stablehlo/blob/main/stablehlo/integrations/python/stablehlo/savedmodel/README.md) +for details. + ## Community Building an amazing portability layer between ML frameworks and ML compilers diff --git a/stablehlo/integrations/python/CMakeLists.txt b/stablehlo/integrations/python/CMakeLists.txt index 142bef11276..f1c9fffb253 100644 --- a/stablehlo/integrations/python/CMakeLists.txt +++ b/stablehlo/integrations/python/CMakeLists.txt @@ -46,6 +46,13 @@ declare_mlir_dialect_python_bindings( SOURCES dialects/stablehlo.py DIALECT_NAME stablehlo) +declare_mlir_python_sources(StablehloToSavedModelPythonSources + ADD_TO_PARENT StablehloPythonSources + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" + SOURCES + stablehlo/savedmodel/stablehlo_to_tf_saved_model.py +) + declare_mlir_python_sources(VhloPythonSources) declare_mlir_python_sources(VhloPythonSources.Dialects ADD_TO_PARENT VhloPythonSources @@ -141,7 +148,7 @@ add_mlir_python_modules(StablehloUnifiedPythonModules VhloPythonExtensions COMMON_CAPI_LINK_LIBS StablehloUnifiedPythonCAPI - ) +) ################################################################################ # Tests diff --git a/stablehlo/integrations/python/stablehlo/savedmodel/README.md b/stablehlo/integrations/python/stablehlo/savedmodel/README.md new file mode 100644 index 00000000000..fe6970f20ef --- /dev/null +++ b/stablehlo/integrations/python/stablehlo/savedmodel/README.md @@ -0,0 +1,84 @@ +# StableHLO to Tensorflow SavedModel + +`stablehlo_to_tf_saved_model.py` provides the following API to convert a +stablehlo program to TensorFlow SavedModel. + +```python +stablehlo_to_tf_saved_model( + module: mlir.ir.Module, + saved_model_dir: os.PathLike, + input_locations: list = [], + state_dict: dict = {}, +) +``` + +where + +* `module`: An StableHLO module. +* `saved_model_dir`: Path to save TF saved-model artifacts. +* `target_version`: Serialization version of StableHLO. Default: current + stablehlo version. +* `input_locations`: List of input argument types: either it could be a + parameter with a name associated with it or a positional argument. The + parameters are generally the weights or biases of a model with pre-trained + constant values. Default: empty list. +* `state_dict`: Mapping of named input parameters with constants. Default: + empty list. + +For example, to export a simple +[torch.nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) +model to TensorFlow SavedModel using the above API, we need the following +arguments to the API. + +* `module` + +```mlir + module @linearmodule attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { + + func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = stablehlo.transpose %arg1, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[2,2]{0,1}"} : (tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = stablehlo.dot_general %arg2, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %2 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<2xf32>) -> tensor<2x2xf32> + %3 = stablehlo.add %1, %2 : tensor<2x2xf32> + return %3 : tensor<2x2xf32>\n + } +} +``` + +* `input_locations` + +```python +input_locations = [ + InputLocation.parameter(name='linear_layer.bias'), # bias parameter + InputLocation.parameter(name='linear_layer.weight'), # weight parameter + InputLocation.input_arg(position=0), # positional input argument +] +``` + +* `state_dict` + +```python +state_dict = { + 'linear_layer.weight': np.array( + [[0.19075723, -0.13815854], [0.46516803, 0.12362058]], dtype='float32' + ), + 'linear_layer.bias': np.array([-0.37076423, 0.03301], dtype='float32'), +} +``` + +## Python package dependencies + +The above API depends on + +* MLIR Python bindings: To express an MLIR module. +* TensorFlow: Only used to work with TF saved model artifacts. + +## Testing + +The repository provides [stablehlo_to_tf_saved_model_test.py](https://github.com/openxla/stablehlo/blob/main/stablehlo/integrations/python/tests/[stablehlo_to_tf_saved_model_test.py) +to test the API and here is how to run it. + +```sh +pip install tensorflow-cpu +PYTHONPATH="./build/python_packages/stablehlo" python3 ../../tests/stablehlo_to_tf_saved_model_test.py +``` diff --git a/stablehlo/integrations/python/stablehlo/savedmodel/stablehlo_to_tf_saved_model.py b/stablehlo/integrations/python/stablehlo/savedmodel/stablehlo_to_tf_saved_model.py new file mode 100644 index 00000000000..534c045139c --- /dev/null +++ b/stablehlo/integrations/python/stablehlo/savedmodel/stablehlo_to_tf_saved_model.py @@ -0,0 +1,244 @@ +# Copyright 2024 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import dataclasses +from dataclasses import dataclass +import enum +import itertools +import logging +import os +from typing import Any, Dict, List +import mlir.dialects.stablehlo as stablehlo +import mlir.ir as ir + +try: + import tensorflow as tf + from tensorflow.compiler.tf2xla.python import xla as tfxla +except ImportError: + logging.error( + 'This module is need tensorflow with xla support.\n' + 'Please install tensorflow with `pip install tf-nightly`.\n' + ) + raise + + +# Class to specifiy the input or output signature of a stablehlo function. +@dataclass +class VariableSignature: # either argument or parameters + shape: List[int] + dtype: str + dynamic_dims: List[int] = dataclasses.field(default_factory=list) + + +# Classes to specify the input type (parameter, argument) of a function. +class VariableType(enum.Enum): + INPUT_ARG = 'input_arg' + PARAMETER = 'parameter' + + +@dataclass +class InputLocation: + type_: VariableType + position: int = -1 + name: str = '' + + @classmethod + def parameter(cls, name: str): + return cls(type_=VariableType.PARAMETER, name=name) + + @classmethod + def input_arg(cls, position: int): + return cls(type_=VariableType.INPUT_ARG, position=position) + + +# Class to specify stablehlo input specification. +@dataclass +class StableHLOFuncSpec: + # stablehlo input signature + input_signature: List[VariableSignature] + # stablehlo output signature + output_signature: List[VariableSignature] + # annotations on stablehlo arguments as constants or variables + input_locations: List[InputLocation] + # serialized stablehlo format + bytecode: bytes + # map from constant arguments to constant values + state_dict: Dict[str, Any] + + +class StableHLOToTFSavedModel: + + def __init__(self, spec: StableHLOFuncSpec): + self.stablehlo_type_to_tf_type = { + 'i1': 'bool', + 'i8': 'int8', + 'i16': 'i32', + 'i32': 'int32', + 'i64': 'int64', + 'f16': 'float16', + 'f32': 'float32', + 'f64': 'float64', + 'bf16': 'bfloat16', + } + self.stablehlo_program = spec + + # Logic to convert stablehlo program to tf saved model + + def _get_shape_with_dynamic(self, signature: VariableSignature): + shape = copy.copy(signature.shape) + for i in signature.dynamic_dims: + shape[i] = None + return shape + + def _extract_call_parameters(self, args): + call_args = [] + for loc in self.stablehlo_program.input_locations: + if str(loc.type_) == str(VariableType.PARAMETER): + call_args.append(self.stablehlo_program.state_dict[loc.name]) + else: + call_args.append(args[loc.position]) + return call_args + + def _wrap_as_tf_func(self): + def inner(*args): + try: + Touts = [ + self.stablehlo_type_to_tf_type[sig.dtype] + for sig in self.stablehlo_program.output_signature + ] + except KeyError as e: + raise KeyError(f'TensorFlow type mapping not found: {e}') from None + + Souts = [ + self._get_shape_with_dynamic(sig) + for sig in self.stablehlo_program.output_signature + ] + call_args = self._extract_call_parameters(args) + m = tfxla.call_module( + tuple(call_args), + version=5, + Tout=Touts, # dtype information + Sout=Souts, # Shape information + function_list=[], + module=self.stablehlo_program.bytecode, + ) + return m + + return inner + + def _make_tf_function(self): + return self._wrap_as_tf_func() + + def _make_input_signatures(self) -> List[tf.TensorSpec]: + input_pos_to_spec = { + loc.position: spec + for loc, spec in itertools.chain( + zip( + self.stablehlo_program.input_locations, + self.stablehlo_program.input_signature, + ), + [], + ) + if str(loc.type_) == str(VariableType.INPUT_ARG) + } + for i in range(len(input_pos_to_spec)): + spec = input_pos_to_spec[i] + shape = self._get_shape_with_dynamic(spec) + try: + dtype = getattr(tf, self.stablehlo_type_to_tf_type[spec.dtype]) + except KeyError as e: + raise KeyError( + f'TensorFlow type mapping not found for {spec.dtype}: {e}' + ) from None + + yield tf.TensorSpec( + shape=shape, + dtype=dtype, + name=f'args_{i}', + ) + + def to_tf_saved_model( + self, + path: os.PathLike, + serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + function_alias: str = '', + ) -> None: + tfm = tf.Module() + + self.stablehlo_program.state_dict = { + k: tf.Variable(v, trainable=False, name=k) + for k, v in self.stablehlo_program.state_dict.items() + } + + input_signatures = list(self._make_input_signatures()) + + tfm.f = tf.function( + self._make_tf_function(), input_signature=input_signatures + ) + tfm._variables = list(self.stablehlo_program.state_dict.values()) + signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)} + save_options = tf.saved_model.SaveOptions( + function_aliases={ + function_alias: tfm.f, + } + ) + tf.saved_model.save( + tfm, + path, + signatures=signatures, + options=save_options, + ) + + +# Top level API for stablehlo to tf saved model + + +def stablehlo_to_tf_saved_model( + module: ir.Module, + saved_model_dir: os.PathLike, + target_version: str = stablehlo.get_current_version(), + input_locations: list = [], + state_dict: dict = {}, +): + input_signatures = [ + VariableSignature( + shape=input.shape, + dtype=str(input.element_type), + dynamic_dims=[], + ) + for input in module.body.operations[0].type.inputs + ] + output_signature = [ + VariableSignature( + shape=result.shape, + dtype=str(result.element_type), + dynamic_dims=[], + ) + for result in module.body.operations[0].type.results + ] + + if input_locations == []: + for i in range(len(module.body.operations[0].type.inputs)): + input_locations.append(InputLocation.input_arg(position=i)) + + shlo_spec = StableHLOFuncSpec( + input_signature=input_signatures, + output_signature=output_signature, + input_locations=input_locations, + state_dict=state_dict, + bytecode=stablehlo.serialize_portable_artifact(module, target_version), + ) + + StableHLOToTFSavedModel(shlo_spec).to_tf_saved_model(saved_model_dir) diff --git a/stablehlo/integrations/python/tests/CMakeLists.txt b/stablehlo/integrations/python/tests/CMakeLists.txt index 51cc62dd372..e6d5f58709d 100644 --- a/stablehlo/integrations/python/tests/CMakeLists.txt +++ b/stablehlo/integrations/python/tests/CMakeLists.txt @@ -27,6 +27,7 @@ endfunction() add_stablehlo_python_test(stablehlo-python-chlo chlo.py) add_stablehlo_python_test(stablehlo-python-smoketest smoketest.py) add_stablehlo_python_test(stablehlo-python-stablehlo stablehlo.py) +add_stablehlo_python_test(stablehlo-python-stablehlo-to-saved-model stablehlo_to_tf_saved_model_test.py) add_stablehlo_python_test(stablehlo-python-vhlo vhlo.py) add_dependencies(check-stablehlo-quick check-stablehlo-python) diff --git a/stablehlo/integrations/python/tests/stablehlo_to_tf_saved_model_test.py b/stablehlo/integrations/python/tests/stablehlo_to_tf_saved_model_test.py new file mode 100644 index 00000000000..1740fc87180 --- /dev/null +++ b/stablehlo/integrations/python/tests/stablehlo_to_tf_saved_model_test.py @@ -0,0 +1,63 @@ +# Copyright 2024 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import mlir.dialects.stablehlo as stablehlo +import mlir.ir as ir +from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import InputLocation, stablehlo_to_tf_saved_model +import numpy as np +import tensorflow as tf +from tensorflow.python.tools import saved_model_utils + +# Convert a stablehlo program, expressing addition of an argument with constant +# values for weight and bias, to saved model. + +MODULE_STRING = """ +module @linearmodule attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { + + func.func @main(%bias: tensor<1xf32>, %weight: tensor<1xf32>, %arg0: tensor<1xf32>) -> tensor<1xf32> { + %0 = stablehlo.add %arg0, %weight: tensor<1xf32> + %1 = stablehlo.add %0, %bias : tensor<1xf32> + return %1 : tensor<1xf32>\n + } +} +""" + +ctx = ir.Context() +stablehlo.register_dialect(ctx) +module = ir.Module.parse(MODULE_STRING, ctx) + +input_locations = [ + InputLocation.parameter(name='linear_layer.bias'), + InputLocation.parameter(name='linear_layer.weight'), + InputLocation.input_arg(position=0), +] +state_dict = { + 'linear_layer.weight': np.array([1], dtype='float32'), + 'linear_layer.bias': np.array([2], dtype='float32'), +} + + +saved_model_dir = tempfile.mkdtemp() +stablehlo_to_tf_saved_model( + module, + saved_model_dir=saved_model_dir, + input_locations=input_locations, + state_dict=state_dict, +) + +restored_model = tf.saved_model.load(saved_model_dir) +restored_result = restored_model.f(tf.constant([3], tf.float32)) +assert np.allclose(restored_result[0], tf.constant([6], tf.float32))