diff --git a/core/src/ops/array/topk.rs b/core/src/ops/array/topk.rs index 8e37e72920..2982615c88 100644 --- a/core/src/ops/array/topk.rs +++ b/core/src/ops/array/topk.rs @@ -1,5 +1,7 @@ +use std::cmp::Ordering; + use tract_data::itertools::Itertools; -use tract_ndarray::{Axis, Dimension}; +use tract_ndarray::{ArrayViewMutD, Axis, Dimension}; use crate::internal::*; @@ -28,39 +30,66 @@ impl EvalOp for Topk { let mut output_shape: TVec = input.shape().into(); let k = k.cast_to_scalar::()? as usize; output_shape[self.axis] = k; - let mut output_values = Tensor::zero::(&output_shape)?; + let dt = input.datum_type(); + let mut output_values = Tensor::zero_dt(dt, &output_shape)?; let mut output_indices = Tensor::zero::(&output_shape)?; let mut iterating_shape = output_shape.clone(); iterating_shape[self.axis] = 1; - let mut output_values_view = output_values.to_array_view_mut::()?; let mut output_indices_view = output_indices.to_array_view_mut::()?; for coords in tract_ndarray::indices(&*iterating_shape) { let mut coords: TVec = coords.as_array_view().as_slice().unwrap().into(); - let mut view = input.to_array_view::()?; - for (ix, x) in coords.iter().enumerate() { - if ix != self.axis { - view.index_axis_inplace(Axis(ix), *x); - } - } - for (ix, (argmax, max)) in view - .iter() - .cloned() - .map(|x| if self.largest { -x } else { x }) - .enumerate() - .sorted_by(|a, b| a.1.total_cmp(&b.1)) - .take(k) - .map(|(pos, val)| if self.largest { (pos, -val) } else { (pos, val) }) - .enumerate() - { - coords[self.axis] = ix; - output_values_view[&*coords] = max; - output_indices_view[&*coords] = argmax as i64; - } + dispatch_numbers!(Self::inner_loop_t(dt)( + self, + &mut coords, + &input, + &mut output_values, + &mut output_indices_view, + k + ))?; } Ok(tvec!(output_values.into_tvalue(), output_indices.into_tvalue())) } } +impl Topk { + fn inner_loop_t( + &self, + coords: &mut [usize], + input: &Tensor, + output_values: &mut Tensor, + output_indices_view: &mut ArrayViewMutD, + k: usize, + ) -> TractResult<()> { + let mut output_values_view = output_values.to_array_view_mut::()?; + let mut view = input.to_array_view::()?; + for (ix, x) in coords.iter().enumerate() { + if ix != self.axis { + view.index_axis_inplace(Axis(ix), *x); + } + } + for (ix, (argmax, max)) in view + .iter() + .cloned() + .enumerate() + .sorted_by(|a, b| { + let ord = { a.1.partial_cmp(&b.1).unwrap_or(Ordering::Less) }; + if self.largest { + ord.reverse() + } else { + ord + } + }) + .take(k) + .enumerate() + { + coords[self.axis] = ix; + output_values_view[&*coords] = max; + output_indices_view[&*coords] = argmax as i64; + } + Ok(()) + } +} + impl TypedOp for Topk { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { let mut fact_values = inputs[0].without_value();