Skip to content

Commit

Permalink
[torchlib] Implement window functions (#1995)
Browse files Browse the repository at this point in the history
- BlackmanWindow
- Hann
- Hamming

Fixes pytorch/pytorch#142458
  • Loading branch information
justinchuby authored Jan 6, 2025
1 parent b064539 commit 98861b0
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 6 deletions.
43 changes: 37 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,10 +1495,19 @@ def aten_bitwise_xor(self: TInt, other: TInt) -> TInt:
return op.BitwiseXor(self, other)


def aten_blackman_window(window_length: int) -> TensorType:
@torch_op("aten::blackman_window", trace_only=True)
def aten_blackman_window(
window_length: int,
dtype: int = 1,
layout: str = "",
device: str = "",
pin_memory: bool = False,
) -> TensorType:
"""blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

raise NotImplementedError()
if dtype is None or dtype == -1:
dtype = 1
return op.BlackmanWindow(window_length, output_datatype=dtype)


def aten_block_diag(tensors: Sequence[TensorType]) -> TensorType:
Expand Down Expand Up @@ -3921,16 +3930,38 @@ def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL:
return op.And(self, op.Not(other))


def aten_hamming_window(window_length: int) -> TensorType:
@torch_op("aten::hamming_window", trace_only=True)
def aten_hamming_window(
window_length: int,
dtype: int = 1,
layout: str = "",
device: str = "",
pin_memory: bool = False,
) -> TensorType:
"""hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

raise NotImplementedError()
if dtype is None or dtype == -1:
dtype = 1
# ONNX uses different alpha/beta values for the Hamming window
# Whereas PyTorch uses alpha=0.54, beta=0.46, ONNX uses
# alpha=0.543478, beta=0.456522. This causes a slight difference
# in the output values, but we still uses the HammingWindow op for performance.
return op.HammingWindow(window_length, output_datatype=dtype)


def aten_hann_window(window_length: int) -> TensorType:
@torch_op("aten::hann_window", trace_only=True)
def aten_hann_window(
window_length: int,
dtype: int = 1,
layout: str = "",
device: str = "",
pin_memory: bool = False,
) -> TensorType:
"""hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

raise NotImplementedError()
if dtype is None or dtype == -1:
dtype = 1
return op.HannWindow(window_length, output_datatype=dtype)


def aten_hardshrink(self: TensorType, lambd: float = 0.5) -> TensorType:
Expand Down
31 changes: 31 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,6 +1935,16 @@ def shape(size, rank, with_batch_channel=True):
)


def sample_inputs_window_functions(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
del device
del requires_grad

for window_length in [2, 3, 7, 10, 32]:
yield opinfo_core.SampleInput(window_length, kwargs=dict(dtype=dtype))


class _TestParamsMaxPoolEmptyStrideBase:
# Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203
def __init__(self):
Expand Down Expand Up @@ -2037,6 +2047,13 @@ def __init__(self):
sample_inputs_func=sample_inputs_bernoulli_p_deterministic,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.blackman_window",
aten_name="blackman_window",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
sample_inputs_func=sample_inputs_window_functions,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.col2im",
aten_name="col2im",
Expand Down Expand Up @@ -2115,6 +2132,20 @@ def __init__(self):
lhs_make_tensor_kwargs=dict(low=0),
rhs_make_tensor_kwargs=dict(exclude_zero=True, low=0),
),
opinfo_core.OpInfo(
"ops.aten.hamming_window",
aten_name="hamming_window",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
sample_inputs_func=sample_inputs_window_functions,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.hann_window",
aten_name="hann_window",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
sample_inputs_func=sample_inputs_window_functions,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.index.Tensor",
aten_name="index.Tensor",
Expand Down
7 changes: 7 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ def _where_input_wrangler(
TorchLibOpInfo("bitwise_right_shift_int64", core_ops.aten_bitwise_right_shift_int64),
TorchLibOpInfo("bitwise_right_shift_int8", core_ops.aten_bitwise_right_shift_int8),
TorchLibOpInfo("bitwise_xor", core_ops.aten_bitwise_xor),
TorchLibOpInfo("ops.aten.blackman_window", core_ops.aten_blackman_window),
TorchLibOpInfo("bmm", core_ops.aten_bmm),
TorchLibOpInfo("broadcast_to", core_ops.aten_broadcast_to),
TorchLibOpInfo("cat", core_ops.aten_cat).skip(
Expand Down Expand Up @@ -1630,6 +1631,12 @@ def _where_input_wrangler(
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
reason="Using op.InstanceNormalization to simulate GroupNorm, which does not support 0-dim input",
),
TorchLibOpInfo(
"ops.aten.hamming_window",
core_ops.aten_hamming_window,
tolerance={torch.float32: (8e-2, 6e-3)},
),
TorchLibOpInfo("ops.aten.hann_window", core_ops.aten_hann_window),
TorchLibOpInfo("heaviside", core_ops.aten_heaviside),
TorchLibOpInfo(
"nn.functional.grid_sample",
Expand Down

0 comments on commit 98861b0

Please sign in to comment.