From 1a0ba3f69f5075754ecae9c92abce9360861a7a5 Mon Sep 17 00:00:00 2001 From: pengwa Date: Tue, 26 Mar 2024 13:09:20 +0800 Subject: [PATCH] Fix softmax export (#20057) ### Description Why we need to define softmax export logic here? For the usage `nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)` in the model, https://github.com/huggingface/transformers/blob/76a33a10923ccc1074917f6b6a1e719e626b7dc9/src/transformers/models/mistral/modeling_mistral.py#L302 If dtype is specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. While existing ONNX exporter do the cast after the operation, which is not correct. (https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset13.py#L27). This override can be a workaround before PyTorch fix the issues in coming releases. (TODO: pengwa - add PyTorch versions when the issue is fixed). @thiagocrepaldi We may need a fix in PyTorch repo as well. ### Motivation and Context --- .../ortmodule/_custom_op_symbolic_registry.py | 33 ++++++++++-- .../python/orttraining_test_ortmodule_api.py | 54 +++++++++++++++++++ 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index f81aef5f6b9c4..dd7fea3ceda10 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -10,7 +10,7 @@ from packaging import version from packaging.version import Version from torch.onnx import register_custom_op_symbolic -from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args +from torch.onnx.symbolic_helper import parse_args from onnxruntime.training.utils import pytorch_type_to_onnx_dtype @@ -176,9 +176,9 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): try: # Tolerant to the case when sizes of indices are not available or not usable (for example # when DeepSpeed stage3 enabled, all weights size is (0), this will fail.) - indices_shape = _get_tensor_sizes(indices) + indices_shape = sym_help._get_tensor_sizes(indices) if indices_shape is not None and hasattr(weight.type(), "with_sizes"): - output_type = weight.type().with_sizes([*indices_shape, _get_tensor_dim_size(weight, 1)]) + output_type = weight.type().with_sizes([*indices_shape, sym_help._get_tensor_dim_size(weight, 1)]) output.setType(output_type) except IndexError: output.setType(weight.type()) @@ -845,3 +845,30 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): ) return res + + +# Adapted from torch.onnx.symbolic_opset13.softmax - +# https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset13.py#L27 +# We don't need overloads symbolic_opset9 because training support opsets >= 13. +# +# Why we need to define softmax export logic here? +# For the usage `nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)` in the model, +# https://github.com/huggingface/transformers/blob/76a33a10923ccc1074917f6b6a1e719e626b7dc9/src/transformers/models/mistral/modeling_mistral.py#L302 +# If dtype is specified, the input tensor is casted to dtype before the operation is performed. +# This is useful for preventing data type overflows. While existing ONNX exporter do the cast after the operation. +# This override can be a workaround before PyTorch fix the issues in coming releases. +# (TODO: pengwa - add PyTorch versions when the issue is fixed). +@register_symbolic("softmax") +@parse_args("v", "i", "none") +def softmax(g, input, dim, dtype=None): + from torch.onnx import _type_utils + + casted_input = input + need_cast_for_compute = dtype and dtype.node().kind() != "prim::Constant" + if need_cast_for_compute: + parsed_dtype = sym_help._get_const(dtype, "i", "dtype") + casted_input = g.op("Cast", input, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()) + + softmax = g.op("Softmax", casted_input, axis_i=dim) + + return softmax diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 7afad9145ed27..d6f55e787c320 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -33,6 +33,7 @@ from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule, _fallback, _io, _utils from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient from onnxruntime.training.ortmodule.options import _SkipCheck +from onnxruntime.training.utils import pytorch_type_to_onnx_dtype DEFAULT_OPSET = 17 @@ -6496,3 +6497,56 @@ def run_step(model, x, y, z): torch.cuda.synchronize() if original_val is not None: os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val + + +@pytest.mark.parametrize("softmax_compute_type", [torch.float16, torch.float32]) +def test_overridden_softmax_export(softmax_compute_type): + class CustomSoftmaxExportTest(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, attn_weight): + return torch.nn.functional.softmax(attn_weight, dim=-1, dtype=softmax_compute_type) + + device = "cuda" + pt_model = CustomSoftmaxExportTest().to(device) + ort_model = ORTModule( + copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="overridden_softmax_export") + ) + + def run_step(model, attn_weight): + prediction = model(attn_weight) + prediction.sum().backward() + return prediction + + # reset manual seed to reset the generator + torch.manual_seed(2333) + attn_weight = torch.randn([20, 6, 10, 10], dtype=torch.float, device=device, requires_grad=True) + ort_attn_weight = copy.deepcopy(attn_weight) + pt_prediction = run_step(pt_model, attn_weight) + ort_prediction = run_step(ort_model, ort_attn_weight) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(attn_weight.grad, ort_attn_weight.grad) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + + # Check the ONNX Softmax is running in float32. + execution_mgr = ort_model._torch_module._execution_manager._training_manager + from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name + + # Keep the logic aligned with _graph_execution_manager.py + path = os.path.join( + execution_mgr._debug_options.save_onnx_models.path, + _get_onnx_file_name( + execution_mgr._debug_options.save_onnx_models.name_prefix, "torch_exported", execution_mgr._export_mode + ), + ) + + onnx_model = onnx.load(path) + onnx_nodes = [n for n in onnx_model.graph.node] + + assert onnx_nodes[0].op_type == "Cast" + to_attr = onnx_nodes[0].attribute[0] + assert to_attr.name == "to" + to_value = to_attr.i + assert to_value == pytorch_type_to_onnx_dtype(softmax_compute_type), "Cast to attribute is not as expected"