Skip to content

Commit

Permalink
test filters
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Dec 14, 2023
1 parent 95ca310 commit 23258e4
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 16 deletions.
29 changes: 17 additions & 12 deletions nnef/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,18 @@ pub fn rewrite_model(model: &mut TypedModel) -> TractResult<()> {
"rewrite_deconv_with_n_axis",
tract_core::ops::cnn::rewrite_deconv_with_n_axis,
)
.with_rule_for("rewrite_kernel_conv_in_oihw", crate::ops::nnef::ser::rewrite_kernel_conv_in_oihw)
.with_rule_for("rewrite_kernel_deconv_in_oihw", crate::ops::nnef::ser::rewrite_kernel_deconv_in_oihw)
.with_rule_for("rewrite_consistent_quantized_conv", crate::ops::nnef::ser::rewrite_consistent_quantized_conv)
.with_rule_for(
"rewrite_kernel_conv_in_oihw",
crate::ops::nnef::ser::rewrite_kernel_conv_in_oihw,
)
.with_rule_for(
"rewrite_kernel_deconv_in_oihw",
crate::ops::nnef::ser::rewrite_kernel_deconv_in_oihw,
)
.with_rule_for(
"rewrite_consistent_quantized_conv",
crate::ops::nnef::ser::rewrite_consistent_quantized_conv,
)
.rewrite(&(), model)
}

Expand Down Expand Up @@ -342,25 +351,21 @@ impl<'a> IntoAst<'a> {
force_variable: bool,
) -> TractResult<Arc<RValue>> {
let mut name: Identifier = name.as_ref().into();
let have_tract_core = self.ensure_registry(&"tract_core".into()).is_ok();
if !force_variable && tensor.len() <= 8 {
if tensor.datum_type() == String::datum_type() {
return Ok(Self::dump_rec_tensor(&tensor.to_array_view::<String>()?, |f| {
string(f)
})
.into());
} else if tensor.datum_type() == DatumType::F16 {
return Ok(
Self::dump_rec_tensor(&tensor.to_array_view::<f16>()?, |f| numeric(f)).into()
);
} else if tensor.datum_type() == DatumType::F32 {
return Ok(
Self::dump_rec_tensor(&tensor.to_array_view::<f32>()?, |f| numeric(f)).into()
);
} else if tensor.datum_type() == DatumType::F64 {
return Ok(
Self::dump_rec_tensor(&tensor.to_array_view::<f64>()?, |f| numeric(f)).into()
);
} else if self.ensure_registry(&"tract_core".into()).is_ok() {
} else if have_tract_core && tensor.datum_type() == DatumType::F16 {
let array = Self::dump_rec_tensor(&tensor.to_array_view::<f16>()?, |f| numeric(f)).into();
return Ok(invocation("tract_core_cast", &[array], &[("to", string("f16"))]));
} else if have_tract_core && tensor.datum_type().is_integer() {
if let Ok(value) = tensor.cast_to::<i64>() {
let value =
Self::dump_rec_tensor(&value.to_array_view::<i64>().unwrap(), |i| {
Expand Down
74 changes: 70 additions & 4 deletions test-rt/test-f16/suite.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use infra::Test;
use suite_unit::conv_q::{QConvProblem, QConvProblemParams};

pub fn suite() -> &'static infra::TestSuite {
lazy_static::lazy_static! {
Expand All @@ -13,16 +14,81 @@ fn mk_suite() -> infra::TestSuite {
onnx.ignore(&ignore_onnx);
let mut unit = suite_unit::suite().unwrap().clone();
unit.ignore_case(&ignore_unit);
unit.get_sub_mut("conv_q").add_arbitrary_with_filter::<QConvProblem>(
"proptest",
QConvProblemParams::default(),
compatible_conv_q,
);
infra::TestSuite::default().with("onnx", onnx).with("unit", unit)
}

fn ignore_unit(_t: &[String], _tc: &dyn Test) -> bool {
false
fn ignore_unit(t: &[String], case: &dyn Test) -> bool {
if let Some(qcp) = case.downcast_ref::<QConvProblem>() {
if !compatible_conv_q(qcp) {
return true;
}
}
let [section, _unit] = t else { return false };
["q_flavours"].contains(&&**section)
}

fn ignore_onnx(t: &[String]) -> bool {
t.last().unwrap().starts_with("test_logsoftmax_large_number")
r#"
test_averagepool_2d_ceil
test_averagepool_2d_pads_count_include_pad
test_averagepool_2d_precomputed_pads_count_include_pad
test_averagepool_2d_same_lower
test_cast_STRING_to_FLOAT
test_castlike_STRING_to_FLOAT_expanded
test_constantlike_ones_with_input
test_constantlike_threes_with_shape_and_dtype
test_constantlike_zeros_without_input_dtype
test_cumsum_1d_exclusive
test_cumsum_1d_reverse_exclusive
test_cumsum_2d
test_dequantizelinear
test_dropout_random
test_dynamicquantizelinear
test_dynamicquantizelinear_max_adjusted
test_dynamicquantizelinear_min_adjusted
test_gemm_broadcast
test_gemm_nobroadcast
test_maxpool_2d_ceil
test_maxpool_2d_same_lower
test_maxpool_with_argmax_2d_precomputed_pads
test_mod_broadcast
test_mod_int64_fmod
test_mod_mixed_sign_float16
test_mod_mixed_sign_float32
test_mod_mixed_sign_float64
test_mod_mixed_sign_int16
test_mod_mixed_sign_int32
test_mod_mixed_sign_int64
test_mod_mixed_sign_int8
test_mod_uint16
test_mod_uint32
test_mod_uint64
test_mod_uint8
test_matmulinteger
test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded
test_nonzero_example
test_quantizelinear
test_qlinearmatmul_2D
test_qlinearmatmul_3D
test_reduce_prod_default_axes_keepdims_example
test_reshape_reordered_dims
test_resize_upsample_scales_linear_align_corners
test_resize_downsample_scales_linear
test_unsqueeze
"#
.trim()
.lines()
.any(|s| t.last().unwrap() == s.trim())
|| t.last().unwrap().starts_with("test_logsoftmax_large_number")
|| t.last().unwrap().starts_with("test_softmax_large_number")
|| t.last().unwrap().starts_with("test_resize")
|| t.last().unwrap() == "test_reduce_prod_default_axes_keepdims_example"
}

fn compatible_conv_q(qcp: &QConvProblem) -> bool {
qcp.qp.iter().all(|t| t.len() == 1)
}

0 comments on commit 23258e4

Please sign in to comment.