diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2a69b1a5a..b577c6535 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3840,19 +3840,6 @@ def aten_grid_sampler_3d_backward( raise NotImplementedError() -def aten_group_norm( - input: TensorType, - num_groups: int, - weight: Optional[TensorType] = None, - bias: Optional[TensorType] = None, - eps: float = 1e-05, - cudnn_enabled: bool = True, -) -> TensorType: - """group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor""" - - raise NotImplementedError() - - def aten_gru_cell( input: TensorType, hx: TensorType, @@ -6087,7 +6074,9 @@ def _aten_native_group_norm_onnx( axes_unsqueeze = op.Range(1, input_rank - 1, 1) weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze) + weight_full_shape = op.CastLike(weight_full_shape, norm) norm_mul_weight = op.Mul(norm, weight_full_shape) + bias_full_shape = op.CastLike(bias_full_shape, norm_mul_weight) norm_result = op.Add(norm_mul_weight, bias_full_shape) # Compute mean and rstd, but using Torch algorithm # The returned shape for mean and vstd should be [N, group, -1] diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 5e0da20d0..6243499fb 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -593,6 +593,56 @@ def aten_glu_backward_jvp( raise NotImplementedError() +@torch_op("aten::group_norm", trace_only=True) +def aten_group_norm( + input: TFloat, + num_groups: int, + weight: Optional[TFloat] = None, + bias: Optional[TFloat] = None, + eps: float = 1e-05, + cudnn_enabled: bool = True, +) -> TensorType: + """group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor""" + + # Actually we don't need N,C,HxW value because the input tensor has that information + if weight is None: # Set to 1.0 as default, the shape is Channel size + weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2)) + + if bias is None: # Set to 0.0 as default, the shape is Channel size + bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) + + # Because onnx.GroupNorm() need size=group for weight and bias + # But the torch's aten function's input need size=channel, the size mismatched + # So we have to use onnx.InstanceNorm() to simulate + neg_1 = op.Constant(value_ints=[-1]) + # Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter + group_tensor = op.Reshape(num_groups, neg_1) + # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1] + shape_input = op.Concat(op.Constant(value_ints=[0]), group_tensor, neg_1, axis=0) + input_reshaped = op.Reshape(input, shape_input) + weight_inst_norm = op.Expand( + op.CastLike(op.Constant(value_float=1.0), input), group_tensor + ) + bias_inst_norm = op.Expand(op.CastLike(op.Constant(value_float=0.0), input), group_tensor) + norm = op.InstanceNormalization( + input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps + ) + # Reshape back to input's shape + norm = op.Reshape(norm, op.Shape(input)) + # Using the input weight and bias to do affine + # But need to unsqueeze to the target shape for broading cast easy + input_rank = Rank(input) + one = op.Constant(value_int=1) + axes_unsqueeze = op.Range(one, op.Sub(input_rank, one), one) + weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) + bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze) + weight_full_shape = op.CastLike(weight_full_shape, norm) + norm_mul_weight = op.Mul(norm, weight_full_shape) + bias_full_shape = op.CastLike(bias_full_shape, norm_mul_weight) + norm_result = op.Add(norm_mul_weight, bias_full_shape) + return norm_result + + def aten_glu_jvp(glu: TensorType, x: TensorType, dx: TensorType, dim: int) -> TensorType: """glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor""" diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 386534915..b66f7214a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1696,6 +1696,14 @@ def _where_input_wrangler( matcher=lambda sample: sample.args[1] == 2, reason="fixme: 'bicubic' mode in ORT implemented differently with Torch", ), + TorchLibOpInfo( + "nn.functional.group_norm", + nn_ops.aten_group_norm, + tolerance={torch.float16: (1e-2, 7e-3)}, + ).xfail( + matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), + reason="Using op.InstanceNormalization to simulate GroupNorm, which does not support 0-dim input", + ), TorchLibOpInfo("heaviside", core_ops.aten_heaviside), TorchLibOpInfo( "hstack",