Skip to content

Commit

Permalink
write test filter in rust instead
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Nov 20, 2023
1 parent 87d16df commit af31422
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 35 deletions.
72 changes: 62 additions & 10 deletions test-rt/infra/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#![allow(clippy::len_zero)]
use core::fmt;
use std::collections::HashMap;
use std::fmt::Debug;
use std::io::Write;

use downcast_rs::Downcast;
use dyn_clone::DynClone;
use itertools::Itertools;
use proptest::prelude::{any_with, Arbitrary};
use proptest::strategy::Strategy;
use proptest::test_runner::{Config, FileFailurePersistence, TestRunner};
use tract_core::internal::Approximation;
use tract_core::runtime::Runtime;
Expand Down Expand Up @@ -68,9 +71,20 @@ impl TestSuite {
id: impl ToString,
params: A::Parameters,
) where
A::Parameters: Clone + Send + Sync,
A::Parameters: Clone + Send + Sync + Debug,
{
self.add(id, ProptestWrapper::<A>(params));
self.add(id, ProptestWrapper::<A>(params, |_| true));
}

pub fn add_arbitrary_with_filter<A: Arbitrary + Test + Clone>(
&mut self,
id: impl ToString,
params: A::Parameters,
filter: fn(&A) -> bool,
) where
A::Parameters: Clone + Send + Sync + Debug,
{
self.add(id, ProptestWrapper::<A>(params, filter));
}

pub fn with(mut self, id: impl ToString, test: impl Into<TestSuite>) -> Self {
Expand Down Expand Up @@ -234,14 +248,48 @@ impl TestSuite {
}
}

#[derive(Clone, Debug)]
struct ProptestWrapper<A: Arbitrary + Test + Clone>(A::Parameters)
/*
trait TestFilter<A>: DynClone + Send + Sync {
fn filter(&self, a: &A) -> bool;
}
dyn_clone::clone_trait_object!(<A> TestFilter<A>);
#[derive(Clone)]
struct AcceptAllFilter;
impl<A> TestFilter<A> for AcceptAllFilter {
fn filter(&self, _a: &A) -> bool {
true
}
}
#[derive(Clone)]
struct FilterWrapper<A, F>(F);
impl<A: Clone, F: Clone> TestFilter<A> for FilterWrapper<A, F> {
fn filter(&self, a: &A) -> bool {
(self.0)(a)
}
}
*/

#[derive(Clone)]
struct ProptestWrapper<A: Arbitrary + Test + Clone>(A::Parameters, fn(&A) -> bool)
where
A::Parameters: Clone + Send + Sync + Debug;

impl<A: Arbitrary + Test + Clone + Send + Sync> Debug for ProptestWrapper<A>
where
A::Parameters: Clone + Send + Sync;
A::Parameters: Clone + Send + Sync + Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self.0)
}
}

impl<A: Arbitrary + Test + Clone> Test for ProptestWrapper<A>
where
A::Parameters: Clone + Send + Sync,
A::Parameters: Clone + Send + Sync + Debug,
{
fn run_with_approx(
&self,
Expand All @@ -253,10 +301,14 @@ where
failure_persistence: Some(Box::new(FileFailurePersistence::Off)),
..Config::default()
});
runner.run(&any_with::<A>(self.0.clone()), |v| {
v.run_with_approx(id, runtime, approx)
.map_err(|e| proptest::test_runner::TestCaseError::Fail(format!("{e:?}").into()))
})?;
runner.run(
&any_with::<A>(self.0.clone()).prop_filter("Test case filter", |a| self.1(a)),
|v| {
v.run_with_approx(id, runtime, approx).map_err(|e| {
proptest::test_runner::TestCaseError::Fail(format!("{e:?}").into())
})
},
)?;
Ok(())
}
}
2 changes: 2 additions & 0 deletions test-rt/test-tflite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ edition = "2021"
lazy_static.workspace = true
regex.workspace = true
infra = { path = "../infra" }
tract-core = { path = "../../core" }
suite-onnx = { path = "../suite-onnx" }
suite-unit = { path = "../suite-unit" }

