Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nnef unit tests #1285

Merged
merged 15 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis/debug-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ fi

# useful as debug_asserts will come into play
cargo test -p tract-core
cargo test -p test-onnx-core -p test-onnx-nnef-cycle -p test-unit-core
cargo test -p test-onnx-core -p test-nnef-cycle -p test-unit-core
2 changes: 1 addition & 1 deletion .travis/onnx-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ opset=onnx_"${1:-1_13_0}"

cargo -q test -p test-unit-core $CARGO_EXTRA -q
cargo -q test -p test-onnx-core $CARGO_EXTRA -q --no-default-features --features $opset
cargo -q test -p test-onnx-nnef-cycle $CARGO_EXTRA -q --no-default-features
cargo -q test -p test-nnef-cycle $CARGO_EXTRA -q --no-default-features
2 changes: 1 addition & 1 deletion .travis/regular-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export CACHEDIR
# useful as debug_asserts will come into play
cargo -q test -q -p tract-core --features paranoid_assertions $CARGO_EXTRA
cargo -q test -q -p test-onnx-core $CARGO_EXTRA
cargo -q test -q -p test-onnx-nnef-cycle $CARGO_EXTRA
cargo -q test -q -p test-nnef-cycle $CARGO_EXTRA

cargo check -p tract-nnef --features complex $CARGO_EXTRA
cargo check -p tract --no-default-features $CARGO_EXTRA
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ members = [
"test-rt/suite-onnx",
"test-rt/test-unit-core",
"test-rt/test-onnx-core",
"test-rt/test-onnx-nnef-cycle",
"test-rt/test-nnef-cycle",
"test-rt/test-tflite",
]

Expand Down
8 changes: 7 additions & 1 deletion core/src/ops/cnn/conv/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ pub struct Conv {
pub pool_spec: PoolSpec,
pub kernel_fmt: KernelFormat,
pub group: usize,
// None -> floats
// Some(I32) -> output is I32 (use quantized kernels, but output will be i32). last 2 Q inputs
// are ignored
// Some(QXX) -> quantized XX, but parameters are ignored (I8, U8, or I32) in favor of last 2 Q inputs
pub q_params: Option<DatumType>,
}

Expand Down Expand Up @@ -808,6 +812,7 @@ impl EvalOp for Conv {

impl TypedOp for Conv {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(self.q_params.is_some() || inputs[0].datum_type.is_float());
let q_inputs = if self.q_params.is_some() { 6 } else { 0 };
if inputs.len() != 3 + q_inputs {
bail!("Wrong number of inputs: expected {} got {}", 3 + q_inputs, inputs.len());
Expand Down Expand Up @@ -839,7 +844,8 @@ impl TypedOp for Conv {
inputs[2].rank() == 0
|| (inputs[2].rank() == 1
&& inputs[2].shape.volume() == self.output_channels().to_dim()),
"Bias should be scalar or a vector with one value per output channel, got:{:?}",
"Bias should be scalar or a vector with one value per output channel. Output channels is {}, bias is {:?}",
self.output_channels(),
inputs[2]
);
let mut fact = self.pool_spec.output_facts(inputs)?.remove(0);
Expand Down
43 changes: 43 additions & 0 deletions core/src/ops/cnn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,46 @@ pub fn wire_reshape_bias_for_bin(
}
Ok(bias)
}

pub fn rewrite_conv_with_n_axis(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
name: &str,
conv: &Conv,
) -> TractResult<Option<TypedModelPatch>> {
if !conv.pool_spec.data_format.has_n() {
let mut new = conv.clone();
new.pool_spec.data_format = conv.pool_spec.data_format.with_n();
let mut patch = TypedModelPatch::default();
let mut wire = patch.taps(model, &node.inputs)?;
wire[0] = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &[wire[0]])?[0];
wire = patch.wire_node(name, new, &wire)?;
wire = patch.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), &wire)?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
return Ok(Some(patch));
}
Ok(None)
}

pub fn rewrite_deconv_with_n_axis(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
name: &str,
deconv: &DeconvUnary,
) -> TractResult<Option<TypedModelPatch>> {
if !deconv.pool_spec.data_format.has_n() {
let mut new = deconv.clone();
new.pool_spec.data_format = deconv.pool_spec.data_format.with_n();
let mut patch = TypedModelPatch::default();
let mut wire = patch.taps(model, &node.inputs)?;
wire[0] = patch.wire_node(format!("{name}.add_n"), AxisOp::Add(0), &[wire[0]])?[0];
wire = patch.wire_node(name, new, &wire)?;
wire = patch.wire_node(format!("{name}.rm_n"), AxisOp::Rm(0), &wire)?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
return Ok(Some(patch));
}
Ok(None)
}

