Skip to content

Commit

Permalink
Add Op(_native_batch_norm_legit_no_training and _native_batch_norm_le…
Browse files Browse the repository at this point in the history
…git) | feat(torchlib) (#1116)

Fix #817

Add the support of `_native_batch_norm_legit_no_training` and
`_native_batch_norm_legit`, which are two new aten ops to replace
aten::native_batch_norm according to
https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1501-L1510.

Previous to this PR, due to lack of support of
`_native_batch_norm_legit_no_training` and `_native_batch_norm_legit`,
the exporter decomposes `native_batch_norm` to a bunch of other nodes
and drags down the performance.

NOTE: The mismatch result size between CUDA/CPU export doesn't happen
even with these nodes supported. Could be fixed somewhere else.

Tested with the code:

```python
import torch

import onnxruntime


def repro_split():
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.bn = torch.nn.BatchNorm2d(64)
            self.conv = torch.nn.Conv2d(64, 64, 3)

        def forward(self, x):
            x = self.bn(x)
            x = self.conv(x)
            return torch.split(x, [16, 24, 24], 1)

    model = Model().cuda().eval()
    x = torch.randn(1, 64, 32, 32).cuda()
    export_output = torch.onnx.dynamo_export(model, x)

    onnxruntime.InferenceSession(export_output.model_proto.SerializeToString())
    export_output.save("coat_lite_mini.onnx")
    export_output.save_diagnostics("debug_bn.sarif")

    session = onnxruntime.InferenceSession("coat_lite_mini.onnx")
    input_names = [ort_input.name for ort_input in session.get_inputs()]
    onnx_format_args = export_output.adapt_torch_inputs_to_onnx(
        x
    )
    ort_input = {k: v.cpu().numpy() for k, v in zip(input_names, onnx_format_args)}
    print(session.run(None, ort_input))


repro_split()
```
  • Loading branch information
titaiwangms authored Oct 28, 2023
1 parent 70843ef commit f35e844
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5432,7 +5432,28 @@ def aten_narrow_copy(self: TensorType, dim: int, start: INT64, length: INT64) ->
raise NotImplementedError()


@torch_op("aten::native_batch_norm", trace_only=True)
# NOTE: https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1501-L1510
# _native_batch_norm_legit_no_training and _native_batch_norm_legit are meant to
# replace native_batch_norm within unknown time period.
# TODO: Refactor this after native_batch_norm is deprecated.
@torch_op("aten::_native_batch_norm_legit_no_training", trace_only=True)
def aten_native_batch_norm_no_training(
input: TFloat,
weight: Optional[TFloat] = None,
bias: Optional[TFloat] = None,
running_mean: Optional[TFloat] = None,
running_var: Optional[TFloat] = None,
momentum: float = 0.9,
eps: float = 1e-05,
) -> Tuple[TFloat, TFloat, TFloat]:
"""_native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)"""

return aten_native_batch_norm(
input, weight, bias, running_mean, running_var, False, momentum, eps
)


@torch_op(("aten::native_batch_norm", "aten::_native_batch_norm_legit"), trace_only=True)
def aten_native_batch_norm(
input: TFloat,
weight: Optional[TFloat] = None,
Expand Down

0 comments on commit f35e844

Please sign in to comment.