diff --git a/hir/src/ops/matmul.rs b/hir/src/ops/matmul.rs index 0d3d6f2e47..4b947c868c 100644 --- a/hir/src/ops/matmul.rs +++ b/hir/src/ops/matmul.rs @@ -65,7 +65,7 @@ impl Expansion for MatMulInference { } if implicit_n { let b = InOut::In(1); - let n_axis = axes.axis((b, axes.rank(b) - 2))?; + let n_axis = axes.axis((b, axes.rank(b) - 1))?; axes = axes.remove_output_axis(0, n_axis.outputs[0][0])?; } target.wire_node(