From 98861b05bf33093609e24323f2ddbf21ef11e71d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 6 Jan 2025 13:11:49 -0800 Subject: [PATCH] [torchlib] Implement window functions (#1995) - BlackmanWindow - Hann - Hamming Fixes https://github.com/pytorch/pytorch/issues/142458 --- .../function_libs/torch_lib/ops/core.py | 43 ++++++++++++++++--- tests/function_libs/torch_lib/extra_opinfo.py | 31 +++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 7 +++ 3 files changed, 75 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 584c178d5..1145e9b13 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1495,10 +1495,19 @@ def aten_bitwise_xor(self: TInt, other: TInt) -> TInt: return op.BitwiseXor(self, other) -def aten_blackman_window(window_length: int) -> TensorType: +@torch_op("aten::blackman_window", trace_only=True) +def aten_blackman_window( + window_length: int, + dtype: int = 1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - raise NotImplementedError() + if dtype is None or dtype == -1: + dtype = 1 + return op.BlackmanWindow(window_length, output_datatype=dtype) def aten_block_diag(tensors: Sequence[TensorType]) -> TensorType: @@ -3921,16 +3930,38 @@ def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL: return op.And(self, op.Not(other)) -def aten_hamming_window(window_length: int) -> TensorType: +@torch_op("aten::hamming_window", trace_only=True) +def aten_hamming_window( + window_length: int, + dtype: int = 1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - raise NotImplementedError() + if dtype is None or dtype == -1: + dtype = 1 + # ONNX uses different alpha/beta values for the Hamming window + # Whereas PyTorch uses alpha=0.54, beta=0.46, ONNX uses + # alpha=0.543478, beta=0.456522. This causes a slight difference + # in the output values, but we still uses the HammingWindow op for performance. + return op.HammingWindow(window_length, output_datatype=dtype) -def aten_hann_window(window_length: int) -> TensorType: +@torch_op("aten::hann_window", trace_only=True) +def aten_hann_window( + window_length: int, + dtype: int = 1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - raise NotImplementedError() + if dtype is None or dtype == -1: + dtype = 1 + return op.HannWindow(window_length, output_datatype=dtype) def aten_hardshrink(self: TensorType, lambd: float = 0.5) -> TensorType: diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 91f1df916..4dc486c5e 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1935,6 +1935,16 @@ def shape(size, rank, with_batch_channel=True): ) +def sample_inputs_window_functions(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + del device + del requires_grad + + for window_length in [2, 3, 7, 10, 32]: + yield opinfo_core.SampleInput(window_length, kwargs=dict(dtype=dtype)) + + class _TestParamsMaxPoolEmptyStrideBase: # Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203 def __init__(self): @@ -2037,6 +2047,13 @@ def __init__(self): sample_inputs_func=sample_inputs_bernoulli_p_deterministic, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.blackman_window", + aten_name="blackman_window", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_window_functions, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.col2im", aten_name="col2im", @@ -2115,6 +2132,20 @@ def __init__(self): lhs_make_tensor_kwargs=dict(low=0), rhs_make_tensor_kwargs=dict(exclude_zero=True, low=0), ), + opinfo_core.OpInfo( + "ops.aten.hamming_window", + aten_name="hamming_window", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_window_functions, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.hann_window", + aten_name="hann_window", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_window_functions, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.index.Tensor", aten_name="index.Tensor", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 07164d594..bebd9a8ab 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -695,6 +695,7 @@ def _where_input_wrangler( TorchLibOpInfo("bitwise_right_shift_int64", core_ops.aten_bitwise_right_shift_int64), TorchLibOpInfo("bitwise_right_shift_int8", core_ops.aten_bitwise_right_shift_int8), TorchLibOpInfo("bitwise_xor", core_ops.aten_bitwise_xor), + TorchLibOpInfo("ops.aten.blackman_window", core_ops.aten_blackman_window), TorchLibOpInfo("bmm", core_ops.aten_bmm), TorchLibOpInfo("broadcast_to", core_ops.aten_broadcast_to), TorchLibOpInfo("cat", core_ops.aten_cat).skip( @@ -1630,6 +1631,12 @@ def _where_input_wrangler( matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), reason="Using op.InstanceNormalization to simulate GroupNorm, which does not support 0-dim input", ), + TorchLibOpInfo( + "ops.aten.hamming_window", + core_ops.aten_hamming_window, + tolerance={torch.float32: (8e-2, 6e-3)}, + ), + TorchLibOpInfo("ops.aten.hann_window", core_ops.aten_hann_window), TorchLibOpInfo("heaviside", core_ops.aten_heaviside), TorchLibOpInfo( "nn.functional.grid_sample",