Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine softmax arm #1316

Merged
merged 7 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions core/src/ops/nn/softmax/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use math::{
};
use num_traits::Float;
use std::fmt::Debug;
use tract_num_traits::Zero;

use crate::internal::*;
use ndarray::prelude::*;
Expand All @@ -33,16 +34,7 @@ impl Op for Softmax {
}

fn info(&self) -> TractResult<Vec<String>> {
let mut infos = vec![format!("Axis: {:?}", self.axes), format!("Exp impl: {:?}", self.exp)];
if self.exp == SoftmaxExp::FastCompact {
infos.push(format!(
"{}/{}/{}",
(tract_linalg::ops().max_f32)().name(),
(tract_linalg::ops().softmax2_fastcompact_f32)().name(),
(tract_linalg::ops().mul_by_scalar_f32)().name()
))
}
Ok(infos)
Ok(vec![format!("Axis: {:?}", self.axes), format!("Exp impl: {:?}", self.exp)])
}

op_as_typed_op!();
Expand Down Expand Up @@ -155,6 +147,11 @@ impl Softmax {
{
let slice: &mut [f32] = unsafe { std::mem::transmute(slice) };
self.softmax_inner_slice_f32(slice)?;
} else if let Some(slice) =
view.as_slice_mut().filter(|_| T::datum_type() == f16::datum_type())
{
let slice: &mut [f16] = unsafe { std::mem::transmute(slice) };
self.softmax_inner_slice_f16(slice)?;
} else {
softmax_inner(view);
}
Expand Down Expand Up @@ -196,6 +193,27 @@ impl Softmax {
Ok(tvec!(output_tensor.into_tvalue()))
}

fn softmax_inner_slice_f16(&self, slice: &mut [f16]) -> TractResult<()> {
let max = (tract_linalg::ops().max_f16)().run(slice)?;
let sum = match self.exp {
SoftmaxExp::Libc => {
let mut s = f16::zero();
for x in slice.iter_mut() {
let y = (*x - max).exp();
s += y;
*x = y;
}
s
}
SoftmaxExp::FastCompact => {
(tract_linalg::ops().softmax2_fastcompact_f16)().run_with_params(slice, max)?
}
};
let rsum = sum.recip();
(tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, rsum)?;
Ok(())
}

fn softmax_inner_slice_f32(&self, slice: &mut [f32]) -> TractResult<()> {
let max = (tract_linalg::ops().max_f32)().run(slice)?;
let sum = match self.exp {
Expand Down
2 changes: 2 additions & 0 deletions linalg/src/arm64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ pub fn plug(ops: &mut Ops) {
ops.leaky_relu_f16 = Box::new(|| arm64fp16_leaky_relu_f16_16n::ew());
ops.tanh_f16 = Box::new(|| arm64fp16_tanh_f16_8n::ew());
ops.sigmoid_f16 = Box::new(|| arm64fp16_sigmoid_f16_8n::ew());
ops.max_f16 = Box::new(|| arm64fp16_max_f16_32n::red());
ops.mul_by_scalar_f16 = Box::new(|| arm64fp16_mul_by_scalar_f16_32n::ew());
} else {
log::info!("No native fp16 support");
}
Expand Down
4 changes: 4 additions & 0 deletions linalg/src/arm64/arm64fp16.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use tract_data::half::f16;

mod by_scalar;
mod leaky_relu;
mod max;
pub use by_scalar::*;
pub use leaky_relu::*;
pub use max::*;

use crate::frame::mmm::*;
MMMKernel!(f16, arm64fp16_mmm_f16_16x8_gen; 16, 8; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16());
Expand Down
41 changes: 41 additions & 0 deletions linalg/src/arm64/arm64fp16/by_scalar.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use crate::f16;

ew_impl_wrap!(
f16,
arm64fp16_mul_by_scalar_f16_32n,
32,
4,
f16,
fn run(buf: &mut [f16], s: f16) {
assert!(buf.len() % 16 == 0);
assert!(buf.len() > 0);
#[target_feature(enable = "fp16")]
unsafe fn run(buf: &mut[f16], s: f16) {
let len = buf.len();
let ptr = buf.as_ptr();
std::arch::asm!("
dup v0.8h, v0.h[0]
1:
ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}]
fmul v4.8h, v4.8h, v0.8h
fmul v5.8h, v5.8h, v0.8h
fmul v6.8h, v6.8h, v0.8h
fmul v7.8h, v7.8h, v0.8h
st1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}], 64
subs {len}, {len}, 32
bne 1b
",
len = inout(reg) len => _,
ptr = inout(reg) ptr => _,
in("v0") s.to_bits(),
out("v4") _, out("v5") _, out("v6") _, out("v7") _,);
}
unsafe { run(buf, s) }
}
);

