diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 6c73a6f7c..84f75b1a4 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -672,7 +672,7 @@ def _get_im2col_indices_along_dim( # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1) # with steps = stride - blocks_d = input_d + (padding_d * 2) - (dilation_d * (kernel_size_d - 1)) + blocks_d = input_d + ((padding_d * 2) - (dilation_d * (kernel_size_d - 1))) # Stride kernel over input and find starting indices along dim d blocks_d_indices = op.Range(0, blocks_d, stride_d)