From 1c6f1306aced695136fe74ad03f4eba1eefb7fa3 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 23 Jan 2024 13:44:50 +0100 Subject: [PATCH] fix approx test --- core/src/ops/nn/softmax/mod.rs | 5 +++-- linalg/src/generic/softmax.rs | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/ops/nn/softmax/mod.rs b/core/src/ops/nn/softmax/mod.rs index d30654c63b..8a9f7f8c03 100644 --- a/core/src/ops/nn/softmax/mod.rs +++ b/core/src/ops/nn/softmax/mod.rs @@ -365,7 +365,8 @@ mod test { fn check(&self) -> Result<()> { let inputs = tvec!(self.data.clone().into_tvalue()); let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float()); - let softmax = Softmax { axes: self.axes.clone(), quant_output_dt }; + let softmax = + Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() }; // Compute quantized output let result = softmax.eval(inputs)?; @@ -375,7 +376,7 @@ mod test { // Compute reference output let input_float = self.data.cast_to::()?; let inputs_float = tvec!(input_float.into_owned().into_tvalue()); - let softmax_float = Softmax { axes: self.axes.clone(), quant_output_dt: None }; + let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() }; let reference_float = softmax_float.eval(inputs_float)?; let reference_array = args_1!(reference_float); let reference = reference_array.to_array_view::()?; diff --git a/linalg/src/generic/softmax.rs b/linalg/src/generic/softmax.rs index 632d4f8bbe..cdbda86216 100644 --- a/linalg/src/generic/softmax.rs +++ b/linalg/src/generic/softmax.rs @@ -117,7 +117,7 @@ proptest::proptest! { #[test] fn t_xnnpack(x in -100f32..0.) { use tract_data::internal::{tensor0, Approximation}; - tensor0(xnnpack_loop2_exp(x)).close_enough(&tensor0(x.exp()), Approximation::Approximate).unwrap(); + tensor0(xnnpack_loop2_exp(x)).close_enough(&tensor0(x.exp()), Approximation::SuperApproximate).unwrap(); } }