#[cfg(test)]
mod test_arm64fp16_mul_by_scalar_f16_32n {
use super::*;
mul_by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_mul_by_scalar_f16_32n);
}
61 changes: 61 additions & 0 deletions linalg/src/arm64/arm64fp16/max.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use tract_data::half::f16;

reduce_impl_wrap!(
f16,
arm64fp16_max_f16_32n,
32,
8,
(),
f16::MIN,
#[inline(never)]
fn run(buf: &[f16], _: ()) -> f16 {
assert!(buf.len() % 32 == 0);
assert!(buf.len() > 0);
#[target_feature(enable = "fp16")]
unsafe fn run(buf: &[f16]) -> f16 {
let len = buf.len();
let ptr = buf.as_ptr();
let mut out: u16;
std::arch::asm!("
ins v0.h[0], {min:w}
dup v0.8h, v0.h[0]
dup v1.8h, v0.h[0]
dup v2.8h, v0.h[0]
dup v3.8h, v0.h[0]

2:
ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}], 64
fmax v0.8h, v0.8h, v4.8h
fmax v1.8h, v1.8h, v5.8h
fmax v2.8h, v2.8h, v6.8h
fmax v3.8h, v3.8h, v7.8h

subs {len}, {len}, 32
bne 2b

fmax v0.8h, v0.8h, v1.8h
fmax v2.8h, v2.8h, v3.8h
fmax v0.8h, v0.8h, v2.8h
fmaxv h0, v0.8h
",
// using v0 as inout triggers https://github.com/rust-lang/rust/issues/120374
min = in(reg) f16::MIN.to_bits(),
ptr = inout(reg) ptr => _,
len = inout(reg) len => _,
out("v0") out, out("v1") _, out("v2") _, out("v3") _,
out("v4") _, out("v5") _, out("v6") _, out("v7") _,);
f16::from_bits(out)
}
unsafe { run(buf) }
},
#[inline(never)]
fn reduce_two(a: f16, b: f16) -> f16 {
a.max(b)
}
);

