Skip to content

Commit

Permalink
fix decovn
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Dec 12, 2023
1 parent 80c5246 commit fa1d9d1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
21 changes: 21 additions & 0 deletions core/src/ops/cnn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,24 @@ pub fn rewrite_conv_with_n_axis(
Ok(None)
}

pub fn rewrite_deconv_with_n_axis(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
name: &str,
deconv: &DeconvUnary,
) -> TractResult<Option<TypedModelPatch>> {
if !deconv.pool_spec.data_format.has_n() {
let mut new = deconv.clone();
new.pool_spec.data_format = deconv.pool_spec.data_format.with_n();
let mut patch = TypedModelPatch::default();
let mut wire = patch.taps(model, &node.inputs)?;
wire[0] = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &[wire[0]])?[0];
wire = patch.wire_node(name, new, &wire)?;
wire = patch.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), &wire)?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
return Ok(Some(patch));
}
Ok(None)
}

1 change: 1 addition & 0 deletions nnef/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub fn to_proto_model(framework: &Nnef, model: &TypedModel) -> TractResult<Proto
tract_core::ops::einsum::rewrite_einsums_as_matmul(&mut fixed_model)?;
Rewriter::default()
.with_rule_for("rewrite_conv_with_n_axis", tract_core::ops::cnn::rewrite_conv_with_n_axis)
.with_rule_for("rewrite_deconv_with_n_axis", tract_core::ops::cnn::rewrite_deconv_with_n_axis)
.with_rule_for("rewrite_kernel_in_oihw", crate::ops::nnef::ser::rewrite_kernel_in_oihw)
.rewrite(&(), &mut fixed_model)?;
let mut into_ast = IntoAst::new(framework, &fixed_model);
Expand Down

0 comments on commit fa1d9d1

Please sign in to comment.