Skip to content

Commit

Permalink
Add bfloat16 to float16 converter in onnxruntime specific rewriter (#…
Browse files Browse the repository at this point in the history
…1492)

Previous to this PR, as numpy currently does not support bfloat16,
onnxruntime python interface would not be able to execute bfloat16
input/output models. This PR provides temporary solution that convert
model input/output to float16, and add op.Cast after/before them to
enable it. The pass lives under onnxruntime rewriter defaulting off.

NOTE: We should delete this offline pass once we have bfloat16
supportingin numpy.
BLOCKED: onnxruntime doesn't seem to have implementation of bfloat16 on
op.Add and op.Mul on both CUDA and CPU(see test case.)

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
titaiwangms and justinchuby authored May 7, 2024
1 parent 34a0cd1 commit 280fb39
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 0 deletions.
1 change: 1 addition & 0 deletions onnxscript/rewriter/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def rewrite(
if pattern_rules:
count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model)
print(f"Applied {count} of onnxruntime specific pattern rewrite rules.")

model_proto = ir.serde.serialize_model(model)
remove_unused.remove_unused_nodes(model_proto)
remove_unused_function.remove_unused_functions(model_proto)
Expand Down
101 changes: 101 additions & 0 deletions onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import logging

from onnxscript import ir

logger = logging.getLogger(__name__)
CREATED_CAST_BFLOAT16_NAME_SUFFIX = "_cast_bfloat16"


def _convert_inputs_from_bfloat16_to_float16(value: ir.Input) -> None:
if value.dtype != ir.DataType.BFLOAT16:
return
value.dtype = ir.DataType.FLOAT16
_insert_cast_nodes_for_float16_to_bfloat16_to_inputs(value)


def _convert_outputs_from_bfloat16_to_float16(value: ir.Value) -> None:
if value.dtype != ir.DataType.BFLOAT16:
return
_insert_cast_nodes_for_bfloat16_to_float16_to_outputs(value)


def _insert_cast_nodes_for_float16_to_bfloat16_to_inputs(value: ir.Input) -> None:
user_nodes_and_indices = tuple(value.uses())

attr = ir.AttrInt64(name="to", value=ir.DataType.BFLOAT16)
cast = ir.Node(
domain="",
op_type="Cast",
inputs=[value],
num_outputs=1,
attributes=[attr],
)
cast.outputs[0].dtype = ir.DataType.BFLOAT16
cast.outputs[0].shape = value.shape

for node, index in tuple(value.uses()):
if node is cast:
continue
node.replace_input_with(index, cast.outputs[0])

# NOTE: A safer way to insert the cast node is to prepend it to the first node
# of the graph
assert user_nodes_and_indices[0][0].graph is not None, "The node should belong to a graph"
user_nodes_and_indices[0][0].graph[0].prepend(cast)


def _insert_cast_nodes_for_bfloat16_to_float16_to_outputs(value: ir.Value) -> None:
node = value.producer()
index = value.index()
if node is None or index is None:
logger.warning("Output value %s has no producer or index", value)
return

attr = ir.AttrInt64(name="to", value=ir.DataType.FLOAT16)
cast = ir.Node(
domain="",
op_type="Cast",
inputs=[node.outputs[index]],
num_outputs=1,
attributes=[attr],
)
cast.outputs[0].dtype = ir.DataType.FLOAT16
cast.outputs[0].shape = node.outputs[index].shape
# To prevent naming conflicts, we need to append suffix to the output name of the cast node
# TODO: Remove this after naming authority covers this case
cast.outputs[0].name = node.outputs[index].name + CREATED_CAST_BFLOAT16_NAME_SUFFIX # type: ignore[operator]
node.append(cast)

assert node.graph is not None, "Node graph should not be None"
# Update graph/function outputs
for idx, graph_or_function_output in enumerate(node.graph.outputs):
if graph_or_function_output == node.outputs[index]:
node.graph.outputs[idx] = cast.outputs[0]
# Swap the output name of the node with the output name of the cast node to
# preserve the output name in the graph
node.outputs[index].name, cast.outputs[0].name = (
cast.outputs[0].name,
node.outputs[index].name,
)


