-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add bfloat16 to float16 converter in onnxruntime specific rewriter (#…
…1492) Previous to this PR, as numpy currently does not support bfloat16, onnxruntime python interface would not be able to execute bfloat16 input/output models. This PR provides temporary solution that convert model input/output to float16, and add op.Cast after/before them to enable it. The pass lives under onnxruntime rewriter defaulting off. NOTE: We should delete this offline pass once we have bfloat16 supportingin numpy. BLOCKED: onnxruntime doesn't seem to have implementation of bfloat16 on op.Add and op.Mul on both CUDA and CPU(see test case.) --------- Co-authored-by: Justin Chu <[email protected]>
- Loading branch information
1 parent
34a0cd1
commit 280fb39
Showing
3 changed files
with
198 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
101 changes: 101 additions & 0 deletions
101
onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import logging | ||
|
||
from onnxscript import ir | ||
|
||
logger = logging.getLogger(__name__) | ||
CREATED_CAST_BFLOAT16_NAME_SUFFIX = "_cast_bfloat16" | ||
|
||
|
||
def _convert_inputs_from_bfloat16_to_float16(value: ir.Input) -> None: | ||
if value.dtype != ir.DataType.BFLOAT16: | ||
return | ||
value.dtype = ir.DataType.FLOAT16 | ||
_insert_cast_nodes_for_float16_to_bfloat16_to_inputs(value) | ||
|
||
|
||
def _convert_outputs_from_bfloat16_to_float16(value: ir.Value) -> None: | ||
if value.dtype != ir.DataType.BFLOAT16: | ||
return | ||
_insert_cast_nodes_for_bfloat16_to_float16_to_outputs(value) | ||
|
||
|
||
def _insert_cast_nodes_for_float16_to_bfloat16_to_inputs(value: ir.Input) -> None: | ||
user_nodes_and_indices = tuple(value.uses()) | ||
|
||
attr = ir.AttrInt64(name="to", value=ir.DataType.BFLOAT16) | ||
cast = ir.Node( | ||
domain="", | ||
op_type="Cast", | ||
inputs=[value], | ||
num_outputs=1, | ||
attributes=[attr], | ||
) | ||
cast.outputs[0].dtype = ir.DataType.BFLOAT16 | ||
cast.outputs[0].shape = value.shape | ||
|
||
for node, index in tuple(value.uses()): | ||
if node is cast: | ||
continue | ||
node.replace_input_with(index, cast.outputs[0]) | ||
|
||
# NOTE: A safer way to insert the cast node is to prepend it to the first node | ||
# of the graph | ||
assert user_nodes_and_indices[0][0].graph is not None, "The node should belong to a graph" | ||
user_nodes_and_indices[0][0].graph[0].prepend(cast) | ||
|
||
|
||
def _insert_cast_nodes_for_bfloat16_to_float16_to_outputs(value: ir.Value) -> None: | ||
node = value.producer() | ||
index = value.index() | ||
if node is None or index is None: | ||
logger.warning("Output value %s has no producer or index", value) | ||
return | ||
|
||
attr = ir.AttrInt64(name="to", value=ir.DataType.FLOAT16) | ||
cast = ir.Node( | ||
domain="", | ||
op_type="Cast", | ||
inputs=[node.outputs[index]], | ||
num_outputs=1, | ||
attributes=[attr], | ||
) | ||
cast.outputs[0].dtype = ir.DataType.FLOAT16 | ||
cast.outputs[0].shape = node.outputs[index].shape | ||
# To prevent naming conflicts, we need to append suffix to the output name of the cast node | ||
# TODO: Remove this after naming authority covers this case | ||
cast.outputs[0].name = node.outputs[index].name + CREATED_CAST_BFLOAT16_NAME_SUFFIX # type: ignore[operator] | ||
node.append(cast) | ||
|
||
assert node.graph is not None, "Node graph should not be None" | ||
# Update graph/function outputs | ||
for idx, graph_or_function_output in enumerate(node.graph.outputs): | ||
if graph_or_function_output == node.outputs[index]: | ||
node.graph.outputs[idx] = cast.outputs[0] | ||
# Swap the output name of the node with the output name of the cast node to | ||
# preserve the output name in the graph | ||
node.outputs[index].name, cast.outputs[0].name = ( | ||
cast.outputs[0].name, | ||
node.outputs[index].name, | ||
) | ||
|
||
|
||
def dtype_adapter_for_bfloat16_model(model: ir.Model) -> None: | ||
"""Adapt the model datatype if it's bfloat16. | ||
Because onnxruntime does not support bfloat16 as input/output datatype, we need to | ||
convert the bfloat16 datatype to float16. This function will convert the bfloat16 | ||
datatype to float16 and insert Cast nodes to convert float16 to bfloat16. | ||
Model: | ||
inputs(float16) -> Cast(bfloat16) -> nodes(bfloat16) -> Cast(float16) -> outputs(float16) | ||
TODO: Delete this function after onnxruntime supports bfloat16. | ||
Args: | ||
model: The model to adapt. | ||
""" | ||
for input in model.graph.inputs: | ||
_convert_inputs_from_bfloat16_to_float16(input) | ||
for output in model.graph.outputs: | ||
_convert_outputs_from_bfloat16_to_float16(output) |
96 changes: 96 additions & 0 deletions
96
onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
import onnx | ||
import onnx.checker | ||
import onnx.shape_inference | ||
import onnxruntime | ||
|
||
from onnxscript import ir | ||
from onnxscript.rewriter.onnxruntime.bfloat16_utils import bfloat16_converter | ||
|
||
|
||
class Bfloat16ConversionTest(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.v0 = ir.Input(name="v0", shape=ir.Shape([2, 3, 4])) | ||
self.v0.dtype = ir.DataType.BFLOAT16 | ||
self.v1 = ir.Input(name="v1", shape=ir.Shape([2, 3, 4])) | ||
self.v1.dtype = ir.DataType.BFLOAT16 | ||
self.v2 = ir.Input(name="v2", shape=ir.Shape([2, 3, 4])) | ||
self.v2.dtype = ir.DataType.BFLOAT16 | ||
|
||
self.add_node = ir.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1) | ||
self.add_node.outputs[0].dtype = ir.DataType.BFLOAT16 | ||
self.mul_node = ir.Node( | ||
"", "Mul", inputs=(self.add_node.outputs[0], self.v2), num_outputs=1 | ||
) | ||
self.mul_node.outputs[0].dtype = ir.DataType.BFLOAT16 | ||
self.graph = ir.Graph( | ||
name="bfloat16_conversion_test", | ||
inputs=(self.v0, self.v1, self.v2), | ||
outputs=(self.add_node.outputs[0], self.mul_node.outputs[0]), | ||
nodes=(self.add_node, self.mul_node), | ||
opset_imports={"": 18}, | ||
) | ||
self.original_output_names = [output.name for output in self.graph.outputs] | ||
self.model = ir.Model( | ||
graph=self.graph, | ||
ir_version=8, | ||
producer_name="bfloat16_conversion_test", | ||
) | ||
bfloat16_converter.dtype_adapter_for_bfloat16_model(self.model) | ||
|
||
def test_input_and_output_are_float16(self): | ||
for input in self.model.graph.inputs: | ||
self.assertEqual(input.dtype, ir.DataType.FLOAT16) | ||
for output in self.model.graph.outputs: | ||
self.assertEqual(output.dtype, ir.DataType.FLOAT16) | ||
|
||
def test_cast_nodes_are_inserted(self): | ||
cast_node_count = 0 | ||
for node in self.model.graph: | ||
if node.op_type == "Cast": | ||
cast_node_count += 1 | ||
self.assertEqual(cast_node_count, 5) | ||
|
||
for input in self.model.graph.inputs: | ||
for input_user, _ in input.uses(): | ||
self.assertEqual(input_user.op_type, "Cast") | ||
self.assertEqual(input_user.outputs[0].dtype, ir.DataType.BFLOAT16) | ||
for output in self.model.graph.outputs: | ||
self.assertEqual(output.producer().op_type, "Cast") | ||
self.assertEqual(output.producer().inputs[0].dtype, ir.DataType.BFLOAT16) | ||
|
||
def test_graph_output_name_is_preserved(self): | ||
self.assertEqual( | ||
[output.name for output in self.model.graph.outputs], | ||
self.original_output_names, | ||
) | ||
|
||
def test_bfloat16_converted_model_runtime(self): | ||
model_proto = ir.serde.serialize_model(self.model) | ||
model_proto_filled_shape_type = onnx.shape_inference.infer_shapes( | ||
model_proto, check_type=True, strict_mode=True, data_prop=True | ||
) | ||
onnx.checker.check_model(model_proto_filled_shape_type, full_check=True) | ||
try: | ||
ort_session = onnxruntime.InferenceSession( | ||
model_proto_filled_shape_type.SerializeToString() | ||
) | ||
v0 = np.random.randn(2, 3, 4).astype(np.float16) | ||
v1 = np.random.randn(2, 3, 4).astype(np.float16) | ||
v2 = np.random.randn(2, 3, 4).astype(np.float16) | ||
ort_inputs = {"v0": v0, "v1": v1, "v2": v2} | ||
ort_outputs = ort_session.run(None, ort_inputs) | ||
expected_output = (v0 + v1) * v2 | ||
np.testing.assert_allclose(ort_outputs[0], expected_output, rtol=1e-2, atol=1e-2) | ||
np.testing.assert_allclose(ort_outputs[1], expected_output, rtol=1e-2, atol=1e-2) | ||
except Exception as e: | ||
self.assertIn( | ||
"[ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Add(14)", | ||
str(e), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |