diff --git a/nnef/src/ops/core/einsum.rs b/nnef/src/ops/core/einsum.rs index 7f38167757..9274bd9c11 100644 --- a/nnef/src/ops/core/einsum.rs +++ b/nnef/src/ops/core/einsum.rs @@ -1,11 +1,9 @@ use crate::internal::*; use crate::ser::*; -use tract_core::ops::einsum::BasicMatMul; use tract_core::ops::einsum::EinSum; use tract_core::tract_data::itertools::Itertools; pub fn register(registry: &mut Registry) { - registry.register_dumper(ser_basic_matmul); registry.register_dumper(ser); registry.register_primitive( "tract_core_einsum", @@ -54,23 +52,6 @@ pub fn ser(ast: &mut IntoAst, node: &TypedNode, op: &EinSum) -> TractResult TractResult>> { - let inputs = node.inputs.iter().map(|i| (*ast.mapping[i]).clone()).collect_vec(); - if op.transpose_c { - Ok(Some(invocation( - "matmul", - &[Arc::new(inputs[1].clone()), Arc::new(inputs[0].clone())], - &[("transposeA", logical(!op.transpose_b)), ("transposeB", logical(!op.transpose_a))], - ))) - } else { - Ok(Some(invocation( - "matmul", - &[Arc::new(inputs[0].clone()), Arc::new(inputs[1].clone())], - &[("transposeA", logical(op.transpose_a)), ("transposeB", logical(op.transpose_b))], - ))) - } -} - pub fn ser_einsum(ast: &mut IntoAst, node: &TypedNode) -> TractResult>> { let einsum = node.op_as::().unwrap(); let inputs: Vec<_> = node.inputs.iter().map(|i| (*ast.mapping[i]).clone()).collect(); diff --git a/nnef/src/ops/nnef/deser.rs b/nnef/src/ops/nnef/deser.rs index 321305d50f..8f82f75c5d 100644 --- a/nnef/src/ops/nnef/deser.rs +++ b/nnef/src/ops/nnef/deser.rs @@ -340,12 +340,12 @@ pub fn conv_or_deconv( } let mut inputs = tvec!(input, kernel, bias); - let (group, pool_spec) = read_conv_parameters( + let (group, pool_spec) = dbg!(read_conv_parameters( builder, invocation, kernel_fact.shape.as_concrete().context("Except fixed kernel shape")?, &input_fact, - )?; + )?); let output_dt: Option = if input_fact.datum_type.is_float() { None diff --git a/nnef/src/ops/nnef/mod.rs b/nnef/src/ops/nnef/mod.rs index b10a521486..10fa0b55cc 100644 --- a/nnef/src/ops/nnef/mod.rs +++ b/nnef/src/ops/nnef/mod.rs @@ -129,6 +129,8 @@ pub fn tract_nnef() -> Registry { primitive(&mut registry, "box", deser::sum_pool); registry.register_dumper(ser::sum_pool); + registry.register_dumper(ser::basic_matmul); + for frag in stdlib { if frag.body.is_some() { registry.register_fragment(frag); diff --git a/nnef/src/ops/nnef/ser.rs b/nnef/src/ops/nnef/ser.rs index eef73cad2c..cc8c1208c0 100644 --- a/nnef/src/ops/nnef/ser.rs +++ b/nnef/src/ops/nnef/ser.rs @@ -9,7 +9,9 @@ use tract_core::ops::cnn::Conv; use tract_core::ops::cnn::DeconvUnary; use tract_core::ops::cnn::KernelFormat; use tract_core::ops::cnn::PoolSpec; +use tract_core::ops::einsum::BasicMatMul; use tract_core::ops::nn::DataFormat; +use tract_core::tract_data::itertools::Itertools; pub fn source( ast: &mut IntoAst, @@ -29,6 +31,23 @@ pub fn source( Ok(None) } +pub fn basic_matmul(ast: &mut IntoAst, node: &TypedNode, op: &BasicMatMul) -> TractResult>> { + let inputs = node.inputs.iter().map(|i| (*ast.mapping[i]).clone()).collect_vec(); + if op.transpose_c { + Ok(Some(invocation( + "matmul", + &[Arc::new(inputs[1].clone()), Arc::new(inputs[0].clone())], + &[("transposeA", logical(!op.transpose_b)), ("transposeB", logical(!op.transpose_a))], + ))) + } else { + Ok(Some(invocation( + "matmul", + &[Arc::new(inputs[0].clone()), Arc::new(inputs[1].clone())], + &[("transposeA", logical(op.transpose_a)), ("transposeB", logical(op.transpose_b))], + ))) + } +} + pub fn konst( ast: &mut IntoAst, node: &TypedNode,