diff --git a/core/src/model/fact.rs b/core/src/model/fact.rs index aa88506d04..c36176cb30 100644 --- a/core/src/model/fact.rs +++ b/core/src/model/fact.rs @@ -105,7 +105,9 @@ impl ShapeFact { self.dims.remove(axis); if let Some(concrete) = &mut self.concrete { concrete.remove(axis); - } + } else { + self.compute_concrete(); + }; Ok(()) } @@ -124,6 +126,14 @@ impl ShapeFact { let void: &[usize] = &[]; Self::from(void) } + + pub fn consistent(&self) -> TractResult<()> { + ensure!( + self.concrete + == self.dims.iter().map(|d| d.to_usize()).collect::>>().ok() + ); + Ok(()) + } } impl std::ops::Deref for ShapeFact { @@ -241,6 +251,7 @@ impl TypedFact { } pub fn consistent(&self) -> TractResult<()> { + self.shape.consistent()?; if let Some(k) = &self.konst { if !self.matches(k.as_ref(), None)? { bail!("fact says {}, constant is {:?}", self.format_dt_shape_nocheck(), k);