Expand All @@ -17,6 +18,7 @@ regex.workspace = true
lazy_static.workspace = true
log.workspace = true
tflitec.workspace = true
tract-core = { path = "../../core", version = "=0.20.22-pre" }
tract-tflite = { path = "../../tflite", version = "=0.20.22-pre" }
tract-onnx-opl = { path = "../../onnx-opl", version = "=0.20.22-pre" }
infra = { path = "../infra" }
Expand Down
63 changes: 38 additions & 25 deletions test-rt/test-tflite/suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use infra::Test;
use regex::Regex;
use suite_unit::conv_f32::{ConvProblem, ConvProblemParams};
use suite_unit::conv_q::{QConvProblem, QConvProblemParams};
use tract_core::internal::*;

pub fn suite() -> &'static infra::TestSuite {
lazy_static::lazy_static! {
Expand All @@ -20,14 +21,15 @@ fn mk_suite() -> infra::TestSuite {
let cv =
ConvProblemParams { no_group: true, geo_rank: Some(1..3), ..ConvProblemParams::default() };
unit.get_sub_mut("conv_f32").add_arbitrary::<ConvProblem>("proptest", cv.clone());
unit.get_sub_mut("conv_q").add_arbitrary::<QConvProblem>(
unit.get_sub_mut("conv_q").add_arbitrary_with_filter::<QConvProblem>(
"proptest",
QConvProblemParams {
conv: cv,
no_kernel_zero_point: true,
tflite_rules: true,
..QConvProblemParams::default()
},
compatible_conv_q,
);
infra::TestSuite::default().with("onnx", onnx).with("unit", unit)
}
Expand Down Expand Up @@ -128,34 +130,45 @@ fn skip_onnx(t: &[String]) -> bool {
}

fn ignore_unit(t: &[String], case: &dyn Test) -> bool {
if let Some(cp) = case.downcast_ref::<ConvProblem>() {
if !compatible_conv_f32(cp) {
return true;
}
}
if let Some(qcp) = case.downcast_ref::<QConvProblem>() {
if !is_tflite_compatible(qcp) {
return true
if !compatible_conv_q(qcp) {
return true;
}
}
let [section, unit] = t else { return false };
let unit_exclude_patterns = patterns(
"
# grouping and depthwise
group.*
lazy_im2col_big
lazy_im2col_big_2
batch_3d
bias_3d_1
# kernel with non 0 zero_point
kernel_zp
a0_b0_0
# tflite does not support mixed type convolution
i8_u8.*
u8_i8.*
",
);
["deconv"].contains(&&**section) || unit_exclude_patterns.iter().any(|pat| pat.is_match(unit))
let [section, _unit] = t else { return false };
["deconv"].contains(&&**section)
}

fn is_tflite_compatible(qcp: &QConvProblem) -> bool {
false
fn compatible_conv_f32(qcp: &ConvProblem) -> bool {
qcp.group == 1 && qcp.kernel.ndim() == 4
}

fn compatible_conv_q(qcp: &QConvProblem) -> bool {
if qcp.group != 1 {
return false;
}
let idt = qcp.data.datum_type();
let kdt = qcp.kernel.datum_type();
// all u8 and per-layer
if idt.unquantized() == u8::datum_type()
&& kdt.unquantized() == u8::datum_type()
&& qcp.qp.iter().all(|qp| qp.is_uniform())
{
return true;
}
// all i8 and no zero_point
if idt.unquantized() == i8::datum_type()
&& kdt.unquantized() == i8::datum_type()
&& qcp.qp[0].is_zero().unwrap()
&& qcp.qp[2].is_zero().unwrap()
&& qcp.qp[4].is_zero().unwrap()
{
return true;
}
return false;
}

0 comments on commit af31422

Please sign in to comment.