From e3a93661d7e7b1988b8e17711a68c64e3303c968 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 6 Dec 2023 05:08:10 +0800 Subject: [PATCH] AddOp(aten_replication_pad1d) | feat(torchlib) (#1203) - A large change in extra_opinfo.py, is just because re-sort the functions by name. --- onnxscript/function_libs/torch_lib/ops/nn.py | 10 +- .../function_libs/torch_lib/extra_opinfo.py | 2182 +++++++++-------- .../function_libs/torch_lib/ops_test_data.py | 4 + 3 files changed, 1117 insertions(+), 1079 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 7be4ef1a8..4eda5627f 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1496,10 +1496,18 @@ def aten_relu6(self: TReal) -> TReal: return op.Min(op.Relu(self), six) +@torch_op("aten::replication_pad1d") def aten_replication_pad1d(self: TensorType, padding: INT64) -> TensorType: """replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor""" - raise NotImplementedError() + # assert len(padding) == 2 + # Input of padding argument should be [x,y], need change to onnx format [0, x, 0, y] + start = op.Slice(padding, [0], [1], axes=[0]) + end = op.Slice(padding, [1], [2], axes=[0]) + padding_onnx = op.Concat( + op.Constant(value_ints=[0]), start, op.Constant(value_ints=[0]), end, axis=0 + ) + return op.Pad(self, padding_onnx, mode="edge") def aten_replication_pad1d_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 64e6a0622..60b9eb4f8 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -20,31 +20,110 @@ M = 10 -def sample_inputs__local_scalar_dense(op_info, device, dtype, requires_grad, **kwargs): +def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs): del op_info - shapes = ( - (), - (1,), - (3,), - (1, 1), - (1, 2), - (2, 1), - (1, 1, 1), - (2, 2, 2), - ) + shapes = [ + [3], + [], + [3, 2], + [2, 3, 2], + ] for shape in shapes: - t = torch_testing.make_tensor( - shape, - low=0, - high=1, - device=device, - dtype=dtype, - requires_grad=requires_grad, - **kwargs, - ) - yield opinfo_core.SampleInput(t) + for p in (0, 0.5, 1): + t = torch_testing.make_tensor( + shape, + low=0, + high=1, + device=device, + dtype=dtype, + requires_grad=requires_grad, + **kwargs, + ) + yield opinfo_core.SampleInput(t, args=(p,)) + yield opinfo_core.SampleInput(t, kwargs={"p": p}) + + +def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_grad, **kwargs): + del op_info + + shapes = [ + [3], + [], + [3, 2], + [2, 3, 2], + ] + + for shape in shapes: + for p in (0, 1): + t = torch_testing.make_tensor( + shape, + low=0, + high=1, + device=device, + dtype=dtype, + requires_grad=requires_grad, + **kwargs, + ) + yield opinfo_core.SampleInput(t, args=(p,)) + yield opinfo_core.SampleInput(t, kwargs={"p": p}) + + +def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): + del op_info + # input_shape, output_size, kernal, dilation, padding, stride + cases = ( + ( + (1, 12, 12), + (4, 5), + (2, 2), + {"dilation": (1, 1), "padding": (0, 0), "stride": (1, 1)}, + ), + ( + (1, 8, 30), + (4, 5), + (2, 2), + {"dilation": (1, 1), "padding": (1, 1), "stride": (1, 1)}, + ), + ( + (1, 8, 9), + (4, 4), + (2, 2), + {"dilation": (1, 1), "padding": (0, 0), "stride": (1, 1)}, + ), + ( + (1, 8, 25), + (4, 4), + (2, 2), + {"dilation": (1, 1), "padding": (1, 1), "stride": (1, 1)}, + ), + ( + (1, 8, 9), + (4, 4), + (2, 2), + {"dilation": (1, 1), "padding": (1, 1), "stride": (2, 2)}, + ), + ( + (1, 9, 4), + (4, 4), + (3, 3), + {"dilation": (1, 1), "padding": (1, 1), "stride": (2, 2)}, + ), + ( + (1, 18, 16), + (2, 2), + (1, 1), + {"dilation": (2, 2), "padding": (3, 3), "stride": (2, 2)}, + ), + ) + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + for shape, output_size, kernel_size, kwargs in cases: + tensor = make_arg(shape) + yield opinfo_core.SampleInput(tensor, args=(output_size, kernel_size), kwargs=kwargs) def sample_inputs_conv3d(op_info, device, dtype, requires_grad, **kwargs): @@ -194,373 +273,408 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs): ) -def _prepare_data_for_fft_ops(device, dtype, requires_grad=False): - # Adapted from https://github.com/pytorch/pytorch/blob/01069ad4be449f376cf88a56d842b8eb50f6e9b6/torch/testing/_internal/opinfo/core.py#L2448C1-L2541C79 - is_fp16_or_chalf = dtype in (torch.complex32, torch.half) - if not is_fp16_or_chalf: - oned_tensor = functools.partial( - opinfo_core.make_tensor, - (31,), - device=device, - dtype=dtype, - requires_grad=requires_grad, - ) - nd_tensor = functools.partial( - opinfo_core.make_tensor, - (S, S + 1, S + 2), - device=device, - dtype=dtype, - requires_grad=requires_grad, - ) - else: - low = None - high = None - shapes = ((2, 8, 9), (33,)) +def sample_inputs_embedding_renorm(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs - oned_tensor = functools.partial( - opinfo_core.make_tensor, - shapes[1], - device=device, - low=low, - high=high, - dtype=dtype, - requires_grad=requires_grad, + def make_input(shape): + return common_methods_invocations.make_tensor( + shape, device=device, dtype=dtype, requires_grad=requires_grad ) - nd_tensor = functools.partial( - opinfo_core.make_tensor, - shapes[0], + + def make_long_input(shape, *, low, high, noncontiguous=False): + return common_methods_invocations.make_tensor( + shape, device=device, + dtype=torch.long, low=low, high=high, - dtype=dtype, - requires_grad=requires_grad, + noncontiguous=noncontiguous, ) - return oned_tensor, nd_tensor - - -def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_): - del self # Unused - oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad) - - for normalization, forward in itertools.product((0, 1, 2), (True, False)): - # 1-D - yield opinfo_core.SampleInput( - oned_tensor(), dim=(0,), normalization=normalization, forward=forward - ) - # N-D - for dim in [ - (0,), - (1,), - (2,), - (1, 2), - (0, 1), - (0, 1, 2), - ]: - yield opinfo_core.SampleInput( - nd_tensor(), dim=dim, normalization=normalization, forward=forward + for max_norm in (0.5, 1.0, 5.0): + for norm_type in (0.8, 1.0, 2.0, 2.5): + idx = make_long_input((6,), low=0, high=S) + weights = make_input((S, S)) * 2 + yield common_methods_invocations.SampleInput( + weights, + args=(idx,), + kwargs={"max_norm": max_norm, "norm_type": norm_type}, ) -def sample_inputs__fft_r2c(self, device, dtype, requires_grad=False, **_): - del self # Unused - oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad) +def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs - for normalization, one_sided in itertools.product((0, 1, 2), (True, True)): - # 1-D - yield opinfo_core.SampleInput( - oned_tensor(), dim=(0,), normalization=normalization, onesided=one_sided + def make_input(shape): + return common_methods_invocations.make_tensor( + shape, device=device, dtype=dtype, requires_grad=requires_grad ) - # N-D - for dim in [ - (0,), - (1,), - (2,), - (1, 2), - (0, 1), - (0, 1, 2), - ]: - yield opinfo_core.SampleInput( - nd_tensor(), dim=dim, normalization=normalization, onesided=one_sided - ) - -def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_): - del self # Unused - oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad) - - for normalization in (0, 1, 2): - # 1-D - yield opinfo_core.SampleInput( - oned_tensor(), dim=(0,), normalization=normalization, last_dim_size=12 + def make_long_input(shape, *, low, high, noncontiguous=False): + return common_methods_invocations.make_tensor( + shape, + device=device, + dtype=torch.long, + low=low, + high=high, + noncontiguous=noncontiguous, ) - # N-D - for dim in [ - (0,), - (1,), - (2,), - (1, 2), - (0, 1), - (0, 1, 2), - ]: - yield opinfo_core.SampleInput( - nd_tensor(), dim=dim, normalization=normalization, last_dim_size=6 - ) - -def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs): - del op_info # unused - del kwargs - make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad - ) - - # Ordered as input shape, normalized_shape, eps - cases: tuple[tuple[int], tuple[int], float] = ( # type: ignore[assignment] - ((1, 2, 3), (1, 2, 3), 0.5), - ((2, 2, 3), (2, 3), -0.5), - ((1,), (1,), 1e-5), - ((1, 2), (2,), 1e-5), - ((0, 1), (1,), 1e-5), - ) + def make_per_sample_weight(flag, idx): + # a tensor of float / double weights, or None + # to indicate all weights should be taken to be 1 + if flag: + return make_input(idx.reshape(-1).shape) + return None - for input_shape, normalized_shape, eps in cases: # type: ignore[misc] - # Shape of weight and bias should be the same as normalized_shape - weight = make_arg(normalized_shape) # type: ignore[has-type] - bias = make_arg(normalized_shape) # type: ignore[has-type] - yield opinfo_core.SampleInput( - make_arg(input_shape), # type: ignore[has-type] - args=(normalized_shape, weight, bias, eps), # type: ignore[has-type] - ) - yield opinfo_core.SampleInput( - make_arg(input_shape), # type: ignore[has-type] - args=(normalized_shape, None, bias, eps), # type: ignore[has-type] - ) - yield opinfo_core.SampleInput( - make_arg(input_shape), # type: ignore[has-type] - args=(normalized_shape, weight, None, eps), # type: ignore[has-type] - ) - yield opinfo_core.SampleInput( - make_arg(input_shape), # type: ignore[has-type] - args=(normalized_shape, None, None, eps), # type: ignore[has-type] - ) + offsets = [ + torch.tensor([0, 2, 3], device=device, dtype=torch.long), + torch.tensor([0, 0, 2], device=device, dtype=torch.long), + torch.tensor([0, 2, 2, 4], device=device, dtype=torch.long), + ] + for offset in offsets: + for include_last_offset in (True, False): + for generate_per_sample_weight in (True, False): + for mode in ( + 0, + 1, + 2, + ): # ('sum', 'mean', 'max') + # per_sample_weights only support mode='sum' + if generate_per_sample_weight and mode in (1, 2): # ('mean', 'max'): + continue + # 1-D index tensor + indices = make_long_input((S,), low=0, high=M) + per_sample_weights = make_per_sample_weight( + generate_per_sample_weight, indices + ) + # 0 + yield common_methods_invocations.SampleInput( + make_input((M, S)), + args=(indices,), + kwargs={ + "offsets": offset, + "mode": mode, + "per_sample_weights": per_sample_weights, + "include_last_offset": include_last_offset, + }, + ) -class _TestParamsMaxPoolEmptyStrideBase: - # Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203 - def __init__(self): - self.kwargs = { - "kernel_size": [3], - "stride": [()], - "ceil_mode": [True, False], - "padding": [0, 1], - "dilation": [1], - } + indices = make_long_input((S,), low=0, high=M, noncontiguous=True) + per_sample_weights = make_per_sample_weight( + generate_per_sample_weight, indices + ) + # 1 + yield common_methods_invocations.SampleInput( + make_input((M, S)), + args=(indices,), + kwargs={ + "offsets": offset, + "mode": mode, + "per_sample_weights": per_sample_weights, + "include_last_offset": include_last_offset, + }, + ) - # fmt: off - self.shapes = [ - [1, 2, None], # batch - [2], # channels - [3, 6] # signal - ] - # fmt: on + if mode != 2: # "max" mode in 2-D index tensor make aten func crash + # 2-D index tensor + indices = make_long_input((S, S), low=0, high=M) + per_sample_weights = make_per_sample_weight( + generate_per_sample_weight, indices + ) + # 2 + yield common_methods_invocations.SampleInput( + make_input((M, S)), + args=(indices,), + kwargs={ + "offsets": offset, + "mode": mode, + "per_sample_weights": per_sample_weights, + "include_last_offset": include_last_offset, + }, + ) - def _gen_shape(self): - for shape in itertools.product(*self.shapes): - # shape[0] is None indicates missing batch dimension - if shape[0] is None: - shape = shape[1:] + indices = make_long_input((S, S), low=0, high=M, noncontiguous=True) + per_sample_weights = make_per_sample_weight( + generate_per_sample_weight, indices + ) + # 3 + yield common_methods_invocations.SampleInput( + make_input((M, S)), + args=(indices,), + kwargs={ + "offsets": offset, + "mode": mode, + "per_sample_weights": per_sample_weights, + "include_last_offset": include_last_offset, + }, + ) - yield shape, torch.contiguous_format - # only 2d (N, C, H, W) rank 4 tensors support channels_last memory format - if len(self.shapes) == 4 and len(shape) == 4: - yield shape, torch.channels_last - def _gen_kwargs(self): - keys = self.kwargs.keys() - for values in itertools.product(*self.kwargs.values()): - yield dict(zip(keys, values)) +def sample_inputs_embedding_bag_padding_idx(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs - def gen_input_params(self): - yield from itertools.product(self._gen_shape(), self._gen_kwargs()) + def make_input(shape): + return common_methods_invocations.make_tensor( + shape, device=device, dtype=dtype, requires_grad=requires_grad + ) + def make_long_input(shape, *, low, high, noncontiguous=False): + return common_methods_invocations.make_tensor( + shape, + device=device, + dtype=torch.long, + low=low, + high=high, + noncontiguous=noncontiguous, + ) -class _TestParamsMaxPool1dEmptyStride(_TestParamsMaxPoolEmptyStrideBase): - def __init__(self): - super().__init__() - self.kwargs["kernel_size"] += [(3,)] - self.kwargs["stride"] += [(2,)] - self.kwargs["padding"] += [(1,)] - self.kwargs["dilation"] += [(1,)] + def make_per_sample_weight(flag, idx): + # a tensor of float / double weights, or None + # to indicate all weights should be taken to be 1 + if flag: + return make_input(idx.reshape(-1).shape) + return None + offsets = [ + torch.tensor([0, 2, 3], device=device, dtype=torch.long), + # Below case not work for FullGraph mode, guess due to op.While() bug: + # when the initial condition is False, it still excute the loop body once. + # torch.tensor([0, 0, 2], device=device, dtype=torch.long), + # torch.tensor([0, 2, 2, 4], device=device, dtype=torch.long), + ] + for offset in offsets: + for include_last_offset in (True, False): + for generate_per_sample_weight in (True, False): + for mode in ( + 0, + 1, + 2, + ): # ('sum', 'mean', 'max') + # per_sample_weights only support mode='sum' + if generate_per_sample_weight and mode in (1, 2): # ('mean', 'max'): + continue -class _TestParamsMaxPool2dEmptyStride(_TestParamsMaxPoolEmptyStrideBase): - def __init__(self): - super().__init__() - self.kwargs["kernel_size"] += [(3, 2)] - self.kwargs["stride"] += [(2, 1)] - self.kwargs["padding"] += [(1, 1)] - self.kwargs["dilation"] += [(1, 2)] + for padding_idx in (-1, 0, 1, 2, 3): + # 1-D index tensor + indices = make_long_input((S,), low=0, high=M) + per_sample_weights = make_per_sample_weight( + generate_per_sample_weight, indices + ) + # 0 + yield common_methods_invocations.SampleInput( + make_input((M, S)), + args=(indices,), + kwargs={ + "offsets": offset, + "scale_grad_by_freq": False, + "mode": mode, + "sparse": False, + "per_sample_weights": per_sample_weights, + "include_last_offset": include_last_offset, + "padding_idx": padding_idx, + }, + ) - self.shapes.append([6]) + indices = make_long_input((S,), low=0, high=M, noncontiguous=True) + per_sample_weights = make_per_sample_weight( + generate_per_sample_weight, indices + ) + # 1 + yield common_methods_invocations.SampleInput( + make_input((M, S)), + args=(indices,), + kwargs={ + "offsets": offset, + "scale_grad_by_freq": False, + "mode": mode, + "sparse": False, + "per_sample_weights": per_sample_weights, + "include_last_offset": include_last_offset, + "padding_idx": padding_idx, + }, + ) + # if mode != 2: # "max" mode in 2-D index tensor make aten func crash + # # 2-D index tensor + # indices = make_long_input((S, S), low=0, high=M) + # per_sample_weights = make_per_sample_weight( + # generate_per_sample_weight, indices + # ) + # # 2 + # yield common_methods_invocations.SampleInput( + # make_input((M, S)), + # args=(indices,), + # kwargs={ + # "offsets": offset, + # "mode": mode, + # "per_sample_weights": per_sample_weights, + # "include_last_offset": include_last_offset, + # "padding_idx": padding_idx, + # }, + # ) -class _TestParamsMaxPool3dEmptyStride(_TestParamsMaxPoolEmptyStrideBase): - def __init__(self): - super().__init__() - self.kwargs["kernel_size"] += [(3, 2, 3)] - self.kwargs["stride"] += [(2, 1, 2)] - self.kwargs["dilation"] += [(1, 2, 1)] + # indices = make_long_input((S, S), low=0, high=M, noncontiguous=True) + # per_sample_weights = make_per_sample_weight( + # generate_per_sample_weight, indices + # ) + # # 3 + # yield common_methods_invocations.SampleInput( + # make_input((M, S)), + # args=(indices,), + # kwargs={ + # "offsets": offset, + # "mode": mode, + # "per_sample_weights": per_sample_weights, + # "include_last_offset": include_last_offset, + # "padding_idx": padding_idx, + # }, + # ) - self.shapes.append([6]) - self.shapes.append([5]) +def sample_inputs__local_scalar_dense(op_info, device, dtype, requires_grad, **kwargs): + del op_info -def sample_inputs_max_pool_empty_strides(op_info, device, dtype, requires_grad, **kwargs): - make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=False + shapes = ( + (), + (1,), + (3,), + (1, 1), + (1, 2), + (2, 1), + (1, 1, 1), + (2, 2, 2), ) - # FIXME: (RuntimeError: non-empty 3D or 4D (batch mode) tensor expected for input) - - params_generator_type_dict = { - "ops.aten.max_pool1d": _TestParamsMaxPool1dEmptyStride, - "ops.aten.max_pool2d": _TestParamsMaxPool2dEmptyStride, - "ops.aten.max_pool3d": _TestParamsMaxPool3dEmptyStride, - } - - params_generator = params_generator_type_dict[op_info.name]() - for (shape, memory_format), kwargs in params_generator.gen_input_params(): - arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) - yield opinfo_core.SampleInput(arg, kwargs=kwargs) - + for shape in shapes: + t = torch_testing.make_tensor( + shape, + low=0, + high=1, + device=device, + dtype=dtype, + requires_grad=requires_grad, + **kwargs, + ) + yield opinfo_core.SampleInput(t) -def sample_inputs_max_pool1d_with_indices(op_info, device, dtype, requires_grad, **kwargs): - del op_info - make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=False - ) - params_generator = ( - common_methods_invocations._TestParamsMaxPool1d() # pylint: disable=protected-access - ) - for (shape, memory_format), kwargs in params_generator.gen_input_params(): - arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) - yield opinfo_core.SampleInput(arg, kwargs=kwargs) +def _prepare_data_for_fft_ops(device, dtype, requires_grad=False): + # Adapted from https://github.com/pytorch/pytorch/blob/01069ad4be449f376cf88a56d842b8eb50f6e9b6/torch/testing/_internal/opinfo/core.py#L2448C1-L2541C79 + is_fp16_or_chalf = dtype in (torch.complex32, torch.half) + if not is_fp16_or_chalf: + oned_tensor = functools.partial( + opinfo_core.make_tensor, + (31,), + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + nd_tensor = functools.partial( + opinfo_core.make_tensor, + (S, S + 1, S + 2), + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + else: + low = None + high = None + shapes = ((2, 8, 9), (33,)) -def sample_inputs_max_pool2d_with_indices(op_info, device, dtype, requires_grad, **kwargs): - del op_info - make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=False - ) - params_generator = ( - common_methods_invocations._TestParamsMaxPool2d() # pylint: disable=protected-access - ) - for (shape, memory_format), kwargs in params_generator.gen_input_params(): - arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) - yield opinfo_core.SampleInput(arg, kwargs=kwargs) + oned_tensor = functools.partial( + opinfo_core.make_tensor, + shapes[1], + device=device, + low=low, + high=high, + dtype=dtype, + requires_grad=requires_grad, + ) + nd_tensor = functools.partial( + opinfo_core.make_tensor, + shapes[0], + device=device, + low=low, + high=high, + dtype=dtype, + requires_grad=requires_grad, + ) + return oned_tensor, nd_tensor -def sample_inputs_max_pool3d_with_indices(op_info, device, dtype, requires_grad, **kwargs): - del op_info - make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=False - ) - params_generator = ( - common_methods_invocations._TestParamsMaxPool3d() # pylint: disable=protected-access - ) - for (shape, memory_format), kwargs in params_generator.gen_input_params(): - arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) - yield opinfo_core.SampleInput(arg, kwargs=kwargs) +def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_): + del self # Unused + oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad) -def sample_inputs_native_group_norm(op_info, device, dtype, requires_grad, **kwargs): - del op_info - make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad - ) + for normalization, forward in itertools.product((0, 1, 2), (True, False)): + # 1-D + yield opinfo_core.SampleInput( + oned_tensor(), dim=(0,), normalization=normalization, forward=forward + ) + # N-D + for dim in [ + (0,), + (1,), + (2,), + (1, 2), + (0, 1), + (0, 1, 2), + ]: + yield opinfo_core.SampleInput( + nd_tensor(), dim=dim, normalization=normalization, forward=forward + ) - # Ordered as input shape, C,N,HxW, and kwargs for group and eps - cases = ( - ((1, 6, 3), (6,), (6,), 1, 6, 3, {"group": 2, "eps": 0.5}), - ((2, 6, 3), (6,), (6,), 2, 6, 3, {"group": 3, "eps": -0.5}), - ((5, 5, 5), (5,), (5,), 5, 5, 5, {"group": 1, "eps": 1e-5}), - ((5, 8, 10), (8,), (8,), 5, 8, 10, {"group": 4, "eps": 1e-5}), - ) - for input_shape, weight, bias, N, C, HxW, kwargs in cases: - # args: running mean, running var, weight and bias should necessarily be of shape: (channels,) - channels = input_shape[1] if len(input_shape) > 1 else 0 - weight = make_arg(channels) if channels > 0 else None - bias = make_arg(channels) if channels > 0 else None +def sample_inputs__fft_r2c(self, device, dtype, requires_grad=False, **_): + del self # Unused + oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad) + for normalization, one_sided in itertools.product((0, 1, 2), (True, True)): + # 1-D yield opinfo_core.SampleInput( - make_arg(input_shape), - args=( - weight, - bias, - N, - C, - HxW, - ), - kwargs=kwargs, + oned_tensor(), dim=(0,), normalization=normalization, onesided=one_sided ) + # N-D + for dim in [ + (0,), + (1,), + (2,), + (1, 2), + (0, 1), + (0, 1, 2), + ]: + yield opinfo_core.SampleInput( + nd_tensor(), dim=dim, normalization=normalization, onesided=one_sided + ) -def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): - del op_info - # input_shape, output_size, kernal, dilation, padding, stride - cases = ( - ( - (1, 12, 12), - (4, 5), - (2, 2), - {"dilation": (1, 1), "padding": (0, 0), "stride": (1, 1)}, - ), - ( - (1, 8, 30), - (4, 5), - (2, 2), - {"dilation": (1, 1), "padding": (1, 1), "stride": (1, 1)}, - ), - ( - (1, 8, 9), - (4, 4), - (2, 2), - {"dilation": (1, 1), "padding": (0, 0), "stride": (1, 1)}, - ), - ( - (1, 8, 25), - (4, 4), - (2, 2), - {"dilation": (1, 1), "padding": (1, 1), "stride": (1, 1)}, - ), - ( - (1, 8, 9), - (4, 4), - (2, 2), - {"dilation": (1, 1), "padding": (1, 1), "stride": (2, 2)}, - ), - ( - (1, 9, 4), - (4, 4), - (3, 3), - {"dilation": (1, 1), "padding": (1, 1), "stride": (2, 2)}, - ), - ( - (1, 18, 16), - (2, 2), - (1, 1), - {"dilation": (2, 2), "padding": (3, 3), "stride": (2, 2)}, - ), - ) +def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_): + del self # Unused + oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad) - make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad - ) - for shape, output_size, kernel_size, kwargs in cases: - tensor = make_arg(shape) - yield opinfo_core.SampleInput(tensor, args=(output_size, kernel_size), kwargs=kwargs) + for normalization in (0, 1, 2): + # 1-D + yield opinfo_core.SampleInput( + oned_tensor(), dim=(0,), normalization=normalization, last_dim_size=12 + ) + # N-D + for dim in [ + (0,), + (1,), + (2,), + (1, 2), + (0, 1), + (0, 1, 2), + ]: + yield opinfo_core.SampleInput( + nd_tensor(), dim=dim, normalization=normalization, last_dim_size=6 + ) def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs): @@ -612,124 +726,42 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args) -def sample_inputs_native_dropout( - op_info, device, dtype, requires_grad, *, valid_input_dim=None, **kwargs -): - del op_info # Unused - del kwargs # Unused - make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad - ) - - if valid_input_dim: - cases = ((S,) * i for i in valid_input_dim) - else: - cases = ((S, S), (S,), ()) - # ONNX requires 0 <= p < 1 - p_vals = [0.0] - - training_vals = [True, False] - - for case, p, training in itertools.product(cases, p_vals, training_vals): - yield opinfo_core.SampleInput(make_arg(case), p=p, train=training) - - -def sample_inputs_normal_tensor_float(op_info, device, dtype, requires_grad, **kwargs): - del op_info - del requires_grad - del kwargs - make_arg = functools.partial( - torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=False - ) - samples = ( - ((S, S), 0.0), - ((S, S, S), 4.2), - ) - for mean, std in samples: - yield opinfo_core.SampleInput(make_arg(mean), std) - - -def sample_inputs_normal_float_tensor(op_info, device, dtype, requires_grad, **kwargs): - del op_info - del requires_grad - del kwargs - make_arg = functools.partial( - torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=False - ) - samples = ( - (4.2, (S, S)), - (-2.0, (S, S, S)), - ) - for mean, std in samples: - yield opinfo_core.SampleInput(mean, make_arg(std, low=0.0)) - - -def sample_inputs_normal_tensor_tensor(op_info, device, dtype, requires_grad, **kwargs): - del op_info - del requires_grad +def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs): + del op_info # unused del kwargs - make_arg = functools.partial( - torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=False - ) - samples = ( - ((S, S), (S, S)), - ((S, S, S), (S, S, S)), - ) - for mean, std in samples: - yield opinfo_core.SampleInput(make_arg(mean), make_arg(std, low=0.0)) - - -def sample_inputs_rand(op_info, device, dtype, requires_grad, **kwargs): - del op_info # Unused - del device # Unused - del requires_grad # Unused - del kwargs # Unused - - shapes = ( - (M,), - (S, S), - (S, S, S), - ) - - for shape in shapes: - yield opinfo_core.SampleInput(shape, kwargs=dict(dtype=dtype)) - - -def sample_inputs_rand_like(op_info, device, dtype, requires_grad, **kwargs): - del op_info # Unused - del kwargs # Unused - make_arg = functools.partial( torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad ) - shapes = ( - (M,), - (S, S), - (S, S, S), - ) - - for shape in shapes: - yield opinfo_core.SampleInput(make_arg(shape)) - - -def sample_inputs_rand_like_dtype(op_info, device, dtype, requires_grad, **kwargs): - del op_info # Unused - del kwargs # Unused - - make_arg = functools.partial( - torch_testing.make_tensor, - device=device, - dtype=torch.float32, - requires_grad=requires_grad, - ) - shapes = ( - (M,), - (S, S), - (S, S, S), + + # Ordered as input shape, normalized_shape, eps + cases: tuple[tuple[int], tuple[int], float] = ( # type: ignore[assignment] + ((1, 2, 3), (1, 2, 3), 0.5), + ((2, 2, 3), (2, 3), -0.5), + ((1,), (1,), 1e-5), + ((1, 2), (2,), 1e-5), + ((0, 1), (1,), 1e-5), ) - for shape in shapes: - yield opinfo_core.SampleInput(make_arg(shape), kwargs=dict(dtype=dtype)) + for input_shape, normalized_shape, eps in cases: # type: ignore[misc] + # Shape of weight and bias should be the same as normalized_shape + weight = make_arg(normalized_shape) # type: ignore[has-type] + bias = make_arg(normalized_shape) # type: ignore[has-type] + yield opinfo_core.SampleInput( + make_arg(input_shape), # type: ignore[has-type] + args=(normalized_shape, weight, bias, eps), # type: ignore[has-type] + ) + yield opinfo_core.SampleInput( + make_arg(input_shape), # type: ignore[has-type] + args=(normalized_shape, None, bias, eps), # type: ignore[has-type] + ) + yield opinfo_core.SampleInput( + make_arg(input_shape), # type: ignore[has-type] + args=(normalized_shape, weight, None, eps), # type: ignore[has-type] + ) + yield opinfo_core.SampleInput( + make_arg(input_shape), # type: ignore[has-type] + args=(normalized_shape, None, None, eps), # type: ignore[has-type] + ) def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): @@ -764,457 +796,392 @@ def sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(t, **kwargs) -def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs): - high = 10 - - for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): - # With high - yield opinfo_core.SampleInput(high, sample.input.shape, *sample.args, **sample.kwargs) +def sample_inputs__log_softmax( + op_info, + device, + dtype, + requires_grad, + **kwargs, +): + del op_info # Unused + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + cases = [ + ((S,), (0,)), + ((S, S), (0,)), + ((S, S), (1,)), + ((S, S), (-1,)), + ((S, M, S), (2,)), + ((S, 0, 0), (-1,)), + ] -def sample_inputs_randint_low(self, device, dtype, requires_grad, **kwargs): - low = 2 - high = 10 + for (shape, dim), half_to_float in itertools.product(cases, (False,)): + # NOTE: softmax with half to float conversion is not supported on CPU + # So we don't test it here + kwargs = dict(half_to_float=half_to_float) + yield opinfo_core.SampleInput(make_arg(shape), args=dim, kwargs=kwargs) - for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): - # With low and high - yield opinfo_core.SampleInput( - low, high, sample.input.shape, *sample.args, **sample.kwargs - ) +def sample_inputs_max_pool_empty_strides(op_info, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=False + ) -def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs): - high = 10 + # FIXME: (RuntimeError: non-empty 3D or 4D (batch mode) tensor expected for input) - for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): - # With high - yield opinfo_core.SampleInput(sample.input, high, *sample.args, **sample.kwargs) + params_generator_type_dict = { + "ops.aten.max_pool1d": _TestParamsMaxPool1dEmptyStride, + "ops.aten.max_pool2d": _TestParamsMaxPool2dEmptyStride, + "ops.aten.max_pool3d": _TestParamsMaxPool3dEmptyStride, + } + params_generator = params_generator_type_dict[op_info.name]() + for (shape, memory_format), kwargs in params_generator.gen_input_params(): + arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) + yield opinfo_core.SampleInput(arg, kwargs=kwargs) -def sample_inputs_randint_like_dtype(self, device, dtype, requires_grad, **kwargs): - high = 10 - for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): - # With low and high - yield opinfo_core.SampleInput(sample.input, high, *sample.args, **sample.kwargs) +def sample_inputs_max_pool1d_with_indices(op_info, device, dtype, requires_grad, **kwargs): + del op_info + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=False + ) + params_generator = ( + common_methods_invocations._TestParamsMaxPool1d() # pylint: disable=protected-access + ) + for (shape, memory_format), kwargs in params_generator.gen_input_params(): + arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) + yield opinfo_core.SampleInput(arg, kwargs=kwargs) -def sample_inputs_randint_like_low_dtype(self, device, dtype, requires_grad, **kwargs): - low = 2 - high = 10 +def sample_inputs_max_pool2d_with_indices(op_info, device, dtype, requires_grad, **kwargs): + del op_info + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=False + ) + params_generator = ( + common_methods_invocations._TestParamsMaxPool2d() # pylint: disable=protected-access + ) + for (shape, memory_format), kwargs in params_generator.gen_input_params(): + arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) + yield opinfo_core.SampleInput(arg, kwargs=kwargs) - for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): - # With low and high - yield opinfo_core.SampleInput(sample.input, low, high, *sample.args, **sample.kwargs) +def sample_inputs_max_pool3d_with_indices(op_info, device, dtype, requires_grad, **kwargs): + del op_info + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=False + ) + params_generator = ( + common_methods_invocations._TestParamsMaxPool3d() # pylint: disable=protected-access + ) + for (shape, memory_format), kwargs in params_generator.gen_input_params(): + arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) + yield opinfo_core.SampleInput(arg, kwargs=kwargs) -def sample_inputs_randint_like_low_dtype_dtype(self, device, dtype, requires_grad, **kwargs): - low = 2 - high = 10 - for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): - # With low and high - yield opinfo_core.SampleInput(sample.input, low, high, *sample.args, **sample.kwargs) +def sample_inputs_native_group_norm(op_info, device, dtype, requires_grad, **kwargs): + del op_info + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + # Ordered as input shape, C,N,HxW, and kwargs for group and eps + cases = ( + ((1, 6, 3), (6,), (6,), 1, 6, 3, {"group": 2, "eps": 0.5}), + ((2, 6, 3), (6,), (6,), 2, 6, 3, {"group": 3, "eps": -0.5}), + ((5, 5, 5), (5,), (5,), 5, 5, 5, {"group": 1, "eps": 1e-5}), + ((5, 8, 10), (8,), (8,), 5, 8, 10, {"group": 4, "eps": 1e-5}), + ) -def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): - del op # Unused - del device # Unused - del requires_grad # Unused - del kwargs # Unused + for input_shape, weight, bias, N, C, HxW, kwargs in cases: + # args: running mean, running var, weight and bias should necessarily be of shape: (channels,) + channels = input_shape[1] if len(input_shape) > 1 else 0 + weight = make_arg(channels) if channels > 0 else None + bias = make_arg(channels) if channels > 0 else None - shapes = ((M,), (S, S)) + yield opinfo_core.SampleInput( + make_arg(input_shape), + args=( + weight, + bias, + N, + C, + HxW, + ), + kwargs=kwargs, + ) - for shape in shapes: - yield opinfo_core.SampleInput(input=shape, kwargs=dict(dtype=dtype)) +def sample_inputs_native_dropout( + op_info, device, dtype, requires_grad, *, valid_input_dim=None, **kwargs +): + del op_info # Unused + del kwargs # Unused + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) -def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs): - del op_info - del kwargs + if valid_input_dim: + cases = ((S,) * i for i in valid_input_dim) + else: + cases = ((S, S), (S,), ()) + # ONNX requires 0 <= p < 1 + p_vals = [0.0] - def mt(shape, **kwargs): - return torch_testing.make_tensor( - shape, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs - ) + training_vals = [True, False] - yield opinfo_core.SampleInput(mt(100), n_fft=10, return_complex=True) - yield opinfo_core.SampleInput(mt(100), n_fft=10, return_complex=False) - if dtype.is_complex: - yield opinfo_core.SampleInput(mt(100), n_fft=10) + for case, p, training in itertools.product(cases, p_vals, training_vals): + yield opinfo_core.SampleInput(make_arg(case), p=p, train=training) - yield opinfo_core.SampleInput(mt(10), n_fft=7, return_complex=True) - yield opinfo_core.SampleInput(mt((10, 100)), n_fft=16, hop_length=4, return_complex=True) - window = mt(16, low=0.5, high=2.0) - yield opinfo_core.SampleInput( - mt((2, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True) +# NOTE: In `_native_batch_norm_legit` tests, it generates two kinds of args: +# 1. (input, weight, bias, running_mean, running_var, training, momentum, eps) +# 2. (input, weight, bias, training, momentum, eps) +# which requires two function signatures to take the inputs, that's why we have +# two sample_inputs functions here instead. +def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs): + samples = common_methods_invocations.sample_inputs_batch_norm( + op_info, device, dtype, requires_grad, **kwargs ) - yield opinfo_core.SampleInput( - mt((3, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + training = sample.kwargs.get("training", True) + momentum = sample.kwargs.get("momentum", 0.5) + eps = sample.kwargs.get("eps", 1e-5) + if args[0] is not None and args[1] is not None: + yield opinfo_core.SampleInput( + sample.input, + args=(args[2], args[3], args[0], args[1], training, momentum, eps), + ) + + +def sample_inputs__native_batch_norm_legit_no_stats( + op_info, device, dtype, requires_grad, **kwargs +): + samples = common_methods_invocations.sample_inputs_batch_norm( + op_info, device, dtype, requires_grad, **kwargs ) - if not dtype.is_complex: - yield opinfo_core.SampleInput( - mt((10, 100)), n_fft=16, window=window, onesided=False, return_complex=True - ) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + training = sample.kwargs.get("training", True) + momentum = sample.kwargs.get("momentum", 0.5) + eps = sample.kwargs.get("eps", 1e-5) + if args[0] is not None and args[1] is None: + yield opinfo_core.SampleInput( + sample.input, args=(args[2], args[3], training, momentum, eps) + ) -def sample_inputs_tensor_bool(op_info, device, dtype, requires_grad, **kwargs): +def sample_inputs_normal_tensor_float(op_info, device, dtype, requires_grad, **kwargs): del op_info - del device del requires_grad del kwargs - yield opinfo_core.SampleInput(True, dtype=dtype) - yield opinfo_core.SampleInput(False, dtype=dtype) + make_arg = functools.partial( + torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=False + ) + samples = ( + ((S, S), 0.0), + ((S, S, S), 4.2), + ) + for mean, std in samples: + yield opinfo_core.SampleInput(make_arg(mean), std) -def sample_inputs_tensor_float(op_info, device, dtype, requires_grad, **kwargs): +def sample_inputs_normal_float_tensor(op_info, device, dtype, requires_grad, **kwargs): del op_info - del device del requires_grad del kwargs - yield opinfo_core.SampleInput(3.0, dtype=dtype) - yield opinfo_core.SampleInput(-1.0, dtype=dtype) + make_arg = functools.partial( + torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=False + ) + samples = ( + (4.2, (S, S)), + (-2.0, (S, S, S)), + ) + for mean, std in samples: + yield opinfo_core.SampleInput(mean, make_arg(std, low=0.0)) -def sample_inputs_tensor_int(op_info, device, dtype, requires_grad, **kwargs): +def sample_inputs_normal_tensor_tensor(op_info, device, dtype, requires_grad, **kwargs): del op_info - del device del requires_grad del kwargs - yield opinfo_core.SampleInput(2, dtype=dtype) - yield opinfo_core.SampleInput(-5, dtype=dtype) + make_arg = functools.partial( + torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=False + ) + samples = ( + ((S, S), (S, S)), + ((S, S, S), (S, S, S)), + ) + for mean, std in samples: + yield opinfo_core.SampleInput(make_arg(mean), make_arg(std, low=0.0)) -def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs): - del op_info +def sample_inputs_rand(op_info, device, dtype, requires_grad, **kwargs): + del op_info # Unused + del device # Unused + del requires_grad # Unused + del kwargs # Unused - shapes = [ - [3], - [], - [3, 2], - [2, 3, 2], - ] + shapes = ( + (M,), + (S, S), + (S, S, S), + ) for shape in shapes: - for p in (0, 0.5, 1): - t = torch_testing.make_tensor( - shape, - low=0, - high=1, - device=device, - dtype=dtype, - requires_grad=requires_grad, - **kwargs, - ) - yield opinfo_core.SampleInput(t, args=(p,)) - yield opinfo_core.SampleInput(t, kwargs={"p": p}) + yield opinfo_core.SampleInput(shape, kwargs=dict(dtype=dtype)) -def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_grad, **kwargs): - del op_info +def sample_inputs_rand_like(op_info, device, dtype, requires_grad, **kwargs): + del op_info # Unused + del kwargs # Unused - shapes = [ - [3], - [], - [3, 2], - [2, 3, 2], - ] + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + shapes = ( + (M,), + (S, S), + (S, S, S), + ) for shape in shapes: - for p in (0, 1): - t = torch_testing.make_tensor( - shape, - low=0, - high=1, - device=device, - dtype=dtype, - requires_grad=requires_grad, - **kwargs, - ) - yield opinfo_core.SampleInput(t, args=(p,)) - yield opinfo_core.SampleInput(t, kwargs={"p": p}) + yield opinfo_core.SampleInput(make_arg(shape)) -def sample_inputs_embedding_renorm(op_info, device, dtype, requires_grad, **kwargs): - del op_info - del kwargs +def sample_inputs_rand_like_dtype(op_info, device, dtype, requires_grad, **kwargs): + del op_info # Unused + del kwargs # Unused - def make_input(shape): - return common_methods_invocations.make_tensor( - shape, device=device, dtype=dtype, requires_grad=requires_grad - ) + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=torch.float32, + requires_grad=requires_grad, + ) + shapes = ( + (M,), + (S, S), + (S, S, S), + ) - def make_long_input(shape, *, low, high, noncontiguous=False): - return common_methods_invocations.make_tensor( - shape, - device=device, - dtype=torch.long, - low=low, - high=high, - noncontiguous=noncontiguous, - ) + for shape in shapes: + yield opinfo_core.SampleInput(make_arg(shape), kwargs=dict(dtype=dtype)) - for max_norm in (0.5, 1.0, 5.0): - for norm_type in (0.8, 1.0, 2.0, 2.5): - idx = make_long_input((6,), low=0, high=S) - weights = make_input((S, S)) * 2 - yield common_methods_invocations.SampleInput( - weights, - args=(idx,), - kwargs={"max_norm": max_norm, "norm_type": norm_type}, - ) +def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs): + high = 10 -def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs): - del op_info - del kwargs + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With high + yield opinfo_core.SampleInput(high, sample.input.shape, *sample.args, **sample.kwargs) - def make_input(shape): - return common_methods_invocations.make_tensor( - shape, device=device, dtype=dtype, requires_grad=requires_grad - ) - def make_long_input(shape, *, low, high, noncontiguous=False): - return common_methods_invocations.make_tensor( - shape, - device=device, - dtype=torch.long, - low=low, - high=high, - noncontiguous=noncontiguous, +def sample_inputs_randint_low(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With low and high + yield opinfo_core.SampleInput( + low, high, sample.input.shape, *sample.args, **sample.kwargs ) - def make_per_sample_weight(flag, idx): - # a tensor of float / double weights, or None - # to indicate all weights should be taken to be 1 - if flag: - return make_input(idx.reshape(-1).shape) - return None - offsets = [ - torch.tensor([0, 2, 3], device=device, dtype=torch.long), - torch.tensor([0, 0, 2], device=device, dtype=torch.long), - torch.tensor([0, 2, 2, 4], device=device, dtype=torch.long), - ] - for offset in offsets: - for include_last_offset in (True, False): - for generate_per_sample_weight in (True, False): - for mode in ( - 0, - 1, - 2, - ): # ('sum', 'mean', 'max') - # per_sample_weights only support mode='sum' - if generate_per_sample_weight and mode in (1, 2): # ('mean', 'max'): - continue +def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs): + high = 10 - # 1-D index tensor - indices = make_long_input((S,), low=0, high=M) - per_sample_weights = make_per_sample_weight( - generate_per_sample_weight, indices - ) - # 0 - yield common_methods_invocations.SampleInput( - make_input((M, S)), - args=(indices,), - kwargs={ - "offsets": offset, - "mode": mode, - "per_sample_weights": per_sample_weights, - "include_last_offset": include_last_offset, - }, - ) + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With high + yield opinfo_core.SampleInput(sample.input, high, *sample.args, **sample.kwargs) - indices = make_long_input((S,), low=0, high=M, noncontiguous=True) - per_sample_weights = make_per_sample_weight( - generate_per_sample_weight, indices - ) - # 1 - yield common_methods_invocations.SampleInput( - make_input((M, S)), - args=(indices,), - kwargs={ - "offsets": offset, - "mode": mode, - "per_sample_weights": per_sample_weights, - "include_last_offset": include_last_offset, - }, - ) - if mode != 2: # "max" mode in 2-D index tensor make aten func crash - # 2-D index tensor - indices = make_long_input((S, S), low=0, high=M) - per_sample_weights = make_per_sample_weight( - generate_per_sample_weight, indices - ) - # 2 - yield common_methods_invocations.SampleInput( - make_input((M, S)), - args=(indices,), - kwargs={ - "offsets": offset, - "mode": mode, - "per_sample_weights": per_sample_weights, - "include_last_offset": include_last_offset, - }, - ) +def sample_inputs_randint_like_dtype(self, device, dtype, requires_grad, **kwargs): + high = 10 + + for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): + # With low and high + yield opinfo_core.SampleInput(sample.input, high, *sample.args, **sample.kwargs) - indices = make_long_input((S, S), low=0, high=M, noncontiguous=True) - per_sample_weights = make_per_sample_weight( - generate_per_sample_weight, indices - ) - # 3 - yield common_methods_invocations.SampleInput( - make_input((M, S)), - args=(indices,), - kwargs={ - "offsets": offset, - "mode": mode, - "per_sample_weights": per_sample_weights, - "include_last_offset": include_last_offset, - }, - ) +def sample_inputs_randint_like_low_dtype(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 -def sample_inputs_embedding_bag_padding_idx(op_info, device, dtype, requires_grad, **kwargs): - del op_info - del kwargs + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With low and high + yield opinfo_core.SampleInput(sample.input, low, high, *sample.args, **sample.kwargs) - def make_input(shape): - return common_methods_invocations.make_tensor( - shape, device=device, dtype=dtype, requires_grad=requires_grad - ) - def make_long_input(shape, *, low, high, noncontiguous=False): - return common_methods_invocations.make_tensor( - shape, - device=device, - dtype=torch.long, - low=low, - high=high, - noncontiguous=noncontiguous, - ) +def sample_inputs_randint_like_low_dtype_dtype(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 - def make_per_sample_weight(flag, idx): - # a tensor of float / double weights, or None - # to indicate all weights should be taken to be 1 - if flag: - return make_input(idx.reshape(-1).shape) - return None + for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): + # With low and high + yield opinfo_core.SampleInput(sample.input, low, high, *sample.args, **sample.kwargs) - offsets = [ - torch.tensor([0, 2, 3], device=device, dtype=torch.long), - # Below case not work for FullGraph mode, guess due to op.While() bug: - # when the initial condition is False, it still excute the loop body once. - # torch.tensor([0, 0, 2], device=device, dtype=torch.long), - # torch.tensor([0, 2, 2, 4], device=device, dtype=torch.long), - ] - for offset in offsets: - for include_last_offset in (True, False): - for generate_per_sample_weight in (True, False): - for mode in ( - 0, - 1, - 2, - ): # ('sum', 'mean', 'max') - # per_sample_weights only support mode='sum' - if generate_per_sample_weight and mode in (1, 2): # ('mean', 'max'): - continue - for padding_idx in (-1, 0, 1, 2, 3): - # 1-D index tensor - indices = make_long_input((S,), low=0, high=M) - per_sample_weights = make_per_sample_weight( - generate_per_sample_weight, indices - ) - # 0 - yield common_methods_invocations.SampleInput( - make_input((M, S)), - args=(indices,), - kwargs={ - "offsets": offset, - "scale_grad_by_freq": False, - "mode": mode, - "sparse": False, - "per_sample_weights": per_sample_weights, - "include_last_offset": include_last_offset, - "padding_idx": padding_idx, - }, - ) +def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): + del op # Unused + del device # Unused + del requires_grad # Unused + del kwargs # Unused - indices = make_long_input((S,), low=0, high=M, noncontiguous=True) - per_sample_weights = make_per_sample_weight( - generate_per_sample_weight, indices - ) - # 1 - yield common_methods_invocations.SampleInput( - make_input((M, S)), - args=(indices,), - kwargs={ - "offsets": offset, - "scale_grad_by_freq": False, - "mode": mode, - "sparse": False, - "per_sample_weights": per_sample_weights, - "include_last_offset": include_last_offset, - "padding_idx": padding_idx, - }, - ) + shapes = ((M,), (S, S)) - # if mode != 2: # "max" mode in 2-D index tensor make aten func crash - # # 2-D index tensor - # indices = make_long_input((S, S), low=0, high=M) - # per_sample_weights = make_per_sample_weight( - # generate_per_sample_weight, indices - # ) - # # 2 - # yield common_methods_invocations.SampleInput( - # make_input((M, S)), - # args=(indices,), - # kwargs={ - # "offsets": offset, - # "mode": mode, - # "per_sample_weights": per_sample_weights, - # "include_last_offset": include_last_offset, - # "padding_idx": padding_idx, - # }, - # ) + for shape in shapes: + yield opinfo_core.SampleInput(input=shape, kwargs=dict(dtype=dtype)) - # indices = make_long_input((S, S), low=0, high=M, noncontiguous=True) - # per_sample_weights = make_per_sample_weight( - # generate_per_sample_weight, indices - # ) - # # 3 - # yield common_methods_invocations.SampleInput( - # make_input((M, S)), - # args=(indices,), - # kwargs={ - # "offsets": offset, - # "mode": mode, - # "per_sample_weights": per_sample_weights, - # "include_last_offset": include_last_offset, - # "padding_idx": padding_idx, - # }, - # ) +def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs -def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs): + cases: tuple = ( # ignore + ((2, 3), (1, 2)), + ((4, 5), (0, 1)), + ((6, 7), (1, 1)), + ((8, 9), (1, 0)), + ) + + make_inp = opinfo_core.partial( + torch.testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + for shape, pad in cases: + yield opinfo_core.SampleInput(make_inp(shape), args=(pad,)) + + +def sample_inputs_replication_pad1d(op_info, device, dtype, requires_grad, **kwargs): del op_info - # Case `target_end == 1`, where `target_end = (input.size(dimension) - size) // step + 1`. - t = torch_testing.make_tensor( - (2, 3, 4), - device=device, - dtype=dtype, - requires_grad=requires_grad, - **kwargs, + del kwargs + + cases: tuple = ( # ignore + ((2, 3), (1, 2)), + ((4, 5), (0, 1)), + ((6, 7), (1, 1)), + ((8, 9), (1, 0)), ) - for dimension, size, step in [ - (1, 2, 2), - (-1, 2, 2), - (-2, 2, 2), - ]: - yield opinfo_core.SampleInput(t, args=(dimension, size, step)) + + make_inp = opinfo_core.partial( + torch.testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + for shape, pad in cases: + yield opinfo_core.SampleInput(make_inp(shape), args=(pad,)) def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs): @@ -1232,74 +1199,18 @@ def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs) ((L, L, L), (L, L, L), (1, 0, L, 1)), ((L, L, L), (L, L // 2, L), (1, L // 2, L, 1)), ((L, L, L), (L, L // 4, L), (1, L // 2, L, 2)), - ((L, L, L), (L, L, L), (2, 0, L, 1)), - ((L, L, L), (L, L, L // 2), (2, L // 2, L, 1)), - ((L, L, L), (L, L, L // 4), (2, L // 2, L, 2)), - ((L, L, L), (L, L // 2, L), (1, L // 2, L * 2, 1)), # end > L - ((L, L, L), (L, L, L), (-2, 0, L, 1)), # negative dim - ((L, L, L), (L, L, L // 4), (-1, L // 2, L * 2, 2)), # end > L and negative dim - ) - - for input_shape, src_shape, args in cases: - input_ = make_arg(input_shape) - src = make_arg(src_shape) - yield opinfo_core.SampleInput(input_, args=(src, *args)) - - -def sample_inputs__log_softmax( - op_info, - device, - dtype, - requires_grad, - **kwargs, -): - del op_info # Unused - - make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad - ) - cases = [ - ((S,), (0,)), - ((S, S), (0,)), - ((S, S), (1,)), - ((S, S), (-1,)), - ((S, M, S), (2,)), - ((S, 0, 0), (-1,)), - ] - - for (shape, dim), half_to_float in itertools.product(cases, (False,)): - # NOTE: softmax with half to float conversion is not supported on CPU - # So we don't test it here - kwargs = dict(half_to_float=half_to_float) - yield opinfo_core.SampleInput(make_arg(shape), args=dim, kwargs=kwargs) - - -def sample_inputs__softmax( - op_info, - device, - dtype, - requires_grad, - **kwargs, -): - del op_info # Unused - - make_arg = functools.partial( - torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad - ) - cases = [ - ((S,), (0,)), - ((S, S), (0,)), - ((S, S), (1,)), - ((S, S), (-1,)), - ((S, M, S), (2,)), - ((S, 0, 0), (-1,)), - ] + ((L, L, L), (L, L, L), (2, 0, L, 1)), + ((L, L, L), (L, L, L // 2), (2, L // 2, L, 1)), + ((L, L, L), (L, L, L // 4), (2, L // 2, L, 2)), + ((L, L, L), (L, L // 2, L), (1, L // 2, L * 2, 1)), # end > L + ((L, L, L), (L, L, L), (-2, 0, L, 1)), # negative dim + ((L, L, L), (L, L, L // 4), (-1, L // 2, L * 2, 2)), # end > L and negative dim + ) - for (shape, dim), half_to_float in itertools.product(cases, (False,)): - # NOTE: softmax with half to float conversion is not supported on CPU - # So we don't test it here - kwargs = dict(half_to_float=half_to_float) - yield opinfo_core.SampleInput(make_arg(shape), args=dim, kwargs=kwargs) + for input_shape, src_shape, args in cases: + input_ = make_arg(input_shape) + src = make_arg(src_shape) + yield opinfo_core.SampleInput(input_, args=(src, *args)) def sample_inputs__scaled_dot_product_flash_attention( @@ -1381,69 +1292,177 @@ def sample_inputs__scaled_dot_product_efficient_attention( yield from samples -# NOTE: In `_native_batch_norm_legit` tests, it generates two kinds of args: -# 1. (input, weight, bias, running_mean, running_var, training, momentum, eps) -# 2. (input, weight, bias, training, momentum, eps) -# which requires two function signatures to take the inputs, that's why we have -# two sample_inputs functions here instead. -def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs): - samples = common_methods_invocations.sample_inputs_batch_norm( - op_info, device, dtype, requires_grad, **kwargs +def sample_inputs__softmax( + op_info, + device, + dtype, + requires_grad, + **kwargs, +): + del op_info # Unused + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad ) - for sample in samples: - # torch.native_batch_norm does not support 0 numel tensors - # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) - if sample.input.numel() == 0: - continue - args = sample.args - training = sample.kwargs.get("training", True) - momentum = sample.kwargs.get("momentum", 0.5) - eps = sample.kwargs.get("eps", 1e-5) - if args[0] is not None and args[1] is not None: - yield opinfo_core.SampleInput( - sample.input, - args=(args[2], args[3], args[0], args[1], training, momentum, eps), - ) + cases = [ + ((S,), (0,)), + ((S, S), (0,)), + ((S, S), (1,)), + ((S, S), (-1,)), + ((S, M, S), (2,)), + ((S, 0, 0), (-1,)), + ] + for (shape, dim), half_to_float in itertools.product(cases, (False,)): + # NOTE: softmax with half to float conversion is not supported on CPU + # So we don't test it here + kwargs = dict(half_to_float=half_to_float) + yield opinfo_core.SampleInput(make_arg(shape), args=dim, kwargs=kwargs) -def sample_inputs__native_batch_norm_legit_no_stats( - op_info, device, dtype, requires_grad, **kwargs -): - samples = common_methods_invocations.sample_inputs_batch_norm( - op_info, device, dtype, requires_grad, **kwargs + +def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + def mt(shape, **kwargs): + return torch_testing.make_tensor( + shape, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + yield opinfo_core.SampleInput(mt(100), n_fft=10, return_complex=True) + yield opinfo_core.SampleInput(mt(100), n_fft=10, return_complex=False) + if dtype.is_complex: + yield opinfo_core.SampleInput(mt(100), n_fft=10) + + yield opinfo_core.SampleInput(mt(10), n_fft=7, return_complex=True) + yield opinfo_core.SampleInput(mt((10, 100)), n_fft=16, hop_length=4, return_complex=True) + + window = mt(16, low=0.5, high=2.0) + yield opinfo_core.SampleInput( + mt((2, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True) ) - for sample in samples: - # torch.native_batch_norm does not support 0 numel tensors - # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) - if sample.input.numel() == 0: - continue - args = sample.args - training = sample.kwargs.get("training", True) - momentum = sample.kwargs.get("momentum", 0.5) - eps = sample.kwargs.get("eps", 1e-5) - if args[0] is not None and args[1] is None: - yield opinfo_core.SampleInput( - sample.input, args=(args[2], args[3], training, momentum, eps) - ) + yield opinfo_core.SampleInput( + mt((3, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True) + ) + if not dtype.is_complex: + yield opinfo_core.SampleInput( + mt((10, 100)), n_fft=16, window=window, onesided=False, return_complex=True + ) -def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwargs): +def sample_inputs_tensor_bool(op_info, device, dtype, requires_grad, **kwargs): del op_info + del device + del requires_grad del kwargs + yield opinfo_core.SampleInput(True, dtype=dtype) + yield opinfo_core.SampleInput(False, dtype=dtype) - cases: tuple = ( # ignore - ((2, 3), (1, 2)), - ((4, 5), (0, 1)), - ((6, 7), (1, 1)), - ((8, 9), (1, 0)), - ) - make_inp = opinfo_core.partial( - torch.testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad +def sample_inputs_tensor_float(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del device + del requires_grad + del kwargs + yield opinfo_core.SampleInput(3.0, dtype=dtype) + yield opinfo_core.SampleInput(-1.0, dtype=dtype) + + +def sample_inputs_tensor_int(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del device + del requires_grad + del kwargs + yield opinfo_core.SampleInput(2, dtype=dtype) + yield opinfo_core.SampleInput(-5, dtype=dtype) + + +def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs): + del op_info + # Case `target_end == 1`, where `target_end = (input.size(dimension) - size) // step + 1`. + t = torch_testing.make_tensor( + (2, 3, 4), + device=device, + dtype=dtype, + requires_grad=requires_grad, + **kwargs, ) + for dimension, size, step in [ + (1, 2, 2), + (-1, 2, 2), + (-2, 2, 2), + ]: + yield opinfo_core.SampleInput(t, args=(dimension, size, step)) - for shape, pad in cases: - yield opinfo_core.SampleInput(make_inp(shape), args=(pad,)) + +class _TestParamsMaxPoolEmptyStrideBase: + # Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203 + def __init__(self): + self.kwargs = { + "kernel_size": [3], + "stride": [()], + "ceil_mode": [True, False], + "padding": [0, 1], + "dilation": [1], + } + + # fmt: off + self.shapes = [ + [1, 2, None], # batch + [2], # channels + [3, 6] # signal + ] + # fmt: on + + def _gen_shape(self): + for shape in itertools.product(*self.shapes): + # shape[0] is None indicates missing batch dimension + if shape[0] is None: + shape = shape[1:] + + yield shape, torch.contiguous_format + # only 2d (N, C, H, W) rank 4 tensors support channels_last memory format + if len(self.shapes) == 4 and len(shape) == 4: + yield shape, torch.channels_last + + def _gen_kwargs(self): + keys = self.kwargs.keys() + for values in itertools.product(*self.kwargs.values()): + yield dict(zip(keys, values)) + + def gen_input_params(self): + yield from itertools.product(self._gen_shape(), self._gen_kwargs()) + + +class _TestParamsMaxPool1dEmptyStride(_TestParamsMaxPoolEmptyStrideBase): + def __init__(self): + super().__init__() + self.kwargs["kernel_size"] += [(3,)] + self.kwargs["stride"] += [(2,)] + self.kwargs["padding"] += [(1,)] + self.kwargs["dilation"] += [(1,)] + + +class _TestParamsMaxPool2dEmptyStride(_TestParamsMaxPoolEmptyStrideBase): + def __init__(self): + super().__init__() + self.kwargs["kernel_size"] += [(3, 2)] + self.kwargs["stride"] += [(2, 1)] + self.kwargs["padding"] += [(1, 1)] + self.kwargs["dilation"] += [(1, 2)] + + self.shapes.append([6]) + + +class _TestParamsMaxPool3dEmptyStride(_TestParamsMaxPoolEmptyStrideBase): + def __init__(self): + super().__init__() + self.kwargs["kernel_size"] += [(3, 2, 3)] + self.kwargs["stride"] += [(2, 1, 2)] + self.kwargs["dilation"] += [(1, 2, 1)] + + self.shapes.append([6]) + self.shapes.append([5]) # NOTE: How to create an OpInfo: @@ -1462,31 +1481,20 @@ def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwar # the `op` field explicitly. OP_DB: List[opinfo_core.OpInfo] = [ opinfo_core.OpInfo( - "ops.aten._fft_c2c", - aten_name="_fft_c2c", - dtypes=common_dtype.complex_types(), - sample_inputs_func=sample_inputs__fft_c2c, - supports_out=False, - ), - opinfo_core.OpInfo( - "ops.aten._fft_c2r", - aten_name="_fft_c2r", - dtypes=common_dtype.complex_types(), - sample_inputs_func=sample_inputs__fft_c2r, - supports_out=False, - ), - opinfo_core.OpInfo( - "ops.aten._fft_r2c", - aten_name="_fft_r2c", - dtypes=common_dtype.floating_types(), - sample_inputs_func=sample_inputs__fft_r2c, + "ops.aten.bernoulli.p", + aten_name="bernoulli.p", + # dtypes can be a tuple of (torch.float, torch.double). + dtypes=common_dtype.all_types(), + sample_inputs_func=sample_inputs_bernoulli_p, supports_out=False, ), opinfo_core.OpInfo( - "ops.aten._local_scalar_dense", - aten_name="_local_scalar_dense", + # Deterministic bernoulli sampling where p is either 0 or 1 + "ops.aten.bernoulli.p_deterministic", + op=torch.ops.aten.bernoulli.p, + aten_name="bernoulli.p", dtypes=common_dtype.all_types(), - sample_inputs_func=sample_inputs__local_scalar_dense, + sample_inputs_func=sample_inputs_bernoulli_p_deterministic, supports_out=False, ), opinfo_core.OpInfo( @@ -1496,6 +1504,20 @@ def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwar sample_inputs_func=sample_inputs_col2im, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.conv3d", + aten_name="conv3d", + dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16), + sample_inputs_func=sample_inputs_conv3d, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.convolution", + aten_name="convolution", + dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16), + sample_inputs_func=sample_inputs_convolution, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.embedding_bag", aten_name="embedding_bag", @@ -1518,24 +1540,24 @@ def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwar supports_out=False, ), opinfo_core.OpInfo( - "ops.aten.conv3d", - aten_name="conv3d", - dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16), - sample_inputs_func=sample_inputs_conv3d, + "ops.aten._fft_c2c", + aten_name="_fft_c2c", + dtypes=common_dtype.complex_types(), + sample_inputs_func=sample_inputs__fft_c2c, supports_out=False, ), opinfo_core.OpInfo( - "ops.aten.convolution", - aten_name="convolution", - dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16), - sample_inputs_func=sample_inputs_convolution, + "ops.aten._fft_c2r", + aten_name="_fft_c2r", + dtypes=common_dtype.complex_types(), + sample_inputs_func=sample_inputs__fft_c2r, supports_out=False, ), opinfo_core.OpInfo( - "ops.aten.reflection_pad1d", - aten_name="ops.aten.reflection_pad1d", - dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16), - sample_inputs_func=sample_inputs_reflection_pad1d, + "ops.aten._fft_r2c", + aten_name="_fft_r2c", + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs__fft_r2c, supports_out=False, ), opinfo_core.OpInfo( @@ -1553,6 +1575,21 @@ def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwar sample_inputs_func=sample_inputs_layer_norm, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._local_scalar_dense", + aten_name="_local_scalar_dense", + dtypes=common_dtype.all_types(), + sample_inputs_func=sample_inputs__local_scalar_dense, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten._log_softmax", + op=torch.ops.aten._log_softmax, # pylint: disable=protected-access + aten_name="_log_softmax", + dtypes=common_dtype.floating_types_and_half(), + sample_inputs_func=sample_inputs__log_softmax, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.max_pool1d", variant_test_name="empty_strides", @@ -1591,6 +1628,36 @@ def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwar sample_inputs_func=sample_inputs_native_group_norm, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._native_batch_norm_legit", + aten_name="_native_batch_norm_legit", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs__native_batch_norm_legit, + ), + opinfo_core.OpInfo( + "ops.aten._native_batch_norm_legit_functional", + aten_name="_native_batch_norm_legit_functional", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs__native_batch_norm_legit, + ), + opinfo_core.OpInfo( + "ops.aten._native_batch_norm_legit.no_stats", + aten_name="_native_batch_norm_legit.no_stats", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs__native_batch_norm_legit_no_stats, + ), opinfo_core.OpInfo( "ops.aten.normal.float_Tensor", aten_name="normal.Tensor_Tensor", @@ -1612,27 +1679,6 @@ def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwar sample_inputs_func=sample_inputs_normal_tensor_tensor, supports_out=False, ), - opinfo_core.OpInfo( - "nn.functional.max_pool1d_with_indices", - aten_name="max_pool1d_with_indices", - dtypes=common_dtype.floating_types_and(torch.bfloat16), - sample_inputs_func=sample_inputs_max_pool1d_with_indices, - supports_out=False, - ), - opinfo_core.OpInfo( - "nn.functional.max_pool2d_with_indices", - aten_name="max_pool2d_with_indices", - dtypes=common_dtype.floating_types_and(torch.bfloat16), - sample_inputs_func=sample_inputs_max_pool2d_with_indices, - supports_out=False, - ), - opinfo_core.OpInfo( - "nn.functional.max_pool3d_with_indices", - aten_name="max_pool3d_with_indices", - dtypes=common_dtype.floating_types_and(torch.bfloat16), - sample_inputs_func=sample_inputs_max_pool3d_with_indices, - supports_out=False, - ), opinfo_core.OpInfo( "ops.aten.rand", aten_name="rand", @@ -1721,6 +1767,62 @@ def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwar sample_inputs_func=sample_inputs_like_fns_dtype, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.reflection_pad1d", + aten_name="ops.aten.reflection_pad1d", + dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16), + sample_inputs_func=sample_inputs_reflection_pad1d, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.replication_pad1d", + aten_name="ops.aten.replication_pad1d", + dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16), + sample_inputs_func=sample_inputs_replication_pad1d, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten._scaled_dot_product_flash_attention", + aten_name="_scaled_dot_product_flash_attention", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + # NOTE: Different from aten::scaled_dot_product_attention, this op doesn't support + # dim<=3 input. + sample_inputs_func=sample_inputs__scaled_dot_product_flash_attention, + supports_out=False, + supports_forward_ad=False, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + ), + opinfo_core.OpInfo( + "ops.aten._scaled_dot_product_efficient_attention", + aten_name="_scaled_dot_product_efficient_attention", + # only support CUDA + dtypes=common_dtype.empty_types(), + dtypesIfCUDA=common_dtype.floating_types_and(torch.bfloat16), + # NOTE: Different from aten::scaled_dot_product_attention, this op doesn't support + # dim<=3 input. + sample_inputs_func=sample_inputs__scaled_dot_product_efficient_attention, + supports_out=False, + supports_forward_ad=False, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + decorators=[common_device_type.onlyCUDA], + ), + opinfo_core.OpInfo( + "ops.aten.slice_scatter", + aten_name="slice_scatter", + dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_slice_scatter, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten._softmax", + op=torch.ops.aten._softmax, # pylint: disable=protected-access + aten_name="_softmax", + dtypes=common_dtype.floating_types_and_half(), + sample_inputs_func=sample_inputs__softmax, + supports_out=False, + ), # NOTE: torch.STFT has pre-padding and it's not supported by aten::stft # This custom OpInfo uses aten::stft directly. opinfo_core.OpInfo( @@ -1751,23 +1853,6 @@ def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwar sample_inputs_func=sample_inputs_tensor_int, supports_out=False, ), - opinfo_core.OpInfo( - "ops.aten.bernoulli.p", - aten_name="bernoulli.p", - # dtypes can be a tuple of (torch.float, torch.double). - dtypes=common_dtype.all_types(), - sample_inputs_func=sample_inputs_bernoulli_p, - supports_out=False, - ), - opinfo_core.OpInfo( - # Deterministic bernoulli sampling where p is either 0 or 1 - "ops.aten.bernoulli.p_deterministic", - op=torch.ops.aten.bernoulli.p, - aten_name="bernoulli.p", - dtypes=common_dtype.all_types(), - sample_inputs_func=sample_inputs_bernoulli_p_deterministic, - supports_out=False, - ), opinfo_core.OpInfo( "ops.aten.unfold", aten_name="unfold", @@ -1776,83 +1861,24 @@ def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwar supports_out=False, ), opinfo_core.OpInfo( - "ops.aten.slice_scatter", - aten_name="slice_scatter", - dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), - sample_inputs_func=sample_inputs_slice_scatter, - supports_out=False, - ), - opinfo_core.OpInfo( - "ops.aten._log_softmax", - op=torch.ops.aten._log_softmax, # pylint: disable=protected-access - aten_name="_log_softmax", - dtypes=common_dtype.floating_types_and_half(), - sample_inputs_func=sample_inputs__log_softmax, - supports_out=False, - ), - opinfo_core.OpInfo( - "ops.aten._softmax", - op=torch.ops.aten._softmax, # pylint: disable=protected-access - aten_name="_softmax", - dtypes=common_dtype.floating_types_and_half(), - sample_inputs_func=sample_inputs__softmax, - supports_out=False, - ), - opinfo_core.OpInfo( - "ops.aten._scaled_dot_product_flash_attention", - aten_name="_scaled_dot_product_flash_attention", + "nn.functional.max_pool1d_with_indices", + aten_name="max_pool1d_with_indices", dtypes=common_dtype.floating_types_and(torch.bfloat16), - # NOTE: Different from aten::scaled_dot_product_attention, this op doesn't support - # dim<=3 input. - sample_inputs_func=sample_inputs__scaled_dot_product_flash_attention, - supports_out=False, - supports_forward_ad=False, - supports_fwgrad_bwgrad=True, - check_batched_forward_grad=False, - ), - opinfo_core.OpInfo( - "ops.aten._scaled_dot_product_efficient_attention", - aten_name="_scaled_dot_product_efficient_attention", - # only support CUDA - dtypes=common_dtype.empty_types(), - dtypesIfCUDA=common_dtype.floating_types_and(torch.bfloat16), - # NOTE: Different from aten::scaled_dot_product_attention, this op doesn't support - # dim<=3 input. - sample_inputs_func=sample_inputs__scaled_dot_product_efficient_attention, + sample_inputs_func=sample_inputs_max_pool1d_with_indices, supports_out=False, - supports_forward_ad=False, - supports_fwgrad_bwgrad=True, - check_batched_forward_grad=False, - decorators=[common_device_type.onlyCUDA], ), opinfo_core.OpInfo( - "ops.aten._native_batch_norm_legit", - aten_name="_native_batch_norm_legit", - dtypes=common_dtype.floating_types_and(torch.bfloat16), - dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16), - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - assert_jit_shape_analysis=True, - sample_inputs_func=sample_inputs__native_batch_norm_legit, - ), - opinfo_core.OpInfo( - "ops.aten._native_batch_norm_legit_functional", - aten_name="_native_batch_norm_legit_functional", + "nn.functional.max_pool2d_with_indices", + aten_name="max_pool2d_with_indices", dtypes=common_dtype.floating_types_and(torch.bfloat16), - dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16), - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - assert_jit_shape_analysis=True, - sample_inputs_func=sample_inputs__native_batch_norm_legit, + sample_inputs_func=sample_inputs_max_pool2d_with_indices, + supports_out=False, ), opinfo_core.OpInfo( - "ops.aten._native_batch_norm_legit.no_stats", - aten_name="_native_batch_norm_legit.no_stats", + "nn.functional.max_pool3d_with_indices", + aten_name="max_pool3d_with_indices", dtypes=common_dtype.floating_types_and(torch.bfloat16), - dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16), - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - assert_jit_shape_analysis=True, - sample_inputs_func=sample_inputs__native_batch_norm_legit_no_stats, + sample_inputs_func=sample_inputs_max_pool3d_with_indices, + supports_out=False, ), ] diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index abf319872..2d5b1ab72 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1230,6 +1230,10 @@ def _where_input_wrangler( dtypes=(torch.int64,), reason="fixme: ORT did not implement Relu for int64. https://github.com/microsoft/onnxruntime/issues/16654", ), + TorchLibOpInfo( + "ops.aten.replication_pad1d", + nn_ops.aten_replication_pad1d, + ), TorchLibOpInfo( "nn.functional.replication_pad2d", nn_ops.aten_replication_pad2d,