def dtype_adapter_for_bfloat16_model(model: ir.Model) -> None:
"""Adapt the model datatype if it's bfloat16.
Because onnxruntime does not support bfloat16 as input/output datatype, we need to
convert the bfloat16 datatype to float16. This function will convert the bfloat16
datatype to float16 and insert Cast nodes to convert float16 to bfloat16.
Model:
inputs(float16) -> Cast(bfloat16) -> nodes(bfloat16) -> Cast(float16) -> outputs(float16)
TODO: Delete this function after onnxruntime supports bfloat16.
Args:
model: The model to adapt.
"""
for input in model.graph.inputs:
_convert_inputs_from_bfloat16_to_float16(input)
for output in model.graph.outputs:
_convert_outputs_from_bfloat16_to_float16(output)
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import unittest

import numpy as np
import onnx
import onnx.checker
import onnx.shape_inference
import onnxruntime

from onnxscript import ir
from onnxscript.rewriter.onnxruntime.bfloat16_utils import bfloat16_converter


class Bfloat16ConversionTest(unittest.TestCase):
def setUp(self) -> None:
self.v0 = ir.Input(name="v0", shape=ir.Shape([2, 3, 4]))
self.v0.dtype = ir.DataType.BFLOAT16
self.v1 = ir.Input(name="v1", shape=ir.Shape([2, 3, 4]))
self.v1.dtype = ir.DataType.BFLOAT16
self.v2 = ir.Input(name="v2", shape=ir.Shape([2, 3, 4]))
self.v2.dtype = ir.DataType.BFLOAT16

self.add_node = ir.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1)
self.add_node.outputs[0].dtype = ir.DataType.BFLOAT16
self.mul_node = ir.Node(
"", "Mul", inputs=(self.add_node.outputs[0], self.v2), num_outputs=1
)
self.mul_node.outputs[0].dtype = ir.DataType.BFLOAT16
self.graph = ir.Graph(
name="bfloat16_conversion_test",
inputs=(self.v0, self.v1, self.v2),
outputs=(self.add_node.outputs[0], self.mul_node.outputs[0]),
nodes=(self.add_node, self.mul_node),
opset_imports={"": 18},
)
self.original_output_names = [output.name for output in self.graph.outputs]
self.model = ir.Model(
graph=self.graph,
ir_version=8,
producer_name="bfloat16_conversion_test",
)
bfloat16_converter.dtype_adapter_for_bfloat16_model(self.model)

def test_input_and_output_are_float16(self):
for input in self.model.graph.inputs:
self.assertEqual(input.dtype, ir.DataType.FLOAT16)
for output in self.model.graph.outputs:
self.assertEqual(output.dtype, ir.DataType.FLOAT16)

def test_cast_nodes_are_inserted(self):
cast_node_count = 0
for node in self.model.graph:
if node.op_type == "Cast":
cast_node_count += 1
self.assertEqual(cast_node_count, 5)

for input in self.model.graph.inputs:
for input_user, _ in input.uses():
self.assertEqual(input_user.op_type, "Cast")
self.assertEqual(input_user.outputs[0].dtype, ir.DataType.BFLOAT16)
for output in self.model.graph.outputs:
self.assertEqual(output.producer().op_type, "Cast")
self.assertEqual(output.producer().inputs[0].dtype, ir.DataType.BFLOAT16)

def test_graph_output_name_is_preserved(self):
self.assertEqual(
[output.name for output in self.model.graph.outputs],
self.original_output_names,
)

def test_bfloat16_converted_model_runtime(self):
model_proto = ir.serde.serialize_model(self.model)
model_proto_filled_shape_type = onnx.shape_inference.infer_shapes(
model_proto, check_type=True, strict_mode=True, data_prop=True
)
onnx.checker.check_model(model_proto_filled_shape_type, full_check=True)
try:
ort_session = onnxruntime.InferenceSession(
model_proto_filled_shape_type.SerializeToString()
)
v0 = np.random.randn(2, 3, 4).astype(np.float16)
v1 = np.random.randn(2, 3, 4).astype(np.float16)
v2 = np.random.randn(2, 3, 4).astype(np.float16)
ort_inputs = {"v0": v0, "v1": v1, "v2": v2}
ort_outputs = ort_session.run(None, ort_inputs)
expected_output = (v0 + v1) * v2
np.testing.assert_allclose(ort_outputs[0], expected_output, rtol=1e-2, atol=1e-2)
np.testing.assert_allclose(ort_outputs[1], expected_output, rtol=1e-2, atol=1e-2)
except Exception as e:
self.assertIn(
"[ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Add(14)",
str(e),
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 280fb39

Please sign in to comment.