Skip to content

Commit

Permalink
Remove extra ops
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Aug 5, 2024
1 parent ac99103 commit 41200fc
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TFloat,
TFloatOrBFloat16,
TFloatOrUInt8,
TInt,
TReal,
TTensor,
)
Expand Down Expand Up @@ -659,7 +660,7 @@ def aten_huber_loss_backward(


def _get_im2col_indices_along_dim(
input_d: int,
input_d: TInt,
kernel_size_d: int,
dilation_d: int,
padding_d: int,
Expand All @@ -671,21 +672,18 @@ def _get_im2col_indices_along_dim(
# each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
# with steps = stride

blocks_d = op.Add(input_d, padding_d * 2)
blocks_d = op.Sub(blocks_d, (dilation_d * (kernel_size_d - 1)))
blocks_d = input_d + (padding_d * 2) - (dilation_d * (kernel_size_d - 1))

Check warning on line 675 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L675

Added line #L675 was not covered by tests

# Stride kernel over input and find starting indices along dim d
blocks_d_indices = op.Range(0, blocks_d, stride_d)
blocks_d_indices = op.Unsqueeze(blocks_d_indices, [0])

Check warning on line 679 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L678-L679

Added lines #L678 - L679 were not covered by tests

# Apply dilation on kernel and find its indices along dim d
kernel_grid = op.Range(0, kernel_size_d * dilation_d, dilation_d)
kernel_grid = op.Unsqueeze(kernel_grid, [0])
kernel_mask = op.Unsqueeze(kernel_grid, [1])

Check warning on line 683 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L682-L683

Added lines #L682 - L683 were not covered by tests

# Broadcast and add kernel staring positions (indices) with
# kernel_grid along dim d, to get block indices along dim d
blocks_d_indices = op.Unsqueeze(blocks_d_indices, [0])
# Reshape to [1, -1]
kernel_mask = op.Reshape(kernel_grid, op.Constant(value_ints=[-1, 1]))
block_mask = op.Add(blocks_d_indices, kernel_mask)

Check warning on line 687 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L687

Added line #L687 was not covered by tests

return block_mask

Check warning on line 689 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L689

Added line #L689 was not covered by tests
Expand Down

0 comments on commit 41200fc

Please sign in to comment.