diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 2f29aab9cd..ac95ecd1ed 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -690,6 +690,7 @@ def _get_im2col_indices_along_dim( return block_mask + def _get_im2col_padded_input(input, padding_h, padding_w): # Input is always 4-D tensor (N, C, H, W) # Padding tensor has the following format: (padding_h, padding_w) @@ -1520,7 +1521,6 @@ def aten_one_hot(self: TensorType, num_classes: int = -1) -> TensorType: raise NotImplementedError() -@torch_op("aten::pad", trace_only=True) def aten_pad( self: TensorType, pad: INT64, mode: str = "constant", value: Optional[float] = None ) -> TensorType: