Skip to content

Commit

Permalink
matmul is standard nnef
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Dec 13, 2023
1 parent 6789787 commit 09942f6
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 21 deletions.
19 changes: 0 additions & 19 deletions nnef/src/ops/core/einsum.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use crate::internal::*;
use crate::ser::*;
use tract_core::ops::einsum::BasicMatMul;
use tract_core::ops::einsum::EinSum;
use tract_core::tract_data::itertools::Itertools;

pub fn register(registry: &mut Registry) {
registry.register_dumper(ser_basic_matmul);
registry.register_dumper(ser);
registry.register_primitive(
"tract_core_einsum",
Expand Down Expand Up @@ -54,23 +52,6 @@ pub fn ser(ast: &mut IntoAst, node: &TypedNode, op: &EinSum) -> TractResult<Opti
}
}

pub fn ser_basic_matmul(ast: &mut IntoAst, node: &TypedNode, op: &BasicMatMul) -> TractResult<Option<Arc<RValue>>> {
let inputs = node.inputs.iter().map(|i| (*ast.mapping[i]).clone()).collect_vec();
if op.transpose_c {
Ok(Some(invocation(
"matmul",
&[Arc::new(inputs[1].clone()), Arc::new(inputs[0].clone())],
&[("transposeA", logical(!op.transpose_b)), ("transposeB", logical(!op.transpose_a))],
)))
} else {
Ok(Some(invocation(
"matmul",
&[Arc::new(inputs[0].clone()), Arc::new(inputs[1].clone())],
&[("transposeA", logical(op.transpose_a)), ("transposeB", logical(op.transpose_b))],
)))
}
}

pub fn ser_einsum(ast: &mut IntoAst, node: &TypedNode) -> TractResult<Option<Arc<RValue>>> {
let einsum = node.op_as::<EinSum>().unwrap();
let inputs: Vec<_> = node.inputs.iter().map(|i| (*ast.mapping[i]).clone()).collect();
Expand Down
4 changes: 2 additions & 2 deletions nnef/src/ops/nnef/deser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,12 @@ pub fn conv_or_deconv(
}

let mut inputs = tvec!(input, kernel, bias);
let (group, pool_spec) = read_conv_parameters(
let (group, pool_spec) = dbg!(read_conv_parameters(
builder,
invocation,
kernel_fact.shape.as_concrete().context("Except fixed kernel shape")?,
&input_fact,
)?;
)?);

let output_dt: Option<DatumType> = if input_fact.datum_type.is_float() {
None
Expand Down
2 changes: 2 additions & 0 deletions nnef/src/ops/nnef/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ pub fn tract_nnef() -> Registry {
primitive(&mut registry, "box", deser::sum_pool);
registry.register_dumper(ser::sum_pool);

registry.register_dumper(ser::basic_matmul);

for frag in stdlib {
if frag.body.is_some() {
registry.register_fragment(frag);
Expand Down
19 changes: 19 additions & 0 deletions nnef/src/ops/nnef/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use tract_core::ops::cnn::Conv;
use tract_core::ops::cnn::DeconvUnary;
use tract_core::ops::cnn::KernelFormat;
use tract_core::ops::cnn::PoolSpec;
use tract_core::ops::einsum::BasicMatMul;
use tract_core::ops::nn::DataFormat;
use tract_core::tract_data::itertools::Itertools;

pub fn source(
ast: &mut IntoAst,
Expand All @@ -29,6 +31,23 @@ pub fn source(
Ok(None)
}

pub fn basic_matmul(ast: &mut IntoAst, node: &TypedNode, op: &BasicMatMul) -> TractResult<Option<Arc<RValue>>> {
let inputs = node.inputs.iter().map(|i| (*ast.mapping[i]).clone()).collect_vec();
if op.transpose_c {
Ok(Some(invocation(
"matmul",
&[Arc::new(inputs[1].clone()), Arc::new(inputs[0].clone())],
&[("transposeA", logical(!op.transpose_b)), ("transposeB", logical(!op.transpose_a))],
)))
} else {
Ok(Some(invocation(
"matmul",
&[Arc::new(inputs[0].clone()), Arc::new(inputs[1].clone())],
&[("transposeA", logical(op.transpose_a)), ("transposeB", logical(op.transpose_b))],
)))
}
}

pub fn konst(
ast: &mut IntoAst,
node: &TypedNode,
Expand Down

0 comments on commit 09942f6

Please sign in to comment.