diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4c22df181..63f692954 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4391,7 +4391,10 @@ def aten_instance_norm( ), "running_mean and running_var must be provided when use_input_stats is False" batch_size = op.Shape(input, start=0, end=1) - bn_input = op.Reshape(input, op.Concat([1, -1], op.Shape(input, start=2), axis=0)) + bn_input = op.Reshape( + input, + op.Concat(op.Constant(value_ints=[1, -1]), op.Shape(input, start=2), axis=0), + ) weight = op.Tile(weight, batch_size) bias = op.Tile(bias, batch_size) running_mean = op.Tile(running_mean, batch_size)