From 187848845e0a37c16c3b1413dcfd95ffe8692c86 Mon Sep 17 00:00:00 2001 From: Hubert de La Jonquiere Date: Wed, 11 Dec 2024 15:27:26 +0100 Subject: [PATCH] Deduplicate Const when wireing new node in typed graph + fixes in NNEF --- core/src/model/typed.rs | 6 ++++++ nnef/src/deser.rs | 22 +++++----------------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/core/src/model/typed.rs b/core/src/model/typed.rs index f451a3aef0..b8449e6127 100644 --- a/core/src/model/typed.rs +++ b/core/src/model/typed.rs @@ -48,6 +48,12 @@ impl SpecialOps> for TypedModel { ) -> TractResult> { let op = op.into(); let name = name.into(); + if let Some(konst) = op.downcast_ref::() { + // only if no opaque fact is present. + if konst.1.is_none() { + return Ok(tvec![self.add_const(name, konst.0.clone())?]); + } + } if self.nodes.iter().any(|n| n.name == name) { bail!("Duplicate node name: {name}"); } diff --git a/nnef/src/deser.rs b/nnef/src/deser.rs index c26105f437..68a87b3063 100644 --- a/nnef/src/deser.rs +++ b/nnef/src/deser.rs @@ -664,16 +664,9 @@ impl CoerceFrom for (Arc, DatumType) { impl CoerceFrom for OutletId { fn coerce(builder: &mut ModelBuilder, from: &Value) -> TractResult { match from { - Value::Tensor(t) => { - Ok(builder.wire_as_outlets(tract_core::ops::konst::Const::new(t.clone()), &[])?[0]) - } - Value::Scalar(f) => { - Ok(builder - .wire_as_outlets(tract_core::ops::konst::Const::new(rctensor0(*f)), &[])?[0]) - } - Value::Dim(i) => Ok(builder - .wire_as_outlets(tract_core::ops::konst::Const::new(rctensor0(i.clone())), &[])? - [0]), + Value::Tensor(t) => builder.add_const(t.clone()), + Value::Scalar(f) => builder.add_const(rctensor0(*f)), + Value::Dim(i) => builder.add_const(rctensor0(i.clone())), Value::Wire(outlet) => Ok(*outlet), Value::Tuple(tuple) if tuple.len() == 1 => OutletId::coerce(builder, &tuple[0]), Value::Array(inputs) => { @@ -689,13 +682,8 @@ impl CoerceFrom for OutletId { .wire_as_outlets(tract_core::ops::array::TypedConcat::new(0), &outlets) .map(|o| o[0]) } - Value::String(s) => Ok(builder - .wire_as_outlets(tract_core::ops::konst::Const::new(rctensor0(s.clone())), &[])? - [0]), - Value::Bool(b) => { - Ok(builder - .wire_as_outlets(tract_core::ops::konst::Const::new(rctensor0(*b)), &[])?[0]) - } + Value::String(s) => builder.add_const(rctensor0(s.clone())), + Value::Bool(b) => builder.add_const(rctensor0(*b)), _ => bail!("Can not build an outletid from {:?}", from), } }