Skip to content

Commit

Permalink
Stablehlo to TF saved-model (#2157)
Browse files Browse the repository at this point in the history
**Draft PR**

Provides an API to convert stablehlo to tf saved model. Added a test to
convert a simple linear model.

Things to me done:

- [x] Provide documentation of the API
- [x] Test should used the API from the preinstalled stablehlo python
package
- [x] Make sure the API is provoded as part of stablehlo python wheel.
  • Loading branch information
sdasgup3 authored Apr 23, 2024
1 parent b13e8da commit ace3f39
Show file tree
Hide file tree
Showing 7 changed files with 410 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/buildAndTestCMake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ jobs:
- name: Build and Test StableHLO (with Python bindings)
shell: bash
run: |
pip install tensorflow-cpu
./build_tools/github_actions/ci_build_cmake.sh "$LLVM_BUILD_DIR" "$STABLEHLO_BUILD_DIR"
env:
CMAKE_BUILD_TYPE: Release
Expand Down
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ If you'd like to build the Python bindings, you'll need to install a few
additional dependencies.

```sh
pip install install -r ./llvm-project/mlir/python/requirements.txt
pip install -r ./llvm-project/mlir/python/requirements.txt
```

If you've built MLIR & StableHLO using the script above, the Python bindings
Expand All @@ -165,6 +165,14 @@ We also make nightly wheels available on our GitHub Releases page.
pip install stablehlo -f https://github.com/openxla/stablehlo/releases/expanded_assets/dev-wheels
```

## StableHLO to TensorFLow SavedModel

This repository offers tooling for the conversion of a StableHLO program,
including its metadata (representing trained weights and biases), into a
TensorFlow SavedModel. Please refer to
[README.md](https://github.com/openxla/stablehlo/blob/main/stablehlo/integrations/python/stablehlo/savedmodel/README.md)
for details.

## Community

Building an amazing portability layer between ML frameworks and ML compilers
Expand Down
9 changes: 8 additions & 1 deletion stablehlo/integrations/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ declare_mlir_dialect_python_bindings(
SOURCES dialects/stablehlo.py
DIALECT_NAME stablehlo)

declare_mlir_python_sources(StablehloToSavedModelPythonSources
ADD_TO_PARENT StablehloPythonSources
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
SOURCES
stablehlo/savedmodel/stablehlo_to_tf_saved_model.py
)

declare_mlir_python_sources(VhloPythonSources)
declare_mlir_python_sources(VhloPythonSources.Dialects
ADD_TO_PARENT VhloPythonSources
Expand Down Expand Up @@ -141,7 +148,7 @@ add_mlir_python_modules(StablehloUnifiedPythonModules
VhloPythonExtensions
COMMON_CAPI_LINK_LIBS
StablehloUnifiedPythonCAPI
)
)

################################################################################
# Tests
Expand Down
84 changes: 84 additions & 0 deletions stablehlo/integrations/python/stablehlo/savedmodel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# StableHLO to Tensorflow SavedModel

`stablehlo_to_tf_saved_model.py` provides the following API to convert a
stablehlo program to TensorFlow SavedModel.

```python
stablehlo_to_tf_saved_model(
module: mlir.ir.Module,
saved_model_dir: os.PathLike,
input_locations: list = [],
state_dict: dict = {},
)
```

where

* `module`: An StableHLO module.
* `saved_model_dir`: Path to save TF saved-model artifacts.
* `target_version`: Serialization version of StableHLO. Default: current
stablehlo version.
* `input_locations`: List of input argument types: either it could be a
parameter with a name associated with it or a positional argument. The
parameters are generally the weights or biases of a model with pre-trained
constant values. Default: empty list.
* `state_dict`: Mapping of named input parameters with constants. Default:
empty list.

For example, to export a simple
[torch.nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)
model to TensorFlow SavedModel using the above API, we need the following
arguments to the API.

* `module`

```mlir
module @linearmodule attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = stablehlo.transpose %arg1, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[2,2]{0,1}"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = stablehlo.dot_general %arg2, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%2 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<2xf32>) -> tensor<2x2xf32>
%3 = stablehlo.add %1, %2 : tensor<2x2xf32>
return %3 : tensor<2x2xf32>\n
}
}
```

* `input_locations`

```python
input_locations = [
InputLocation.parameter(name='linear_layer.bias'), # bias parameter
InputLocation.parameter(name='linear_layer.weight'), # weight parameter
InputLocation.input_arg(position=0), # positional input argument
]
```

* `state_dict`

```python
state_dict = {
'linear_layer.weight': np.array(
[[0.19075723, -0.13815854], [0.46516803, 0.12362058]], dtype='float32'
),
'linear_layer.bias': np.array([-0.37076423, 0.03301], dtype='float32'),
}
```

## Python package dependencies

The above API depends on

* MLIR Python bindings: To express an MLIR module.
* TensorFlow: Only used to work with TF saved model artifacts.

## Testing

The repository provides [stablehlo_to_tf_saved_model_test.py](https://github.com/openxla/stablehlo/blob/main/stablehlo/integrations/python/tests/[stablehlo_to_tf_saved_model_test.py)
to test the API and here is how to run it.

```sh
pip install tensorflow-cpu
PYTHONPATH="./build/python_packages/stablehlo" python3 ../../tests/stablehlo_to_tf_saved_model_test.py
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# Copyright 2024 The StableHLO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import dataclasses
from dataclasses import dataclass
import enum
import itertools
import logging
import os
from typing import Any, Dict, List
import mlir.dialects.stablehlo as stablehlo
import mlir.ir as ir

