Skip to content

Commit

Permalink
Add Ops(_native_batch_norm_legit_functional) | feat(torchlib) (#1143)
Browse files Browse the repository at this point in the history
Fix #1140 

Add
(1) `aten::_native_batch_norm_legit.no_stats`
(2) `aten::_to_copy`
(3) `aten::_native_batch_norm_legit_functional`

`aten::_native_batch_norm_legit_functional` is only invoked by
Functionalization pass, so it can't be tested in op_test. It will be
added into op_test in converter side. The only difference btween the op
and `aten::_native_batch_norm_legit` is the output numbers.
`aten::_native_batch_norm_legit_functional` returns running_mean and
running_var according to
https://github.com/pytorch/pytorch/blob/1488bafb274fcc82c8aac429bad61738bc3f950e/torch/_decomp/decompositions.py#L1804-L1826

`aten_native_batch_norm_legit` is split into two sample inputs to
separately feed into different ONNX variants, since they require
different set of arguments.

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
titaiwangms and justinchuby authored Nov 10, 2023
1 parent fdef96c commit 88ee668
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 4 deletions.
139 changes: 136 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,7 +2046,7 @@ def aten_convolution_overrideable(
raise NotImplementedError()


@torch_op("aten::copy")
@torch_op(("aten::copy", "aten::_to_copy"))
def aten_copy(
self: TTensor, src: TTensor, non_blocking: bool = False # pylint: disable=unused-argument
) -> TTensor:
Expand Down Expand Up @@ -5456,6 +5456,20 @@ def aten__native_batch_norm_no_training(
)


@torch_op("aten::_native_batch_norm_legit.no_stats", trace_only=True)
def aten__native_batch_norm_no_stats(
input: TFloat,
weight: Optional[TFloat] = None,
bias: Optional[TFloat] = None,
training: bool = False,
momentum: float = 0.9,
eps: float = 1e-05,
) -> Tuple[TFloat, TFloat, TFloat]:
"""_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)"""

return aten_native_batch_norm(input, weight, bias, None, None, training, momentum, eps)


@torch_op(("aten::native_batch_norm", "aten::_native_batch_norm_legit"), trace_only=True)
def aten_native_batch_norm(
input: TFloat,
Expand Down Expand Up @@ -5556,12 +5570,131 @@ def _aten_native_batch_norm_inference_onnx(
momentum=momentum,
training_mode=training,
)
# NOTE: mean and var are omitted in inference mode
# Cannot return 2 dup output, so have to do twice with different variable name
empty_mean = op.Cast(op.Shape(input, start=0, end=0), to=FLOAT.dtype)
empty_var = op.Cast(op.Shape(input, start=0, end=0), to=FLOAT.dtype)
empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm)
empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm)
return norm, empty_mean, empty_var


# TODO: This op is using duplicated code from aten_native_batch_norm,
# need to refactor it later. https://github.com/microsoft/onnxscript/issues/1125
# NOTE: This op is invoked by PyTorch Functionalization, and not in
# native_functions.yaml, It can be found in torch/_decomp/decompositions.py
@torch_op("aten::_native_batch_norm_legit_functional", trace_only=True)
def aten__native_batch_norm_legit_functional(
input: TFloat,
weight: Optional[TFloat] = None,
bias: Optional[TFloat] = None,
running_mean: Optional[TFloat] = None,
running_var: Optional[TFloat] = None,
training: bool = False,
momentum: float = 0.9,
eps: float = 1e-05,
) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]:
if weight is None: # Set to 1.0 as default
weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2))

if bias is None: # Set to 0.0 as default
bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2))

axes = list(range(len(input.shape)))
axes.pop(1)
axes = op.Constant(value_ints=axes)
if running_mean is None: # Using input mean
running_mean = op.Squeeze(op.ReduceMean(input, axes))

if running_var is None: # Using input var
mean = op.ReduceMean(input, axes)
input_sub_mean = op.Sub(input, mean)
sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean)
running_var = op.Squeeze(op.ReduceMean(sqr_input_sub_mean, axes))

# Have to split to 2 private functions, because training_function return 3 outputs
# While inference_function return 1 output
if training is True:
norm, mean, var, new_mean, new_var = _aten__native_batch_norm_training_functional_onnx(
input, weight, bias, running_mean, running_var, axes, training, momentum, eps
)
else:
(
norm,
mean,
var,
new_mean,
new_var,
) = _aten__native_batch_norm_inference_functional_onnx(
input, weight, bias, running_mean, running_var, training, momentum, eps
)
return norm, mean, var, new_mean, new_var


@torch_op("aten::_native_batch_norm_legit_functional", private=True)
def _aten__native_batch_norm_training_functional_onnx(
input: TFloat,
weight: TFloat,
bias: TFloat,
running_mean: TFloat,
running_var: TFloat,
axes: INT64,
training: bool,
momentum: float,
eps: float,
) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]:
# Assert(training is True)
norm, running_mean, running_var = op.BatchNormalization(
input,
weight,
bias,
running_mean,
running_var,
epsilon=eps,
momentum=momentum,
training_mode=training,
)
# Compute var and rstd
mean = op.ReduceMean(input, axes)
input_sub_mean = op.Sub(input, mean)
sqr = op.Mul(input_sub_mean, input_sub_mean)
var = op.ReduceMean(sqr, axes, keepdims=False)
rstd = op.Div(1.0, op.Sqrt(var + eps))
# Get mean again with size = [1, C]
mean = op.ReduceMean(input, axes, keepdims=False)
# NOTE: Fixed to be FLOAT dtype
running_mean = op.Cast(running_mean, to=FLOAT.dtype)
running_var = op.Cast(running_var, to=FLOAT.dtype)
return norm, mean, rstd, running_mean, running_var