#[cfg(test)]
mod test_arm64fp16_max_f16_32n {
use super::*;
max_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_max_f16_32n);
}
8 changes: 3 additions & 5 deletions linalg/src/frame/element_wise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ macro_rules! ew_impl_wrap {
fn alignment_items() -> usize {
$alignment_items
}
#[inline(always)]
fn alignment_bytes() -> usize {
$alignment_items * std::mem::size_of::<$ti>()
}
$run
}
}
Expand Down Expand Up @@ -120,7 +116,9 @@ where
T: LADatum,
{
fn name() -> &'static str;
fn alignment_bytes() -> usize;
fn alignment_bytes() -> usize {
Self::alignment_items() * T::datum_type().size_of()
}
fn alignment_items() -> usize;
fn nr() -> usize;
fn run(vec: &mut [T], params: Params);
Expand Down
23 changes: 18 additions & 5 deletions linalg/src/frame/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tract_data::TractResult;

use crate::LADatum;

use super::element_wise_helper::{reduce_slice_with_alignment, map_reduce_slice_with_alignment};
use super::element_wise_helper::{map_reduce_slice_with_alignment, reduce_slice_with_alignment};

macro_rules! reduce_impl_wrap {
($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $neutral: expr, $run: item, $reduce_two: item) => {
Expand Down Expand Up @@ -130,7 +130,9 @@ where
T: LADatum,
{
fn name() -> &'static str;
fn alignment_bytes() -> usize;
fn alignment_bytes() -> usize {
Self::alignment_items() * T::datum_type().size_of()
}
fn alignment_items() -> usize;
fn nr() -> usize;
fn neutral() -> T;
Expand All @@ -141,6 +143,7 @@ where
}
}

#[allow(unused_macros)]
macro_rules! map_reduce_impl_wrap {
($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $map_neutral: expr, $reduce_neutral: expr, $run: item, $reduce_two: item) => {
paste! {
Expand Down Expand Up @@ -233,7 +236,9 @@ where
T: LADatum,
{
fn name() -> &'static str;
fn alignment_bytes() -> usize;
fn alignment_bytes() -> usize {
Self::alignment_items() * T::datum_type().size_of()
}
fn alignment_items() -> usize;
fn nr() -> usize;
fn map_neutral() -> T;
Expand Down Expand Up @@ -287,7 +292,14 @@ pub mod test {
reference_map: impl Fn(T) -> T,
reference_reduce: impl Fn(T, T) -> T,
) -> TestCaseResult {
test_map_reduce_params::<K, T, ()>(values, map_neutral, neutral, reference_map, reference_reduce, ())
test_map_reduce_params::<K, T, ()>(
values,
map_neutral,
neutral,
reference_map,
reference_reduce,
(),
)
}

pub fn test_map_reduce_params<K: MapReduceKer<T, Params>, T: LADatum, Params>(
Expand All @@ -305,7 +317,8 @@ pub mod test {
let op = K::red();
let mut found = values.to_vec();
let expected_values = values.iter().copied().map(reference_map).collect_vec();
let expected_reduced = expected_values.iter().fold(map_neutral, |acc, i| reference_reducer(acc, *i));
let expected_reduced =
expected_values.iter().fold(map_neutral, |acc, i| reference_reducer(acc, *i));
let red = op.run_with_params(&mut found, params).unwrap();
tensor1(&found)
.close_enough(&tensor1(&expected_values), Approximation::SuperApproximate)
Expand Down
2 changes: 1 addition & 1 deletion linalg/src/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub mod sigmoid;
pub mod softmax;
pub mod tanh;

pub use self::by_scalar::SMulByScalar4;
pub use self::by_scalar::{HMulByScalar8, SMulByScalar4};
pub use self::erf::SErf4;
pub use self::leaky_relu::{HLeakyRelu8, SLeakyRelu4};
pub use self::lut::GenericLut8;
Expand Down
38 changes: 33 additions & 5 deletions linalg/src/generic/by_scalar.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use tract_data::internal::f16;

use crate::element_wise::ElementWiseKer;


Expand All @@ -10,11 +12,7 @@ impl ElementWiseKer<f32, f32> for SMulByScalar4 {
}

fn alignment_items() -> usize {
16
}

fn alignment_bytes() -> usize {
16
4
}

fn nr() -> usize {
Expand All @@ -33,3 +31,33 @@ impl ElementWiseKer<f32, f32> for SMulByScalar4 {
pub mod mul_by_scalar_f32 {
mul_by_scalar_frame_tests!(true, f32, crate::generic::by_scalar::SMulByScalar4);
}

#[derive(Clone, Debug)]
pub struct HMulByScalar8;

impl ElementWiseKer<f16, f16> for HMulByScalar8 {
fn name() -> &'static str {
"generic"
}

fn alignment_items() -> usize {
8
}

fn nr() -> usize {
8
}

fn run(x: &mut [f16], s: f16) {
debug_assert!(x.len() % Self::nr() == 0);
debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
x.iter_mut().for_each(|px| *px *= s)
}
}

#[cfg(test)]
#[macro_use]
pub mod mul_by_scalar_f16 {
use super::*;
mul_by_scalar_frame_tests!(true, f16, crate::generic::by_scalar::HMulByScalar8);
}
Loading
Loading