Skip to content

Commit

Permalink
give up on extending tflite support for now
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Dec 11, 2023
1 parent fe41722 commit 7bc1ef3
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
12 changes: 12 additions & 0 deletions test-rt/suite-unit/src/q_flavours.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ impl Test for QFlavoursProblem {
.run(tvec![self.input.clone().into_tvalue()])?
.remove(0)
.into_tensor();
dbg!(&output);
let reference = self.input.cast_to::<f32>()?;
let comparison = output.cast_to::<f32>()?;
comparison.close_enough(&reference, approx)
Expand All @@ -62,5 +63,16 @@ impl Test for QFlavoursProblem {
pub fn suite() -> TractResult<TestSuite> {
let mut suite = TestSuite::default();
suite.add_arbitrary::<QFlavoursProblem>("proptest", ());
suite.add(
"trivial_0",
QFlavoursProblem {
input: tensor0(0u8)
.cast_to_dt(
u8::datum_type().quantize(QParams::ZpScale { zero_point: 0, scale: 1. }),
)
.unwrap()
.into_owned(),
},
);
Ok(suite)
}
12 changes: 8 additions & 4 deletions test-rt/test-tflite/suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ fn ignore_unit(t: &[String], case: &dyn Test) -> bool {
return true;
}
}
let [section, unit] = t else { return false };
["deconv"].contains(&&**section)
let [section, _unit] = t else { return false };
["deconv", "q_flavours"].contains(&&**section)
}

fn compatible_conv_f32(qcp: &ConvProblem) -> bool {
Expand All @@ -154,8 +154,12 @@ fn compatible_conv_q(qcp: &QConvProblem) -> bool {
if odt != idt.unquantized() {
return false;
}
// per-layer (will convert all to u8)
if qcp.qp.iter().all(|qp| qp.is_uniform()) {

// all u8 and per-layer
if idt.unquantized() == u8::datum_type()
&& kdt.unquantized() == u8::datum_type()
&& qcp.qp.iter().all(|qp| qp.is_uniform())
{
return true;
}
// all i8 and no zero_point
Expand Down
7 changes: 4 additions & 3 deletions tflite/src/rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use tract_core::ops::cnn::KernelFormat;
use tract_core::ops::cnn::{ConvUnary, PaddingSpec};
use tract_core::ops::einsum::BasicMatMul;
use tract_core::ops::element_wise::ElementWiseOp;
use tract_core::ops::math::{add, sub, Recip};
use tract_core::ops::matmul::mir_quant::wire_ensure_q8_flavour;
use tract_core::ops::math::Recip;
use tract_core::ops::nn::{DataFormat, Softmax};
use tract_core::tract_data::itertools::Itertools;

Expand All @@ -16,7 +15,7 @@ pub fn rewrite_for_tflite(model: &mut TypedModel) -> TractResult<()> {
.with_rule_for("trivial_axes_around_matmul", trivial_axes_around_matmul)
.with_rule_for("kernel_in_ohwi", kernel_in_ohwi)
.with_rule_for("bias_as_vector", bias_as_vector)
.with_rule_for("per_layer_in_u8", per_layer_in_u8)
// .with_rule_for("per_layer_in_u8", per_layer_in_u8)
.with_rule_for("make_1d_2d", make_1d_2d)
.with_rule_for("force_n_axis", force_n_axis)
.with_rule_for("nchw-to-nhwc", nchw_to_nhwc)
Expand Down Expand Up @@ -128,6 +127,7 @@ fn bias_as_vector(
Ok(Some(patch))
}

/*
fn per_layer_in_u8(
_ctx: &(),
model: &TypedModel,
Expand Down Expand Up @@ -156,6 +156,7 @@ fn per_layer_in_u8(
patch.shunt_outside(model, node.id.into(), output[0])?;
Ok(Some(patch))
}
*/

fn force_n_axis(
_ctx: &(),
Expand Down

0 comments on commit 7bc1ef3

Please sign in to comment.