@torch_op("aten::_native_batch_norm_legit_functional", private=True)
def _aten__native_batch_norm_inference_functional_onnx(
input: TFloat,
weight: TFloat,
bias: TFloat,
running_mean: TFloat,
running_var: TFloat,
training: bool,
momentum: float,
eps: float,
) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]:
# Assert(training is False)
norm = op.BatchNormalization(
input,
weight,
bias,
running_mean,
running_var,
epsilon=eps,
momentum=momentum,
training_mode=training,
)
# NOTE: mean and var are ommited in inference mode
# Cannot return 2 dup output, so have to do twice with different variable name
empty_mean = op.CastLike(op.Shape(input, start=0, end=0), norm)
empty_var = op.CastLike(op.Shape(input, start=0, end=0), norm)
return norm, empty_mean, empty_var, running_mean, running_var


def aten_native_batch_norm_backward(
grad_out: TensorType,
input: TensorType,
Expand Down
76 changes: 76 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,52 @@ def sample_inputs_scaled_dot_product_flash_attention(
yield from samples


# NOTE: In `_native_batch_norm_legit` tests, it generates two kinds of args:
# 1. (input, weight, bias, running_mean, running_var, training, momentum, eps)
# 2. (input, weight, bias, training, momentum, eps)
# which requires two function signatures to take the inputs, that's why we have
# two sample_inputs functions here instead.
def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs):
samples = common_methods_invocations.sample_inputs_batch_norm(
op_info, device, dtype, requires_grad, **kwargs
)
for sample in samples:
# torch.native_batch_norm does not support 0 numel tensors
# IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
if sample.input.numel() == 0:
continue
args = sample.args
training = sample.kwargs.get("training", True)
momentum = sample.kwargs.get("momentum", 0.5)
eps = sample.kwargs.get("eps", 1e-5)
if args[0] is not None and args[1] is not None:
yield opinfo_core.SampleInput(
sample.input,
args=(args[2], args[3], args[0], args[1], training, momentum, eps),
)


def sample_inputs__native_batch_norm_legit_no_stats(
op_info, device, dtype, requires_grad, **kwargs
):
samples = common_methods_invocations.sample_inputs_batch_norm(
op_info, device, dtype, requires_grad, **kwargs
)
for sample in samples:
# torch.native_batch_norm does not support 0 numel tensors
# IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
if sample.input.numel() == 0:
continue
args = sample.args
training = sample.kwargs.get("training", True)
momentum = sample.kwargs.get("momentum", 0.5)
eps = sample.kwargs.get("eps", 1e-5)
if args[0] is not None and args[1] is None:
yield opinfo_core.SampleInput(
sample.input, args=(args[2], args[3], training, momentum, eps)
)


# NOTE: How to create an OpInfo:
# 1. Create a function that generates sample inputs for the op.
# This function should yield SampleInputs.
Expand Down Expand Up @@ -1633,4 +1679,34 @@ def sample_inputs_scaled_dot_product_flash_attention(
supports_fwgrad_bwgrad=True,
check_batched_forward_grad=False,
),
opinfo_core.OpInfo(
"ops.aten._native_batch_norm_legit",
aten_name="_native_batch_norm_legit",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
sample_inputs_func=sample_inputs__native_batch_norm_legit,
),
opinfo_core.OpInfo(
"ops.aten._native_batch_norm_legit_functional",
aten_name="_native_batch_norm_legit_functional",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
sample_inputs_func=sample_inputs__native_batch_norm_legit,
),
opinfo_core.OpInfo(
"ops.aten._native_batch_norm_legit.no_stats",
aten_name="_native_batch_norm_legit.no_stats",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
sample_inputs_func=sample_inputs__native_batch_norm_legit_no_stats,
),
]
1 change: 0 additions & 1 deletion onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ def run_test_output_match(

# Obtain the tolerance for the op
rtol, atol = torchlib_op_info.get_tolerance(dtype)

for i, cpu_sample in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
# Provide the repr to subtest because tensors are not serializable in parallel test runs
Expand Down
14 changes: 14 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,6 +1680,20 @@ def _where_input_wrangler(
reason="fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975",
),
TorchLibOpInfo("native_batch_norm", core_ops.aten_native_batch_norm, trace_only=True),
TorchLibOpInfo(
"ops.aten._native_batch_norm_legit", core_ops.aten_native_batch_norm, trace_only=True
),
TorchLibOpInfo(
"ops.aten._native_batch_norm_legit.no_stats",
core_ops.aten__native_batch_norm_no_stats,
trace_only=True,
),
TorchLibOpInfo(
"ops.aten._native_batch_norm_legit_functional",
core_ops.aten__native_batch_norm_legit_functional,
trace_only=True,
compare_shape_only_for_output=(3, 4),
),
TorchLibOpInfo(
"ops.aten.native_group_norm",
core_ops.aten_native_group_norm,
Expand Down

0 comments on commit 88ee668

Please sign in to comment.