Skip to content

Commit

Permalink
fix: allow correct TDim less handling
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienBalianSonos committed Mar 1, 2024
1 parent b8466cd commit a6892bb
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions nnef/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,19 @@ impl Registry {

// mitigation of nnef "scalar" type mismatch with tract-core more
// strict types
if (!a_dt.is_quantized() || !b_dt.is_quantized()) && a_dt != b_dt {
if builder.model.node(a.node).op_is::<tract_core::ops::konst::Const>() {
a = builder.wire_as_outlets(tract_core::ops::cast::cast(b_dt), &[a])?[0];
a_dt = b_dt;
} else {
b = builder.wire_as_outlets(tract_core::ops::cast::cast(a_dt), &[b])?[0];
b_dt = a_dt;
};
if !a_dt.is_quantized() || !b_dt.is_quantized() {
if a_dt != b_dt {
if builder.model.node(a.node).op_is::<tract_core::ops::konst::Const>() {
a = builder.wire_as_outlets(tract_core::ops::cast::cast(b_dt), &[a])?[0];
a_dt = b_dt;
} else {
b = builder.wire_as_outlets(tract_core::ops::cast::cast(a_dt), &[b])?[0];
b_dt = a_dt;
}
}
let operating_dt = bin.1.operating_datum_type(a_dt, b_dt)?;
// avoid cast unified dtype to happen when all inputs quantized
// that can be unaligned at process time
a = builder.wire_as_outlets(tract_core::ops::cast::cast(operating_dt), &[a])?[0];
b = builder.wire_as_outlets(tract_core::ops::cast::cast(operating_dt), &[b])?[0];
}
Expand Down

0 comments on commit a6892bb

Please sign in to comment.