diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index f81aef5f6b9c4..dd7fea3ceda10 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -10,7 +10,7 @@ from packaging import version from packaging.version import Version from torch.onnx import register_custom_op_symbolic -from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args +from torch.onnx.symbolic_helper import parse_args from onnxruntime.training.utils import pytorch_type_to_onnx_dtype @@ -176,9 +176,9 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): try: # Tolerant to the case when sizes of indices are not available or not usable (for example # when DeepSpeed stage3 enabled, all weights size is (0), this will fail.) - indices_shape = _get_tensor_sizes(indices) + indices_shape = sym_help._get_tensor_sizes(indices) if indices_shape is not None and hasattr(weight.type(), "with_sizes"): - output_type = weight.type().with_sizes([*indices_shape, _get_tensor_dim_size(weight, 1)]) + output_type = weight.type().with_sizes([*indices_shape, sym_help._get_tensor_dim_size(weight, 1)]) output.setType(output_type) except IndexError: output.setType(weight.type()) @@ -845,3 +845,30 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): ) return res + + +# Adapted from torch.onnx.symbolic_opset13.softmax - +# https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset13.py#L27 +# We don't need overloads symbolic_opset9 because training support opsets >= 13. +# +# Why we need to define softmax export logic here? +# For the usage `nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)` in the model, +# https://github.com/huggingface/transformers/blob/76a33a10923ccc1074917f6b6a1e719e626b7dc9/src/transformers/models/mistral/modeling_mistral.py#L302 +# If dtype is specified, the input tensor is casted to dtype before the operation is performed. +# This is useful for preventing data type overflows. While existing ONNX exporter do the cast after the operation. +# This override can be a workaround before PyTorch fix the issues in coming releases. +# (TODO: pengwa - add PyTorch versions when the issue is fixed). +@register_symbolic("softmax") +@parse_args("v", "i", "none") +def softmax(g, input, dim, dtype=None): + from torch.onnx import _type_utils + + casted_input = input + need_cast_for_compute = dtype and dtype.node().kind() != "prim::Constant" + if need_cast_for_compute: + parsed_dtype = sym_help._get_const(dtype, "i", "dtype") + casted_input = g.op("Cast", input, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()) + + softmax = g.op("Softmax", casted_input, axis_i=dim) + + return softmax diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 7afad9145ed27..7c356301470b9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -33,6 +33,7 @@ from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule, _fallback, _io, _utils from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient from onnxruntime.training.ortmodule.options import _SkipCheck +from onnxruntime.training.utils import pytorch_type_to_onnx_dtype DEFAULT_OPSET = 17 @@ -6496,3 +6497,54 @@ def run_step(model, x, y, z): torch.cuda.synchronize() if original_val is not None: os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val + + +@pytest.mark.parametrize("softmax_compute_type", [torch.float16, torch.float32]) +def test_overriden_softmax_export(softmax_compute_type): + class CustomSoftmaxExportTest(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, attn_weight): + return torch.nn.functional.softmax(attn_weight, dim=-1, dtype=softmax_compute_type) + + device = "cuda" + pt_model = CustomSoftmaxExportTest().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="overriden_softmax_export")) + + def run_step(model, attn_weight): + prediction = model(attn_weight) + prediction.sum().backward() + return prediction + + # reset manual seed to reset the generator + torch.manual_seed(2333) + attn_weight = torch.randn([20, 6, 10, 10], dtype=torch.float, device=device, requires_grad=True) + ort_attn_weight = copy.deepcopy(attn_weight) + pt_prediction = run_step(pt_model, attn_weight) + ort_prediction = run_step(ort_model, ort_attn_weight) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(attn_weight.grad, ort_attn_weight.grad) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + + # Check the ONNX Softmax is running in float32. + execution_mgr = ort_model._torch_module._execution_manager._training_manager + from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name + + # Keep the logic aligned with _graph_execution_manager.py + path = os.path.join( + execution_mgr._debug_options.save_onnx_models.path, + _get_onnx_file_name( + execution_mgr._debug_options.save_onnx_models.name_prefix, "torch_exported", execution_mgr._export_mode + ), + ) + + onnx_model = onnx.load(path) + onnx_nodes = [n for n in onnx_model.graph.node] + + assert onnx_nodes[0].op_type == "Cast" + to_attr = onnx_nodes[0].attribute[0] + assert to_attr.name == "to" + to_value = to_attr.i + assert to_value == pytorch_type_to_onnx_dtype(softmax_compute_type), "Cast to attribute is not as expected"