diff --git a/core/src/ops/fft.rs b/core/src/ops/fft.rs index da62849654..1fa697e0b7 100644 --- a/core/src/ops/fft.rs +++ b/core/src/ops/fft.rs @@ -1,9 +1,9 @@ use crate::internal::*; +use num_complex::Complex; use rustfft::num_traits::{Float, FromPrimitive}; use rustfft::{FftDirection, FftNum}; use tract_data::itertools::Itertools; use tract_ndarray::Axis; -use num_complex::Complex; #[derive(Clone, Debug, Hash)] pub struct Fft { @@ -11,14 +11,11 @@ pub struct Fft { pub inverse: bool, } - - impl Fft { fn eval_t( &self, tensor: &mut Tensor, - ) -> TractResult<()> - { + ) -> TractResult<()> { let mut iterator_shape: TVec = tensor.shape().into(); iterator_shape.pop(); // last dim is [re, im] iterator_shape[self.axis] = 1; @@ -30,16 +27,20 @@ impl Fft { for coords in tract_ndarray::indices(&*iterator_shape) { v.clear(); let mut slice = array.slice_each_axis_mut(|ax| { - if ax.axis.index() == self.axis || ax.stride == 1 { // ax.stride == 1 => last dim + if ax.axis.index() == self.axis || ax.stride == 1 { + // ax.stride == 1 => last dim (..).into() } else { let c = coords[ax.axis.index()] as isize; (c..=c).into() } }); - v.extend(slice.iter().tuples().map(|(r,i)| Complex::new(*r,*i))); + v.extend(slice.iter().tuples().map(|(r, i)| Complex::new(*r, *i))); fft.process(&mut v); - slice.iter_mut().zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter())).for_each(|(s, v)| *s = v); + slice + .iter_mut() + .zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter())) + .for_each(|(s, v)| *s = v); } Ok(()) } @@ -80,8 +81,14 @@ impl EvalOp for Fft { impl TypedOp for Fft { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - anyhow::ensure!(inputs[0].rank() >= 2, "Expect rank 2 (one for fft dimension, one for complex dimension"); - anyhow::ensure!(inputs[0].shape.last().unwrap() == &2.to_dim(), "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"); + anyhow::ensure!( + inputs[0].rank() >= 2, + "Expect rank 2 (one for fft dimension, one for complex dimension" + ); + anyhow::ensure!( + inputs[0].shape.last().unwrap() == &2.to_dim(), + "Fft operators expect inner (last) dimension to be 2 for real and imaginary part" + ); Ok(tvec!(inputs[0].without_value())) } @@ -96,14 +103,11 @@ pub struct Stft { pub window: Option>, } - - impl Stft { fn eval_t( &self, input: &Tensor, - ) -> TractResult - { + ) -> TractResult { let mut iterator_shape: TVec = input.shape().into(); iterator_shape.pop(); // [re,im] iterator_shape[self.axis] = 1; @@ -140,10 +144,19 @@ impl Stft { }); for f in 0..frames { v.clear(); - v.extend(islice.iter().tuples().skip(self.stride * f).take(self.frame).map(|(re,im)| Complex::new(*re, *im))); + v.extend( + islice + .iter() + .tuples() + .skip(self.stride * f) + .take(self.frame) + .map(|(re, im)| Complex::new(*re, *im)), + ); if let Some(win) = &self.window { let win = win.as_slice::()?; - v.iter_mut().zip(win.iter()).for_each(|(v, w)| *v = *v * Complex::new(*w, T::zero())); + v.iter_mut() + .zip(win.iter()) + .for_each(|(v, w)| *v = *v * Complex::new(*w, T::zero())); } fft.process(&mut v); oslice @@ -174,8 +187,8 @@ impl EvalOp for Stft { let input = args_1!(inputs); let output = match input.datum_type() { DatumType::F16 => { - let mut temp = input.cast_to::()?.into_owned(); - self.eval_t::(&mut temp)?.cast_to::()?.into_owned() + let temp = input.cast_to::()?; + self.eval_t::(&temp)?.cast_to::()?.into_owned() } DatumType::F32 => self.eval_t::(&input)?, DatumType::F64 => self.eval_t::(&input)?, @@ -187,8 +200,14 @@ impl EvalOp for Stft { impl TypedOp for Stft { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - anyhow::ensure!(inputs[0].rank() >= 2, "Expect rank 2 (one for fft dimension, one for complex dimension"); - anyhow::ensure!(inputs[0].shape.last().unwrap() == &2.to_dim(), "Fft operators expect inner (last) dimension to be 2 for real and imaginary part"); + anyhow::ensure!( + inputs[0].rank() >= 2, + "Expect rank 2 (one for fft dimension, one for complex dimension" + ); + anyhow::ensure!( + inputs[0].shape.last().unwrap() == &2.to_dim(), + "Fft operators expect inner (last) dimension to be 2 for real and imaginary part" + ); let mut shape = inputs[0].shape.to_tvec(); let frames = (inputs[0].shape[self.axis].clone() - self.frame) / self.stride + 1; shape[self.axis] = frames;