Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Op (group_norm) | feat(torchlib) #1750

Merged
merged 3 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3840,19 +3840,6 @@ def aten_grid_sampler_3d_backward(
raise NotImplementedError()


def aten_group_norm(
input: TensorType,
num_groups: int,
weight: Optional[TensorType] = None,
bias: Optional[TensorType] = None,
eps: float = 1e-05,
cudnn_enabled: bool = True,
) -> TensorType:
"""group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor"""

raise NotImplementedError()


def aten_gru_cell(
input: TensorType,
hx: TensorType,
Expand Down
50 changes: 50 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,56 @@ def aten_glu_backward_jvp(
raise NotImplementedError()


@torch_op("aten::group_norm", trace_only=True)
def aten_group_norm(
input: TFloat,
num_groups: int,
weight: Optional[TFloat] = None,
bias: Optional[TFloat] = None,
eps: float = 1e-05,
cudnn_enabled: bool = True,
) -> TensorType:
"""group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor"""

# Actually we don't need N,C,HxW value because the input tensor has that information
if weight is None: # Set to 1.0 as default, the shape is Channel size
weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2))

if bias is None: # Set to 0.0 as default, the shape is Channel size
bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2))

# Because onnx.GroupNorm() need size=group for weight and bias
# But the torch's aten function's input need size=channel, the size mismatched
# So we have to use onnx.InstanceNorm() to simulate
neg_1 = op.Constant(value_ints=[-1])
# Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter
group_tensor = op.Reshape(num_groups, neg_1)
# 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1]
shape_input = op.Concat(op.Constant(value_ints=[0]), group_tensor, neg_1, axis=0)
input_reshaped = op.Reshape(input, shape_input)
weight_inst_norm = op.Expand(
op.CastLike(op.Constant(value_float=1.0), input), group_tensor
)
bias_inst_norm = op.Expand(op.CastLike(op.Constant(value_float=0.0), input), group_tensor)
norm = op.InstanceNormalization(
input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps
)
# Reshape back to input's shape
norm = op.Reshape(norm, op.Shape(input))
# Using the input weight and bias to do affine
# But need to unsqueeze to the target shape for broading cast easy
input_rank = Rank(input)
one = op.Constant(value_int=1)
axes_unsqueeze = op.Range(one, op.Sub(input_rank, one), one)
weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze)
bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze)
weight_full_shape = op.CastLike(weight_full_shape, norm)
norm_mul_weight = op.Mul(norm, weight_full_shape)
bias_full_shape = op.CastLike(bias_full_shape, norm_mul_weight)
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
norm_result = op.Add(norm_mul_weight, bias_full_shape)
return norm_result


def aten_glu_jvp(glu: TensorType, x: TensorType, dx: TensorType, dim: int) -> TensorType:
"""glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor"""

Expand Down
8 changes: 8 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,14 @@ def _where_input_wrangler(
matcher=lambda sample: sample.args[1] == 2,
reason="fixme: 'bicubic' mode in ORT implemented differently with Torch",
),
TorchLibOpInfo(
"nn.functional.group_norm",
nn_ops.aten_group_norm,
tolerance={torch.float16: (1e-2, 7e-3)},
).xfail(
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
reason="Using op.InstanceNormalization to simulate GroupNorm, which does not support 0-dim input",
),
TorchLibOpInfo("heaviside", core_ops.aten_heaviside),
TorchLibOpInfo(
"hstack",
Expand Down
Loading