Skip to content

Commit

Permalink
Merge pull request #411 from robertknight/arg-min-max-fast-path
Browse files Browse the repository at this point in the history
Add fast path for ArgMin / ArgMax when axis is contiguous
  • Loading branch information
robertknight authored Nov 16, 2024
2 parents d4a7808 + 90a8604 commit cbcfa9f
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 45 deletions.
36 changes: 35 additions & 1 deletion rten-tensor/src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,20 @@ pub struct Lane<'a, T> {
size: usize,
}

impl<'a, T> Lane<'a, T> {
/// Return the remaining part of the lane as a slice, if it is contiguous.
pub fn as_slice(&self) -> Option<&'a [T]> {
match self.stride {
1 => {
let remainder = self.data.slice(self.index..self.size);
// Safety: The stride is 1, so we know the lane is contiguous.
Some(unsafe { remainder.as_slice() })
}
_ => None,
}
}
}

impl<'a, T> Iterator for Lane<'a, T> {
type Item = &'a T;

Expand Down Expand Up @@ -1094,7 +1108,7 @@ pub fn for_each_mut<T, F: Fn(&mut T)>(mut view: TensorViewMut<T>, f: F) {
// tests on tensor methods.
#[cfg(test)]
mod tests {
use crate::{AsView, AxisChunks, AxisChunksMut, Lanes, LanesMut, Tensor};
use crate::{AsView, AxisChunks, AxisChunksMut, Lanes, LanesMut, NdTensor, Tensor};

#[test]
fn test_axis_chunks_empty() {
Expand Down Expand Up @@ -1129,6 +1143,26 @@ mod tests {
assert!(Lanes::new(x.view().view_ref(), 1).next().is_none());
}

#[test]
fn test_lane_as_slice() {
// Contiguous lane
let x = NdTensor::from([0, 1, 2]);
let mut lane = x.lanes(0).next().unwrap();
assert_eq!(lane.as_slice(), Some([0, 1, 2].as_slice()));
lane.next();
assert_eq!(lane.as_slice(), Some([1, 2].as_slice()));
lane.next();
lane.next();
assert_eq!(lane.as_slice(), Some([0i32; 0].as_slice()));
lane.next();
assert_eq!(lane.as_slice(), Some([0i32; 0].as_slice()));

// Non-contiguous lane
let x = NdTensor::from([[1i32, 2], [3, 4]]);
let lane = x.lanes(0).next().unwrap();
assert_eq!(lane.as_slice(), None);
}

#[test]
fn test_lanes_mut_empty() {
let mut x = Tensor::<i32>::zeros(&[5, 0]);
Expand Down
1 change: 1 addition & 0 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,7 @@ impl<'a, T, L: Clone + MutLayout> TensorBase<ViewData<'a, T>, L> {
///
/// See [`AsView::lanes`].
pub fn lanes(&self, dim: usize) -> Lanes<'a, T> {
assert!(dim < self.ndim());
Lanes::new(self.view_ref(), dim)
}

Expand Down
33 changes: 33 additions & 0 deletions src/number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,39 @@ impl Identities for i32 {
}
}

/// Test if a number is a float NaN ("Not a number") value.
pub trait IsNaN {
/// Return true if the current value is a NaN. See [`f32::is_nan`].
///
/// This is always false for integer types.
#[allow(clippy::wrong_self_convention)] // Match `f32::is_nan` etc.
fn is_nan(self) -> bool;
}

macro_rules! impl_isnan_float {
($type:ty) => {
impl IsNaN for $type {
fn is_nan(self) -> bool {
<$type>::is_nan(self)
}
}
};
}
macro_rules! impl_isnan_int {
($type:ty) => {
impl IsNaN for $type {
fn is_nan(self) -> bool {
false
}
}
};
}

impl_isnan_float!(f32);
impl_isnan_int!(i32);
impl_isnan_int!(i8);
impl_isnan_int!(u8);

/// Convert between a primitive type and an array of bytes in little-endian
/// order.
pub trait LeBytes {
Expand Down
9 changes: 6 additions & 3 deletions src/ops/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use rten_tensor::prelude::*;
use rten_tensor::{to_slice_items, NdTensorView, SliceItem, Tensor, TensorView, TensorViewMut};
use smallvec::SmallVec;

use crate::number::IsNaN;
use crate::ops::reduce::{cmp_nan_greater, cmp_nan_less};
use crate::ops::{
resolve_axis, resolve_index, Input, InputList, IntoOpResult, OpError, Operator, OutputList,
Expand Down Expand Up @@ -392,7 +393,9 @@ pub enum ScatterReduction {
Max,
}

fn scatter_reduce<T: Copy + PartialOrd + std::ops::Add<Output = T> + std::ops::Mul<Output = T>>(
fn scatter_reduce<
T: Copy + PartialOrd + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + IsNaN,
>(
current: T,
update: T,
reduction: Option<ScatterReduction>,
Expand All @@ -416,7 +419,7 @@ fn scatter_reduce<T: Copy + PartialOrd + std::ops::Add<Output = T> + std::ops::M
}

pub fn scatter_elements<
T: Copy + Default + PartialOrd + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
T: Copy + Default + PartialOrd + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + IsNaN,
>(
pool: &TensorPool,
data: TensorView<T>,
Expand Down Expand Up @@ -499,7 +502,7 @@ impl Operator for ScatterElements {
}

pub fn scatter_nd<
T: Copy + Default + PartialOrd + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
T: Copy + Default + PartialOrd + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + IsNaN,
>(
pool: &TensorPool,
data: TensorView<T>,
Expand Down
18 changes: 9 additions & 9 deletions src/ops/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::Debug;
use rten_tensor::prelude::*;
use rten_tensor::{MutLayout, NdTensorView, Storage, Tensor, TensorBase, TensorView};

use crate::number::{Identities, IsInt};
use crate::number::{Identities, IsInt, IsNaN};
use crate::ops::OpError;
use crate::ops::{
arg_max, div, matmul, mul, pad, reduce_l2, reduce_max, reduce_mean, reduce_min, reduce_sum,
Expand All @@ -22,7 +22,7 @@ pub trait Operators {

fn arg_max(&self, axis: isize, keep_dims: bool) -> Result<Tensor<i32>, OpError>
where
Self::Elem: Copy + PartialOrd;
Self::Elem: Copy + PartialOrd + IsNaN;

fn div(&self, other: TensorView<Self::Elem>) -> Result<Tensor<Self::Elem>, OpError>
where
Expand All @@ -44,15 +44,15 @@ pub trait Operators {
keep_dims: bool,
) -> Result<Tensor<Self::Elem>, OpError>
where
Self::Elem: Copy + PartialOrd;
Self::Elem: Copy + PartialOrd + IsNaN;

fn reduce_min(
&self,
axes: Option<&[i32]>,
keep_dims: bool,
) -> Result<Tensor<Self::Elem>, OpError>
where
Self::Elem: Copy + PartialOrd;
Self::Elem: Copy + PartialOrd + IsNaN;

fn reduce_sum(
&self,
Expand All @@ -78,7 +78,7 @@ pub trait Operators {
sorted: bool,
) -> Result<(Tensor<Self::Elem>, Tensor<i32>), OpError>
where
Self::Elem: Copy + Default + PartialOrd;
Self::Elem: Copy + Default + PartialOrd + IsNaN;
}

/// Trait which exposes ONNX operators as methods of tensors.
Expand Down Expand Up @@ -112,7 +112,7 @@ impl<T: Send, S: Storage<Elem = T>, L: MutLayout> Operators for TensorBase<S, L>

fn arg_max(&self, axis: isize, keep_dims: bool) -> Result<Tensor<i32>, OpError>
where
T: Copy + PartialOrd,
T: Copy + PartialOrd + IsNaN,
{
let view = self.as_dyn();
use_thread_pool(|| arg_max(&TensorPool::new(), view, axis, keep_dims))
Expand Down Expand Up @@ -142,15 +142,15 @@ impl<T: Send, S: Storage<Elem = T>, L: MutLayout> Operators for TensorBase<S, L>

fn reduce_max(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor<T>, OpError>
where
T: Copy + PartialOrd,
T: Copy + PartialOrd + IsNaN,
{
let view = self.as_dyn();
use_thread_pool(|| reduce_max(&TensorPool::new(), view, axes, keep_dims))
}

fn reduce_min(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor<T>, OpError>
where
T: Copy + PartialOrd,
T: Copy + PartialOrd + IsNaN,
{
let view = self.as_dyn();
use_thread_pool(|| reduce_min(&TensorPool::new(), view, axes, keep_dims))
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<T: Send, S: Storage<Elem = T>, L: MutLayout> Operators for TensorBase<S, L>
sorted: bool,
) -> Result<(Tensor<Self::Elem>, Tensor<i32>), OpError>
where
T: Copy + Default + PartialOrd,
T: Copy + Default + PartialOrd + IsNaN,
{
let view = self.as_dyn();
use_thread_pool(|| topk(&TensorPool::new(), view, k, axis, largest, sorted))
Expand Down
71 changes: 41 additions & 30 deletions src/ops/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use rten_tensor;
use rten_tensor::prelude::*;
use rten_tensor::{DynIndices, NdTensor, NdTensorView, SliceItem, Tensor, TensorView};

use crate::number::Identities;
use crate::number::{Identities, IsNaN};
use crate::ops::layout::squeeze_in_place;
use crate::ops::{
resolve_axes, resolve_axis, Input, InputList, IntoOpResult, OpError, Operator, OutputList,
Expand Down Expand Up @@ -37,11 +37,24 @@ fn select_max_index<T, Cmp: Fn(&T, &T) -> std::cmp::Ordering>(
.collect();
let mut reduced_data = pool.alloc(reduced_shape.iter().product());

fn max_position_by<'a, T: 'a>(
iter: impl Iterator<Item = &'a T>,
compare: impl Fn(&'a T, &'a T) -> std::cmp::Ordering,
) -> usize {
let (index, _) = iter.enumerate().max_by(|a, b| compare(a.1, b.1)).unwrap(); // Ok because we checked tensor is not empty.
index
}

if !input.is_empty() {
for slice in input.lanes(resolved_axis) {
let (index, _) = slice.enumerate().max_by(|a, b| compare(a.1, b.1)).unwrap(); // Ok because we checked tensor is not empty.
reduced_data.push(index as i32);
}
reduced_data.extend(input.lanes(resolved_axis).map(|lane| {
let index = if let Some(slice) = lane.as_slice() {
// Fast path for contiguous lanes.
max_position_by(slice.iter(), &compare)
} else {
max_position_by(lane, &compare)
};
index as i32
}));
}

let mut reduced = Tensor::<i32>::from_data(&reduced_shape, reduced_data);
Expand Down Expand Up @@ -96,7 +109,7 @@ macro_rules! dispatch_single_axis_reduce_op {
/// Return the index of the maximum value along a given axis.
///
/// NaN values are propagated by treating NaNs as greater than other values.
pub fn arg_max<T: Copy + PartialOrd>(
pub fn arg_max<T: Copy + PartialOrd + IsNaN>(
pool: &TensorPool,
input: TensorView<T>,
axis: isize,
Expand Down Expand Up @@ -125,7 +138,7 @@ impl Operator for ArgMax {
/// Return the index of the minimum value along a given axis.
///
/// NaN values are propagated by treating NaNs as smaller than other values.
pub fn arg_min<T: Copy + PartialOrd>(
pub fn arg_min<T: Copy + PartialOrd + IsNaN>(
pool: &TensorPool,
input: TensorView<T>,
axis: isize,
Expand All @@ -134,7 +147,7 @@ pub fn arg_min<T: Copy + PartialOrd>(
select_max_index(pool, input, axis, keep_dims, |a, b| {
match a.partial_cmp(b) {
Some(ordering) => ordering.reverse(),
None => cmp_nan_greater(a, b),
None => cmp_nan_greater(*a, *b),
}
})
}
Expand Down Expand Up @@ -493,16 +506,12 @@ impl Operator for ReduceL2 {
}
}

fn is_nan<T: PartialOrd>(a: &T) -> bool {
a.partial_cmp(a).is_none()
}

/// Compare `a` and `b`, treating all NaN values as greater than non-NaN values.
pub fn cmp_nan_greater<T: PartialOrd>(a: T, b: T) -> std::cmp::Ordering {
pub fn cmp_nan_greater<T: PartialOrd + IsNaN>(a: T, b: T) -> std::cmp::Ordering {
match a.partial_cmp(&b) {
Some(ordering) => ordering,
None => {
if is_nan(&a) {
if a.is_nan() {
std::cmp::Ordering::Greater
} else {
std::cmp::Ordering::Less
Expand All @@ -512,11 +521,11 @@ pub fn cmp_nan_greater<T: PartialOrd>(a: T, b: T) -> std::cmp::Ordering {
}

/// Compare `a` and `b`, treating all NaN values as less than non-NaN values.
pub fn cmp_nan_less<T: PartialOrd>(a: T, b: T) -> std::cmp::Ordering {
pub fn cmp_nan_less<T: PartialOrd + IsNaN>(a: T, b: T) -> std::cmp::Ordering {
match a.partial_cmp(&b) {
Some(ordering) => ordering,
None => {
if is_nan(&a) {
if a.is_nan() {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
Expand All @@ -525,7 +534,7 @@ pub fn cmp_nan_less<T: PartialOrd>(a: T, b: T) -> std::cmp::Ordering {
}
}

fn reduce_min_max<T: Copy + PartialOrd>(
fn reduce_min_max<T: Copy + PartialOrd + IsNaN>(
pool: &TensorPool,
input: TensorView<T>,
axes: Option<&[i32]>,
Expand All @@ -535,7 +544,7 @@ fn reduce_min_max<T: Copy + PartialOrd>(
struct MinMaxReducer {
max: bool,
}
impl<T: Copy + PartialOrd> Reducer<T> for MinMaxReducer {
impl<T: Copy + PartialOrd + IsNaN> Reducer<T> for MinMaxReducer {
fn reduce<I: ExactSizeIterator<Item = T>>(&self, iter: I) -> T {
let reduced = if self.max {
iter.max_by(|a, b| cmp_nan_greater(*a, *b))
Expand Down Expand Up @@ -563,7 +572,7 @@ fn get_axes<'a>(
Ok(axes)
}

pub fn reduce_min<T: Copy + PartialOrd>(
pub fn reduce_min<T: Copy + PartialOrd + IsNaN>(
pool: &TensorPool,
input: TensorView<T>,
axes: Option<&[i32]>,
Expand All @@ -590,7 +599,7 @@ impl Operator for ReduceMin {
}
}

pub fn reduce_max<T: Copy + PartialOrd>(
pub fn reduce_max<T: Copy + PartialOrd + IsNaN>(
pool: &TensorPool,
input: TensorView<T>,
axes: Option<&[i32]>,
Expand Down Expand Up @@ -716,7 +725,7 @@ impl Operator for ReduceSumSquare {
}
}

pub fn topk<T: Copy + Default + PartialOrd>(
pub fn topk<T: Copy + Default + PartialOrd + IsNaN>(
pool: &TensorPool,
values: TensorView<T>,
k: usize,
Expand Down Expand Up @@ -858,15 +867,12 @@ mod tests {
// Common use case of a tensor of (batch, item, prob) where
// `item` is eg. a token index in a sequence or box ID for object
// detection.
let seq_probs = Tensor::from_data(
&[1, 4, 3],
vec![
0.1, 0.2, 0.9, // First item
0.9, 0.1, 0.2, // Second item
0.3, 0.8, 0.4, // Third item
0.1, 0.01, 0.2, // Fourth item
],
);
let seq_probs = Tensor::from([[
[0.1, 0.2, 0.9],
[0.9, 0.1, 0.2],
[0.3, 0.8, 0.4],
[0.1, 0.01, 0.2],
]]);
let seq_classes = arg_max(&pool, seq_probs.view(), 2, false /* keep_dims */).unwrap();
assert_eq!(seq_classes.shape(), &[1, 4]);
assert_eq!(seq_classes.to_vec(), &[2, 0, 1, 2]);
Expand All @@ -891,6 +897,11 @@ mod tests {
"Cannot select index from empty sequence"
))
);

// Non-contiguous lanes
let mat = Tensor::from([[1, 2], [4, 8], [5, 6]]);
let col_max = arg_max(&pool, mat.view(), 0, false /* keep_dims */).unwrap();
assert_eq!(col_max, NdTensor::from([2, 1]));
}

// We only have basic tests for ArgMin since most of the implementation is
Expand Down
Loading

0 comments on commit cbcfa9f

Please sign in to comment.