Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Op(_native_batch_norm_legit_no_training and _native_batch_norm_le…
…git) | feat(torchlib) (#1116) Fix #817 Add the support of `_native_batch_norm_legit_no_training` and `_native_batch_norm_legit`, which are two new aten ops to replace aten::native_batch_norm according to https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1501-L1510. Previous to this PR, due to lack of support of `_native_batch_norm_legit_no_training` and `_native_batch_norm_legit`, the exporter decomposes `native_batch_norm` to a bunch of other nodes and drags down the performance. NOTE: The mismatch result size between CUDA/CPU export doesn't happen even with these nodes supported. Could be fixed somewhere else. Tested with the code: ```python import torch import onnxruntime def repro_split(): class Model(torch.nn.Module): def __init__(self): super().__init__() self.bn = torch.nn.BatchNorm2d(64) self.conv = torch.nn.Conv2d(64, 64, 3) def forward(self, x): x = self.bn(x) x = self.conv(x) return torch.split(x, [16, 24, 24], 1) model = Model().cuda().eval() x = torch.randn(1, 64, 32, 32).cuda() export_output = torch.onnx.dynamo_export(model, x) onnxruntime.InferenceSession(export_output.model_proto.SerializeToString()) export_output.save("coat_lite_mini.onnx") export_output.save_diagnostics("debug_bn.sarif") session = onnxruntime.InferenceSession("coat_lite_mini.onnx") input_names = [ort_input.name for ort_input in session.get_inputs()] onnx_format_args = export_output.adapt_torch_inputs_to_onnx( x ) ort_input = {k: v.cpu().numpy() for k, v in zip(input_names, onnx_format_args)} print(session.run(None, ort_input)) repro_split() ```
- Loading branch information