diff --git a/core/src/ops/einsum/mod.rs b/core/src/ops/einsum/mod.rs index 4c969d2dac..8a78e4e9e9 100644 --- a/core/src/ops/einsum/mod.rs +++ b/core/src/ops/einsum/mod.rs @@ -295,7 +295,7 @@ impl TypedOp for EinSum { } fn cost(&self, inputs: &[&TypedFact]) -> TractResult> { - let shapes: TVec<&[TDim]> = inputs.iter().map(|t| &*t.shape).collect(); + let shapes = self.actual_input_shapes_from_facts(inputs)?; let oshape = eval::output_shape(&self.axes, &shapes)?; let ks = self .axes