diff --git a/nnef/src/ser.rs b/nnef/src/ser.rs index 838999aa38..3168b88871 100644 --- a/nnef/src/ser.rs +++ b/nnef/src/ser.rs @@ -12,9 +12,18 @@ pub fn rewrite_model(model: &mut TypedModel) -> TractResult<()> { "rewrite_deconv_with_n_axis", tract_core::ops::cnn::rewrite_deconv_with_n_axis, ) - .with_rule_for("rewrite_kernel_conv_in_oihw", crate::ops::nnef::ser::rewrite_kernel_conv_in_oihw) - .with_rule_for("rewrite_kernel_deconv_in_oihw", crate::ops::nnef::ser::rewrite_kernel_deconv_in_oihw) - .with_rule_for("rewrite_consistent_quantized_conv", crate::ops::nnef::ser::rewrite_consistent_quantized_conv) + .with_rule_for( + "rewrite_kernel_conv_in_oihw", + crate::ops::nnef::ser::rewrite_kernel_conv_in_oihw, + ) + .with_rule_for( + "rewrite_kernel_deconv_in_oihw", + crate::ops::nnef::ser::rewrite_kernel_deconv_in_oihw, + ) + .with_rule_for( + "rewrite_consistent_quantized_conv", + crate::ops::nnef::ser::rewrite_consistent_quantized_conv, + ) .rewrite(&(), model) } @@ -342,25 +351,21 @@ impl<'a> IntoAst<'a> { force_variable: bool, ) -> TractResult> { let mut name: Identifier = name.as_ref().into(); + let have_tract_core = self.ensure_registry(&"tract_core".into()).is_ok(); if !force_variable && tensor.len() <= 8 { if tensor.datum_type() == String::datum_type() { return Ok(Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| { string(f) }) .into()); - } else if tensor.datum_type() == DatumType::F16 { - return Ok( - Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| numeric(f)).into() - ); } else if tensor.datum_type() == DatumType::F32 { return Ok( Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| numeric(f)).into() ); - } else if tensor.datum_type() == DatumType::F64 { - return Ok( - Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| numeric(f)).into() - ); - } else if self.ensure_registry(&"tract_core".into()).is_ok() { + } else if have_tract_core && tensor.datum_type() == DatumType::F16 { + let array = Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| numeric(f)).into(); + return Ok(invocation("tract_core_cast", &[array], &[("to", string("f16"))])); + } else if have_tract_core && tensor.datum_type().is_integer() { if let Ok(value) = tensor.cast_to::() { let value = Self::dump_rec_tensor(&value.to_array_view::().unwrap(), |i| { diff --git a/test-rt/test-f16/suite.rs b/test-rt/test-f16/suite.rs index 0ba641798b..f98cbe15a3 100644 --- a/test-rt/test-f16/suite.rs +++ b/test-rt/test-f16/suite.rs @@ -1,4 +1,5 @@ use infra::Test; +use suite_unit::conv_q::{QConvProblem, QConvProblemParams}; pub fn suite() -> &'static infra::TestSuite { lazy_static::lazy_static! { @@ -13,16 +14,81 @@ fn mk_suite() -> infra::TestSuite { onnx.ignore(&ignore_onnx); let mut unit = suite_unit::suite().unwrap().clone(); unit.ignore_case(&ignore_unit); + unit.get_sub_mut("conv_q").add_arbitrary_with_filter::( + "proptest", + QConvProblemParams::default(), + compatible_conv_q, + ); infra::TestSuite::default().with("onnx", onnx).with("unit", unit) } -fn ignore_unit(_t: &[String], _tc: &dyn Test) -> bool { - false +fn ignore_unit(t: &[String], case: &dyn Test) -> bool { + if let Some(qcp) = case.downcast_ref::() { + if !compatible_conv_q(qcp) { + return true; + } + } + let [section, _unit] = t else { return false }; + ["q_flavours"].contains(&&**section) } fn ignore_onnx(t: &[String]) -> bool { - t.last().unwrap().starts_with("test_logsoftmax_large_number") + r#" +test_averagepool_2d_ceil +test_averagepool_2d_pads_count_include_pad +test_averagepool_2d_precomputed_pads_count_include_pad +test_averagepool_2d_same_lower +test_cast_STRING_to_FLOAT +test_castlike_STRING_to_FLOAT_expanded +test_constantlike_ones_with_input +test_constantlike_threes_with_shape_and_dtype +test_constantlike_zeros_without_input_dtype +test_cumsum_1d_exclusive +test_cumsum_1d_reverse_exclusive +test_cumsum_2d +test_dequantizelinear +test_dropout_random +test_dynamicquantizelinear +test_dynamicquantizelinear_max_adjusted +test_dynamicquantizelinear_min_adjusted +test_gemm_broadcast +test_gemm_nobroadcast +test_maxpool_2d_ceil +test_maxpool_2d_same_lower +test_maxpool_with_argmax_2d_precomputed_pads +test_mod_broadcast +test_mod_int64_fmod +test_mod_mixed_sign_float16 +test_mod_mixed_sign_float32 +test_mod_mixed_sign_float64 +test_mod_mixed_sign_int16 +test_mod_mixed_sign_int32 +test_mod_mixed_sign_int64 +test_mod_mixed_sign_int8 +test_mod_uint16 +test_mod_uint32 +test_mod_uint64 +test_mod_uint8 +test_matmulinteger +test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded +test_nonzero_example +test_quantizelinear +test_qlinearmatmul_2D +test_qlinearmatmul_3D +test_reduce_prod_default_axes_keepdims_example +test_reshape_reordered_dims +test_resize_upsample_scales_linear_align_corners +test_resize_downsample_scales_linear +test_unsqueeze +"# + .trim() + .lines() + .any(|s| t.last().unwrap() == s.trim()) + || t.last().unwrap().starts_with("test_logsoftmax_large_number") || t.last().unwrap().starts_with("test_softmax_large_number") || t.last().unwrap().starts_with("test_resize") - || t.last().unwrap() == "test_reduce_prod_default_axes_keepdims_example" +} + +fn compatible_conv_q(qcp: &QConvProblem) -> bool { + qcp.qp.iter().all(|t| t.len() == 1) }