Skip to content

Commit

Permalink
Deduplicate Const when wireing new node in typed graph + fixes in NNEF
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Dec 11, 2024
1 parent e78057f commit c0abc05
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
6 changes: 6 additions & 0 deletions core/src/model/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ impl SpecialOps<TypedFact, Box<dyn TypedOp>> for TypedModel {
) -> TractResult<TVec<OutletId>> {
let op = op.into();
let name = name.into();
if let Some(konst) = op.downcast_ref::<Const>() {
// only if no opaque fact is present.
if konst.1.is_none() {
return Ok(tvec![self.add_const(name, konst.0.clone())?]);
}
}
if self.nodes.iter().any(|n| n.name == name) {
bail!("Duplicate node name: {name}");
}
Expand Down
22 changes: 5 additions & 17 deletions nnef/src/deser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -664,16 +664,9 @@ impl CoerceFrom<Value> for (Arc<Tensor>, DatumType) {
impl CoerceFrom<Value> for OutletId {
fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult<Self> {
match from {
Value::Tensor(t) => {
Ok(builder.wire_as_outlets(tract_core::ops::konst::Const::new(t.clone()), &[])?[0])
}
Value::Scalar(f) => {
Ok(builder
.wire_as_outlets(tract_core::ops::konst::Const::new(rctensor0(*f)), &[])?[0])
}
Value::Dim(i) => Ok(builder
.wire_as_outlets(tract_core::ops::konst::Const::new(rctensor0(i.clone())), &[])?
[0]),
Value::Tensor(t) => builder.add_const(t.clone()),
Value::Scalar(f) => builder.add_const(rctensor0(*f)),
Value::Dim(i) => builder.add_const(rctensor0(i.clone())),
Value::Wire(outlet) => Ok(*outlet),
Value::Tuple(tuple) if tuple.len() == 1 => OutletId::coerce(builder, &tuple[0]),
Value::Array(inputs) => {
Expand All @@ -689,13 +682,8 @@ impl CoerceFrom<Value> for OutletId {
.wire_as_outlets(tract_core::ops::array::TypedConcat::new(0), &outlets)
.map(|o| o[0])
}
Value::String(s) => Ok(builder
.wire_as_outlets(tract_core::ops::konst::Const::new(rctensor0(s.clone())), &[])?
[0]),
Value::Bool(b) => {
Ok(builder
.wire_as_outlets(tract_core::ops::konst::Const::new(rctensor0(*b)), &[])?[0])
}
Value::String(s) => builder.add_const(rctensor0(s.clone())),
Value::Bool(b) => builder.add_const(rctensor0(*b)),
_ => bail!("Can not build an outletid from {:?}", from),
}
}
Expand Down

0 comments on commit c0abc05

Please sign in to comment.