From 7bc1ef34cfbdff79da80bb220919df45ec823961 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 11 Dec 2023 11:41:40 +0100 Subject: [PATCH] give up on extending tflite support for now --- test-rt/suite-unit/src/q_flavours.rs | 12 ++++++++++++ test-rt/test-tflite/suite.rs | 12 ++++++++---- tflite/src/rewriter.rs | 7 ++++--- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/test-rt/suite-unit/src/q_flavours.rs b/test-rt/suite-unit/src/q_flavours.rs index 8a52812d14..6c516d3a4e 100644 --- a/test-rt/suite-unit/src/q_flavours.rs +++ b/test-rt/suite-unit/src/q_flavours.rs @@ -53,6 +53,7 @@ impl Test for QFlavoursProblem { .run(tvec![self.input.clone().into_tvalue()])? .remove(0) .into_tensor(); + dbg!(&output); let reference = self.input.cast_to::()?; let comparison = output.cast_to::()?; comparison.close_enough(&reference, approx) @@ -62,5 +63,16 @@ impl Test for QFlavoursProblem { pub fn suite() -> TractResult { let mut suite = TestSuite::default(); suite.add_arbitrary::("proptest", ()); + suite.add( + "trivial_0", + QFlavoursProblem { + input: tensor0(0u8) + .cast_to_dt( + u8::datum_type().quantize(QParams::ZpScale { zero_point: 0, scale: 1. }), + ) + .unwrap() + .into_owned(), + }, + ); Ok(suite) } diff --git a/test-rt/test-tflite/suite.rs b/test-rt/test-tflite/suite.rs index 8881ee09d8..628ae1988b 100644 --- a/test-rt/test-tflite/suite.rs +++ b/test-rt/test-tflite/suite.rs @@ -136,8 +136,8 @@ fn ignore_unit(t: &[String], case: &dyn Test) -> bool { return true; } } - let [section, unit] = t else { return false }; - ["deconv"].contains(&&**section) + let [section, _unit] = t else { return false }; + ["deconv", "q_flavours"].contains(&&**section) } fn compatible_conv_f32(qcp: &ConvProblem) -> bool { @@ -154,8 +154,12 @@ fn compatible_conv_q(qcp: &QConvProblem) -> bool { if odt != idt.unquantized() { return false; } - // per-layer (will convert all to u8) - if qcp.qp.iter().all(|qp| qp.is_uniform()) { + + // all u8 and per-layer + if idt.unquantized() == u8::datum_type() + && kdt.unquantized() == u8::datum_type() + && qcp.qp.iter().all(|qp| qp.is_uniform()) + { return true; } // all i8 and no zero_point diff --git a/tflite/src/rewriter.rs b/tflite/src/rewriter.rs index 6f3e747d58..25ca044820 100644 --- a/tflite/src/rewriter.rs +++ b/tflite/src/rewriter.rs @@ -5,8 +5,7 @@ use tract_core::ops::cnn::KernelFormat; use tract_core::ops::cnn::{ConvUnary, PaddingSpec}; use tract_core::ops::einsum::BasicMatMul; use tract_core::ops::element_wise::ElementWiseOp; -use tract_core::ops::math::{add, sub, Recip}; -use tract_core::ops::matmul::mir_quant::wire_ensure_q8_flavour; +use tract_core::ops::math::Recip; use tract_core::ops::nn::{DataFormat, Softmax}; use tract_core::tract_data::itertools::Itertools; @@ -16,7 +15,7 @@ pub fn rewrite_for_tflite(model: &mut TypedModel) -> TractResult<()> { .with_rule_for("trivial_axes_around_matmul", trivial_axes_around_matmul) .with_rule_for("kernel_in_ohwi", kernel_in_ohwi) .with_rule_for("bias_as_vector", bias_as_vector) - .with_rule_for("per_layer_in_u8", per_layer_in_u8) +// .with_rule_for("per_layer_in_u8", per_layer_in_u8) .with_rule_for("make_1d_2d", make_1d_2d) .with_rule_for("force_n_axis", force_n_axis) .with_rule_for("nchw-to-nhwc", nchw_to_nhwc) @@ -128,6 +127,7 @@ fn bias_as_vector( Ok(Some(patch)) } +/* fn per_layer_in_u8( _ctx: &(), model: &TypedModel, @@ -156,6 +156,7 @@ fn per_layer_in_u8( patch.shunt_outside(model, node.id.into(), output[0])?; Ok(Some(patch)) } +*/ fn force_n_axis( _ctx: &(),