Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Jul 25, 2024
1 parent fb5d31c commit 219868e
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,12 @@ def _get_im2col_padded_input(input, padding_h, padding_w):
# Padding tensor has the following format: (padding_h, padding_w)
# Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
pad = op.Concat(
op.Constant(value_ints=[0, 0]), op.Unsqueeze(padding_h, [0]), op.Unsqueeze(padding_w, [0]),
op.Constant(value_ints=[0, 0]), op.Unsqueeze(padding_h, [0]), op.Unsqueeze(padding_w, [0]),
op.Constant(value_ints=[0, 0]),
op.Unsqueeze(padding_h, [0]),
op.Unsqueeze(padding_w, [0]),
op.Constant(value_ints=[0, 0]),
op.Unsqueeze(padding_h, [0]),
op.Unsqueeze(padding_w, [0]),
axis=0,
)
return op.Pad(input, pad)
Expand Down Expand Up @@ -1516,6 +1520,7 @@ def aten_one_hot(self: TensorType, num_classes: int = -1) -> TensorType:
raise NotImplementedError()


@torch_op("aten::pad", trace_only=True)
def aten_pad(
self: TensorType, pad: INT64, mode: str = "constant", value: Optional[float] = None
) -> TensorType:
Expand Down

0 comments on commit 219868e

Please sign in to comment.