diff --git a/core/src/ops/nn/softmax/mod.rs b/core/src/ops/nn/softmax/mod.rs index d30654c63b..8a9f7f8c03 100644 --- a/core/src/ops/nn/softmax/mod.rs +++ b/core/src/ops/nn/softmax/mod.rs @@ -365,7 +365,8 @@ mod test { fn check(&self) -> Result<()> { let inputs = tvec!(self.data.clone().into_tvalue()); let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float()); - let softmax = Softmax { axes: self.axes.clone(), quant_output_dt }; + let softmax = + Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() }; // Compute quantized output let result = softmax.eval(inputs)?; @@ -375,7 +376,7 @@ mod test { // Compute reference output let input_float = self.data.cast_to::()?; let inputs_float = tvec!(input_float.into_owned().into_tvalue()); - let softmax_float = Softmax { axes: self.axes.clone(), quant_output_dt: None }; + let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() }; let reference_float = softmax_float.eval(inputs_float)?; let reference_array = args_1!(reference_float); let reference = reference_array.to_array_view::()?; diff --git a/linalg/src/generic/softmax.rs b/linalg/src/generic/softmax.rs index 632d4f8bbe..cdbda86216 100644 --- a/linalg/src/generic/softmax.rs +++ b/linalg/src/generic/softmax.rs @@ -117,7 +117,7 @@ proptest::proptest! { #[test] fn t_xnnpack(x in -100f32..0.) { use tract_data::internal::{tensor0, Approximation}; - tensor0(xnnpack_loop2_exp(x)).close_enough(&tensor0(x.exp()), Approximation::Approximate).unwrap(); + tensor0(xnnpack_loop2_exp(x)).close_enough(&tensor0(x.exp()), Approximation::SuperApproximate).unwrap(); } }