Skip to content

Commit

Permalink
Fix softmax export (#20057)
Browse files Browse the repository at this point in the history
### Description


Why we need to define softmax export logic here?
For the usage `nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32)` in the model,

https://github.com/huggingface/transformers/blob/76a33a10923ccc1074917f6b6a1e719e626b7dc9/src/transformers/models/mistral/modeling_mistral.py#L302
If dtype is specified, the input tensor is casted to dtype before the
operation is performed.
This is useful for preventing data type overflows. While existing ONNX
exporter do the cast after the operation, which is not correct.
(https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset13.py#L27).
This override can be a workaround before PyTorch fix the issues in
coming releases.
 (TODO: pengwa - add PyTorch versions when the issue is fixed).


@thiagocrepaldi We may need a fix in PyTorch repo as well. 

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored Mar 26, 2024
1 parent 7d976cf commit 1a0ba3f
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from packaging import version
from packaging.version import Version
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args
from torch.onnx.symbolic_helper import parse_args

from onnxruntime.training.utils import pytorch_type_to_onnx_dtype

Expand Down Expand Up @@ -176,9 +176,9 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
try:
# Tolerant to the case when sizes of indices are not available or not usable (for example
# when DeepSpeed stage3 enabled, all weights size is (0), this will fail.)
indices_shape = _get_tensor_sizes(indices)
indices_shape = sym_help._get_tensor_sizes(indices)
if indices_shape is not None and hasattr(weight.type(), "with_sizes"):
output_type = weight.type().with_sizes([*indices_shape, _get_tensor_dim_size(weight, 1)])
output_type = weight.type().with_sizes([*indices_shape, sym_help._get_tensor_dim_size(weight, 1)])
output.setType(output_type)
except IndexError:
output.setType(weight.type())
Expand Down Expand Up @@ -845,3 +845,30 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
)

return res


# Adapted from torch.onnx.symbolic_opset13.softmax -
# https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset13.py#L27
# We don't need overloads symbolic_opset9 because training support opsets >= 13.
#
# Why we need to define softmax export logic here?
# For the usage `nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)` in the model,
# https://github.com/huggingface/transformers/blob/76a33a10923ccc1074917f6b6a1e719e626b7dc9/src/transformers/models/mistral/modeling_mistral.py#L302
# If dtype is specified, the input tensor is casted to dtype before the operation is performed.
# This is useful for preventing data type overflows. While existing ONNX exporter do the cast after the operation.
# This override can be a workaround before PyTorch fix the issues in coming releases.
# (TODO: pengwa - add PyTorch versions when the issue is fixed).
@register_symbolic("softmax")
@parse_args("v", "i", "none")
def softmax(g, input, dim, dtype=None):
from torch.onnx import _type_utils

casted_input = input
need_cast_for_compute = dtype and dtype.node().kind() != "prim::Constant"
if need_cast_for_compute:
parsed_dtype = sym_help._get_const(dtype, "i", "dtype")
casted_input = g.op("Cast", input, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type())

softmax = g.op("Softmax", casted_input, axis_i=dim)

return softmax
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from onnxruntime.training.ortmodule import DebugOptions, LogLevel, ORTModule, _fallback, _io, _utils
from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient
from onnxruntime.training.ortmodule.options import _SkipCheck
from onnxruntime.training.utils import pytorch_type_to_onnx_dtype

DEFAULT_OPSET = 17

Expand Down Expand Up @@ -6496,3 +6497,56 @@ def run_step(model, x, y, z):
torch.cuda.synchronize()
if original_val is not None:
os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] = original_val


@pytest.mark.parametrize("softmax_compute_type", [torch.float16, torch.float32])
def test_overridden_softmax_export(softmax_compute_type):
class CustomSoftmaxExportTest(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, attn_weight):
return torch.nn.functional.softmax(attn_weight, dim=-1, dtype=softmax_compute_type)

device = "cuda"
pt_model = CustomSoftmaxExportTest().to(device)
ort_model = ORTModule(
copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="overridden_softmax_export")
)

def run_step(model, attn_weight):
prediction = model(attn_weight)
prediction.sum().backward()
return prediction

# reset manual seed to reset the generator
torch.manual_seed(2333)
attn_weight = torch.randn([20, 6, 10, 10], dtype=torch.float, device=device, requires_grad=True)
ort_attn_weight = copy.deepcopy(attn_weight)
pt_prediction = run_step(pt_model, attn_weight)
ort_prediction = run_step(ort_model, ort_attn_weight)

_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(attn_weight.grad, ort_attn_weight.grad)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)

# Check the ONNX Softmax is running in float32.
execution_mgr = ort_model._torch_module._execution_manager._training_manager
from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name

# Keep the logic aligned with _graph_execution_manager.py
path = os.path.join(
execution_mgr._debug_options.save_onnx_models.path,
_get_onnx_file_name(
execution_mgr._debug_options.save_onnx_models.name_prefix, "torch_exported", execution_mgr._export_mode
),
)

onnx_model = onnx.load(path)
onnx_nodes = [n for n in onnx_model.graph.node]

assert onnx_nodes[0].op_type == "Cast"
to_attr = onnx_nodes[0].attribute[0]
assert to_attr.name == "to"
to_value = to_attr.i
assert to_value == pytorch_type_to_onnx_dtype(softmax_compute_type), "Cast to attribute is not as expected"

0 comments on commit 1a0ba3f

Please sign in to comment.