Skip to content

Commit

Permalink
Merge branch 'main' into justinchu/remove-global-avg-pool
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Jul 24, 2024
2 parents 888b27b + 874365e commit c94031a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 13 deletions.
15 changes: 2 additions & 13 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3840,19 +3840,6 @@ def aten_grid_sampler_3d_backward(
raise NotImplementedError()


def aten_group_norm(
input: TensorType,
num_groups: int,
weight: Optional[TensorType] = None,
bias: Optional[TensorType] = None,
eps: float = 1e-05,
cudnn_enabled: bool = True,
) -> TensorType:
"""group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor"""

raise NotImplementedError()


def aten_gru_cell(
input: TensorType,
hx: TensorType,
Expand Down Expand Up @@ -6087,7 +6074,9 @@ def _aten_native_group_norm_onnx(
axes_unsqueeze = op.Range(1, input_rank - 1, 1)
weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze)
bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze)
weight_full_shape = op.CastLike(weight_full_shape, norm)
norm_mul_weight = op.Mul(norm, weight_full_shape)
bias_full_shape = op.CastLike(bias_full_shape, norm_mul_weight)
norm_result = op.Add(norm_mul_weight, bias_full_shape)
# Compute mean and rstd, but using Torch algorithm
# The returned shape for mean and vstd should be [N, group, -1]
Expand Down
50 changes: 50 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,56 @@ def aten_glu_backward_jvp(
raise NotImplementedError()


@torch_op("aten::group_norm", trace_only=True)
def aten_group_norm(
input: TFloat,
num_groups: int,
weight: Optional[TFloat] = None,
bias: Optional[TFloat] = None,
eps: float = 1e-05,
cudnn_enabled: bool = True,
) -> TensorType:
"""group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor"""

# Actually we don't need N,C,HxW value because the input tensor has that information
if weight is None: # Set to 1.0 as default, the shape is Channel size
weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2))

if bias is None: # Set to 0.0 as default, the shape is Channel size
bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2))

# Because onnx.GroupNorm() need size=group for weight and bias
# But the torch's aten function's input need size=channel, the size mismatched
# So we have to use onnx.InstanceNorm() to simulate
neg_1 = op.Constant(value_ints=[-1])
# Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter
group_tensor = op.Reshape(num_groups, neg_1)
# 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1]
shape_input = op.Concat(op.Constant(value_ints=[0]), group_tensor, neg_1, axis=0)
input_reshaped = op.Reshape(input, shape_input)
weight_inst_norm = op.Expand(
op.CastLike(op.Constant(value_float=1.0), input), group_tensor
)
bias_inst_norm = op.Expand(op.CastLike(op.Constant(value_float=0.0), input), group_tensor)
norm = op.InstanceNormalization(
input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps
)
# Reshape back to input's shape
norm = op.Reshape(norm, op.Shape(input))
# Using the input weight and bias to do affine
# But need to unsqueeze to the target shape for broading cast easy
input_rank = Rank(input)
one = op.Constant(value_int=1)
axes_unsqueeze = op.Range(one, op.Sub(input_rank, one), one)
weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze)
bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze)
weight_full_shape = op.CastLike(weight_full_shape, norm)
norm_mul_weight = op.Mul(norm, weight_full_shape)
bias_full_shape = op.CastLike(bias_full_shape, norm_mul_weight)
norm_result = op.Add(norm_mul_weight, bias_full_shape)
return norm_result


def aten_glu_jvp(glu: TensorType, x: TensorType, dx: TensorType, dim: int) -> TensorType:
"""glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor"""

Expand Down
8 changes: 8 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,14 @@ def _where_input_wrangler(
matcher=lambda sample: sample.args[1] == 2,
reason="fixme: 'bicubic' mode in ORT implemented differently with Torch",
),
TorchLibOpInfo(
"nn.functional.group_norm",
nn_ops.aten_group_norm,
tolerance={torch.float16: (1e-2, 7e-3)},
).xfail(
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
reason="Using op.InstanceNormalization to simulate GroupNorm, which does not support 0-dim input",
),
TorchLibOpInfo("heaviside", core_ops.aten_heaviside),
TorchLibOpInfo(
"hstack",
Expand Down

0 comments on commit c94031a

Please sign in to comment.