Skip to content

Commit

Permalink
clippy
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Dec 14, 2023
1 parent 6a06d6b commit e877fd2
Showing 1 changed file with 39 additions and 20 deletions.
59 changes: 39 additions & 20 deletions core/src/ops/fft.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
use crate::internal::*;
use num_complex::Complex;
use rustfft::num_traits::{Float, FromPrimitive};
use rustfft::{FftDirection, FftNum};
use tract_data::itertools::Itertools;
use tract_ndarray::Axis;
use num_complex::Complex;

#[derive(Clone, Debug, Hash)]
pub struct Fft {
pub axis: usize,
pub inverse: bool,
}



impl Fft {
fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
&self,
tensor: &mut Tensor,
) -> TractResult<()>
{
) -> TractResult<()> {
let mut iterator_shape: TVec<usize> = tensor.shape().into();
iterator_shape.pop(); // last dim is [re, im]
iterator_shape[self.axis] = 1;
Expand All @@ -30,16 +27,20 @@ impl Fft {
for coords in tract_ndarray::indices(&*iterator_shape) {
v.clear();
let mut slice = array.slice_each_axis_mut(|ax| {
if ax.axis.index() == self.axis || ax.stride == 1 { // ax.stride == 1 => last dim
if ax.axis.index() == self.axis || ax.stride == 1 {
// ax.stride == 1 => last dim
(..).into()
} else {
let c = coords[ax.axis.index()] as isize;
(c..=c).into()
}
});
v.extend(slice.iter().tuples().map(|(r,i)| Complex::new(*r,*i)));
v.extend(slice.iter().tuples().map(|(r, i)| Complex::new(*r, *i)));
fft.process(&mut v);
slice.iter_mut().zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter())).for_each(|(s, v)| *s = v);
slice
.iter_mut()
.zip(v.iter().flat_map(|cmpl| [cmpl.re, cmpl.im].into_iter()))
.for_each(|(s, v)| *s = v);
}
Ok(())
}
Expand Down Expand Up @@ -80,8 +81,14 @@ impl EvalOp for Fft {

impl TypedOp for Fft {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
anyhow::ensure!(inputs[0].rank() >= 2, "Expect rank 2 (one for fft dimension, one for complex dimension");
anyhow::ensure!(inputs[0].shape.last().unwrap() == &2.to_dim(), "Fft operators expect inner (last) dimension to be 2 for real and imaginary part");
anyhow::ensure!(
inputs[0].rank() >= 2,
"Expect rank 2 (one for fft dimension, one for complex dimension"
);
anyhow::ensure!(
inputs[0].shape.last().unwrap() == &2.to_dim(),
"Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
);
Ok(tvec!(inputs[0].without_value()))
}

Expand All @@ -96,14 +103,11 @@ pub struct Stft {
pub window: Option<Arc<Tensor>>,
}



impl Stft {
fn eval_t<T: Datum + FftNum + FromPrimitive + Float>(
&self,
input: &Tensor,
) -> TractResult<Tensor>
{
) -> TractResult<Tensor> {
let mut iterator_shape: TVec<usize> = input.shape().into();
iterator_shape.pop(); // [re,im]
iterator_shape[self.axis] = 1;
Expand Down Expand Up @@ -140,10 +144,19 @@ impl Stft {
});
for f in 0..frames {
v.clear();
v.extend(islice.iter().tuples().skip(self.stride * f).take(self.frame).map(|(re,im)| Complex::new(*re, *im)));
v.extend(
islice
.iter()
.tuples()
.skip(self.stride * f)
.take(self.frame)
.map(|(re, im)| Complex::new(*re, *im)),
);
if let Some(win) = &self.window {
let win = win.as_slice::<T>()?;
v.iter_mut().zip(win.iter()).for_each(|(v, w)| *v = *v * Complex::new(*w, T::zero()));
v.iter_mut()
.zip(win.iter())
.for_each(|(v, w)| *v = *v * Complex::new(*w, T::zero()));
}
fft.process(&mut v);
oslice
Expand Down Expand Up @@ -174,8 +187,8 @@ impl EvalOp for Stft {
let input = args_1!(inputs);
let output = match input.datum_type() {
DatumType::F16 => {
let mut temp = input.cast_to::<f32>()?.into_owned();
self.eval_t::<f32>(&mut temp)?.cast_to::<f16>()?.into_owned()
let temp = input.cast_to::<f32>()?;
self.eval_t::<f32>(&temp)?.cast_to::<f16>()?.into_owned()
}
DatumType::F32 => self.eval_t::<f32>(&input)?,
DatumType::F64 => self.eval_t::<f64>(&input)?,
Expand All @@ -187,8 +200,14 @@ impl EvalOp for Stft {

impl TypedOp for Stft {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
anyhow::ensure!(inputs[0].rank() >= 2, "Expect rank 2 (one for fft dimension, one for complex dimension");
anyhow::ensure!(inputs[0].shape.last().unwrap() == &2.to_dim(), "Fft operators expect inner (last) dimension to be 2 for real and imaginary part");
anyhow::ensure!(
inputs[0].rank() >= 2,
"Expect rank 2 (one for fft dimension, one for complex dimension"
);
anyhow::ensure!(
inputs[0].shape.last().unwrap() == &2.to_dim(),
"Fft operators expect inner (last) dimension to be 2 for real and imaginary part"
);
let mut shape = inputs[0].shape.to_tvec();
let frames = (inputs[0].shape[self.axis].clone() - self.frame) / self.stride + 1;
shape[self.axis] = frames;
Expand Down

0 comments on commit e877fd2

Please sign in to comment.