diff --git a/onnxscript/rewriter/onnxruntime/__init__.py b/onnxscript/rewriter/onnxruntime/__init__.py index 67a9ee12c..4e9007e36 100644 --- a/onnxscript/rewriter/onnxruntime/__init__.py +++ b/onnxscript/rewriter/onnxruntime/__init__.py @@ -51,6 +51,7 @@ def rewrite( if pattern_rules: count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model) print(f"Applied {count} of onnxruntime specific pattern rewrite rules.") + model_proto = ir.serde.serialize_model(model) remove_unused.remove_unused_nodes(model_proto) remove_unused_function.remove_unused_functions(model_proto) diff --git a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py new file mode 100644 index 000000000..e4afb432d --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py @@ -0,0 +1,101 @@ +import logging + +from onnxscript import ir + +logger = logging.getLogger(__name__) +CREATED_CAST_BFLOAT16_NAME_SUFFIX = "_cast_bfloat16" + + +def _convert_inputs_from_bfloat16_to_float16(value: ir.Input) -> None: + if value.dtype != ir.DataType.BFLOAT16: + return + value.dtype = ir.DataType.FLOAT16 + _insert_cast_nodes_for_float16_to_bfloat16_to_inputs(value) + + +def _convert_outputs_from_bfloat16_to_float16(value: ir.Value) -> None: + if value.dtype != ir.DataType.BFLOAT16: + return + _insert_cast_nodes_for_bfloat16_to_float16_to_outputs(value) + + +def _insert_cast_nodes_for_float16_to_bfloat16_to_inputs(value: ir.Input) -> None: + user_nodes_and_indices = tuple(value.uses()) + + attr = ir.AttrInt64(name="to", value=ir.DataType.BFLOAT16) + cast = ir.Node( + domain="", + op_type="Cast", + inputs=[value], + num_outputs=1, + attributes=[attr], + ) + cast.outputs[0].dtype = ir.DataType.BFLOAT16 + cast.outputs[0].shape = value.shape + + for node, index in tuple(value.uses()): + if node is cast: + continue + node.replace_input_with(index, cast.outputs[0]) + + # NOTE: A safer way to insert the cast node is to prepend it to the first node + # of the graph + assert user_nodes_and_indices[0][0].graph is not None, "The node should belong to a graph" + user_nodes_and_indices[0][0].graph[0].prepend(cast) + + +def _insert_cast_nodes_for_bfloat16_to_float16_to_outputs(value: ir.Value) -> None: + node = value.producer() + index = value.index() + if node is None or index is None: + logger.warning("Output value %s has no producer or index", value) + return + + attr = ir.AttrInt64(name="to", value=ir.DataType.FLOAT16) + cast = ir.Node( + domain="", + op_type="Cast", + inputs=[node.outputs[index]], + num_outputs=1, + attributes=[attr], + ) + cast.outputs[0].dtype = ir.DataType.FLOAT16 + cast.outputs[0].shape = node.outputs[index].shape + # To prevent naming conflicts, we need to append suffix to the output name of the cast node + # TODO: Remove this after naming authority covers this case + cast.outputs[0].name = node.outputs[index].name + CREATED_CAST_BFLOAT16_NAME_SUFFIX # type: ignore[operator] + node.append(cast) + + assert node.graph is not None, "Node graph should not be None" + # Update graph/function outputs + for idx, graph_or_function_output in enumerate(node.graph.outputs): + if graph_or_function_output == node.outputs[index]: + node.graph.outputs[idx] = cast.outputs[0] + # Swap the output name of the node with the output name of the cast node to + # preserve the output name in the graph + node.outputs[index].name, cast.outputs[0].name = ( + cast.outputs[0].name, + node.outputs[index].name, + ) + + +def dtype_adapter_for_bfloat16_model(model: ir.Model) -> None: + """Adapt the model datatype if it's bfloat16. + + Because onnxruntime does not support bfloat16 as input/output datatype, we need to + convert the bfloat16 datatype to float16. This function will convert the bfloat16 + datatype to float16 and insert Cast nodes to convert float16 to bfloat16. + + Model: + inputs(float16) -> Cast(bfloat16) -> nodes(bfloat16) -> Cast(float16) -> outputs(float16) + + TODO: Delete this function after onnxruntime supports bfloat16. + + Args: + model: The model to adapt. + + """ + for input in model.graph.inputs: + _convert_inputs_from_bfloat16_to_float16(input) + for output in model.graph.outputs: + _convert_outputs_from_bfloat16_to_float16(output) diff --git a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py new file mode 100644 index 000000000..ed53d2f64 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py @@ -0,0 +1,96 @@ +import unittest + +import numpy as np +import onnx +import onnx.checker +import onnx.shape_inference +import onnxruntime + +from onnxscript import ir +from onnxscript.rewriter.onnxruntime.bfloat16_utils import bfloat16_converter + + +class Bfloat16ConversionTest(unittest.TestCase): + def setUp(self) -> None: + self.v0 = ir.Input(name="v0", shape=ir.Shape([2, 3, 4])) + self.v0.dtype = ir.DataType.BFLOAT16 + self.v1 = ir.Input(name="v1", shape=ir.Shape([2, 3, 4])) + self.v1.dtype = ir.DataType.BFLOAT16 + self.v2 = ir.Input(name="v2", shape=ir.Shape([2, 3, 4])) + self.v2.dtype = ir.DataType.BFLOAT16 + + self.add_node = ir.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1) + self.add_node.outputs[0].dtype = ir.DataType.BFLOAT16 + self.mul_node = ir.Node( + "", "Mul", inputs=(self.add_node.outputs[0], self.v2), num_outputs=1 + ) + self.mul_node.outputs[0].dtype = ir.DataType.BFLOAT16 + self.graph = ir.Graph( + name="bfloat16_conversion_test", + inputs=(self.v0, self.v1, self.v2), + outputs=(self.add_node.outputs[0], self.mul_node.outputs[0]), + nodes=(self.add_node, self.mul_node), + opset_imports={"": 18}, + ) + self.original_output_names = [output.name for output in self.graph.outputs] + self.model = ir.Model( + graph=self.graph, + ir_version=8, + producer_name="bfloat16_conversion_test", + ) + bfloat16_converter.dtype_adapter_for_bfloat16_model(self.model) + + def test_input_and_output_are_float16(self): + for input in self.model.graph.inputs: + self.assertEqual(input.dtype, ir.DataType.FLOAT16) + for output in self.model.graph.outputs: + self.assertEqual(output.dtype, ir.DataType.FLOAT16) + + def test_cast_nodes_are_inserted(self): + cast_node_count = 0 + for node in self.model.graph: + if node.op_type == "Cast": + cast_node_count += 1 + self.assertEqual(cast_node_count, 5) + + for input in self.model.graph.inputs: + for input_user, _ in input.uses(): + self.assertEqual(input_user.op_type, "Cast") + self.assertEqual(input_user.outputs[0].dtype, ir.DataType.BFLOAT16) + for output in self.model.graph.outputs: + self.assertEqual(output.producer().op_type, "Cast") + self.assertEqual(output.producer().inputs[0].dtype, ir.DataType.BFLOAT16) + + def test_graph_output_name_is_preserved(self): + self.assertEqual( + [output.name for output in self.model.graph.outputs], + self.original_output_names, + ) + + def test_bfloat16_converted_model_runtime(self): + model_proto = ir.serde.serialize_model(self.model) + model_proto_filled_shape_type = onnx.shape_inference.infer_shapes( + model_proto, check_type=True, strict_mode=True, data_prop=True + ) + onnx.checker.check_model(model_proto_filled_shape_type, full_check=True) + try: + ort_session = onnxruntime.InferenceSession( + model_proto_filled_shape_type.SerializeToString() + ) + v0 = np.random.randn(2, 3, 4).astype(np.float16) + v1 = np.random.randn(2, 3, 4).astype(np.float16) + v2 = np.random.randn(2, 3, 4).astype(np.float16) + ort_inputs = {"v0": v0, "v1": v1, "v2": v2} + ort_outputs = ort_session.run(None, ort_inputs) + expected_output = (v0 + v1) * v2 + np.testing.assert_allclose(ort_outputs[0], expected_output, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(ort_outputs[1], expected_output, rtol=1e-2, atol=1e-2) + except Exception as e: + self.assertIn( + "[ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Add(14)", + str(e), + ) + + +if __name__ == "__main__": + unittest.main()