From 41200fcbbf9f55928c3c08b9dba0b447099d3acf Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Mon, 5 Aug 2024 17:29:41 +0000 Subject: [PATCH] Remove extra ops --- onnxscript/function_libs/torch_lib/ops/nn.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index ac95ecd1e..6c73a6f7c 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -27,6 +27,7 @@ TFloat, TFloatOrBFloat16, TFloatOrUInt8, + TInt, TReal, TTensor, ) @@ -659,7 +660,7 @@ def aten_huber_loss_backward( def _get_im2col_indices_along_dim( - input_d: int, + input_d: TInt, kernel_size_d: int, dilation_d: int, padding_d: int, @@ -671,21 +672,18 @@ def _get_im2col_indices_along_dim( # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1) # with steps = stride - blocks_d = op.Add(input_d, padding_d * 2) - blocks_d = op.Sub(blocks_d, (dilation_d * (kernel_size_d - 1))) + blocks_d = input_d + (padding_d * 2) - (dilation_d * (kernel_size_d - 1)) # Stride kernel over input and find starting indices along dim d blocks_d_indices = op.Range(0, blocks_d, stride_d) + blocks_d_indices = op.Unsqueeze(blocks_d_indices, [0]) # Apply dilation on kernel and find its indices along dim d kernel_grid = op.Range(0, kernel_size_d * dilation_d, dilation_d) - kernel_grid = op.Unsqueeze(kernel_grid, [0]) + kernel_mask = op.Unsqueeze(kernel_grid, [1]) # Broadcast and add kernel staring positions (indices) with # kernel_grid along dim d, to get block indices along dim d - blocks_d_indices = op.Unsqueeze(blocks_d_indices, [0]) - # Reshape to [1, -1] - kernel_mask = op.Reshape(kernel_grid, op.Constant(value_ints=[-1, 1])) block_mask = op.Add(blocks_d_indices, kernel_mask) return block_mask