Skip to content

Commit

Permalink
generic topk
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Dec 7, 2023
1 parent e6ad6d0 commit b06faa7
Showing 1 changed file with 52 additions and 23 deletions.
75 changes: 52 additions & 23 deletions core/src/ops/array/topk.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::cmp::Ordering;

use tract_data::itertools::Itertools;
use tract_ndarray::{Axis, Dimension};
use tract_ndarray::{ArrayViewMutD, Axis, Dimension};

use crate::internal::*;

Expand Down Expand Up @@ -28,39 +30,66 @@ impl EvalOp for Topk {
let mut output_shape: TVec<usize> = input.shape().into();
let k = k.cast_to_scalar::<i64>()? as usize;
output_shape[self.axis] = k;
let mut output_values = Tensor::zero::<f32>(&output_shape)?;
let dt = input.datum_type();
let mut output_values = Tensor::zero_dt(dt, &output_shape)?;
let mut output_indices = Tensor::zero::<i64>(&output_shape)?;
let mut iterating_shape = output_shape.clone();
iterating_shape[self.axis] = 1;
let mut output_values_view = output_values.to_array_view_mut::<f32>()?;
let mut output_indices_view = output_indices.to_array_view_mut::<i64>()?;
for coords in tract_ndarray::indices(&*iterating_shape) {
let mut coords: TVec<usize> = coords.as_array_view().as_slice().unwrap().into();
let mut view = input.to_array_view::<f32>()?;
for (ix, x) in coords.iter().enumerate() {
if ix != self.axis {
view.index_axis_inplace(Axis(ix), *x);
}
}
for (ix, (argmax, max)) in view
.iter()
.cloned()
.map(|x| if self.largest { -x } else { x })
.enumerate()
.sorted_by(|a, b| a.1.total_cmp(&b.1))
.take(k)
.map(|(pos, val)| if self.largest { (pos, -val) } else { (pos, val) })
.enumerate()
{
coords[self.axis] = ix;
output_values_view[&*coords] = max;
output_indices_view[&*coords] = argmax as i64;
}
dispatch_numbers!(Self::inner_loop_t(dt)(
self,
&mut coords,
&input,
&mut output_values,
&mut output_indices_view,
k
))?;
}
Ok(tvec!(output_values.into_tvalue(), output_indices.into_tvalue()))
}
}

impl Topk {
fn inner_loop_t<T: Datum + PartialOrd>(
&self,
coords: &mut [usize],
input: &Tensor,
output_values: &mut Tensor,
output_indices_view: &mut ArrayViewMutD<i64>,
k: usize,
) -> TractResult<()> {
let mut output_values_view = output_values.to_array_view_mut::<T>()?;
let mut view = input.to_array_view::<T>()?;
for (ix, x) in coords.iter().enumerate() {
if ix != self.axis {
view.index_axis_inplace(Axis(ix), *x);
}
}
for (ix, (argmax, max)) in view
.iter()
.cloned()
.enumerate()
.sorted_by(|a, b| {
let ord = { a.1.partial_cmp(&b.1).unwrap_or(Ordering::Less) };
if self.largest {
ord.reverse()
} else {
ord
}
})
.take(k)
.enumerate()
{
coords[self.axis] = ix;
output_values_view[&*coords] = max;
output_indices_view[&*coords] = argmax as i64;
}
Ok(())
}
}

impl TypedOp for Topk {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let mut fact_values = inputs[0].without_value();
Expand Down

0 comments on commit b06faa7

Please sign in to comment.