2 changes: 1 addition & 1 deletion examples/nnef-dump-mobilenet-v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ convention, and the variant of Mobilenet we picked operates on inputs of
224x224 pixels:

```sh
tract mobilenet_v2_1.4_224_frozen.pb -i 1x224x224x3xf32 dump --nnef mobilenet.nnef.tgz
tract mobilenet_v2_1.4_224_frozen.pb -i 1,224,224,3,f32 dump --nnef mobilenet.nnef.tgz
```

## Running with tract_nnef
Expand Down
590 changes: 344 additions & 246 deletions harness/pre-optimized-graphes/hey_snips_v4_model17/expected

Large diffs are not rendered by default.

44 changes: 25 additions & 19 deletions harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,15 @@ fragment tract_core_properties(
graph network(input) -> (output) {
input = external(shape = [24, 40]);
i"lda.output.delay" = tract_pulse_delay(input, axis = 0, delay = 0, overlap = 4);
i"lda.output.add_n" = unsqueeze(i"lda.output.delay", axes = [0]);
i"lda.kernel.0" = variable<scalar>(label = "lda.kernel.0", shape = [200, 40, 5]);
i"lda.bias.0" = variable<scalar>(label = "lda.bias.0", shape = [200]);
i"lda.output_input" = transpose(unsqueeze(i"lda.output.delay", axes = [0]), axes = [0, 2, 1]);
i"lda.output_input" = transpose(i"lda.output.add_n", axes = [0, 2, 1]);
i"lda.output_conv" = conv(i"lda.output_input", i"lda.kernel.0", i"lda.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]);
i"lda.output" = squeeze(transpose(i"lda.output_conv", axes = [0, 2, 1]), axes = [0]);
i"lda.output" = transpose(i"lda.output_conv", axes = [0, 2, 1]);
i"lda.output.rm_n" = squeeze(i"lda.output", axes = [0]);
i"tdnn1.affine.output.filters_as_co_ci" = variable<scalar>(label = "tdnn1.affine.output.filters_as_co_ci", shape = [256, 200]);
i"tdnn1.affine.output.einsum" = matmul(i"tdnn1.affine.output.filters_as_co_ci", i"lda.output", transposeA = false, transposeB = true);
i"tdnn1.affine.output.einsum" = matmul(i"tdnn1.affine.output.filters_as_co_ci", i"lda.output.rm_n", transposeA = false, transposeB = true);
i"tdnn1.affine.output.bias.reshape" = variable<scalar>(label = "tdnn1.affine.output.bias.reshape", shape = [256, 1]);
i"tdnn1.affine.output" = add(i"tdnn1.affine.output.einsum", i"tdnn1.affine.output.bias.reshape");
i"tdnn1.relu.output.low.cst" = [[0.0]];
Expand All @@ -146,26 +148,28 @@ graph network(input) -> (output) {
i"tdnn1.renorm.output-recip" = rsqrt(i"tdnn1.renorm.scaled");
i"tdnn1.renorm.output" = mul(i"tdnn1.relu.output.low", i"tdnn1.renorm.output-recip");
i"tdnn2.affine.output.delay" = tract_pulse_delay(i"tdnn1.renorm.output", axis = 1, delay = 0, overlap = 2);
i"tdnn2.affine.output.add_n" = unsqueeze(i"tdnn2.affine.output.delay", axes = [0]);
i"tdnn2.affine.kernel.0" = variable<scalar>(label = "tdnn2.affine.kernel.0", shape = [256, 256, 3]);
i"tdnn2.affine.bias.0" = variable<scalar>(label = "tdnn2.affine.bias.0", shape = [256]);
i"tdnn2.affine.output_input" = unsqueeze(i"tdnn2.affine.output.delay", axes = [0]);
i"tdnn2.affine.output_conv" = conv(i"tdnn2.affine.output_input", i"tdnn2.affine.kernel.0", i"tdnn2.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn2.affine.output" = squeeze(i"tdnn2.affine.output_conv", axes = [0]);
i"tdnn2.affine.output_conv" = conv(i"tdnn2.affine.output.add_n", i"tdnn2.affine.kernel.0", i"tdnn2.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn2.affine.output" = i"tdnn2.affine.output_conv";
i"tdnn2.affine.output.rm_n" = squeeze(i"tdnn2.affine.output", axes = [0]);
i"tdnn2.relu.output.low.cst" = [[0.0]];
i"tdnn2.relu.output.low" = max(i"tdnn2.affine.output", i"tdnn2.relu.output.low.cst");
i"tdnn2.relu.output.low" = max(i"tdnn2.affine.output.rm_n", i"tdnn2.relu.output.low.cst");
i"tdnn2.renorm.reduced.sq" = square(i"tdnn2.relu.output.low");
i"tdnn2.renorm.reduced.sum" = sum_reduce(i"tdnn2.renorm.reduced.sq", axes = [0]);
i"tdnn2.renorm.scaled-recip" = [[0.00390625]];
i"tdnn2.renorm.scaled" = mul(i"tdnn2.renorm.reduced.sum", i"tdnn2.renorm.scaled-recip");
i"tdnn2.renorm.output-recip" = rsqrt(i"tdnn2.renorm.scaled");
i"tdnn2.renorm.output" = mul(i"tdnn2.relu.output.low", i"tdnn2.renorm.output-recip");
i"tdnn3.affine.output.add_n" = unsqueeze(i"tdnn2.renorm.output", axes = [0]);
i"tdnn3.affine.kernel.0" = variable<scalar>(label = "tdnn3.affine.kernel.0", shape = [256, 256, 3]);
i"tdnn3.affine.bias.0" = variable<scalar>(label = "tdnn3.affine.bias.0", shape = [256]);
i"tdnn3.affine.output_input" = unsqueeze(i"tdnn2.renorm.output", axes = [0]);
i"tdnn3.affine.output_conv" = conv(i"tdnn3.affine.output_input", i"tdnn3.affine.kernel.0", i"tdnn3.affine.bias.0", dilation = [1], stride = [3], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn3.affine.output" = squeeze(i"tdnn3.affine.output_conv", axes = [0]);
i"tdnn3.affine.output_conv" = conv(i"tdnn3.affine.output.add_n", i"tdnn3.affine.kernel.0", i"tdnn3.affine.bias.0", dilation = [1], stride = [3], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn3.affine.output" = i"tdnn3.affine.output_conv";
i"tdnn3.affine.output.rm_n" = squeeze(i"tdnn3.affine.output", axes = [0]);
i"tdnn3.relu.output.low.cst" = [[0.0]];
i"tdnn3.relu.output.low" = max(i"tdnn3.affine.output", i"tdnn3.relu.output.low.cst");
i"tdnn3.relu.output.low" = max(i"tdnn3.affine.output.rm_n", i"tdnn3.relu.output.low.cst");
i"tdnn3.renorm.reduced.sq" = square(i"tdnn3.relu.output.low");
i"tdnn3.renorm.reduced.sum" = sum_reduce(i"tdnn3.renorm.reduced.sq", axes = [0]);
i"tdnn3.renorm.scaled-recip" = [[0.00390625]];
Expand Down Expand Up @@ -210,27 +214,29 @@ graph network(input) -> (output) {
i"fastlstm1.h_new.split-over-1.128..256" = add(i"fastlstm1.h_new.W.split-over-1.128..256.fix_c.0", i"fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice");
i"fastlstm1.h_new.concat-1" = concat([i"fastlstm1.c_final", i"fastlstm1.h_new.split-over-1.128..256"], axis = 0);
i"tdnn4.affine.output.delay" = tract_pulse_delay(i"fastlstm1.h_new.concat-1", axis = 1, delay = 0, overlap = 2);
i"tdnn4.affine.output.add_n" = unsqueeze(i"tdnn4.affine.output.delay", axes = [0]);
i"tdnn4.affine.kernel.0" = variable<scalar>(label = "tdnn4.affine.kernel.0", shape = [256, 256, 3]);
i"tdnn4.affine.bias.0" = variable<scalar>(label = "tdnn4.affine.bias.0", shape = [256]);
i"tdnn4.affine.output_input" = unsqueeze(i"tdnn4.affine.output.delay", axes = [0]);
i"tdnn4.affine.output_conv" = conv(i"tdnn4.affine.output_input", i"tdnn4.affine.kernel.0", i"tdnn4.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn4.affine.output" = squeeze(i"tdnn4.affine.output_conv", axes = [0]);
i"tdnn4.affine.output_conv" = conv(i"tdnn4.affine.output.add_n", i"tdnn4.affine.kernel.0", i"tdnn4.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn4.affine.output" = i"tdnn4.affine.output_conv";
i"tdnn4.affine.output.rm_n" = squeeze(i"tdnn4.affine.output", axes = [0]);
i"tdnn4.relu.output.low.cst" = [[0.0]];
i"tdnn4.relu.output.low" = max(i"tdnn4.affine.output", i"tdnn4.relu.output.low.cst");
i"tdnn4.relu.output.low" = max(i"tdnn4.affine.output.rm_n", i"tdnn4.relu.output.low.cst");
i"tdnn4.renorm.reduced.sq" = square(i"tdnn4.relu.output.low");
i"tdnn4.renorm.reduced.sum" = sum_reduce(i"tdnn4.renorm.reduced.sq", axes = [0]);
i"tdnn4.renorm.scaled-recip" = [[0.00390625]];
i"tdnn4.renorm.scaled" = mul(i"tdnn4.renorm.reduced.sum", i"tdnn4.renorm.scaled-recip");
i"tdnn4.renorm.output-recip" = rsqrt(i"tdnn4.renorm.scaled");
i"tdnn4.renorm.output" = mul(i"tdnn4.relu.output.low", i"tdnn4.renorm.output-recip");
i"tdnn5.affine.output.delay" = tract_pulse_delay(i"tdnn4.renorm.output", axis = 1, delay = 0, overlap = 2);
i"tdnn5.affine.output.add_n" = unsqueeze(i"tdnn5.affine.output.delay", axes = [0]);
i"tdnn5.affine.kernel.0" = variable<scalar>(label = "tdnn5.affine.kernel.0", shape = [256, 256, 3]);
i"tdnn5.affine.bias.0" = variable<scalar>(label = "tdnn5.affine.bias.0", shape = [256]);
i"tdnn5.affine.output_input" = unsqueeze(i"tdnn5.affine.output.delay", axes = [0]);
i"tdnn5.affine.output_conv" = conv(i"tdnn5.affine.output_input", i"tdnn5.affine.kernel.0", i"tdnn5.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn5.affine.output" = squeeze(i"tdnn5.affine.output_conv", axes = [0]);
i"tdnn5.affine.output_conv" = conv(i"tdnn5.affine.output.add_n", i"tdnn5.affine.kernel.0", i"tdnn5.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn5.affine.output" = i"tdnn5.affine.output_conv";
i"tdnn5.affine.output.rm_n" = squeeze(i"tdnn5.affine.output", axes = [0]);
i"tdnn5.relu.output.low.cst" = [[0.0]];
i"tdnn5.relu.output.low" = max(i"tdnn5.affine.output", i"tdnn5.relu.output.low.cst");
i"tdnn5.relu.output.low" = max(i"tdnn5.affine.output.rm_n", i"tdnn5.relu.output.low.cst");
i"tdnn5.renorm.reduced.sq" = square(i"tdnn5.relu.output.low");
i"tdnn5.renorm.reduced.sum" = sum_reduce(i"tdnn5.renorm.reduced.sq", axes = [0]);
i"tdnn5.renorm.scaled-recip" = [[0.00390625]];
Expand Down
2 changes: 1 addition & 1 deletion nnef/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl QuantFormat {
QuantFormat::Linear { params, bits, signed } => match (bits, signed) {
(8, true) => DatumType::QI8(*params),
(8, false) => DatumType::QU8(*params),
(32, true) => DatumType::I32,
(32, true) => DatumType::QI32(*params),
(32, false) => DatumType::U32,
_ => todo!(),
},
Expand Down
54 changes: 34 additions & 20 deletions nnef/src/ast/dump.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::io::Write;

use crate::ast::*;
use tract_core::internal::*;
use tract_itertools::Itertools;
Expand Down Expand Up @@ -319,28 +321,40 @@ impl<'a> Dumper<'a> {
}

fn identifier(&mut self, id: &Identifier) -> TractResult<()> {
if id.0.len() == 0 {
return Ok(());
write_identifier(&mut self.w, id, self.nnef.allow_extended_identifier_syntax, false)
}
}

pub fn write_identifier(
w: &mut dyn Write,
id: &Identifier,
allow_extended_identifier_syntax: bool,
force_double_quotes: bool,
) -> TractResult<()> {
if id.0.len() == 0 {
return Ok(());
}
let first = id.0.chars().next().unwrap();
let force_double_quotes = if force_double_quotes { "\"" } else { "" };
if (first.is_alphabetic() || first == '_')
&& id.0.chars().all(|c| c.is_alphanumeric() || c == '_')
{
write!(w, "{force_double_quotes}{}{force_double_quotes}", id.0)?;
} else if allow_extended_identifier_syntax {
write!(w, "i\"{}\"", id.0.replace('\\', "\\\\").replace('\"', "\\\""))?;
} else {
write!(w, "{force_double_quotes}")?;
if !(first.is_alphabetic() || first == '_') {
write!(w, "_")?;
}
let first = id.0.chars().next().unwrap();
if (first.is_alphabetic() || first == '_')
&& id.0.chars().all(|c| c.is_alphanumeric() || c == '_')
{
write!(self.w, "{}", id.0)?;
} else if self.nnef.allow_extended_identifier_syntax {
write!(self.w, "i\"{}\"", id.0.replace('\\', "\\\\").replace('\"', "\\\""))?;
} else {
if !(first.is_alphabetic() || first == '_') {
write!(self.w, "_")?;
}
for c in id.0.chars() {
if c.is_alphanumeric() {
write!(self.w, "{c}")?;
} else {
write!(self.w, "_")?;
}
for c in id.0.chars() {
if c.is_alphanumeric() {
write!(w, "{c}")?;
} else {
write!(w, "_")?;
}
}
Ok(())
write!(w, "{force_double_quotes}")?;
}
Ok(())
}
23 changes: 10 additions & 13 deletions nnef/src/ast/quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@ use std::str::FromStr;

use nom::branch::permutation;
use nom::character::complete::digit1;
use nom::combinator::map_res;
use nom::sequence::delimited;
use nom::combinator::{map_res, recognize};
use nom::sequence::{delimited, pair};
use tract_core::internal::*;

use nom::{bytes::complete::*, multi::*};
use nom::branch::alt;
use nom::{bytes::complete::*, multi::*};
use nom::{combinator::all_consuming, IResult};
use nom::{combinator::opt, number::complete::float};

use crate::ast::*;

use super::parse::{logical_literal, stag, translate_error, direct_identifier, escaped_identifier};
use super::dump::write_identifier;
use super::parse::{direct_identifier, escaped_identifier, logical_literal, stag, translate_error};

#[inline(never)]
pub fn parse_quantization(doc: &str) -> TractResult<Vec<(Identifier, QuantFormat)>> {
Expand All @@ -31,7 +32,7 @@ fn quantization(i: &str) -> IResult<&str, (Identifier, QuantFormat)> {
}

fn integer_numeric<T: FromStr>(i: &str) -> IResult<&str, T> {
map_res(digit1, |s: &str| s.parse::<T>())(i)
map_res(recognize(pair(opt(tag("-")), digit1)), |s: &str| s.parse::<T>())(i)
}

// <qparam> ::= "<identifier>": <qparam>
Expand Down Expand Up @@ -84,18 +85,14 @@ pub(crate) fn write_quant_format(
format: QuantFormat,
allow_extended_identifier_syntax: bool,
) -> TractResult<()> {
let escaped_name = if allow_extended_identifier_syntax {
format!("i\"{}\"", name.0)
} else {
format!("\"{}\"", name.0)
};
write_identifier(w, name, allow_extended_identifier_syntax, true)?;
match format {
QuantFormat::Linear {
params: QParams::ZpScale {zero_point, scale}, bits, signed
} => writeln!(w, "{}: zero_point_linear_quantize(zero_point = {}, scale = {:.9}, bits = {}, signed = {}, symmetric = {});", escaped_name, zero_point, scale, bits, signed, zero_point == 0)?,
} => writeln!(w, ": zero_point_linear_quantize(zero_point = {zero_point}, scale = {scale:.9}, bits = {bits}, signed = {signed}, symmetric = {});", zero_point == 0)?,
QuantFormat::Linear {
params: QParams::MinMax {min, max}, bits, signed: _
} => writeln!(w, "{}: linear_quantize(max = {:.9}, min = {:.9}, bits = {});", escaped_name, max, min, bits)?, // FIXME we lazily use rust debug escaping form here
} => writeln!(w, ": linear_quantize(max = {max:.9}, min = {min:.9}, bits = {bits});")?,
}
Ok(())
}
Expand Down Expand Up @@ -187,7 +184,7 @@ mod test {
]
);
}

#[test]
fn test_quant_file_1() {
assert_eq!(
Expand Down
Loading
Loading