From fa1d9d187916c359cb243016b4fa246c9ee61303 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 12 Dec 2023 10:31:05 +0100 Subject: [PATCH] fix decovn --- core/src/ops/cnn/mod.rs | 21 +++++++++++++++++++++ nnef/src/ser.rs | 1 + 2 files changed, 22 insertions(+) diff --git a/core/src/ops/cnn/mod.rs b/core/src/ops/cnn/mod.rs index e16bb79566..ef6a0ebca7 100644 --- a/core/src/ops/cnn/mod.rs +++ b/core/src/ops/cnn/mod.rs @@ -90,3 +90,24 @@ pub fn rewrite_conv_with_n_axis( Ok(None) } +pub fn rewrite_deconv_with_n_axis( + _ctx: &(), + model: &TypedModel, + node: &TypedNode, + name: &str, + deconv: &DeconvUnary, +) -> TractResult> { + if !deconv.pool_spec.data_format.has_n() { + let mut new = deconv.clone(); + new.pool_spec.data_format = deconv.pool_spec.data_format.with_n(); + let mut patch = TypedModelPatch::default(); + let mut wire = patch.taps(model, &node.inputs)?; + wire[0] = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &[wire[0]])?[0]; + wire = patch.wire_node(name, new, &wire)?; + wire = patch.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), &wire)?; + patch.shunt_outside(model, node.id.into(), wire[0])?; + return Ok(Some(patch)); + } + Ok(None) +} + diff --git a/nnef/src/ser.rs b/nnef/src/ser.rs index 6e5b979a69..fc065a1575 100644 --- a/nnef/src/ser.rs +++ b/nnef/src/ser.rs @@ -9,6 +9,7 @@ pub fn to_proto_model(framework: &Nnef, model: &TypedModel) -> TractResult