try:
import tensorflow as tf
from tensorflow.compiler.tf2xla.python import xla as tfxla
except ImportError:
logging.error(
'This module is need tensorflow with xla support.\n'
'Please install tensorflow with `pip install tf-nightly`.\n'
)
raise


# Class to specifiy the input or output signature of a stablehlo function.
@dataclass
class VariableSignature: # either argument or parameters
shape: List[int]
dtype: str
dynamic_dims: List[int] = dataclasses.field(default_factory=list)


# Classes to specify the input type (parameter, argument) of a function.
class VariableType(enum.Enum):
INPUT_ARG = 'input_arg'
PARAMETER = 'parameter'


@dataclass
class InputLocation:
type_: VariableType
position: int = -1
name: str = ''

@classmethod
def parameter(cls, name: str):
return cls(type_=VariableType.PARAMETER, name=name)

@classmethod
def input_arg(cls, position: int):
return cls(type_=VariableType.INPUT_ARG, position=position)


# Class to specify stablehlo input specification.
@dataclass
class StableHLOFuncSpec:
# stablehlo input signature
input_signature: List[VariableSignature]
# stablehlo output signature
output_signature: List[VariableSignature]
# annotations on stablehlo arguments as constants or variables
input_locations: List[InputLocation]
# serialized stablehlo format
bytecode: bytes
# map from constant arguments to constant values
state_dict: Dict[str, Any]


class StableHLOToTFSavedModel:

def __init__(self, spec: StableHLOFuncSpec):
self.stablehlo_type_to_tf_type = {
'i1': 'bool',
'i8': 'int8',
'i16': 'i32',
'i32': 'int32',
'i64': 'int64',
'f16': 'float16',
'f32': 'float32',
'f64': 'float64',
'bf16': 'bfloat16',
}
self.stablehlo_program = spec

# Logic to convert stablehlo program to tf saved model

def _get_shape_with_dynamic(self, signature: VariableSignature):
shape = copy.copy(signature.shape)
for i in signature.dynamic_dims:
shape[i] = None
return shape

def _extract_call_parameters(self, args):
call_args = []
for loc in self.stablehlo_program.input_locations:
if str(loc.type_) == str(VariableType.PARAMETER):
call_args.append(self.stablehlo_program.state_dict[loc.name])
else:
call_args.append(args[loc.position])
return call_args

