Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test rt f16 #1287

Merged
merged 8 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ members = [
"test-rt/infra",
"test-rt/suite-unit",
"test-rt/suite-onnx",
"test-rt/test-f16",
"test-rt/test-unit-core",
"test-rt/test-onnx-core",
"test-rt/test-nnef-cycle",
Expand Down
6 changes: 3 additions & 3 deletions core/src/ops/cnn/conv/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -819,11 +819,11 @@ impl TypedOp for Conv {
}
if self.q_params.is_some() {
ensure!(inputs[3].datum_type == i32::datum_type());
ensure!(inputs[4].datum_type == f32::datum_type());
ensure!(inputs[4].datum_type.is_float());
ensure!(inputs[5].datum_type == i32::datum_type());
ensure!(inputs[6].datum_type == f32::datum_type());
ensure!(inputs[6].datum_type.is_float());
ensure!(inputs[7].datum_type == i32::datum_type());
ensure!(inputs[8].datum_type == f32::datum_type());
ensure!(inputs[8].datum_type.is_float());
}
ensure!(self.pool_spec.rank() + 2 == inputs[1].rank());
if self.pool_spec.data_format.shape(&*inputs[0].shape)?.c()
Expand Down
64 changes: 46 additions & 18 deletions core/src/ops/fft.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
use crate::internal::*;
use num_complex::Complex;
use rustfft::num_traits::{Float, FromPrimitive};
use rustfft::{FftDirection, FftNum};
use tract_data::itertools::Itertools;
use tract_ndarray::Axis;
use num_complex::Complex;

#[derive(Clone, Debug, Hash)]
pub struct Fft {
pub axis: usize,
pub inverse: bool,
}



impl Fft {
fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
&self,
tensor: &mut Tensor,
) -> TractResult<()>
{
) -> TractResult<()> {
let mut iterator_shape: TVec<usize> = tensor.shape().into();
iterator_shape.pop(); // last dim is [re, im]
iterator_shape[self.axis] = 1;
Expand All @@ -30,16 +27,20 @@ impl Fft {
for coords in tract_ndarray::indices(&*iterator_shape) {
v.clear();
let mut slice = array.slice_each_axis_mut(|ax| {
if ax.axis.index() == self.axis || ax.stride == 1 { // ax.stride == 1 => last dim
if ax.axis.index() == self.axis || ax.stride == 1 {
// ax.stride == 1 => last dim
(..).into()
} else {
let c = coords[ax.axis.index()] as isize;
(c..=c).into()
}
});
v.extend(slice.iter().tuples().map(|(r,i)| Complex::new(*r,*i)));
v.extend(slice.iter().tuples().map(|(r, i)| Complex::new(*r, *i)));
fft.process(&mut v);
slice.iter_mut().zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter())).for_each(|(s, v)| *s = v);
slice
.iter_mut()
.zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
.for_each(|(s, v)| *s = v);
}
Ok(())
}
Expand All @@ -65,6 +66,11 @@ impl EvalOp for Fft {
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let mut tensor = args_1!(inputs).into_tensor();
match tensor.datum_type() {
DatumType::F16 => {
let mut temp = tensor.cast_to::<f32>()?.into_owned();
self.eval_t::<f32>(&mut temp)?;
tensor = temp.cast_to::<f16>()?.into_owned();
}
DatumType::F32 => self.eval_t::<f32>(&mut tensor)?,
DatumType::F64 => self.eval_t::<f64>(&mut tensor)?,
_ => bail!("FFT not implemented for type {:?}", tensor.datum_type()),
Expand All @@ -75,8 +81,14 @@ impl EvalOp for Fft {

impl TypedOp for Fft {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
anyhow::ensure!(inputs[0].rank() >= 2, "Expect rank 2 (one for fft dimension, one for complex dimension");
anyhow::ensure!(inputs[0].shape.last().unwrap() == &2.to_dim(), "Fft operators expect inner (last) dimension to be 2 for real and imaginary part");
anyhow::ensure!(
inputs[0].rank() >= 2,
"Expect rank 2 (one for fft dimension, one for complex dimension"
);
anyhow::ensure!(
inputs[0].shape.last().unwrap() == &2.to_dim(),
"Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
);
Ok(tvec!(inputs[0].without_value()))
}

Expand All @@ -91,14 +103,11 @@ pub struct Stft {
pub window: Option<Arc<Tensor>>,
}



impl Stft {
fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
&self,
input: &Tensor,
) -> TractResult<Tensor>
{
) -> TractResult<Tensor> {
let mut iterator_shape: TVec<usize> = input.shape().into();
iterator_shape.pop(); // [re,im]
iterator_shape[self.axis] = 1;
Expand Down Expand Up @@ -135,10 +144,19 @@ impl Stft {
});
for f in 0..frames {
v.clear();
v.extend(islice.iter().tuples().skip(self.stride * f).take(self.frame).map(|(re,im)| Complex::new(*re, *im)));
v.extend(
islice
.iter()
.tuples()
.skip(self.stride * f)
.take(self.frame)
.map(|(re, im)| Complex::new(*re, *im)),
);
if let Some(win) = &self.window {
let win = win.as_slice::<T>()?;
v.iter_mut().zip(win.iter()).for_each(|(v, w)| *v = *v * Complex::new(*w, T::zero()));
v.iter_mut()
.zip(win.iter())
.for_each(|(v, w)| *v = *v * Complex::new(*w, T::zero()));
}
fft.process(&mut v);
oslice
Expand Down Expand Up @@ -168,6 +186,10 @@ impl EvalOp for Stft {
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let input = args_1!(inputs);
let output = match input.datum_type() {
DatumType::F16 => {
let temp = input.cast_to::<f32>()?;
self.eval_t::<f32>(&temp)?.cast_to::<f16>()?.into_owned()
}
DatumType::F32 => self.eval_t::<f32>(&input)?,
DatumType::F64 => self.eval_t::<f64>(&input)?,
_ => bail!("FFT not implemented for type {:?}", input.datum_type()),
Expand All @@ -178,8 +200,14 @@ impl EvalOp for Stft {

impl TypedOp for Stft {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
anyhow::ensure!(inputs[0].rank() >= 2, "Expect rank 2 (one for fft dimension, one for complex dimension");
anyhow::ensure!(inputs[0].shape.last().unwrap() == &2.to_dim(), "Fft operators expect inner (last) dimension to be 2 for real and imaginary part");
anyhow::ensure!(
inputs[0].rank() >= 2,
"Expect rank 2 (one for fft dimension, one for complex dimension"
);
anyhow::ensure!(
inputs[0].shape.last().unwrap() == &2.to_dim(),
"Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
);
let mut shape = inputs[0].shape.to_tvec();
let frames = (inputs[0].shape[self.axis].clone() - self.frame) / self.stride + 1;
shape[self.axis] = frames;
Expand Down
18 changes: 15 additions & 3 deletions core/src/ops/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use num_traits::bounds::Bounded;
use num_traits::int::PrimInt;
use num_traits::{Float, Zero};
use tract_data::internal::ClampCast;
use tract_data::itertools::Itertools;
pub use tract_data::prelude::round_ties_to_even;
use tract_linalg::{ScaleShiftAndRound, Scaler};
use tract_num_traits::AsPrimitive;
Expand Down Expand Up @@ -182,7 +183,7 @@ bin_to_super_type!(max, Max, linalg:Max,

bin_to_super_type!(pow, Pow,
declutter: declutter_pow,
[f32, f64] => |c,a,b| *c = a.powf(*b),
[f16, f32, f64] => |c,a,b| *c = a.powf(*b),
[i32, i64] => |c,a,b| *c = a.pow(*b as u32));

bin_to_super_type!(shift_left, ShiftLeft,
Expand Down Expand Up @@ -493,9 +494,14 @@ element_wise!(q_scale, QScale{scaler: Scaler},[i32] => |op, xs| {
Ok(())
});

element_wise!(round_half_to_even, RoundHalfToEven,[ f32] => |_, xs| {
element_wise!(round_half_to_even, RoundHalfToEven,
[f32] => |_, xs| {
xs.iter_mut().for_each(|x| *x = round_ties_to_even(*x));
Ok(())
},
[f16] => |_, xs| {
xs.iter_mut().for_each(|x| *x = f16::from_f32(round_ties_to_even(x.to_f32())));
Ok(())
};
q: [i8, u8, i32] => round_ties_to_even);

Expand Down Expand Up @@ -556,7 +562,13 @@ element_wise!(tanh, Tanh,
);

element_wise!(erf, Erf,
[f32] => |_, xs| { (tract_linalg::ops().erf_f32)().run(xs) };
[f32] => |_, xs| { (tract_linalg::ops().erf_f32)().run(xs) },
[f16] => |_, xs| {
let mut f32s = xs.iter().map(|x| x.to_f32()).collect_vec();
(tract_linalg::ops().erf_f32)().run(&mut f32s)?;
xs.iter_mut().zip(f32s.into_iter()).for_each(|(x, f)| *x = f16::from_f32(f));
Ok(())
};
cost: |dt| {tvec!((Cost::FMA(dt), 11), (Cost::Div(dt), 1))}
);

Expand Down
1 change: 1 addition & 0 deletions core/src/ops/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ element_wise!(sigmoid, Sigmoid,
);

element_wise!(hard_swish, HardSwish,
[f16] => |_, xs| { xs.iter_mut().for_each(|x| *x = *x * f16::from_f32(0.0).max(f16::from_f32(1.0).min(f16::from_f32(1. / 6.) * *x + f16::from_f32(0.5)))); Ok(()) },
[f32] => |_, xs| { xs.iter_mut().for_each(|x| *x = *x * 0f32.max(1f32.min((1. / 6.) * *x + 0.5))); Ok(()) }
);

Expand Down
2 changes: 2 additions & 0 deletions core/src/ops/nn/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ macro_rules! r {
DatumType::I16 => $($path)::*::<i16,_,_,_>($($args),*),
DatumType::I32 => $($path)::*::<i32,_,_,_>($($args),*),
DatumType::I64 => $($path)::*::<i64,_,_,_>($($args),*),
DatumType::F16 => $($path)::*::<f16,_,_,_>($($args),*),
DatumType::F32 => $($path)::*::<f32,_,_,_>($($args),*),
DatumType::F64 => $($path)::*::<f64,_,_,_>($($args),*),
DatumType::QI8(_) => $($path)::*::<i8,_,_,_>($($args),*),
Expand All @@ -30,6 +31,7 @@ macro_rules! r {
DatumType::I16 => $($path)::*::<i16,_,_,_>($($args),*),
DatumType::I32 => $($path)::*::<i32,_,_,_>($($args),*),
DatumType::I64 => $($path)::*::<i64,_,_,_>($($args),*),
DatumType::F16 => $($path)::*::<f16,_,_,_>($($args),*),
DatumType::F32 => $($path)::*::<f32,_,_,_>($($args),*),
DatumType::F64 => $($path)::*::<f64,_,_,_>($($args),*),
DatumType::QI8(_) => $($q_path)::*::<i8,_,_,_>($($q_args),*),
Expand Down
41 changes: 21 additions & 20 deletions core/src/ops/nn/softmax/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use ndarray::prelude::*;
#[derive(Debug, Clone, new, Hash)]
pub struct Softmax {
pub axes: TVec<usize>,
pub output_dt: DatumType,
pub quant_output_dt: Option<DatumType>,
}

impl Op for Softmax {
Expand All @@ -35,27 +35,25 @@ impl TypedOp for Softmax {
let dt = inputs[0].datum_type;
if dt.is_float() {
ensure!(
dt == self.output_dt,
"Softmax input {:?} and output {:?} types in float case should be equal",
dt,
self.output_dt
self.quant_output_dt.is_none(),
"Float softmax should not have quant_output_dt, have {:?}",
self.quant_output_dt
);
} else if dt.is_quantized() {
ensure!(
self.output_dt.is_quantized(),
"Quantized softmax must have input {:?} and output {:?} quantized ",
dt,
self.output_dt
self.quant_output_dt.map(|q| q.is_quantized()).unwrap_or(false),
"Quantized softmax should have a quantized output type (got {:?})",
self.quant_output_dt
);
} else {
bail!(
"Unsupported datum type in softmax: input type {:?}, output type {:?}",
dt,
self.output_dt
self.quant_output_dt
);
}

let fact = self.output_dt.fact(inputs[0].shape.clone());
let fact = self.quant_output_dt.unwrap_or(dt).fact(inputs[0].shape.clone());
Ok(tvec!(fact))
}

Expand All @@ -80,7 +78,7 @@ impl TypedOp for Softmax {
Ok(Some(AxisChangeConsequence::new(
model,
node,
Some(Box::new(Softmax { axes, output_dt: self.output_dt })),
Some(Box::new(Softmax { axes, ..self.clone() })),
change,
)))
} else {
Expand All @@ -104,8 +102,8 @@ impl EvalOp for Softmax {
DatumType::F64 => self.eval_t::<f64>(input)?,
DatumType::F32 => self.eval_t::<f32>(input)?,
DatumType::F16 => self.eval_t::<f16>(input)?,
DatumType::QI8(_) | DatumType::QU8(_) => self.eval_quant_t(input)?,
dt => bail!("Unsupported type {:?}", dt),
DatumType::QI8(_) | DatumType::QU8(_) => self.eval_quant(input)?,
dt => bail!("Unsupported type {dt:?}"),
};
Ok(output)
}
Expand Down Expand Up @@ -139,8 +137,10 @@ impl Softmax {
Ok(tvec!(output.into_tvalue()))
}

fn eval_quant_t(&self, input: TValue) -> TractResult<TVec<TValue>> {
fn eval_quant(&self, input: TValue) -> TractResult<TVec<TValue>> {
let mut iterating_shape: TVec<usize> = input.shape().into();
let output_dt =
self.quant_output_dt.context("Quandized softmax eval with no output type")?;

for i in 0..iterating_shape.len() {
if self.axes.contains(&i) {
Expand All @@ -150,9 +150,9 @@ impl Softmax {

// All operations will be done in u8, we will cast the result appropriately afterward.
let src_is_signed = input.datum_type().is_signed();
let out_is_signed = self.output_dt.is_signed();
let out_is_signed = output_dt.is_signed();
let in_qp = input.datum_type().qparams().unwrap(); // Checked as we are in the quant case
let out_qp = self.output_dt.qparams().unwrap(); // Checked as we are in the quant case
let out_qp = output_dt.qparams().unwrap(); // Checked as we are in the quant case
let mut output = unsafe { input.into_tensor().into_array_unchecked::<u8>() };

for it_coords in tract_ndarray::indices(&*iterating_shape) {
Expand All @@ -166,7 +166,7 @@ impl Softmax {
}

let mut output_tensor = output.into_tensor();
unsafe { output_tensor.set_datum_type(self.output_dt) };
unsafe { output_tensor.set_datum_type(output_dt) };
Ok(tvec!(output_tensor.into_tvalue()))
}
}
Expand Down Expand Up @@ -327,7 +327,8 @@ mod test {
impl SoftmaxProblem {
fn check(&self) -> Result<()> {
let inputs = tvec!(self.data.clone().into_tvalue());
let softmax = Softmax { axes: self.axes.clone(), output_dt: self.output_dt };
let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float());
let softmax = Softmax { axes: self.axes.clone(), quant_output_dt };

// Compute quantized output
let result = softmax.eval(inputs)?;
Expand All @@ -337,7 +338,7 @@ mod test {
// Compute reference output
let input_float = self.data.cast_to::<f32>()?;
let inputs_float = tvec!(input_float.into_owned().into_tvalue());
let softmax_float = Softmax { axes: self.axes.clone(), output_dt: DatumType::F32 };
let softmax_float = Softmax { axes: self.axes.clone(), quant_output_dt: None };
let reference_float = softmax_float.eval(inputs_float)?;
let reference_array = args_1!(reference_float);
let reference = reference_array.to_array_view::<f32>()?;
Expand Down
Loading