diff --git a/Cargo.toml b/Cargo.toml index b4e7ee848b..0c3527aeea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ members = [ "test-rt/infra", "test-rt/suite-unit", "test-rt/suite-onnx", + "test-rt/test-f16", "test-rt/test-unit-core", "test-rt/test-onnx-core", "test-rt/test-nnef-cycle", diff --git a/core/src/ops/cnn/conv/conv.rs b/core/src/ops/cnn/conv/conv.rs index 44534d336e..30eb43ab11 100644 --- a/core/src/ops/cnn/conv/conv.rs +++ b/core/src/ops/cnn/conv/conv.rs @@ -819,11 +819,11 @@ impl TypedOp for Conv { } if self.q_params.is_some() { ensure!(inputs[3].datum_type == i32::datum_type()); - ensure!(inputs[4].datum_type == f32::datum_type()); + ensure!(inputs[4].datum_type.is_float()); ensure!(inputs[5].datum_type == i32::datum_type()); - ensure!(inputs[6].datum_type == f32::datum_type()); + ensure!(inputs[6].datum_type.is_float()); ensure!(inputs[7].datum_type == i32::datum_type()); - ensure!(inputs[8].datum_type == f32::datum_type()); + ensure!(inputs[8].datum_type.is_float()); } ensure!(self.pool_spec.rank() + 2 == inputs[1].rank()); if self.pool_spec.data_format.shape(&*inputs[0].shape)?.c() diff --git a/core/src/ops/fft.rs b/core/src/ops/fft.rs index a7bb4c88ae..1fa697e0b7 100644 --- a/core/src/ops/fft.rs +++ b/core/src/ops/fft.rs @@ -1,9 +1,9 @@ use crate::internal::*; +use num_complex::Complex; use rustfft::num_traits::{Float, FromPrimitive}; use rustfft::{FftDirection, FftNum}; use tract_data::itertools::Itertools; use tract_ndarray::Axis; -use num_complex::Complex; #[derive(Clone, Debug, Hash)] pub struct Fft { @@ -11,14 +11,11 @@ pub struct Fft { pub inverse: bool, } - - impl Fft { fn eval_t( &self, tensor: &mut Tensor, - ) -> TractResult<()> - { + ) -> TractResult<()> { let mut iterator_shape: TVec = tensor.shape().into(); iterator_shape.pop(); // last dim is [re, im] iterator_shape[self.axis] = 1; @@ -30,16 +27,20 @@ impl Fft { for coords in tract_ndarray::indices(&*iterator_shape) { v.clear(); let mut slice = array.slice_each_axis_mut(|ax| { - if ax.axis.index() == self.axis || ax.stride == 1 { // ax.stride == 1 => last dim + if ax.axis.index() == self.axis || ax.stride == 1 { + // ax.stride == 1 => last dim (..).into() } else { let c = coords[ax.axis.index()] as isize; (c..=c).into() } }); - v.extend(slice.iter().tuples().map(|(r,i)| Complex::new(*r,*i))); + v.extend(slice.iter().tuples().map(|(r, i)| Complex::new(*r, *i))); fft.process(&mut v); - slice.iter_mut().zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter())).for_each(|(s, v)| *s = v); + slice + .iter_mut() + .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter())) + .for_each(|(s, v)| *s = v); } Ok(()) } @@ -65,6 +66,11 @@ impl EvalOp for Fft { fn eval(&self, inputs: TVec) -> TractResult> { let mut tensor = args_1!(inputs).into_tensor(); match tensor.datum_type() { + DatumType::F16 => { + let mut temp = tensor.cast_to::()?.into_owned(); + self.eval_t::(&mut temp)?; + tensor = temp.cast_to::()?.into_owned(); + } DatumType::F32 => self.eval_t::(&mut tensor)?, DatumType::F64 => self.eval_t::(&mut tensor)?, _ => bail!("FFT not implemented for type {:?}", tensor.datum_type()), @@ -75,8 +81,14 @@ impl EvalOp for Fft { impl TypedOp for Fft { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - anyhow::ensure!(inputs[0].rank() >= 2, "Expect rank 2 (one for fft dimension, one for complex dimension"); - anyhow::ensure!(inputs[0].shape.last().unwrap() == &2.to_dim(), "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"); + anyhow::ensure!( + inputs[0].rank() >= 2, + "Expect rank 2 (one for fft dimension, one for complex dimension" + ); + anyhow::ensure!( + inputs[0].shape.last().unwrap() == &2.to_dim(), + "Fft operators expect inner (last) dimension to be 2 for real and imaginary part" + ); Ok(tvec!(inputs[0].without_value())) } @@ -91,14 +103,11 @@ pub struct Stft { pub window: Option>, } - - impl Stft { fn eval_t( &self, input: &Tensor, - ) -> TractResult - { + ) -> TractResult { let mut iterator_shape: TVec = input.shape().into(); iterator_shape.pop(); // [re,im] iterator_shape[self.axis] = 1; @@ -135,10 +144,19 @@ impl Stft { }); for f in 0..frames { v.clear(); - v.extend(islice.iter().tuples().skip(self.stride * f).take(self.frame).map(|(re,im)| Complex::new(*re, *im))); + v.extend( + islice + .iter() + .tuples() + .skip(self.stride * f) + .take(self.frame) + .map(|(re, im)| Complex::new(*re, *im)), + ); if let Some(win) = &self.window { let win = win.as_slice::()?; - v.iter_mut().zip(win.iter()).for_each(|(v, w)| *v = *v * Complex::new(*w, T::zero())); + v.iter_mut() + .zip(win.iter()) + .for_each(|(v, w)| *v = *v * Complex::new(*w, T::zero())); } fft.process(&mut v); oslice @@ -168,6 +186,10 @@ impl EvalOp for Stft { fn eval(&self, inputs: TVec) -> TractResult> { let input = args_1!(inputs); let output = match input.datum_type() { + DatumType::F16 => { + let temp = input.cast_to::()?; + self.eval_t::(&temp)?.cast_to::()?.into_owned() + } DatumType::F32 => self.eval_t::(&input)?, DatumType::F64 => self.eval_t::(&input)?, _ => bail!("FFT not implemented for type {:?}", input.datum_type()), @@ -178,8 +200,14 @@ impl EvalOp for Stft { impl TypedOp for Stft { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - anyhow::ensure!(inputs[0].rank() >= 2, "Expect rank 2 (one for fft dimension, one for complex dimension"); - anyhow::ensure!(inputs[0].shape.last().unwrap() == &2.to_dim(), "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"); + anyhow::ensure!( + inputs[0].rank() >= 2, + "Expect rank 2 (one for fft dimension, one for complex dimension" + ); + anyhow::ensure!( + inputs[0].shape.last().unwrap() == &2.to_dim(), + "Fft operators expect inner (last) dimension to be 2 for real and imaginary part" + ); let mut shape = inputs[0].shape.to_tvec(); let frames = (inputs[0].shape[self.axis].clone() - self.frame) / self.stride + 1; shape[self.axis] = frames; diff --git a/core/src/ops/math/mod.rs b/core/src/ops/math/mod.rs index c9461c3dab..b60fb0f2f0 100644 --- a/core/src/ops/math/mod.rs +++ b/core/src/ops/math/mod.rs @@ -8,6 +8,7 @@ use num_traits::bounds::Bounded; use num_traits::int::PrimInt; use num_traits::{Float, Zero}; use tract_data::internal::ClampCast; +use tract_data::itertools::Itertools; pub use tract_data::prelude::round_ties_to_even; use tract_linalg::{ScaleShiftAndRound, Scaler}; use tract_num_traits::AsPrimitive; @@ -182,7 +183,7 @@ bin_to_super_type!(max, Max, linalg:Max, bin_to_super_type!(pow, Pow, declutter: declutter_pow, - [f32, f64] => |c,a,b| *c = a.powf(*b), + [f16, f32, f64] => |c,a,b| *c = a.powf(*b), [i32, i64] => |c,a,b| *c = a.pow(*b as u32)); bin_to_super_type!(shift_left, ShiftLeft, @@ -493,9 +494,14 @@ element_wise!(q_scale, QScale{scaler: Scaler},[i32] => |op, xs| { Ok(()) }); -element_wise!(round_half_to_even, RoundHalfToEven,[ f32] => |_, xs| { +element_wise!(round_half_to_even, RoundHalfToEven, +[f32] => |_, xs| { xs.iter_mut().for_each(|x| *x = round_ties_to_even(*x)); Ok(()) +}, +[f16] => |_, xs| { + xs.iter_mut().for_each(|x| *x = f16::from_f32(round_ties_to_even(x.to_f32()))); + Ok(()) }; q: [i8, u8, i32] => round_ties_to_even); @@ -556,7 +562,13 @@ element_wise!(tanh, Tanh, ); element_wise!(erf, Erf, - [f32] => |_, xs| { (tract_linalg::ops().erf_f32)().run(xs) }; + [f32] => |_, xs| { (tract_linalg::ops().erf_f32)().run(xs) }, + [f16] => |_, xs| { + let mut f32s = xs.iter().map(|x| x.to_f32()).collect_vec(); + (tract_linalg::ops().erf_f32)().run(&mut f32s)?; + xs.iter_mut().zip(f32s.into_iter()).for_each(|(x, f)| *x = f16::from_f32(f)); + Ok(()) +}; cost: |dt| {tvec!((Cost::FMA(dt), 11), (Cost::Div(dt), 1))} ); diff --git a/core/src/ops/nn/mod.rs b/core/src/ops/nn/mod.rs index 84fffa013a..1c53a8cfed 100644 --- a/core/src/ops/nn/mod.rs +++ b/core/src/ops/nn/mod.rs @@ -15,6 +15,7 @@ element_wise!(sigmoid, Sigmoid, ); element_wise!(hard_swish, HardSwish, +[f16] => |_, xs| { xs.iter_mut().for_each(|x| *x = *x * f16::from_f32(0.0).max(f16::from_f32(1.0).min(f16::from_f32(1. / 6.) * *x + f16::from_f32(0.5)))); Ok(()) }, [f32] => |_, xs| { xs.iter_mut().for_each(|x| *x = *x * 0f32.max(1f32.min((1. / 6.) * *x + 0.5))); Ok(()) } ); diff --git a/core/src/ops/nn/reduce.rs b/core/src/ops/nn/reduce.rs index 62d184b968..dc082da7e4 100644 --- a/core/src/ops/nn/reduce.rs +++ b/core/src/ops/nn/reduce.rs @@ -15,6 +15,7 @@ macro_rules! r { DatumType::I16 => $($path)::*::($($args),*), DatumType::I32 => $($path)::*::($($args),*), DatumType::I64 => $($path)::*::($($args),*), + DatumType::F16 => $($path)::*::($($args),*), DatumType::F32 => $($path)::*::($($args),*), DatumType::F64 => $($path)::*::($($args),*), DatumType::QI8(_) => $($path)::*::($($args),*), @@ -30,6 +31,7 @@ macro_rules! r { DatumType::I16 => $($path)::*::($($args),*), DatumType::I32 => $($path)::*::($($args),*), DatumType::I64 => $($path)::*::($($args),*), + DatumType::F16 => $($path)::*::($($args),*), DatumType::F32 => $($path)::*::($($args),*), DatumType::F64 => $($path)::*::($($args),*), DatumType::QI8(_) => $($q_path)::*::($($q_args),*), diff --git a/core/src/ops/nn/softmax/mod.rs b/core/src/ops/nn/softmax/mod.rs index 03aec98aa9..c33d070de1 100644 --- a/core/src/ops/nn/softmax/mod.rs +++ b/core/src/ops/nn/softmax/mod.rs @@ -15,7 +15,7 @@ use ndarray::prelude::*; #[derive(Debug, Clone, new, Hash)] pub struct Softmax { pub axes: TVec, - pub output_dt: DatumType, + pub quant_output_dt: Option, } impl Op for Softmax { @@ -35,27 +35,25 @@ impl TypedOp for Softmax { let dt = inputs[0].datum_type; if dt.is_float() { ensure!( - dt == self.output_dt, - "Softmax input {:?} and output {:?} types in float case should be equal", - dt, - self.output_dt + self.quant_output_dt.is_none(), + "Float softmax should not have quant_output_dt, have {:?}", + self.quant_output_dt ); } else if dt.is_quantized() { ensure!( - self.output_dt.is_quantized(), - "Quantized softmax must have input {:?} and output {:?} quantized ", - dt, - self.output_dt + self.quant_output_dt.map(|q| q.is_quantized()).unwrap_or(false), + "Quantized softmax should have a quantized output type (got {:?})", + self.quant_output_dt ); } else { bail!( "Unsupported datum type in softmax: input type {:?}, output type {:?}", dt, - self.output_dt + self.quant_output_dt ); } - let fact = self.output_dt.fact(inputs[0].shape.clone()); + let fact = self.quant_output_dt.unwrap_or(dt).fact(inputs[0].shape.clone()); Ok(tvec!(fact)) } @@ -80,7 +78,7 @@ impl TypedOp for Softmax { Ok(Some(AxisChangeConsequence::new( model, node, - Some(Box::new(Softmax { axes, output_dt: self.output_dt })), + Some(Box::new(Softmax { axes, ..self.clone() })), change, ))) } else { @@ -104,8 +102,8 @@ impl EvalOp for Softmax { DatumType::F64 => self.eval_t::(input)?, DatumType::F32 => self.eval_t::(input)?, DatumType::F16 => self.eval_t::(input)?, - DatumType::QI8(_) | DatumType::QU8(_) => self.eval_quant_t(input)?, - dt => bail!("Unsupported type {:?}", dt), + DatumType::QI8(_) | DatumType::QU8(_) => self.eval_quant(input)?, + dt => bail!("Unsupported type {dt:?}"), }; Ok(output) } @@ -139,8 +137,10 @@ impl Softmax { Ok(tvec!(output.into_tvalue())) } - fn eval_quant_t(&self, input: TValue) -> TractResult> { + fn eval_quant(&self, input: TValue) -> TractResult> { let mut iterating_shape: TVec = input.shape().into(); + let output_dt = + self.quant_output_dt.context("Quandized softmax eval with no output type")?; for i in 0..iterating_shape.len() { if self.axes.contains(&i) { @@ -150,9 +150,9 @@ impl Softmax { // All operations will be done in u8, we will cast the result appropriately afterward. let src_is_signed = input.datum_type().is_signed(); - let out_is_signed = self.output_dt.is_signed(); + let out_is_signed = output_dt.is_signed(); let in_qp = input.datum_type().qparams().unwrap(); // Checked as we are in the quant case - let out_qp = self.output_dt.qparams().unwrap(); // Checked as we are in the quant case + let out_qp = output_dt.qparams().unwrap(); // Checked as we are in the quant case let mut output = unsafe { input.into_tensor().into_array_unchecked::() }; for it_coords in tract_ndarray::indices(&*iterating_shape) { @@ -166,7 +166,7 @@ impl Softmax { } let mut output_tensor = output.into_tensor(); - unsafe { output_tensor.set_datum_type(self.output_dt) }; + unsafe { output_tensor.set_datum_type(output_dt) }; Ok(tvec!(output_tensor.into_tvalue())) } } @@ -327,7 +327,8 @@ mod test { impl SoftmaxProblem { fn check(&self) -> Result<()> { let inputs = tvec!(self.data.clone().into_tvalue()); - let softmax = Softmax { axes: self.axes.clone(), output_dt: self.output_dt }; + let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float()); + let softmax = Softmax { axes: self.axes.clone(), quant_output_dt }; // Compute quantized output let result = softmax.eval(inputs)?; @@ -337,7 +338,7 @@ mod test { // Compute reference output let input_float = self.data.cast_to::()?; let inputs_float = tvec!(input_float.into_owned().into_tvalue()); - let softmax_float = Softmax { axes: self.axes.clone(), output_dt: DatumType::F32 }; + let softmax_float = Softmax { axes: self.axes.clone(), quant_output_dt: None }; let reference_float = softmax_float.eval(inputs_float)?; let reference_array = args_1!(reference_float); let reference = reference_array.to_array_view::()?; diff --git a/core/src/ops/quant.rs b/core/src/ops/quant.rs index a498e9dfcf..28854e5bf7 100644 --- a/core/src/ops/quant.rs +++ b/core/src/ops/quant.rs @@ -26,6 +26,12 @@ element_wise_oop!(quantize_linear_u8, scale: f32, zero_point: u8 }, + [f16] => u8 |op, xs, ys| { + xs.iter().zip(ys.iter_mut()).for_each(|(x,y)| + *y = quantize_linear_f32_u8(x.to_f32(), op.scale, op.zero_point as i32) + ); + Ok(()) + }, [f32,i32] => u8 |op, xs, ys| { xs.iter().zip(ys.iter_mut()).for_each(|(x,y)| *y = quantize_linear_f32_u8(*x as f32, op.scale, op.zero_point as i32) @@ -273,32 +279,33 @@ impl crate::ops::binary::BinMiniOp for Scale { } fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult { - if a != f32::datum_type() { - bail!("Scale left operand must be f32, got {:?}", a); + if !a.is_float() { + bail!("Scale left operand must be float, got {:?}", a); } Ok(b) } fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult { - if a != f32::datum_type() { - bail!("Scale left operand must be f32, got {:?}", a); + if !a.is_float() { + bail!("Scale left operand must be float, got {:?}", a); } Ok(b) } fn eval_uniform_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> { - let a = a.to_scalar::()?; + let a = a.cast_to_scalar::()?; unsafe fn eval_in_place_t>(a: f32, b: &mut Tensor) where f32: AsPrimitive, { b.as_slice_mut_unchecked::().iter_mut().for_each(|x| *x = scale_by(*x, a)); } - unsafe { dispatch_numbers!(eval_in_place_t(b.datum_type())(*a, b)) } + unsafe { dispatch_numbers!(eval_in_place_t(b.datum_type())(a, b)) } Ok(()) } fn eval_unicast_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> { + let a = a.cast_to::()?; let a = a.to_array_view::()?; unsafe fn eval_in_place_t>( a: &ndarray::ArrayViewD, @@ -314,6 +321,7 @@ impl crate::ops::binary::BinMiniOp for Scale { } fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> { + let a = a.cast_to::()?; let a = a.to_array_view::()?; unsafe fn eval_out_of_place_t>( c: &mut Tensor, @@ -334,7 +342,6 @@ impl crate::ops::binary::BinMiniOp for Scale { } fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> { - // a is f32 by construction (scaler). if we are here in mean c is also f32, so b is f32 let a = a.to_array_view_mut::()?; let b = b.to_array_view::()?; ndarray::Zip::from(a).and_broadcast(b).for_each(|a, b| *a = scale_by(*b, *a)); @@ -348,7 +355,7 @@ impl crate::ops::binary::BinMiniOp for Scale { ) -> TractResult> { let a = model.outlet_fact(node.inputs[0])?; if let Some(a) = &a.uniform { - if *a.to_scalar::()? == 1. { + if a.cast_to_scalar::()? == 1. { return Ok(Some(TypedModelPatch::rewire( model, &node.inputs[1..2], @@ -356,7 +363,7 @@ impl crate::ops::binary::BinMiniOp for Scale { &|_p, x| Ok(x.into()), )?)); } else if node.outputs[0].fact.datum_type == DatumType::I32 { - let factor = *a.to_scalar::()?; + let factor = a.cast_to_scalar::()?; let scaler = Scaler::new(factor, RoundingPolicy::Even); let op = ElementWiseOp(Box::new(QScale { scaler })); diff --git a/core/src/plan.rs b/core/src/plan.rs index 0930541b65..c073cd70b0 100644 --- a/core/src/plan.rs +++ b/core/src/plan.rs @@ -357,7 +357,7 @@ where .model() .input_outlets()? .get(input) - .ok_or_else(|| format_err!("Invalid input id for model ({}).", input))?; + .with_context(|| format!("Invalid input id for model ({input})."))?; let SimpleState { plan, session_state, .. } = self; let plan = (*plan).borrow(); let model = plan.model.borrow(); @@ -370,11 +370,7 @@ where ensure!( fact.matches(&t, Some(&self.session_state.resolved_symbols)) .with_context(|| format!("Setting input {input}"))?, - "Input at index {} has incorrect dtype or shape (got shape {:?} and dtype {:?}, expected to match fact {:?})", - input, - t.shape(), - t.datum_type(), - fact + "Input at index {input} has incorrect dtype or shape (got {t:?}, expected to match fact {fact:?})", ); self.session_state.inputs.insert(outlet.node, t); Ok(()) diff --git a/data/src/tensor.rs b/data/src/tensor.rs index 0950f992d0..ce047e0b84 100644 --- a/data/src/tensor.rs +++ b/data/src/tensor.rs @@ -24,6 +24,7 @@ pub enum Approximation { Exact, Close, Approximate, + SuperApproximate, } impl From for Approximation { @@ -46,6 +47,7 @@ impl Approximation { (Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0.), (Close, _) => (1e-7, 1e-7), (Approximate, _) => (1e-4, 5e-4), + (SuperApproximate, _) => (5e-2, 1e-2), } } } diff --git a/hir/src/ops/nn/layer_max.rs b/hir/src/ops/nn/layer_max.rs index 60f2fd2c3f..9fc5161926 100644 --- a/hir/src/ops/nn/layer_max.rs +++ b/hir/src/ops/nn/layer_max.rs @@ -146,6 +146,7 @@ impl Expansion for LayerSoftmax { let axis = if self.axis < 0 { rank as isize + self.axis } else { self.axis } as usize; let reducing_axes = if self.coerce_to_2d { (axis..rank).collect::>() } else { tvec!(axis) }; + let dt = if dt.is_float() { None } else { Some(dt) }; target.wire_node(name, tract_core::ops::nn::Softmax::new(reducing_axes, dt), inputs) } } diff --git a/hir/src/ops/nn/softmax.rs b/hir/src/ops/nn/softmax.rs index fac5eacd2c..eaf64b5eca 100644 --- a/hir/src/ops/nn/softmax.rs +++ b/hir/src/ops/nn/softmax.rs @@ -44,17 +44,17 @@ impl Expansion for Softmax { let input = target.outlet_fact(inputs[0])?.clone(); let input_dt = input.datum_type; - let output_dt = if input_dt.is_quantized() { + let quant_output_dt = if input_dt.is_quantized() { // Quantization parameters are not specified in ONNX (v13) so we set this value as default // in order to maximize the precision of the output. - DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 0.0078125 }) + Some(DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 0.0078125 })) } else { - input_dt + None }; target.wire_node( name, - tract_core::ops::nn::Softmax { axes: tvec![axis], output_dt }, + tract_core::ops::nn::Softmax { axes: tvec![axis], quant_output_dt }, inputs, ) } diff --git a/nnef/src/ops/nnef/deser.rs b/nnef/src/ops/nnef/deser.rs index 321305d50f..97d264f0de 100644 --- a/nnef/src/ops/nnef/deser.rs +++ b/nnef/src/ops/nnef/deser.rs @@ -680,8 +680,11 @@ pub fn softmax(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> T let axes: TVec = invocation.named_arg_as(builder, "axes")?; let input_fact = builder.model.outlet_fact(x)?.clone(); - let output_dt = - invocation.dt_from_quant_file.get(0).cloned().flatten().unwrap_or(input_fact.datum_type); + let quant_output_dt = if input_fact.datum_type.is_float() { + None + } else { + invocation.dt_from_quant_file.get(0).cloned().flatten() + }; - builder.wire(ops::nn::Softmax { axes, output_dt }, &[x]) + builder.wire(ops::nn::Softmax { axes, quant_output_dt }, &[x]) } diff --git a/nnef/src/ser.rs b/nnef/src/ser.rs index c2265d29d0..3168b88871 100644 --- a/nnef/src/ser.rs +++ b/nnef/src/ser.rs @@ -12,9 +12,18 @@ pub fn rewrite_model(model: &mut TypedModel) -> TractResult<()> { "rewrite_deconv_with_n_axis", tract_core::ops::cnn::rewrite_deconv_with_n_axis, ) - .with_rule_for("rewrite_kernel_conv_in_oihw", crate::ops::nnef::ser::rewrite_kernel_conv_in_oihw) - .with_rule_for("rewrite_kernel_deconv_in_oihw", crate::ops::nnef::ser::rewrite_kernel_deconv_in_oihw) - .with_rule_for("rewrite_consistent_quantized_conv", crate::ops::nnef::ser::rewrite_consistent_quantized_conv) + .with_rule_for( + "rewrite_kernel_conv_in_oihw", + crate::ops::nnef::ser::rewrite_kernel_conv_in_oihw, + ) + .with_rule_for( + "rewrite_kernel_deconv_in_oihw", + crate::ops::nnef::ser::rewrite_kernel_deconv_in_oihw, + ) + .with_rule_for( + "rewrite_consistent_quantized_conv", + crate::ops::nnef::ser::rewrite_consistent_quantized_conv, + ) .rewrite(&(), model) } @@ -342,6 +351,7 @@ impl<'a> IntoAst<'a> { force_variable: bool, ) -> TractResult> { let mut name: Identifier = name.as_ref().into(); + let have_tract_core = self.ensure_registry(&"tract_core".into()).is_ok(); if !force_variable && tensor.len() <= 8 { if tensor.datum_type() == String::datum_type() { return Ok(Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| { @@ -352,7 +362,10 @@ impl<'a> IntoAst<'a> { return Ok( Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| numeric(f)).into() ); - } else if self.ensure_registry(&"tract_core".into()).is_ok() { + } else if have_tract_core && tensor.datum_type() == DatumType::F16 { + let array = Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| numeric(f)).into(); + return Ok(invocation("tract_core_cast", &[array], &[("to", string("f16"))])); + } else if have_tract_core && tensor.datum_type().is_integer() { if let Ok(value) = tensor.cast_to::() { let value = Self::dump_rec_tensor(&value.to_array_view::().unwrap(), |i| { diff --git a/onnx-opl/src/is_inf.rs b/onnx-opl/src/is_inf.rs index 905abae7ff..819eacae1b 100644 --- a/onnx-opl/src/is_inf.rs +++ b/onnx-opl/src/is_inf.rs @@ -7,6 +7,12 @@ tract_core::element_wise_oop!(is_inf, IsInf { detect_positive: bool, detect_nega *y = (op.detect_positive && *x == std::f32::INFINITY) || (op.detect_negative && *x == std::f32::NEG_INFINITY) ); Ok(()) + }, + [f16] => bool |op, xs, ys| { + xs.iter().zip(ys.iter_mut()).for_each(|(x,y)| + *y = (op.detect_positive && *x == f16::INFINITY) || (op.detect_negative && *x == f16::NEG_INFINITY) + ); + Ok(()) }; prefix: "onnx." ); diff --git a/onnx-opl/src/is_nan.rs b/onnx-opl/src/is_nan.rs index ce507dea96..05f2d322ab 100644 --- a/onnx-opl/src/is_nan.rs +++ b/onnx-opl/src/is_nan.rs @@ -4,6 +4,10 @@ element_wise_oop!(is_nan, IsNan, [f32] => bool |_, xs, ys| { xs.iter().zip(ys.iter_mut()).for_each(|(x,y)| *y = x.is_nan()); Ok(()) + }, + [f16] => bool |_, xs, ys| { + xs.iter().zip(ys.iter_mut()).for_each(|(x,y)| *y = x.is_nan()); + Ok(()) }; prefix: "onnx." ); diff --git a/onnx-opl/src/non_max_suppression.rs b/onnx-opl/src/non_max_suppression.rs index 64555be336..0ea3d0498a 100644 --- a/onnx-opl/src/non_max_suppression.rs +++ b/onnx-opl/src/non_max_suppression.rs @@ -1,5 +1,6 @@ use std::cmp::Ordering; +use rustfft::num_traits::Float; use tract_nnef::{ internal::*, tract_ndarray::{s, ArrayView1}, @@ -7,10 +8,10 @@ use tract_nnef::{ pub fn register(registry: &mut Registry) { registry.register_primitive( - "tract_onnx_non_max_suppression", + "tract_onnx_non_max_suppression", ¶meters(), - &[("output", TypeName::Integer.tensor())], - load + &[("output", TypeName::Integer.tensor())], + load, ); registry.register_dumper(dump); } @@ -23,7 +24,7 @@ pub enum BoxRepr { CenterWidthHeight, } -fn get_min_max(lhs: f32, rhs: f32) -> (f32, f32) { +fn get_min_max(lhs: T, rhs: T) -> (T, T) { if lhs >= rhs { (rhs, lhs) } else { @@ -48,12 +49,13 @@ impl BoxRepr { } // iou: intersection over union - fn should_suppress_by_iou( + fn should_suppress_by_iou( &self, - box1: ArrayView1, - box2: ArrayView1, - iou_threshold: f32, + box1: ArrayView1, + box2: ArrayView1, + iou_threshold: T, ) -> bool { + let two = T::one() + T::one(); let (x1_min, x1_max, x2_min, x2_max, y1_min, y1_max, y2_min, y2_max) = match self { BoxRepr::TwoPoints => { let (x1_min, x1_max) = get_min_max(box1[[1]], box1[[3]]); @@ -65,8 +67,8 @@ impl BoxRepr { (x1_min, x1_max, x2_min, x2_max, y1_min, y1_max, y2_min, y2_max) } BoxRepr::CenterWidthHeight => { - let (box1_width_half, box1_height_half) = (box1[[2]] / 2.0, box1[[3]] / 2.0); - let (box2_width_half, box2_height_half) = (box2[[2]] / 2.0, box2[[3]] / 2.0); + let (box1_width_half, box1_height_half) = (box1[[2]] / two, box1[[3]] / two); + let (box2_width_half, box2_height_half) = (box2[[2]] / two, box2[[3]] / two); let (x1_min, x1_max) = (box1[[0]] - box1_width_half, box1[[0]] + box1_width_half); let (x2_min, x2_max) = (box2[[0]] - box2_width_half, box2[[0]] + box2_width_half); @@ -78,14 +80,14 @@ impl BoxRepr { } }; - let intersection_y_min = f32::max(y1_min, y2_min); - let intersection_y_max = f32::min(y1_max, y2_max); + let intersection_y_min = T::max(y1_min, y2_min); + let intersection_y_max = T::min(y1_max, y2_max); if intersection_y_max <= intersection_y_min { return false; } - let intersection_x_min = f32::max(x1_min, x2_min); - let intersection_x_max = f32::min(x1_max, x2_max); + let intersection_x_min = T::max(x1_min, x2_min); + let intersection_x_max = T::min(x1_max, x2_max); if intersection_x_max <= intersection_x_min { return false; } @@ -93,7 +95,7 @@ impl BoxRepr { let intersection_area = (intersection_x_max - intersection_x_min) * (intersection_y_max - intersection_y_min); - if intersection_area <= 0.0 { + if intersection_area.is_sign_negative() { return false; } @@ -102,7 +104,7 @@ impl BoxRepr { let union_area = area1 + area2 - intersection_area; - if area1 <= 0.0 || area2 <= 0.0 || union_area <= 0.0 { + if area1.is_sign_negative() || area2.is_sign_negative() || union_area.is_sign_negative() { return false; } @@ -119,22 +121,8 @@ pub struct NonMaxSuppression { pub has_score_threshold: bool, } - - -impl Op for NonMaxSuppression { - fn name(&self) -> Cow { - "NonMaxSuppression".into() - } - - op_as_typed_op!(); -} - -impl EvalOp for NonMaxSuppression { - fn is_stateless(&self) -> bool { - true - } - - fn eval(&self, inputs: TVec) -> TractResult> { +impl NonMaxSuppression { + fn eval_t(&self, inputs: TVec) -> TractResult> { let (boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) = if self.has_score_threshold { let (t1, t2, t3, t4, t5) = args_5!(inputs); @@ -145,21 +133,21 @@ impl EvalOp for NonMaxSuppression { }; let mut max_output_boxes_per_class = *max_output_boxes_per_class.to_scalar::()?; - let iou_threshold = *iou_threshold.to_scalar::()?; + let iou_threshold = *iou_threshold.to_scalar::()?; let score_threshold = score_threshold - .map_or(Ok::<_, TractError>(None), |val| Ok(Some(*val.to_scalar::()?)))?; + .map_or(Ok::<_, TractError>(None), |val| Ok(Some(*val.to_scalar::()?)))?; if max_output_boxes_per_class == 0 { max_output_boxes_per_class = i64::MAX; } - ensure!((0.0..=1.0).contains(&iou_threshold), "iou_threshold must be between 0 and 1"); + // ensure!((0.0..=1.0).contains(&iou_threshold), "iou_threshold must be between 0 and 1"); let num_batches = scores.shape()[0]; let num_classes = scores.shape()[1]; let num_dim = scores.shape()[2]; - let boxes = boxes.to_array_view::()?; - let scores = scores.to_array_view::()?; + let boxes = boxes.to_array_view::()?; + let scores = scores.to_array_view::()?; // items: (batch, class, index) let mut selected_global: TVec<(usize, usize, usize)> = tvec![]; @@ -167,7 +155,7 @@ impl EvalOp for NonMaxSuppression { for batch in 0..num_batches { for class in 0..num_classes { // items: (score, index) - let mut candidates: TVec<(f32, usize)> = + let mut candidates: TVec<(T, usize)> = if let Some(score_threshold) = score_threshold { (0..num_dim) .map(|i| (scores[[batch, class, i]], i)) @@ -180,7 +168,7 @@ impl EvalOp for NonMaxSuppression { candidates.sort_by(|(a, _), (b, _)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); // items: (score, index) - let mut selected_in_class: TVec<(f32, usize)> = tvec![]; + let mut selected_in_class: TVec<(T, usize)> = tvec![]; for (score, index) in candidates { if selected_in_class.len() as i64 >= max_output_boxes_per_class { @@ -212,6 +200,25 @@ impl EvalOp for NonMaxSuppression { } } +impl Op for NonMaxSuppression { + fn name(&self) -> Cow { + "NonMaxSuppression".into() + } + + op_as_typed_op!(); +} + +impl EvalOp for NonMaxSuppression { + fn is_stateless(&self) -> bool { + true + } + + fn eval(&self, inputs: TVec) -> TractResult> { + let dt = inputs[0].datum_type(); + dispatch_floatlike!(Self::eval_t(dt)(self, inputs)) + } +} + impl TypedOp for NonMaxSuppression { fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult> { Ok(tvec![i64::fact([self.num_selected_indices_symbol.to_dim(), 3usize.to_dim()])]) @@ -231,7 +238,11 @@ fn parameters() -> Vec { ] } -fn dump(ast: &mut IntoAst, node: &TypedNode, op: &NonMaxSuppression) -> TractResult>> { +fn dump( + ast: &mut IntoAst, + node: &TypedNode, + op: &NonMaxSuppression, +) -> TractResult>> { let boxes = ast.mapping[&node.inputs[0]].clone(); let scores = ast.mapping[&node.inputs[1]].clone(); let max_output_boxes_per_class = ast.mapping[&node.inputs[2]].clone(); @@ -255,10 +266,7 @@ fn dump(ast: &mut IntoAst, node: &TypedNode, op: &NonMaxSuppression) -> TractRes Ok(Some(inv)) } -fn load( - builder: &mut ModelBuilder, - invocation: &ResolvedInvocation, -) -> TractResult { +fn load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult { let boxes = invocation.named_arg_as(builder, "boxes")?; let scores = invocation.named_arg_as(builder, "scores")?; let max_output_boxes_per_class = diff --git a/onnx/src/ops/quant.rs b/onnx/src/ops/quant.rs index aa28215e60..b0102eb6de 100644 --- a/onnx/src/ops/quant.rs +++ b/onnx/src/ops/quant.rs @@ -273,6 +273,7 @@ impl EvalOp for DynamicQuantizeLinearU8 { } fn eval(&self, inputs: TVec) -> TractResult> { let input = &inputs[0]; + let input = input.cast_to::()?; let a_input = input.to_array_view::()?; let (scale, zero_point) = scale_and_zero_point(a_input); diff --git a/test-rt/test-f16/Cargo.toml b/test-rt/test-f16/Cargo.toml new file mode 100644 index 0000000000..847009d57f --- /dev/null +++ b/test-rt/test-f16/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "test-f16" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] + +[build-dependencies] +lazy_static.workspace = true +regex.workspace = true +infra = { path = "../infra" } +tract-core = { path = "../../core" } +suite-onnx = { path = "../suite-onnx" } +suite-unit = { path = "../suite-unit" } +tract-onnx-opl = { path = "../../onnx-opl", version = "=0.20.23-pre" } + +[dev-dependencies] +regex.workspace = true +lazy_static.workspace = true +log.workspace = true +tflitec.workspace = true +tract-core = { path = "../../core", version = "=0.20.23-pre" } +tract-nnef = { path = "../../nnef", version = "=0.20.23-pre" } +tract-onnx-opl = { path = "../../onnx-opl", version = "=0.20.23-pre" } +infra = { path = "../infra" } +suite-onnx = { path = "../suite-onnx" } +suite-unit = { path = "../suite-unit" } diff --git a/test-rt/test-f16/build.rs b/test-rt/test-f16/build.rs new file mode 100644 index 0000000000..1d07109787 --- /dev/null +++ b/test-rt/test-f16/build.rs @@ -0,0 +1,7 @@ +#[path="suite.rs"] +mod suite; + +fn main() { + suite::suite().test_runtime("tests", "suite::suite()", "runtime()", "Approximation::SuperApproximate"); +} + diff --git a/test-rt/test-f16/src/lib.rs b/test-rt/test-f16/src/lib.rs new file mode 100644 index 0000000000..cbec63a8db --- /dev/null +++ b/test-rt/test-f16/src/lib.rs @@ -0,0 +1,125 @@ +#![cfg(test)] + +#[path = "../suite.rs"] +mod suite; + +mod run_as_f16 { + use super::*; + use tract_core::internal::*; + use tract_core::model::translator::Translate; + + #[derive(Debug)] + struct RunAsF16; + + impl Runtime for RunAsF16 { + fn name(&self) -> Cow { + "run_as_f16".into() + } + + fn prepare(&self, model: TypedModel) -> TractResult> { + let outputs_dt = + model.outputs.iter().map(|o| model.outlet_fact(*o).unwrap().datum_type).collect(); + let tr = tract_core::floats::FloatPrecisionTranslator::::default(); + let model = tr.translate_model(&model)?; + Ok(Box::new(RunnableAsF16( + Arc::new(model.into_optimized()?.into_runnable()?), + outputs_dt, + ))) + } + } + + #[derive(Debug)] + pub struct RunnableAsF16(pub Arc>, pub TVec); + + impl Runnable for RunnableAsF16 { + fn spawn(&self) -> TractResult> { + Ok(Box::new(StateAsF16(SimpleState::new(self.0.clone())?, self.1.clone()))) + } + } + + #[derive(Debug)] + struct StateAsF16( + TypedSimpleState>>, + TVec, + ); + + impl State for StateAsF16 { + fn run(&mut self, inputs: TVec) -> TractResult> { + let inputs = inputs + .into_iter() + .map(|v| { + if v.datum_type() == DatumType::F32 { + v.into_tensor() + .cast_to_dt(f16::datum_type()) + .unwrap() + .into_owned() + .into_tvalue() + } else { + v + } + }) + .collect(); + let outputs = self.0.run(inputs)?; + Ok(outputs + .into_iter() + .zip(self.1.iter()) + .map(|(t, dt)| t.into_tensor().cast_to_dt(*dt).unwrap().into_owned().into_tvalue()) + .collect()) + } + } + + fn runtime() -> &'static RunAsF16 { + static RUN_AS_F16: RunAsF16 = RunAsF16; + &RUN_AS_F16 + } + + include!(concat!(env!("OUT_DIR"), "/tests/tests.rs")); +} + +mod nnef_f16 { + use std::fmt::Debug; + + use super::run_as_f16::RunnableAsF16; + use super::*; + use tract_core::internal::*; + use tract_core::model::translator::Translate; + use tract_nnef::internal::Nnef; + use tract_onnx_opl::WithOnnx; + + struct NnefF16(Nnef); + + impl Debug for NnefF16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "NnefF16") + } + } + + impl Runtime for NnefF16 { + fn name(&self) -> Cow { + "nnef_f16".into() + } + + fn prepare(&self, model: TypedModel) -> TractResult> { + let outputs_dt = + model.outputs.iter().map(|o| model.outlet_fact(*o).unwrap().datum_type).collect(); + let tr = tract_core::floats::FloatPrecisionTranslator::::default(); + let model = tr.translate_model(&model)?; + let mut buf = vec![]; + self.0.write_to_tar(&model, &mut buf)?; + let reloaded = self.0.model_for_read(&mut &*buf)?; + Ok(Box::new(RunnableAsF16( + Arc::new(reloaded.into_optimized()?.into_runnable()?), + outputs_dt, + ))) + } + } + + fn runtime() -> &'static NnefF16 { + lazy_static::lazy_static! { + static ref RT: NnefF16 = NnefF16(tract_nnef::nnef().with_onnx()); + }; + &RT + } + + include!(concat!(env!("OUT_DIR"), "/tests/tests.rs")); +} diff --git a/test-rt/test-f16/src/main.rs b/test-rt/test-f16/src/main.rs new file mode 100644 index 0000000000..e7a11a969c --- /dev/null +++ b/test-rt/test-f16/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello, world!"); +} diff --git a/test-rt/test-f16/suite.rs b/test-rt/test-f16/suite.rs new file mode 100644 index 0000000000..f98cbe15a3 --- /dev/null +++ b/test-rt/test-f16/suite.rs @@ -0,0 +1,94 @@ +use infra::Test; +use suite_unit::conv_q::{QConvProblem, QConvProblemParams}; + +pub fn suite() -> &'static infra::TestSuite { + lazy_static::lazy_static! { + static ref SUITE: infra::TestSuite = mk_suite(); + }; + &SUITE +} + +#[allow(clippy::needless_update)] +fn mk_suite() -> infra::TestSuite { + let mut onnx = suite_onnx::suite().clone(); + onnx.ignore(&ignore_onnx); + let mut unit = suite_unit::suite().unwrap().clone(); + unit.ignore_case(&ignore_unit); + unit.get_sub_mut("conv_q").add_arbitrary_with_filter::( + "proptest", + QConvProblemParams::default(), + compatible_conv_q, + ); + infra::TestSuite::default().with("onnx", onnx).with("unit", unit) +} + +fn ignore_unit(t: &[String], case: &dyn Test) -> bool { + if let Some(qcp) = case.downcast_ref::() { + if !compatible_conv_q(qcp) { + return true; + } + } + let [section, _unit] = t else { return false }; + ["q_flavours"].contains(&&**section) +} + +fn ignore_onnx(t: &[String]) -> bool { + r#" +test_averagepool_2d_ceil +test_averagepool_2d_pads_count_include_pad +test_averagepool_2d_precomputed_pads_count_include_pad +test_averagepool_2d_same_lower +test_cast_STRING_to_FLOAT +test_castlike_STRING_to_FLOAT_expanded +test_constantlike_ones_with_input +test_constantlike_threes_with_shape_and_dtype +test_constantlike_zeros_without_input_dtype +test_cumsum_1d_exclusive +test_cumsum_1d_reverse_exclusive +test_cumsum_2d +test_dequantizelinear +test_dropout_random +test_dynamicquantizelinear +test_dynamicquantizelinear_max_adjusted +test_dynamicquantizelinear_min_adjusted +test_gemm_broadcast +test_gemm_nobroadcast +test_maxpool_2d_ceil +test_maxpool_2d_same_lower +test_maxpool_with_argmax_2d_precomputed_pads +test_mod_broadcast +test_mod_int64_fmod +test_mod_mixed_sign_float16 +test_mod_mixed_sign_float32 +test_mod_mixed_sign_float64 +test_mod_mixed_sign_int16 +test_mod_mixed_sign_int32 +test_mod_mixed_sign_int64 +test_mod_mixed_sign_int8 +test_mod_uint16 +test_mod_uint32 +test_mod_uint64 +test_mod_uint8 +test_matmulinteger +test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded +test_nonzero_example +test_quantizelinear +test_qlinearmatmul_2D +test_qlinearmatmul_3D +test_reduce_prod_default_axes_keepdims_example +test_reshape_reordered_dims +test_resize_upsample_scales_linear_align_corners +test_resize_downsample_scales_linear +test_unsqueeze +"# + .trim() + .lines() + .any(|s| t.last().unwrap() == s.trim()) + || t.last().unwrap().starts_with("test_logsoftmax_large_number") + || t.last().unwrap().starts_with("test_softmax_large_number") + || t.last().unwrap().starts_with("test_resize") +} + +fn compatible_conv_q(qcp: &QConvProblem) -> bool { + qcp.qp.iter().all(|t| t.len() == 1) +} diff --git a/tflite/src/ops/nn.rs b/tflite/src/ops/nn.rs index 05abed26ba..2d9d4ccb1d 100644 --- a/tflite/src/ops/nn.rs +++ b/tflite/src/ops/nn.rs @@ -139,7 +139,8 @@ fn de_softmax(op: &mut DeserOp) -> TractResult> { let input = args_1!(op.facts()?); let options = builtin!(op, builtin_options_as_softmax_options); ensure!(options.beta() == 1.0); - let softmax = core::nn::Softmax { axes: tvec!(input.rank() - 1), output_dt: input.datum_type }; + let quant_output_dt = Some(input.datum_type).filter(|dt| !dt.is_float()); + let softmax = core::nn::Softmax { axes: tvec!(input.rank() - 1), quant_output_dt }; op.ctx.target.wire_node(op.prefix, softmax, op.inputs) }