def _wrap_as_tf_func(self):
def inner(*args):
try:
Touts = [
self.stablehlo_type_to_tf_type[sig.dtype]
for sig in self.stablehlo_program.output_signature
]
except KeyError as e:
raise KeyError(f'TensorFlow type mapping not found: {e}') from None

Souts = [
self._get_shape_with_dynamic(sig)
for sig in self.stablehlo_program.output_signature
]
call_args = self._extract_call_parameters(args)
m = tfxla.call_module(
tuple(call_args),
version=5,
Tout=Touts, # dtype information
Sout=Souts, # Shape information
function_list=[],
module=self.stablehlo_program.bytecode,
)
return m

return inner

def _make_tf_function(self):
return self._wrap_as_tf_func()

def _make_input_signatures(self) -> List[tf.TensorSpec]:
input_pos_to_spec = {
loc.position: spec
for loc, spec in itertools.chain(
zip(
self.stablehlo_program.input_locations,
self.stablehlo_program.input_signature,
),
[],
)
if str(loc.type_) == str(VariableType.INPUT_ARG)
}
for i in range(len(input_pos_to_spec)):
spec = input_pos_to_spec[i]
shape = self._get_shape_with_dynamic(spec)
try:
dtype = getattr(tf, self.stablehlo_type_to_tf_type[spec.dtype])
except KeyError as e:
raise KeyError(
f'TensorFlow type mapping not found for {spec.dtype}: {e}'
) from None

yield tf.TensorSpec(
shape=shape,
dtype=dtype,
name=f'args_{i}',
)

def to_tf_saved_model(
self,
path: os.PathLike,
serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
function_alias: str = '',
) -> None:
tfm = tf.Module()

self.stablehlo_program.state_dict = {
k: tf.Variable(v, trainable=False, name=k)
for k, v in self.stablehlo_program.state_dict.items()
}

input_signatures = list(self._make_input_signatures())

tfm.f = tf.function(
self._make_tf_function(), input_signature=input_signatures
)
tfm._variables = list(self.stablehlo_program.state_dict.values())
signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)}
save_options = tf.saved_model.SaveOptions(
function_aliases={
function_alias: tfm.f,
}
)
tf.saved_model.save(
tfm,
path,
signatures=signatures,
options=save_options,
)


# Top level API for stablehlo to tf saved model


def stablehlo_to_tf_saved_model(
module: ir.Module,
saved_model_dir: os.PathLike,
target_version: str = stablehlo.get_current_version(),
input_locations: list = [],
state_dict: dict = {},
):
input_signatures = [
VariableSignature(
shape=input.shape,
dtype=str(input.element_type),
dynamic_dims=[],
)
for input in module.body.operations[0].type.inputs
]
output_signature = [
VariableSignature(
shape=result.shape,
dtype=str(result.element_type),
dynamic_dims=[],
)
for result in module.body.operations[0].type.results
]

if input_locations == []:
for i in range(len(module.body.operations[0].type.inputs)):
input_locations.append(InputLocation.input_arg(position=i))

shlo_spec = StableHLOFuncSpec(
input_signature=input_signatures,
output_signature=output_signature,
input_locations=input_locations,
state_dict=state_dict,
bytecode=stablehlo.serialize_portable_artifact(module, target_version),
)

StableHLOToTFSavedModel(shlo_spec).to_tf_saved_model(saved_model_dir)
1 change: 1 addition & 0 deletions stablehlo/integrations/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ endfunction()
add_stablehlo_python_test(stablehlo-python-chlo chlo.py)
add_stablehlo_python_test(stablehlo-python-smoketest smoketest.py)
add_stablehlo_python_test(stablehlo-python-stablehlo stablehlo.py)
add_stablehlo_python_test(stablehlo-python-stablehlo-to-saved-model stablehlo_to_tf_saved_model_test.py)
add_stablehlo_python_test(stablehlo-python-vhlo vhlo.py)

add_dependencies(check-stablehlo-quick check-stablehlo-python)
Loading

0 comments on commit ace3f39

Please sign in to comment.