diff --git a/nnef/src/registry.rs b/nnef/src/registry.rs index e30b189bbc..3f1fcea42a 100644 --- a/nnef/src/registry.rs +++ b/nnef/src/registry.rs @@ -219,15 +219,19 @@ impl Registry { // mitigation of nnef "scalar" type mismatch with tract-core more // strict types - if (!a_dt.is_quantized() || !b_dt.is_quantized()) && a_dt != b_dt { - if builder.model.node(a.node).op_is::() { - a = builder.wire_as_outlets(tract_core::ops::cast::cast(b_dt), &[a])?[0]; - a_dt = b_dt; - } else { - b = builder.wire_as_outlets(tract_core::ops::cast::cast(a_dt), &[b])?[0]; - b_dt = a_dt; - }; + if !a_dt.is_quantized() || !b_dt.is_quantized() { + if a_dt != b_dt { + if builder.model.node(a.node).op_is::() { + a = builder.wire_as_outlets(tract_core::ops::cast::cast(b_dt), &[a])?[0]; + a_dt = b_dt; + } else { + b = builder.wire_as_outlets(tract_core::ops::cast::cast(a_dt), &[b])?[0]; + b_dt = a_dt; + } + } let operating_dt = bin.1.operating_datum_type(a_dt, b_dt)?; + // avoid cast unified dtype to happen when all inputs quantized + // that can be unaligned at process time a = builder.wire_as_outlets(tract_core::ops::cast::cast(operating_dt), &[a])?[0]; b = builder.wire_as_outlets(tract_core::ops::cast::cast(operating_dt), &[b])?[0]; }