Skip to content

Commit

Permalink
fix shapefact optimisation
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 5, 2023
1 parent 30530cc commit d1d95f5
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion core/src/model/fact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ impl ShapeFact {
self.dims.remove(axis);
if let Some(concrete) = &mut self.concrete {
concrete.remove(axis);
}
} else {
self.compute_concrete();
};
Ok(())
}

Expand All @@ -124,6 +126,14 @@ impl ShapeFact {
let void: &[usize] = &[];
Self::from(void)
}

pub fn consistent(&self) -> TractResult<()> {
ensure!(
self.concrete
== self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
);
Ok(())
}
}

impl std::ops::Deref for ShapeFact {
Expand Down Expand Up @@ -241,6 +251,7 @@ impl TypedFact {
}

pub fn consistent(&self) -> TractResult<()> {
self.shape.consistent()?;
if let Some(k) = &self.konst {
if !self.matches(k.as_ref(), None)? {
bail!("fact says {}, constant is {:?}", self.format_dt_shape_nocheck(), k);
Expand Down

0 comments on commit d1d95f5

Please sign in to comment.