From 570d1b3ca83838afd69ec1270a181635faea2183 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 15 Jan 2024 09:44:32 +0100 Subject: [PATCH 01/26] introduce reduce --- linalg/Cargo.toml | 4 ++ linalg/src/frame.rs | 4 ++ linalg/src/frame/element_wise.rs | 4 +- linalg/src/frame/element_wise_helper.rs | 46 ++++++++++++++++++++++- linalg/src/frame/max.rs | 50 +++++++++++++++++++++++++ linalg/src/generic.rs | 1 + linalg/src/generic/max.rs | 42 +++++++++++++++++++++ 7 files changed, 148 insertions(+), 3 deletions(-) create mode 100644 linalg/src/frame/max.rs create mode 100644 linalg/src/generic/max.rs diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index 60342b7ca7..7f7b3d0f7e 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -79,6 +79,10 @@ harness = false name = "sigmoid" harness = false +[[bench]] +name = "softmax" +harness = false + [[bench]] bench = false name = "arm64simd" diff --git a/linalg/src/frame.rs b/linalg/src/frame.rs index 5b20912abb..5370cbb631 100644 --- a/linalg/src/frame.rs +++ b/linalg/src/frame.rs @@ -3,11 +3,15 @@ pub mod element_wise; #[macro_use] pub mod lut; #[macro_use] +pub mod max; +#[macro_use] pub mod mmm; pub mod pack; #[macro_use] pub mod leaky_relu; #[macro_use] +pub mod reduce; +#[macro_use] pub mod sigmoid; #[macro_use] pub mod tanh; diff --git a/linalg/src/frame/element_wise.rs b/linalg/src/frame/element_wise.rs index 089769d064..d5bae4a0e9 100644 --- a/linalg/src/frame/element_wise.rs +++ b/linalg/src/frame/element_wise.rs @@ -5,7 +5,7 @@ use tract_data::TractResult; use crate::LADatum; -use super::element_wise_helper::run_over_slice_with_alignment; +use super::element_wise_helper::map_slice_with_alignment; macro_rules! ew_impl_wrap { ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $run: item) => { @@ -100,7 +100,7 @@ where K: ElementWiseKer + Clone, { fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult<()> { - run_over_slice_with_alignment( + map_slice_with_alignment( vec, |data| K::run(data, params), K::nr(), diff --git a/linalg/src/frame/element_wise_helper.rs b/linalg/src/frame/element_wise_helper.rs index 3116217b11..8eca348a85 100644 --- a/linalg/src/frame/element_wise_helper.rs +++ b/linalg/src/frame/element_wise_helper.rs @@ -2,7 +2,7 @@ use crate::LADatum; use std::alloc::*; use tract_data::TractResult; -pub(crate) fn run_over_slice_with_alignment( +pub(crate) fn map_slice_with_alignment( vec: &mut [T], f: impl Fn(&mut [T]), nr: usize, @@ -40,6 +40,50 @@ where Ok(()) } +pub(crate) fn map_reduce_slice_with_alignment( + vec: &mut [T], + f: impl Fn(&mut [T]) -> T, + nr: usize, + alignment_bytes: usize, + neutral: T, + reduce: impl Fn(T, T) -> T, +) -> TractResult +where + T: LADatum, +{ + if vec.is_empty() { + return Ok(neutral); + } + let mut red = neutral; + unsafe { + TMP.with(|buffer| { + let mut buffer = buffer.borrow_mut(); + buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes); + let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr); + let mut compute_via_temp_buffer = |slice: &mut [T], red: &mut T| { + tmp[..slice.len()].copy_from_slice(slice); + tmp[slice.len()..].fill(neutral); + let t = f(tmp); + *red = reduce(*red, t); + slice.copy_from_slice(&tmp[..slice.len()]) + }; + let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len()); + if prefix_len > 0 { + compute_via_temp_buffer(&mut vec[..prefix_len], &mut red); + } + let aligned_len = (vec.len() - prefix_len) / nr * nr; + if aligned_len > 0 { + let t = f(&mut vec[prefix_len..][..aligned_len]); + red = reduce(red, t); + } + if prefix_len + aligned_len < vec.len() { + compute_via_temp_buffer(&mut vec[prefix_len + aligned_len..], &mut red); + } + }) + } + Ok(red) +} + std::thread_local! { static TMP: std::cell::RefCell = std::cell::RefCell::new(TempBuffer::default()); } diff --git a/linalg/src/frame/max.rs b/linalg/src/frame/max.rs new file mode 100644 index 0000000000..e4cb34331d --- /dev/null +++ b/linalg/src/frame/max.rs @@ -0,0 +1,50 @@ +macro_rules! max_impl { + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $cond: expr) => { + ew_impl!($ti, $func, $nr, $alignment_items); + #[cfg(test)] + paste! { + mod [] { + use super::*; + sigmoid_frame_tests!($cond, $ti, $func); + } + } + }; +} + +#[cfg(test)] +#[macro_use] +pub mod test { + use crate::frame::reduce::ReduceKer; + use crate::LADatum; + use num_traits::{AsPrimitive, Float}; + use proptest::test_runner::TestCaseResult; + + #[macro_export] + macro_rules! max_frame_tests { + ($cond:expr, $t: ty, $ker:ty) => { + proptest::proptest! { + #[test] + fn prop(xs in proptest::collection::vec(-25f32..25.0, 0..100)) { + if $cond { + $crate::frame::max::test::test_max::<$ker, $t>(&*xs).unwrap() + } + } + } + }; + } + + pub fn test_max, T: LADatum + Float>(values: &[f32]) -> TestCaseResult + where + f32: AsPrimitive, + T: AsPrimitive, + { + crate::setup_test_logger(); + let values: Vec = values.iter().copied().map(|x| x.as_()).collect(); + crate::frame::reduce::test::test_reducer::( + &values, + ::min_value(), + |x| x, + |a, b| a.max(b), + ) + } +} diff --git a/linalg/src/generic.rs b/linalg/src/generic.rs index ddbc23f6d4..42420caef8 100644 --- a/linalg/src/generic.rs +++ b/linalg/src/generic.rs @@ -1,6 +1,7 @@ pub mod erf; pub mod leaky_relu; pub mod lut; +pub mod max; pub mod mmm; pub mod rounding; pub mod sigmoid; diff --git a/linalg/src/generic/max.rs b/linalg/src/generic/max.rs new file mode 100644 index 0000000000..c7c3fc2aa1 --- /dev/null +++ b/linalg/src/generic/max.rs @@ -0,0 +1,42 @@ +use crate::frame::reduce::ReduceKer; + +#[derive(Clone, Debug)] +pub struct SMax4; + +impl ReduceKer for SMax4 { + fn name() -> &'static str { + "generic" + } + + fn alignment_bytes() -> usize { + 16 + } + + fn alignment_items() -> usize { + 4 + } + + fn nr() -> usize { + 4 + } + + fn neutral() -> f32 { + f32::MIN + } + + fn reduce_two(a: f32, b: f32) -> f32 { + a.max(b) + } + + fn run(x: &mut [f32], _: ()) -> f32 { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap() + } +} + +#[cfg(test)] +#[macro_use] +pub mod tmax { + max_frame_tests!(true, f32, crate::generic::max::SMax4); +} From 67d73eb36667ee24e64a5382a86b3304a3a6f15b Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 15 Jan 2024 10:09:04 +0100 Subject: [PATCH 02/26] plug in max from linalg --- core/src/ops/nn/reduce.rs | 37 ++++++++++++++++--------- linalg/src/frame/element_wise_helper.rs | 18 ++++++------ linalg/src/generic/max.rs | 2 +- linalg/src/lib.rs | 6 +++- 4 files changed, 38 insertions(+), 25 deletions(-) diff --git a/core/src/ops/nn/reduce.rs b/core/src/ops/nn/reduce.rs index dc082da7e4..f26cba0f5a 100644 --- a/core/src/ops/nn/reduce.rs +++ b/core/src/ops/nn/reduce.rs @@ -1,6 +1,7 @@ use crate::internal::Axis; use crate::internal::*; use std::convert::TryFrom; +use std::mem::transmute; use tract_data::internal::ClampCast; use tract_data::itertools::Itertools; use tract_ndarray::prelude::*; @@ -210,6 +211,12 @@ fn max_t(v: ArrayViewD, _: ()) -> T where T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd, { + if T::datum_type() == f32::datum_type() { + if let Some(slice) = v.as_slice() { + let slice = unsafe { transmute(slice) }; + (tract_linalg::ops().max_f32)().run(slice).unwrap(); + } + } v.fold(T::min_value(), |acc, &v| if acc > v { acc } else { v }) } @@ -297,19 +304,23 @@ impl TypedOp for Reduce { outputs: &[&TypedFact], ) -> TractResult { let mut letters = 'a'..; - let axes = (0..inputs[0].rank()).flat_map(|ix| { - if self.axes.contains(&ix) { - tvec!( - Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()).input(0, ix), - Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()).output(0, ix), - ) - } else { - tvec!(Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()) - .input(0, ix) - .output(0, ix)) - } - .into_iter() - }).collect_vec(); + let axes = (0..inputs[0].rank()) + .flat_map(|ix| { + if self.axes.contains(&ix) { + tvec!( + Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()) + .input(0, ix), + Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()) + .output(0, ix), + ) + } else { + tvec!(Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()) + .input(0, ix) + .output(0, ix)) + } + .into_iter() + }) + .collect_vec(); AxesMapping::new(1, 1, axes) } diff --git a/linalg/src/frame/element_wise_helper.rs b/linalg/src/frame/element_wise_helper.rs index 8eca348a85..a0ae7de2a0 100644 --- a/linalg/src/frame/element_wise_helper.rs +++ b/linalg/src/frame/element_wise_helper.rs @@ -40,9 +40,9 @@ where Ok(()) } -pub(crate) fn map_reduce_slice_with_alignment( - vec: &mut [T], - f: impl Fn(&mut [T]) -> T, +pub(crate) fn reduce_slice_with_alignment( + vec: &[T], + f: impl Fn(&[T]) -> T, nr: usize, alignment_bytes: usize, neutral: T, @@ -60,24 +60,22 @@ where let mut buffer = buffer.borrow_mut(); buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes); let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr); - let mut compute_via_temp_buffer = |slice: &mut [T], red: &mut T| { + let mut compute_via_temp_buffer = |slice: &[T], red: &mut T| { tmp[..slice.len()].copy_from_slice(slice); tmp[slice.len()..].fill(neutral); - let t = f(tmp); - *red = reduce(*red, t); - slice.copy_from_slice(&tmp[..slice.len()]) + *red = reduce(*red, f(tmp)); }; let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len()); if prefix_len > 0 { - compute_via_temp_buffer(&mut vec[..prefix_len], &mut red); + compute_via_temp_buffer(&vec[..prefix_len], &mut red); } let aligned_len = (vec.len() - prefix_len) / nr * nr; if aligned_len > 0 { - let t = f(&mut vec[prefix_len..][..aligned_len]); + let t = f(&vec[prefix_len..][..aligned_len]); red = reduce(red, t); } if prefix_len + aligned_len < vec.len() { - compute_via_temp_buffer(&mut vec[prefix_len + aligned_len..], &mut red); + compute_via_temp_buffer(&vec[prefix_len + aligned_len..], &mut red); } }) } diff --git a/linalg/src/generic/max.rs b/linalg/src/generic/max.rs index c7c3fc2aa1..c21aa2b7ca 100644 --- a/linalg/src/generic/max.rs +++ b/linalg/src/generic/max.rs @@ -28,7 +28,7 @@ impl ReduceKer for SMax4 { a.max(b) } - fn run(x: &mut [f32], _: ()) -> f32 { + fn run(x: &[f32], _: ()) -> f32 { debug_assert!(x.len() % Self::nr() == 0); debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap() diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index 3dd60ff7d0..3e15f041d8 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -16,7 +16,8 @@ include!(concat!(env!("OUT_DIR"), "/extern_kernel_macro.rs")); pub mod frame; pub mod generic; use frame::element_wise::ElementWiseKer; -use frame::MatMatMul; +use frame::reduce::ReduceKer; +use frame::{reduce, MatMatMul}; pub use generic::{ScaleShiftAndRound, Scaler}; #[cfg(target_arch = "x86_64")] pub mod x86_64_fma; @@ -62,6 +63,8 @@ pub struct Ops { pub tanh_f32: Box Box> + Send + Sync>, pub erf_f32: Box Box> + Send + Sync>, pub lut_u8: Box Box + Send + Sync>, + + pub max_f32: Box Box> + Send + Sync>, } impl Ops { @@ -119,6 +122,7 @@ pub fn generic() -> Ops { tanh_f32: Box::new(|| generic::STanh4::ew()), erf_f32: Box::new(|| generic::SErf4::ew()), lut_u8: Box::new(|table: &[u8]| Box::new(lut::LutImpl::::new(table))), + max_f32: Box::new(|| generic::max::SMax4::red()), /* activation_f32: Box::new(|microcode| generic::SActivation::new(microcode)) */ From db6ffd34c0731c7e2535fbc0c96fdc5affa07d3d Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 15 Jan 2024 15:01:51 +0100 Subject: [PATCH 03/26] fast max impl --- linalg/benches/softmax.rs | 59 +++++++++++++++++++++++++ linalg/src/arm64.rs | 4 ++ linalg/src/frame/element_wise_helper.rs | 43 ++++++++++++++++++ linalg/src/frame/max.rs | 16 +++++-- 4 files changed, 118 insertions(+), 4 deletions(-) create mode 100644 linalg/benches/softmax.rs diff --git a/linalg/benches/softmax.rs b/linalg/benches/softmax.rs new file mode 100644 index 0000000000..8b0a7b089c --- /dev/null +++ b/linalg/benches/softmax.rs @@ -0,0 +1,59 @@ +use criterion::*; +use tract_data::prelude::*; +use tract_linalg::frame::reduce::ReduceKer; + +#[inline(never)] +fn loop1_f32(slice: &mut [f32]) -> f32 { + let mut max = std::f32::MIN; + for x in &*slice { + if *x > max { + max = *x; + } + } + max +} + +#[inline(never)] +fn loop2_f32(slice: &mut [f32], max: f32) -> f32 { + let mut sum = 0.; + for x in slice.iter_mut() { + *x = (*x - max).exp(); + sum = sum + *x; + } + sum +} + +#[inline(never)] +fn loop3_f32(slice: &mut [f32], sum: f32) { + let recip = sum.recip(); + for x in slice { + *x = *x * recip; + } +} + +#[inline(never)] +fn rust_f32(slice: &mut [f32]) { + let max = loop1_f32(slice); + let sum = loop2_f32(slice, max); + loop3_f32(slice, sum); +} + +fn softmax_f32(c: &mut Criterion) { + let mut group = c.benchmark_group("softmax_f32"); + group.throughput(Throughput::Elements(1500)); + let mut input = unsafe { Tensor::uninitialized_aligned::(&[1500], 16).unwrap() }; + let input = input.as_slice_mut::().unwrap(); + group.bench_function("rust", |b| b.iter(|| rust_f32(input))); + group.bench_function("loop1/rust", |b| b.iter(|| loop1_f32(input))); + #[cfg(target_arch = "aarch64")] + group.bench_function("loop1/intr", |b| { + b.iter(|| { + tract_linalg::arm64::arm64simd_max_f32_16n::red().run(input).unwrap(); + }) + }); + group.bench_function("rust_loop2", |b| b.iter(|| loop2_f32(input, 1.0))); + group.bench_function("rust_loop3", |b| b.iter(|| loop3_f32(input, 0.21))); +} + +criterion_group!(benches, softmax_f32); +criterion_main!(benches); diff --git a/linalg/src/arm64.rs b/linalg/src/arm64.rs index c0435d6c65..22f8c916a0 100644 --- a/linalg/src/arm64.rs +++ b/linalg/src/arm64.rs @@ -10,12 +10,15 @@ pub use arm64simd::*; mod leaky_relu; pub use leaky_relu::*; +mod max; +pub use max::*; use crate::Ops; use crate::f16; use crate::frame::element_wise::ElementWiseKer; use crate::frame::mmm::kernel::MatMatMulKer; +use crate::frame::reduce::ReduceKer; // https://en.wikipedia.org/wiki/Comparison_of_ARMv8-A_cores const PART_A53: &str = "0xd03"; @@ -229,6 +232,7 @@ pub fn plug(ops: &mut Ops) { ops.leaky_relu_f32 = Box::new(|| arm64simd_leaky_relu_f32_8n::ew()); ops.sigmoid_f32 = Box::new(|| arm64simd_sigmoid_f32_4n::ew()); ops.tanh_f32 = Box::new(|| arm64simd_tanh_f32_4n::ew()); + ops.max_f32 = Box::new(|| arm64simd_max_f32_16n::red()); #[cfg(not(feature = "no_fp16"))] if has_fp16() { log::info!("ARMv8.2 tanh_f16 and sigmoid_f16 activated"); diff --git a/linalg/src/frame/element_wise_helper.rs b/linalg/src/frame/element_wise_helper.rs index a0ae7de2a0..de2223cbf8 100644 --- a/linalg/src/frame/element_wise_helper.rs +++ b/linalg/src/frame/element_wise_helper.rs @@ -82,6 +82,49 @@ where Ok(red) } +pub(crate) fn map_reduce_slice_with_alignment( + vec: &mut [T], + f: impl Fn(&mut [T]) -> T, + nr: usize, + alignment_bytes: usize, + neutral: T, + reduce: impl Fn(T, T) -> T, +) -> TractResult +where + T: LADatum, +{ + if vec.is_empty() { + return Ok(neutral); + } + let mut red = neutral; + unsafe { + TMP.with(|buffer| { + let mut buffer = buffer.borrow_mut(); + buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes); + let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr); + let mut compute_via_temp_buffer = |slice: &mut [T], red: &mut T| { + tmp[..slice.len()].copy_from_slice(slice); + tmp[slice.len()..].fill(neutral); + *red = reduce(*red, f(tmp)); + slice.copy_from_slice(&tmp[..slice.len()]); + }; + let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len()); + if prefix_len > 0 { + compute_via_temp_buffer(&mut vec[..prefix_len], &mut red); + } + let aligned_len = (vec.len() - prefix_len) / nr * nr; + if aligned_len > 0 { + let t = f(&mut vec[prefix_len..][..aligned_len]); + red = reduce(red, t); + } + if prefix_len + aligned_len < vec.len() { + compute_via_temp_buffer(&mut vec[prefix_len + aligned_len..], &mut red); + } + }) + } + Ok(red) +} + std::thread_local! { static TMP: std::cell::RefCell = std::cell::RefCell::new(TempBuffer::default()); } diff --git a/linalg/src/frame/max.rs b/linalg/src/frame/max.rs index e4cb34331d..977a1e5e06 100644 --- a/linalg/src/frame/max.rs +++ b/linalg/src/frame/max.rs @@ -1,15 +1,17 @@ -macro_rules! max_impl { + +/*macro_rules! max_impl { ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $cond: expr) => { - ew_impl!($ti, $func, $nr, $alignment_items); + reduce_impl!($ti, $func, $nr, $alignment_items); #[cfg(test)] paste! { mod [] { use super::*; - sigmoid_frame_tests!($cond, $ti, $func); + max_frame_tests!($cond, $ti, $func); } } }; } +*/ #[cfg(test)] #[macro_use] @@ -30,6 +32,13 @@ pub mod test { } } } + + #[test] + fn empty() { + if $cond { + $crate::frame::max::test::test_max::<$ker, $t>(&[]).unwrap() + } + } }; } @@ -43,7 +52,6 @@ pub mod test { crate::frame::reduce::test::test_reducer::( &values, ::min_value(), - |x| x, |a, b| a.max(b), ) } From 66bc77472b876ee30d7a8fe72a95e53f5a3d284b Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 15 Jan 2024 15:43:41 +0100 Subject: [PATCH 04/26] arm64 max --- linalg/benches/arm64simd.rs | 11 ++- linalg/src/arm64/max.rs | 52 +++++++++++ linalg/src/frame/reduce.rs | 172 ++++++++++++++++++++++++++++++++++++ 3 files changed, 234 insertions(+), 1 deletion(-) create mode 100644 linalg/src/arm64/max.rs create mode 100644 linalg/src/frame/reduce.rs diff --git a/linalg/benches/arm64simd.rs b/linalg/benches/arm64simd.rs index 0fb458718f..1dd244ff19 100644 --- a/linalg/benches/arm64simd.rs +++ b/linalg/benches/arm64simd.rs @@ -1,5 +1,4 @@ #![allow(dead_code, non_upper_case_globals, unused_macros, non_snake_case, unused_assignments)] -#![feature(aarch64_target_feature)] use std::arch::asm; @@ -57,6 +56,16 @@ pub unsafe fn armv8(filter: Option<&str>) { out("v4") _, out("v5") _, out("v6") _, out("v7") _, ) }); + s128!("fmax", 4, { + asm!(" fmax v0.4s, v1.4s, v1.4s + fmax v2.4s, v3.4s, v3.4s + fmax v4.4s, v5.4s, v5.4s + fmax v6.4s, v7.4s, v7.4s ", + out("v0") _, out("v1") _, out("v2") _, out("v3") _, + out("v4") _, out("v5") _, out("v6") _, out("v7") _, + ) + }); + s128!("fmax_with_dep", 1, { asm!("fmax v0.4s, v0.4s, v0.4s", out("v0") _) }); s128!("fmla", 16, { asm!(" fmla v0.4s, v0.4s, v0.4s fmla v1.4s, v1.4s, v1.4s diff --git a/linalg/src/arm64/max.rs b/linalg/src/arm64/max.rs new file mode 100644 index 0000000000..b4f816d9fe --- /dev/null +++ b/linalg/src/arm64/max.rs @@ -0,0 +1,52 @@ +use std::arch::aarch64::{float32x4_t, vdupq_n_f32, vgetq_lane_f32}; + +reduce_impl_wrap!( + f32, + arm64simd_max_f32_16n, + 16, + 4, + (), + f32::MIN, + #[inline(never)] + fn run(buf: &[f32], _: ()) -> f32 { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + let mut out: float32x4_t = vdupq_n_f32(f32::MIN); + std::arch::asm!(" + and v1.16b, v0.16b, v0.16b + and v2.16b, v0.16b, v0.16b + and v3.16b, v0.16b, v0.16b + 1: + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}], 64 + fmax v0.4s, v0.4s, v4.4s + fmax v1.4s, v1.4s, v5.4s + fmax v2.4s, v2.4s, v6.4s + fmax v3.4s, v3.4s, v7.4s + subs {len}, {len}, 16 + bne 1b + fmax v0.4s, v0.4s, v1.4s + fmax v2.4s, v2.4s, v3.4s + fmax v0.4s, v0.4s, v1.4s + fmaxv s0, v0.4s + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("v0") out, out("v1") _, out("v2") _, out("v3") _, + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + vgetq_lane_f32(out, 0) + } + }, + #[inline(never)] + fn reduce_two(a: f32, b: f32) -> f32 { + a.max(b) + } +); + +#[cfg(test)] +mod test_arm64simd_max_f32_16n { + use super::*; + max_frame_tests!(true, f32, arm64simd_max_f32_16n); +} diff --git a/linalg/src/frame/reduce.rs b/linalg/src/frame/reduce.rs new file mode 100644 index 0000000000..cf16020827 --- /dev/null +++ b/linalg/src/frame/reduce.rs @@ -0,0 +1,172 @@ +use std::fmt::Debug; +use std::marker::PhantomData; + +use tract_data::TractResult; + +use crate::LADatum; + +use super::element_wise_helper::reduce_slice_with_alignment; + +macro_rules! reduce_impl_wrap { + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $neutral: expr, $run: item, $reduce_two: item) => { + paste! { + #[derive(Copy, Clone, Debug)] + #[allow(non_camel_case_types)] + pub struct $func; + + impl crate::frame::reduce::ReduceKer<$ti, $params> for $func { + #[inline(always)] + fn name() -> &'static str { + stringify!($func) + } + #[inline(always)] + fn nr() -> usize { + $nr + } + #[inline(always)] + fn alignment_items() -> usize { + $alignment_items + } + #[inline(always)] + fn alignment_bytes() -> usize { + $alignment_items * std::mem::size_of::<$ti>() + } + #[inline(always)] + fn neutral() -> $ti { + $neutral + } + $run + $reduce_two + } + } + }; +} + +/* +macro_rules! reduce_impl { + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr) => { + paste! { + mod [] { + #[allow(unused_imports)] + use tract_data::prelude::f16; + extern_kernel!(fn $func(ptr: *mut $ti, count: usize) -> ()); + } + reduce_impl_wrap!($ti, $func, $nr, $alignment_items, (), + #[inline(never)] + fn run(buf: &mut [$ti], _params: ()) { + unsafe { []::$func(buf.as_mut_ptr(), buf.len()) } + } + ); + } + }; + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty) => { + paste! { + mod [] { + #[allow(unused_imports)] + use tract_data::prelude::f16; + extern_kernel!(fn $func(ptr: *mut $ti, count: usize, params: $params) -> ()); + } + ew_impl_wrap!($ti, $func, $nr, $alignment_items, $params, + #[inline(never)] + fn run(buf: &mut [$ti], params: $params) { + unsafe { []::$func(buf.as_mut_ptr(), buf.len(), params) } + } + ); + } + }; +} +*/ + +pub trait Reduce: Send + Sync + Debug + dyn_clone::DynClone +where + Params: Copy + Send + Sync + Debug + 'static + Default, + T: Copy + Debug + PartialEq + Send + Sync, +{ + fn run(&self, vec: &[T]) -> TractResult { + self.run_with_params(vec, Params::default()) + } + fn run_with_params(&self, vec: &[T], params: Params) -> TractResult; +} + +dyn_clone::clone_trait_object!( Reduce where T: Copy, Params: Copy); + +#[derive(Debug, Clone, new)] +pub struct ReduceImpl +where + T: LADatum, + Params: Copy + Send + Sync + Debug + 'static + Default, + K: ReduceKer + Clone, +{ + phantom: PhantomData<(K, T, Params)>, +} + +impl Reduce for ReduceImpl +where + T: LADatum, + Params: Copy + Send + Sync + Debug + 'static + Default, + K: ReduceKer + Clone, +{ + fn run_with_params(&self, vec: &[T], params: Params) -> TractResult { + reduce_slice_with_alignment( + vec, + |data| K::run(data, params), + K::nr(), + K::alignment_bytes(), + K::neutral(), + K::reduce_two, + ) + } +} + +pub trait ReduceKer: + Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static +where + Params: Copy + Send + Sync + Debug + 'static + Default, + T: LADatum, +{ + fn name() -> &'static str; + fn alignment_bytes() -> usize; + fn alignment_items() -> usize; + fn nr() -> usize; + fn neutral() -> T; + fn reduce_two(a: T, b: T) -> T; + fn run(vec: &[T], params: Params) -> T; + fn red() -> Box> { + Box::new(ReduceImpl::::new()) + } +} + +#[cfg(test)] +pub mod test { + use super::*; + use proptest::test_runner::{TestCaseError, TestCaseResult}; + use tract_data::internal::*; + + pub fn test_reducer, T: LADatum>( + values: &[T], + neutral: T, + reference_reducer: impl Fn(T, T) -> T, + ) -> TestCaseResult { + test_reducer_params::(values, neutral, reference_reducer, ()) + } + + pub fn test_reducer_params, T: LADatum, Params>( + values: &[T], + neutral: T, + reference_reducer: impl Fn(T, T) -> T, + params: Params, + ) -> TestCaseResult + where + Params: Copy + Send + Sync + Debug + 'static + Default, + { + crate::setup_test_logger(); + let op = K::red(); + let expected = values.iter().fold(neutral, |acc, i| reference_reducer(acc, *i)); + let mut found = values; + let red = op.run_with_params(&mut found, params).unwrap(); + tensor0(red) + .close_enough(&tensor0(expected), true) + .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?; + Ok(()) + } +} From d52b9b4323098bf9a29d652fc5453172b5bd6ebe Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 16 Jan 2024 17:06:30 +0100 Subject: [PATCH 05/26] max on intel --- linalg/benches/softmax.rs | 22 +++- linalg/proptest-regressions/generic/mmm.txt | 19 ---- .../proptest-regressions/x86_64_fma/mmm.txt | 17 --- linalg/src/arm64/max.rs | 2 +- linalg/src/frame.rs | 2 + linalg/src/frame/max.rs | 2 +- linalg/src/frame/reduce.rs | 104 +++++++++++++++++- linalg/src/frame/softmax.rs | 47 ++++++++ linalg/src/generic.rs | 2 + linalg/src/generic/softmax.rs | 47 ++++++++ linalg/src/x86_64_fma.rs | 1 + linalg/src/x86_64_fma/max.rs | 66 +++++++++++ 12 files changed, 283 insertions(+), 48 deletions(-) delete mode 100644 linalg/proptest-regressions/generic/mmm.txt delete mode 100644 linalg/proptest-regressions/x86_64_fma/mmm.txt create mode 100644 linalg/src/frame/softmax.rs create mode 100644 linalg/src/generic/softmax.rs create mode 100644 linalg/src/x86_64_fma/max.rs diff --git a/linalg/benches/softmax.rs b/linalg/benches/softmax.rs index 8b0a7b089c..0b87b37eb5 100644 --- a/linalg/benches/softmax.rs +++ b/linalg/benches/softmax.rs @@ -1,9 +1,9 @@ use criterion::*; use tract_data::prelude::*; -use tract_linalg::frame::reduce::ReduceKer; +use tract_linalg::frame::reduce::{MapReduceKer, ReduceKer}; #[inline(never)] -fn loop1_f32(slice: &mut [f32]) -> f32 { +fn loop1_f32_naive(slice: &mut [f32]) -> f32 { let mut max = std::f32::MIN; for x in &*slice { if *x > max { @@ -33,7 +33,7 @@ fn loop3_f32(slice: &mut [f32], sum: f32) { #[inline(never)] fn rust_f32(slice: &mut [f32]) { - let max = loop1_f32(slice); + let max = loop1_f32_naive(slice); let sum = loop2_f32(slice, max); loop3_f32(slice, sum); } @@ -44,14 +44,26 @@ fn softmax_f32(c: &mut Criterion) { let mut input = unsafe { Tensor::uninitialized_aligned::(&[1500], 16).unwrap() }; let input = input.as_slice_mut::().unwrap(); group.bench_function("rust", |b| b.iter(|| rust_f32(input))); - group.bench_function("loop1/rust", |b| b.iter(|| loop1_f32(input))); + group.bench_function("loop1/naive", |b| b.iter(|| loop1_f32_naive(input))); + group.bench_function("loop1/generic", |b| { + b.iter(|| tract_linalg::generic::max::SMax4::red().run(&input)) + }); + #[cfg(target_arch = "x86_64")] + group.bench_function("loop1/iasm", |b| { + b.iter(|| { + tract_linalg::x86_64_fma::max::x86_64_fma_max_f32_32n::red().run(input).unwrap(); + }) + }); #[cfg(target_arch = "aarch64")] group.bench_function("loop1/intr", |b| { b.iter(|| { tract_linalg::arm64::arm64simd_max_f32_16n::red().run(input).unwrap(); }) }); - group.bench_function("rust_loop2", |b| b.iter(|| loop2_f32(input, 1.0))); + group.bench_function("loop2/naive", |b| b.iter(|| loop2_f32(input, 1.0))); + group.bench_function("loop2/generic", |b| { + b.iter(|| tract_linalg::generic::softmax::SSoftMaxL2::red().run_with_params(input, 10.)) + }); group.bench_function("rust_loop3", |b| b.iter(|| loop3_f32(input, 0.21))); } diff --git a/linalg/proptest-regressions/generic/mmm.txt b/linalg/proptest-regressions/generic/mmm.txt deleted file mode 100644 index a6a851136a..0000000000 --- a/linalg/proptest-regressions/generic/mmm.txt +++ /dev/null @@ -1,19 +0,0 @@ -# Seeds for failure cases proptest has generated in the past. It is -# automatically read and these particular cases re-run before any -# novel cases are generated. -# -# It is recommended to check this file in to source control so that -# everyone who runs the test benefits from these saved cases. -cc 121b28bde52462d8007d6aa3b57effba3c4f3ce05a57803cbbcaa4b7096d083f # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 1, 0, 0], mult: 0, boo: PhantomData } -cc d54877b98a8e851f187820853649131a4a0ad62e0967e5c5881577ab2ebc56f8 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 1], mult: 0, boo: PhantomData } -cc a80e1dd4a5bb0c6858ccdaf91a5aa4a8b9b71594473a91aa82583e0db4ec6ef5 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 1], mult: 0, boo: PhantomData } -cc 5aac55722aca2ca476544313178602ef71fbdda4960d3c711605cc70c6e93c06 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, -1], mult: 0, boo: PhantomData } -cc 462968c12ee0fcc1e6a78b7fc4f412d0c7f0766299893886b8ec49cf978eff49 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, -1], mult: 0, boo: PhantomData } -cc 8ac3a83562b28a2ee6fa3896f929b7210ee7d4f559d9e9c2c06c9dd7e7fa93b5 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, 1], mult: 0, boo: PhantomData } -cc d0369cc2a5339619ebabdd3dbdac60ea9ebf542b23ef536db03fba2de3651c21 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, -2], mult: 0, boo: PhantomData } -cc 9c92a9f97a6d24ffaa21370b1eb5aec9880b508a384d037287df8de7f8d4c608 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, 1], mult: 0, boo: PhantomData } -cc f3574fa61a7e297de8cf76c8268dee6c6d062dfc5f98738618bdeb69d5b9b595 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], mult: 0, boo: PhantomData } -cc bb5421af3bb30e0ac316287dc41563ce69eca8ca74abf5f5a7f8e82d7cab6628 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], mult: 0, boo: PhantomData } -cc 45859e240aae3fa0e7667d46dd475d9f3de42e87ace8b3678db322b4a18509af # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255], mult: 0, boo: PhantomData } -cc 65d3a24331e3919a11d220118f97fe21caf769d1c590737cad9aef927959b4e9 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], mult: 0, boo: PhantomData } -cc 378b328c0e46a5038f5876614a1ac16dc796482676f2ac81a670f5dbabb490ec # shrinks to pb = ConvProblem { ci: 1, co: 19, kt: 1, stride: 1, dilation: 4, filters: 19,1,F32 0, 0, 0, 0, 0, 0, 0, 0.026, -0.286, -0.627, -0.248, -0.291..., data: 1,5,F32 0.873, -0.377, -0.818, -0.415, 0.746, phantom: PhantomData } diff --git a/linalg/proptest-regressions/x86_64_fma/mmm.txt b/linalg/proptest-regressions/x86_64_fma/mmm.txt deleted file mode 100644 index 57df875ef5..0000000000 --- a/linalg/proptest-regressions/x86_64_fma/mmm.txt +++ /dev/null @@ -1,17 +0,0 @@ -# Seeds for failure cases proptest has generated in the past. It is -# automatically read and these particular cases re-run before any -# novel cases are generated. -# -# It is recommended to check this file in to source control so that -# everyone who runs the test benefits from these saved cases. -cc 058e808c67cb65d0a9013ff2cc0d9f726f19bef9b21e205396db15871be2586d # shrinks to pb = ReturnCProblem { c: [0.0, 0.0, -0.0, 0.0, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, 0.0, 0.0, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1000702600000000000000.0], boo: PhantomData } -cc 65e8110dc380469eec961f5d4ef07bd31b7f44defe08af00e6dcc812228bcc7f # shrinks to pb = QRightShiftProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], shift: 0, policy: Away, boo: PhantomData } -cc 1f7e80660ef4800d7bbb6acafe9bbff1e396f0f78b296d20deb51b439a40736c # shrinks to pb = QScaleProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], mult: 536870912, shift: 0, policy: Odd, boo: PhantomData } -cc 57c3e684c05eb9cc964eaaf22ccb620a2d9ef5ea5bd38c6dff6c5344b8192619 # shrinks to pb = QScaleProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], mult: 536870912, shift: 0, policy: Odd, boo: PhantomData } -cc 9fce04b53146b7615e54728101dd8ad24ffd157d63c536eb1ca888070c8b1719 # shrinks to pb = PackedOffsetsProblem { a: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], b: [0.0, 0.0, 0.0, 0.0, 1.0, 9.0, 7.0, 7.0, 0.0, 5.0, 2.0, 3.0, 0.0, 3.0, 7.0, 2.0, 2.0, 6.0, 4.0, 4.0, 1.0, 5.0, 5.0, 1.0, 3.0, 7.0, 3.0, 4.0, 7.0, 5.0, 8.0, 8.0, 1.0, 5.0, 6.0, 8.0, 7.0, 6.0], cols_offsets: [28], rows_offsets: [9], add_one: false, _phantom: PhantomData } -cc 708172a1bb5110aa22d15b4a0daa777fae18dda1ab481328ab200f54d72357b6 # shrinks to pb = ConvProblem { ci: 1, co: 17, kt: 1, stride: 3, dilation: 4, filters: 17,1,F32 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.044..., data: 1,3,F32 -0.416, -0.932, -0.21, phantom: PhantomData } -cc e62d087b655259ebda7622f40e0049a49fb803777a3a98343c6855e69944ccfc # shrinks to pb = ConvProblem { ci: 1, co: 17, kt: 1, stride: 4, dilation: 4, filters: 17,1,F32 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0..., data: 1,4,F32 0.251, 0, 0, 0, phantom: PhantomData } -cc 15bfa5078b5479e12e74820888d477cbdd628faa03aa1bef26738dfa8cbff014 # shrinks to (m, k, n, ref a, ref b) = (1, 1, 2, 1,1,F32 0, 1,2,F32 0, 0) -cc dde4ea4eafe9beb60bffe8307cdb62c6143849f247ad9edbb64acd586f57d722 # shrinks to (m, k, n, ref a, ref b) = (1, 1, 2, 1,1,F32 0, 1,2,F32 0, 0) -cc 066cecc4f004061abab40982576e9f10c56dd090751a699e2cffeb3f7260b0e1 # shrinks to (m, k, n, ref a, ref b) = (1, 1, 2, 1,1,F32 0, 1,2,F32 0, 0) -cc 2ba4fcc74dd26a2e26c3c589d03961e5cb753a245a2ce5ec2ee7b9c53b30d8a2 # shrinks to pb = PackedPackedProblem { k: 16, a: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], b: [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], trans_c: false, add_one: false, _phantom: PhantomData } diff --git a/linalg/src/arm64/max.rs b/linalg/src/arm64/max.rs index b4f816d9fe..bae613d957 100644 --- a/linalg/src/arm64/max.rs +++ b/linalg/src/arm64/max.rs @@ -29,7 +29,7 @@ reduce_impl_wrap!( bne 1b fmax v0.4s, v0.4s, v1.4s fmax v2.4s, v2.4s, v3.4s - fmax v0.4s, v0.4s, v1.4s + fmax v0.4s, v0.4s, v1.4s // TODO: TEST ME AND FIND THE BUG! fmaxv s0, v0.4s ", len = inout(reg) len => _, diff --git a/linalg/src/frame.rs b/linalg/src/frame.rs index 5370cbb631..837e35351a 100644 --- a/linalg/src/frame.rs +++ b/linalg/src/frame.rs @@ -14,6 +14,8 @@ pub mod reduce; #[macro_use] pub mod sigmoid; #[macro_use] +pub mod softmax; +#[macro_use] pub mod tanh; pub mod element_wise_helper; diff --git a/linalg/src/frame/max.rs b/linalg/src/frame/max.rs index 977a1e5e06..12a451487c 100644 --- a/linalg/src/frame/max.rs +++ b/linalg/src/frame/max.rs @@ -49,7 +49,7 @@ pub mod test { { crate::setup_test_logger(); let values: Vec = values.iter().copied().map(|x| x.as_()).collect(); - crate::frame::reduce::test::test_reducer::( + crate::frame::reduce::test::test_reduce::( &values, ::min_value(), |a, b| a.max(b), diff --git a/linalg/src/frame/reduce.rs b/linalg/src/frame/reduce.rs index cf16020827..ab97626ce1 100644 --- a/linalg/src/frame/reduce.rs +++ b/linalg/src/frame/reduce.rs @@ -5,7 +5,7 @@ use tract_data::TractResult; use crate::LADatum; -use super::element_wise_helper::reduce_slice_with_alignment; +use super::element_wise_helper::{reduce_slice_with_alignment, map_reduce_slice_with_alignment}; macro_rules! reduce_impl_wrap { ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $neutral: expr, $run: item, $reduce_two: item) => { @@ -136,21 +136,81 @@ where } } +pub trait MapReduce: Send + Sync + Debug + dyn_clone::DynClone +where + Params: Copy + Send + Sync + Debug + 'static + Default, + T: Copy + Debug + PartialEq + Send + Sync, +{ + fn run(&self, vec: &mut [T]) -> TractResult { + self.run_with_params(vec, Params::default()) + } + fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult; +} + +dyn_clone::clone_trait_object!( MapReduce where T: Copy, Params: Copy); + +#[derive(Debug, Clone, new)] +pub struct MapReduceImpl +where + T: LADatum, + Params: Copy + Send + Sync + Debug + 'static + Default, + K: MapReduceKer + Clone, +{ + phantom: PhantomData<(K, T, Params)>, +} + +impl MapReduce for MapReduceImpl +where + T: LADatum, + Params: Copy + Send + Sync + Debug + 'static + Default, + K: MapReduceKer + Clone, +{ + fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult { + map_reduce_slice_with_alignment( + vec, + |data| K::run(data, params), + K::nr(), + K::alignment_bytes(), + K::neutral(), + K::reduce_two, + ) + } +} + +pub trait MapReduceKer: + Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static +where + Params: Copy + Send + Sync + Debug + 'static + Default, + T: LADatum, +{ + fn name() -> &'static str; + fn alignment_bytes() -> usize; + fn alignment_items() -> usize; + fn nr() -> usize; + fn neutral() -> T; + fn reduce_two(a: T, b: T) -> T; + fn run(vec: &mut [T], params: Params) -> T; + fn red() -> Box> { + Box::new(MapReduceImpl::::new()) + } +} + #[cfg(test)] pub mod test { use super::*; use proptest::test_runner::{TestCaseError, TestCaseResult}; use tract_data::internal::*; + use tract_data::itertools::Itertools; - pub fn test_reducer, T: LADatum>( + pub fn test_reduce, T: LADatum>( values: &[T], neutral: T, - reference_reducer: impl Fn(T, T) -> T, + reference_reduce: impl Fn(T, T) -> T, ) -> TestCaseResult { - test_reducer_params::(values, neutral, reference_reducer, ()) + test_reduce_params::(values, neutral, reference_reduce, ()) } - pub fn test_reducer_params, T: LADatum, Params>( + pub fn test_reduce_params, T: LADatum, Params>( values: &[T], neutral: T, reference_reducer: impl Fn(T, T) -> T, @@ -169,4 +229,38 @@ pub mod test { .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?; Ok(()) } + + pub fn test_map_reduce, T: LADatum>( + values: &[T], + neutral: T, + reference_map: impl Fn(T) -> T, + reference_reduce: impl Fn(T, T) -> T, + ) -> TestCaseResult { + test_map_reduce_params::(values, neutral, reference_map, reference_reduce, ()) + } + + pub fn test_map_reduce_params, T: LADatum, Params>( + values: &[T], + neutral: T, + reference_map: impl Fn(T) -> T, + reference_reducer: impl Fn(T, T) -> T, + params: Params, + ) -> TestCaseResult + where + Params: Copy + Send + Sync + Debug + 'static + Default, + { + crate::setup_test_logger(); + let op = K::red(); + let mut found = values.to_vec(); + let expected_values = values.iter().copied().map(reference_map).collect_vec(); + let expected_reduced = expected_values.iter().fold(neutral, |acc, i| reference_reducer(acc, *i)); + let red = op.run_with_params(&mut found, params).unwrap(); + tensor1(&found) + .close_enough(&tensor1(&expected_values), true) + .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?; + tensor0(red) + .close_enough(&tensor0(expected_reduced), true) + .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?; + Ok(()) + } } diff --git a/linalg/src/frame/softmax.rs b/linalg/src/frame/softmax.rs new file mode 100644 index 0000000000..b14d882e86 --- /dev/null +++ b/linalg/src/frame/softmax.rs @@ -0,0 +1,47 @@ +#[cfg(test)] +#[macro_use] +pub mod test { + use crate::frame::reduce::MapReduceKer; + use crate::LADatum; + use num_traits::{AsPrimitive, Float}; + use proptest::test_runner::TestCaseResult; + + #[macro_export] + macro_rules! softmax_l2_frame_tests { + ($cond:expr, $t: ty, $ker:ty) => { + proptest::proptest! { + #[test] + fn prop(xs in proptest::collection::vec(-25f32..25.0, 1..100)) { + if $cond { + $crate::frame::softmax::test::test_softmax_l2::<$ker, $t>(&*xs).unwrap() + } + } + } + #[test] + fn single() { + if $cond { + $crate::frame::softmax::test::test_softmax_l2::<$ker, $t>(&[0.0]).unwrap() + } + } + }; + } + + pub fn test_softmax_l2, T: LADatum + Float>( + values: &[f32], + ) -> TestCaseResult + where + f32: AsPrimitive, + T: AsPrimitive, + { + crate::setup_test_logger(); + let max = values.iter().max_by(|a, b| a.total_cmp(b)).unwrap(); + let values: Vec = values.iter().copied().map(|x| x.as_()).collect(); + crate::frame::reduce::test::test_map_reduce_params::( + &values, + ::min_value(), + |x| (x - max.as_()).exp(), + |a, b| a + b, + max.as_(), + ) + } +} diff --git a/linalg/src/generic.rs b/linalg/src/generic.rs index 42420caef8..36aa5d8559 100644 --- a/linalg/src/generic.rs +++ b/linalg/src/generic.rs @@ -5,6 +5,7 @@ pub mod max; pub mod mmm; pub mod rounding; pub mod sigmoid; +pub mod softmax; pub mod tanh; pub use self::erf::SErf4; @@ -14,4 +15,5 @@ pub use self::mmm::GenericMmm4x1; pub use self::mmm::GenericMmm4x4; pub use self::rounding::{ScaleShiftAndRound, Scaler}; pub use self::sigmoid::{HSigmoid8, SSigmoid4}; +pub use self::softmax::SSoftMaxL2; pub use self::tanh::{HTanh8, STanh4}; diff --git a/linalg/src/generic/softmax.rs b/linalg/src/generic/softmax.rs new file mode 100644 index 0000000000..75c2e8264a --- /dev/null +++ b/linalg/src/generic/softmax.rs @@ -0,0 +1,47 @@ +use crate::frame::reduce::MapReduceKer; + +#[derive(Clone, Debug)] +pub struct SSoftMaxL2; + +impl MapReduceKer for SSoftMaxL2 { + fn name() -> &'static str { + "generic" + } + + fn alignment_bytes() -> usize { + 16 + } + + fn alignment_items() -> usize { + 4 + } + + fn nr() -> usize { + 4 + } + + fn neutral() -> f32 { + f32::MIN + } + + fn reduce_two(a: f32, b: f32) -> f32 { + a + b + } + + fn run(x: &mut [f32], max: f32) -> f32 { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + let mut sum = 0.; + for x in x.iter_mut() { + *x = (*x - max).exp(); + sum += *x; + } + sum + } +} + +#[cfg(test)] +#[macro_use] +pub mod tsoftmax { + softmax_l2_frame_tests!(true, f32, crate::generic::softmax::SSoftMaxL2); +} diff --git a/linalg/src/x86_64_fma.rs b/linalg/src/x86_64_fma.rs index 12664bdcca..907d9b6ef0 100644 --- a/linalg/src/x86_64_fma.rs +++ b/linalg/src/x86_64_fma.rs @@ -5,6 +5,7 @@ use crate::Ops; pub mod mmm; mod intel; +pub mod max; tanh_impl!(f32, fma_tanh_f32, 8, 8, is_x86_feature_detected!("fma")); sigmoid_impl!(f32, fma_sigmoid_f32, 8, 8, is_x86_feature_detected!("fma")); diff --git a/linalg/src/x86_64_fma/max.rs b/linalg/src/x86_64_fma/max.rs new file mode 100644 index 0000000000..69f89d6436 --- /dev/null +++ b/linalg/src/x86_64_fma/max.rs @@ -0,0 +1,66 @@ +reduce_impl_wrap!( + f32, + x86_64_fma_max_f32_32n, + 32, + 32, + (), + f32::MIN, + #[inline(never)] + fn run(buf: &[f32], _: ()) -> f32 { + assert!(buf.len() % 32 == 0); + assert!(buf.len() > 0); + unsafe { x86_64_fma_max_f32_32n_run(buf) } + }, + #[inline(never)] + fn reduce_two(a: f32, b: f32) -> f32 { + a.max(b) + } +); + +#[target_feature(enable = "avx")] +unsafe fn x86_64_fma_max_f32_32n_run(buf: &[f32]) -> f32 { + let len = buf.len(); + let ptr = buf.as_ptr(); + let mut acc = f32::MIN; + std::arch::asm!(" + vbroadcastss ymm0, xmm0 + vmovaps ymm1, ymm0 + vmovaps ymm2, ymm0 + vmovaps ymm3, ymm0 + 2: + vmovaps ymm4, [{ptr}] + vmovaps ymm5, [{ptr} + 32] + vmovaps ymm6, [{ptr} + 64] + vmovaps ymm7, [{ptr} + 96] + vmaxps ymm0, ymm0, ymm4 + vmaxps ymm1, ymm1, ymm5 + vmaxps ymm2, ymm2, ymm6 + vmaxps ymm3, ymm3, ymm7 + add {ptr}, 128 + sub {len}, 32 + jnz 2b + vmaxps ymm0, ymm0, ymm1 + vmaxps ymm2, ymm2, ymm3 + vmaxps ymm0, ymm0, ymm2 + vperm2f128 ymm1, ymm0, ymm0, 1 // copy second half (4xf32) of ymm0 to ymm1 + vmaxps xmm0, xmm0, xmm1 // xmm0 contains 4 values to max + vpermilps xmm1, xmm0, 2 + (3 << 2) // second 2x32 bit half moved to top + vmaxps xmm0, xmm0, xmm1 // xmm0 containes 2 values + vpermilps xmm1, xmm0, 1 // second f32 to top + vmaxps xmm0, xmm0, xmm1 + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("ymm0") acc, + out("ymm1") _, out("ymm2") _, out("ymm3") _, + out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _ + ); + acc + +} + +#[cfg(test)] +mod test_x86_64_fma_max_f32_32n { + use super::*; + max_frame_tests!(true, f32, x86_64_fma_max_f32_32n); +} From fec4d104ec8dfb3d6d6f83f044aaf28d6114ad80 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 16 Jan 2024 17:44:06 +0100 Subject: [PATCH 06/26] fix register confusion --- linalg/src/arm64/max.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linalg/src/arm64/max.rs b/linalg/src/arm64/max.rs index bae613d957..86cc1a2673 100644 --- a/linalg/src/arm64/max.rs +++ b/linalg/src/arm64/max.rs @@ -29,7 +29,7 @@ reduce_impl_wrap!( bne 1b fmax v0.4s, v0.4s, v1.4s fmax v2.4s, v2.4s, v3.4s - fmax v0.4s, v0.4s, v1.4s // TODO: TEST ME AND FIND THE BUG! + fmax v0.4s, v0.4s, v2.4s fmaxv s0, v0.4s ", len = inout(reg) len => _, From fb5262cf164fcc5e1b6b9e67705236217187fc38 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 16 Jan 2024 18:01:07 +0100 Subject: [PATCH 07/26] shut clippy up --- linalg/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index 3e15f041d8..d89bb7cd2d 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -1,5 +1,6 @@ #![allow(clippy::missing_safety_doc)] #![allow(clippy::redundant_closure_call)] +#![allow(clippy::len_zero)] #[macro_use] extern crate derive_new; extern crate lazy_static; From e92200da2fbf93fb0ecc371a717cc1b976b412c1 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 17 Jan 2024 11:34:57 +0100 Subject: [PATCH 08/26] more consistent test naming in generic --- linalg/src/generic/max.rs | 2 +- linalg/src/generic/softmax.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/linalg/src/generic/max.rs b/linalg/src/generic/max.rs index c21aa2b7ca..5d21badc86 100644 --- a/linalg/src/generic/max.rs +++ b/linalg/src/generic/max.rs @@ -37,6 +37,6 @@ impl ReduceKer for SMax4 { #[cfg(test)] #[macro_use] -pub mod tmax { +pub mod s { max_frame_tests!(true, f32, crate::generic::max::SMax4); } diff --git a/linalg/src/generic/softmax.rs b/linalg/src/generic/softmax.rs index 75c2e8264a..0311aa6a60 100644 --- a/linalg/src/generic/softmax.rs +++ b/linalg/src/generic/softmax.rs @@ -42,6 +42,6 @@ impl MapReduceKer for SSoftMaxL2 { #[cfg(test)] #[macro_use] -pub mod tsoftmax { +pub mod s { softmax_l2_frame_tests!(true, f32, crate::generic::softmax::SSoftMaxL2); } From a460f9a8b0a5499239beecc44f7205a75474e82a Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 17 Jan 2024 11:57:41 +0100 Subject: [PATCH 09/26] move inline impl around a bit --- linalg/src/arm64.rs | 8 +-- linalg/src/arm64/arm64fp16.rs | 15 ++++++ .../src/arm64/{ => arm64fp16}/leaky_relu.rs | 52 +------------------ linalg/src/arm64/arm64simd.rs | 25 +++------ linalg/src/arm64/arm64simd/leaky_relu.rs | 50 ++++++++++++++++++ linalg/src/arm64/{ => arm64simd}/max.rs | 0 linalg/src/frame/mmm.rs | 6 +-- 7 files changed, 79 insertions(+), 77 deletions(-) create mode 100644 linalg/src/arm64/arm64fp16.rs rename linalg/src/arm64/{ => arm64fp16}/leaky_relu.rs (54%) create mode 100644 linalg/src/arm64/arm64simd/leaky_relu.rs rename linalg/src/arm64/{ => arm64simd}/max.rs (100%) diff --git a/linalg/src/arm64.rs b/linalg/src/arm64.rs index 22f8c916a0..96a974e914 100644 --- a/linalg/src/arm64.rs +++ b/linalg/src/arm64.rs @@ -8,10 +8,10 @@ mod cortex_a55; //mod cortex_a73; pub use arm64simd::*; -mod leaky_relu; -pub use leaky_relu::*; -mod max; -pub use max::*; +#[cfg(not(feature = "no_fp16"))] +mod arm64fp16; +#[cfg(not(feature = "no_fp16"))] +pub use arm64fp16::*; use crate::Ops; use crate::f16; diff --git a/linalg/src/arm64/arm64fp16.rs b/linalg/src/arm64/arm64fp16.rs new file mode 100644 index 0000000000..82e010d6c9 --- /dev/null +++ b/linalg/src/arm64/arm64fp16.rs @@ -0,0 +1,15 @@ +use tract_data::half::f16; + +mod leaky_relu; +pub use leaky_relu::*; + +use crate::frame::mmm::*; +MMMKernel!(f16, arm64fp16_mmm_f16_16x8_gen; 16, 8; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); +MMMKernel!(f16, arm64fp16_mmm_f16_16x8_a55; 16, 8; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); +MMMKernel!(f16, arm64fp16_mmm_f16_32x4_gen; 32, 4; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); +MMMKernel!(f16, arm64fp16_mmm_f16_32x4_a55; 32, 4; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); +MMMKernel!(f16, arm64fp16_mmm_f16_128x1_gen; 128, 1; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); +MMMKernel!(f16, arm64fp16_mmm_f16_128x1_a55; 128, 1; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); + +tanh_impl!(f16, arm64fp16_tanh_f16_8n, 8, 8, crate::arm64::has_fp16()); +sigmoid_impl!(f16, arm64fp16_sigmoid_f16_8n, 8, 8, crate::arm64::has_fp16()); diff --git a/linalg/src/arm64/leaky_relu.rs b/linalg/src/arm64/arm64fp16/leaky_relu.rs similarity index 54% rename from linalg/src/arm64/leaky_relu.rs rename to linalg/src/arm64/arm64fp16/leaky_relu.rs index 1ff968d96c..79b7f95174 100644 --- a/linalg/src/arm64/leaky_relu.rs +++ b/linalg/src/arm64/arm64fp16/leaky_relu.rs @@ -1,50 +1,5 @@ use tract_data::internal::f16; -ew_impl_wrap!( - f32, - arm64simd_leaky_relu_f32_8n, - 8, - 4, - f32, - #[inline(never)] - fn run(buf: &mut [f32], alpha: f32) { - assert!(buf.len() % 8 == 0); - assert!(buf.len() > 0); - unsafe { - let len = buf.len(); - let ptr = buf.as_ptr(); - std::arch::asm!(" - dup v0.4s, {alpha:v}.s[0] - dup v1.4s, {one:v}.s[0] - 1: - ldp q3, q4, [{ptr}] - - fcmgt v5.4s, v3.4s, #0.0 - fcmgt v6.4s, v4.4s, #0.0 - bsl v5.16b, v1.16b, v0.16b - bsl v6.16b, v1.16b, v0.16b - fmul v3.4s, v3.4s, v5.4s - fmul v4.4s, v4.4s, v6.4s - - stp q3, q4, [{ptr}], #32 - subs {len}, {len}, 8 - bne 1b - ", - one = in(vreg) 1.0f32, - alpha = in(vreg) alpha, - len = inout(reg) len => _, - ptr = inout(reg) ptr => _, - out("v0") _, - out("v1") _, - out("q3") _, - out("q4") _, - out("q5") _, - out("q6") _, - ); - } - } -); - ew_impl_wrap!( f16, arm64fp16_leaky_relu_f16_16n, @@ -92,14 +47,9 @@ ew_impl_wrap!( } ); -#[cfg(test)] -pub mod test_arm64simd_leaky_relu_f32_8n { - use super::*; - leaky_relu_frame_tests!(true, f32, arm64simd_leaky_relu_f32_8n); -} - #[cfg(test)] pub mod test_arm64simd_leaky_relu_f16_16n { use super::*; leaky_relu_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_leaky_relu_f16_16n); } + diff --git a/linalg/src/arm64/arm64simd.rs b/linalg/src/arm64/arm64simd.rs index 5ac30d731a..fd7694f235 100644 --- a/linalg/src/arm64/arm64simd.rs +++ b/linalg/src/arm64/arm64simd.rs @@ -1,6 +1,10 @@ +mod leaky_relu; +mod max; + use crate::frame::mmm::*; -#[cfg(not(feature = "no_fp16"))] -use tract_data::half::f16; + +pub use leaky_relu::arm64simd_leaky_relu_f32_8n; +pub use max::arm64simd_max_f32_16n; MMMKernel!(f32, arm64simd_mmm_f32_8x8_a55; 8, 8; 16, 16; 1, 1; no_prefetch, true); MMMKernel!(f32, arm64simd_mmm_f32_12x8_a55; 12, 8; 16, 16; 1, 1; no_prefetch, true); @@ -23,23 +27,6 @@ MMMKernel!(f32, arm64simd_mmm_f32_64x1_gen; 64, 1; 16, 16; 1, 1; no_prefetch, tr MMMKernel!(i32, arm64simd_mmm_i32_8x8; 8, 8; 16, 16; 0,0; no_prefetch, true); MMMKernel!(i32, arm64simd_mmm_i32_64x1; 64, 1; 16, 1; 0,0; no_prefetch, true); -#[cfg(not(feature = "no_fp16"))] -MMMKernel!(f16, arm64fp16_mmm_f16_16x8_gen; 16, 8; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); -#[cfg(not(feature = "no_fp16"))] -MMMKernel!(f16, arm64fp16_mmm_f16_16x8_a55; 16, 8; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); -#[cfg(not(feature = "no_fp16"))] -MMMKernel!(f16, arm64fp16_mmm_f16_32x4_gen; 32, 4; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); -#[cfg(not(feature = "no_fp16"))] -MMMKernel!(f16, arm64fp16_mmm_f16_32x4_a55; 32, 4; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); -#[cfg(not(feature = "no_fp16"))] -MMMKernel!(f16, arm64fp16_mmm_f16_128x1_gen; 128, 1; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); -#[cfg(not(feature = "no_fp16"))] -MMMKernel!(f16, arm64fp16_mmm_f16_128x1_a55; 128, 1; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); - tanh_impl!(f32, arm64simd_tanh_f32_4n, 4, 4, true); sigmoid_impl!(f32, arm64simd_sigmoid_f32_4n, 4, 4, true); -#[cfg(not(feature = "no_fp16"))] -tanh_impl!(f16, arm64fp16_tanh_f16_8n, 8, 8, crate::arm64::has_fp16()); -#[cfg(not(feature = "no_fp16"))] -sigmoid_impl!(f16, arm64fp16_sigmoid_f16_8n, 8, 8, crate::arm64::has_fp16()); diff --git a/linalg/src/arm64/arm64simd/leaky_relu.rs b/linalg/src/arm64/arm64simd/leaky_relu.rs new file mode 100644 index 0000000000..83901c87a6 --- /dev/null +++ b/linalg/src/arm64/arm64simd/leaky_relu.rs @@ -0,0 +1,50 @@ +ew_impl_wrap!( + f32, + arm64simd_leaky_relu_f32_8n, + 8, + 4, + f32, + #[inline(never)] + fn run(buf: &mut [f32], alpha: f32) { + assert!(buf.len() % 8 == 0); + assert!(buf.len() > 0); + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + dup v0.4s, {alpha:v}.s[0] + dup v1.4s, {one:v}.s[0] + 1: + ldp q3, q4, [{ptr}] + + fcmgt v5.4s, v3.4s, #0.0 + fcmgt v6.4s, v4.4s, #0.0 + bsl v5.16b, v1.16b, v0.16b + bsl v6.16b, v1.16b, v0.16b + fmul v3.4s, v3.4s, v5.4s + fmul v4.4s, v4.4s, v6.4s + + stp q3, q4, [{ptr}], #32 + subs {len}, {len}, 8 + bne 1b + ", + one = in(vreg) 1.0f32, + alpha = in(vreg) alpha, + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + out("v0") _, + out("v1") _, + out("q3") _, + out("q4") _, + out("q5") _, + out("q6") _, + ); + } + } +); + +#[cfg(test)] +pub mod test_arm64simd_leaky_relu_f32_8n { + use super::*; + leaky_relu_frame_tests!(true, f32, arm64simd_leaky_relu_f32_8n); +} diff --git a/linalg/src/arm64/max.rs b/linalg/src/arm64/arm64simd/max.rs similarity index 100% rename from linalg/src/arm64/max.rs rename to linalg/src/arm64/arm64simd/max.rs diff --git a/linalg/src/frame/mmm.rs b/linalg/src/frame/mmm.rs index 7060ba4d1b..df3bf18e30 100644 --- a/linalg/src/frame/mmm.rs +++ b/linalg/src/frame/mmm.rs @@ -37,7 +37,7 @@ macro_rules! MMMKernel { #[derive(Copy, Clone, Debug, new)] pub struct $func; - impl MatMatMulKer<$ti> for $func { + impl $crate::frame::mmm::MatMatMulKer<$ti> for $func { #[inline(always)] fn name() -> &'static str { stringify!($func) @@ -67,9 +67,9 @@ macro_rules! MMMKernel { $end_padding_packed_b } #[inline(always)] - fn kernel(spec: &[FusedKerSpec<$ti>]) -> isize { + fn kernel(spec: &[$crate::frame::mmm::FusedKerSpec<$ti>]) -> isize { debug_assert!(spec.len() > 0); - debug_assert!(matches!(spec[spec.len() - 1], FusedKerSpec::Done)); + debug_assert!(matches!(spec[spec.len() - 1], $crate::frame::mmm::FusedKerSpec::Done)); unsafe { []::$func(spec.as_ptr()) } } #[inline(always)] From 43e6f66d9fd7bfbfbae0232d8ea1f8c664c2565f Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 18 Jan 2024 10:59:00 +0100 Subject: [PATCH 10/26] need avx2 --- linalg/src/x86_64_fma/max.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linalg/src/x86_64_fma/max.rs b/linalg/src/x86_64_fma/max.rs index 69f89d6436..b273427d75 100644 --- a/linalg/src/x86_64_fma/max.rs +++ b/linalg/src/x86_64_fma/max.rs @@ -62,5 +62,5 @@ unsafe fn x86_64_fma_max_f32_32n_run(buf: &[f32]) -> f32 { #[cfg(test)] mod test_x86_64_fma_max_f32_32n { use super::*; - max_frame_tests!(true, f32, x86_64_fma_max_f32_32n); + max_frame_tests!(is_x86_feature_detected!("avx2"), f32, x86_64_fma_max_f32_32n); } From 462ccd000094a1507d370417923211f4dbef6e11 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 18 Jan 2024 17:37:00 +0100 Subject: [PATCH 11/26] softmax loop2 scaffholding --- linalg/benches/softmax.rs | 8 +++- linalg/src/frame.rs | 3 ++ linalg/src/frame/element_wise_helper.rs | 3 +- linalg/src/frame/reduce.rs | 14 +++--- linalg/src/frame/softmax.rs | 7 +++ linalg/src/generic.rs | 2 + linalg/src/generic/softmax.rs | 58 +++++++++++++++++++++++-- linalg/src/lib.rs | 8 +++- 8 files changed, 91 insertions(+), 12 deletions(-) diff --git a/linalg/benches/softmax.rs b/linalg/benches/softmax.rs index 0b87b37eb5..c1d735170d 100644 --- a/linalg/benches/softmax.rs +++ b/linalg/benches/softmax.rs @@ -1,5 +1,6 @@ use criterion::*; use tract_data::prelude::*; +use tract_linalg::element_wise::ElementWiseKer; use tract_linalg::frame::reduce::{MapReduceKer, ReduceKer}; #[inline(never)] @@ -64,7 +65,12 @@ fn softmax_f32(c: &mut Criterion) { group.bench_function("loop2/generic", |b| { b.iter(|| tract_linalg::generic::softmax::SSoftMaxL2::red().run_with_params(input, 10.)) }); - group.bench_function("rust_loop3", |b| b.iter(|| loop3_f32(input, 0.21))); + group.bench_function("loop3/naive", |b| b.iter(|| loop3_f32(input, 0.21))); + group.bench_function("loop3/generic", |b| { + b.iter(|| { + tract_linalg::generic::by_scalar::SMulByScalar4::ew().run_with_params(input, 0.21) + }) + }); } criterion_group!(benches, softmax_f32); diff --git a/linalg/src/frame.rs b/linalg/src/frame.rs index 837e35351a..9810835e0d 100644 --- a/linalg/src/frame.rs +++ b/linalg/src/frame.rs @@ -1,5 +1,8 @@ #[macro_use] pub mod element_wise; + +#[macro_use] +pub mod by_scalar; #[macro_use] pub mod lut; #[macro_use] diff --git a/linalg/src/frame/element_wise_helper.rs b/linalg/src/frame/element_wise_helper.rs index de2223cbf8..f4b3084642 100644 --- a/linalg/src/frame/element_wise_helper.rs +++ b/linalg/src/frame/element_wise_helper.rs @@ -87,6 +87,7 @@ pub(crate) fn map_reduce_slice_with_alignment( f: impl Fn(&mut [T]) -> T, nr: usize, alignment_bytes: usize, + map_neutral: T, neutral: T, reduce: impl Fn(T, T) -> T, ) -> TractResult @@ -104,7 +105,7 @@ where let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr); let mut compute_via_temp_buffer = |slice: &mut [T], red: &mut T| { tmp[..slice.len()].copy_from_slice(slice); - tmp[slice.len()..].fill(neutral); + tmp[slice.len()..].fill(map_neutral); *red = reduce(*red, f(tmp)); slice.copy_from_slice(&tmp[..slice.len()]); }; diff --git a/linalg/src/frame/reduce.rs b/linalg/src/frame/reduce.rs index ab97626ce1..daf8dfe82c 100644 --- a/linalg/src/frame/reduce.rs +++ b/linalg/src/frame/reduce.rs @@ -171,6 +171,7 @@ where |data| K::run(data, params), K::nr(), K::alignment_bytes(), + K::map_neutral(), K::neutral(), K::reduce_two, ) @@ -187,6 +188,7 @@ where fn alignment_bytes() -> usize; fn alignment_items() -> usize; fn nr() -> usize; + fn map_neutral() -> T; fn neutral() -> T; fn reduce_two(a: T, b: T) -> T; fn run(vec: &mut [T], params: Params) -> T; @@ -232,16 +234,18 @@ pub mod test { pub fn test_map_reduce, T: LADatum>( values: &[T], + map_neutral: T, neutral: T, reference_map: impl Fn(T) -> T, reference_reduce: impl Fn(T, T) -> T, ) -> TestCaseResult { - test_map_reduce_params::(values, neutral, reference_map, reference_reduce, ()) + test_map_reduce_params::(values, map_neutral, neutral, reference_map, reference_reduce, ()) } pub fn test_map_reduce_params, T: LADatum, Params>( values: &[T], - neutral: T, + _neutral: T, + map_neutral: T, reference_map: impl Fn(T) -> T, reference_reducer: impl Fn(T, T) -> T, params: Params, @@ -253,13 +257,13 @@ pub mod test { let op = K::red(); let mut found = values.to_vec(); let expected_values = values.iter().copied().map(reference_map).collect_vec(); - let expected_reduced = expected_values.iter().fold(neutral, |acc, i| reference_reducer(acc, *i)); + let expected_reduced = expected_values.iter().fold(map_neutral, |acc, i| reference_reducer(acc, *i)); let red = op.run_with_params(&mut found, params).unwrap(); tensor1(&found) - .close_enough(&tensor1(&expected_values), true) + .close_enough(&tensor1(&expected_values), Approximation::SuperApproximate) .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?; tensor0(red) - .close_enough(&tensor0(expected_reduced), true) + .close_enough(&tensor0(expected_reduced), Approximation::SuperApproximate) .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?; Ok(()) } diff --git a/linalg/src/frame/softmax.rs b/linalg/src/frame/softmax.rs index b14d882e86..ebe5e1fe57 100644 --- a/linalg/src/frame/softmax.rs +++ b/linalg/src/frame/softmax.rs @@ -23,6 +23,12 @@ pub mod test { $crate::frame::softmax::test::test_softmax_l2::<$ker, $t>(&[0.0]).unwrap() } } + #[test] + fn two() { + if $cond { + $crate::frame::softmax::test::test_softmax_l2::<$ker, $t>(&[16.62555, 21.950674]).unwrap() + } + } }; } @@ -39,6 +45,7 @@ pub mod test { crate::frame::reduce::test::test_map_reduce_params::( &values, ::min_value(), + T::zero(), |x| (x - max.as_()).exp(), |a, b| a + b, max.as_(), diff --git a/linalg/src/generic.rs b/linalg/src/generic.rs index 36aa5d8559..ba28449b47 100644 --- a/linalg/src/generic.rs +++ b/linalg/src/generic.rs @@ -1,3 +1,4 @@ +pub mod by_scalar; pub mod erf; pub mod leaky_relu; pub mod lut; @@ -8,6 +9,7 @@ pub mod sigmoid; pub mod softmax; pub mod tanh; +pub use self::by_scalar::SMulByScalar4; pub use self::erf::SErf4; pub use self::leaky_relu::{HLeakyRelu8, SLeakyRelu4}; pub use self::lut::GenericLut8; diff --git a/linalg/src/generic/softmax.rs b/linalg/src/generic/softmax.rs index 0311aa6a60..e83b2e43d5 100644 --- a/linalg/src/generic/softmax.rs +++ b/linalg/src/generic/softmax.rs @@ -20,10 +20,14 @@ impl MapReduceKer for SSoftMaxL2 { 4 } - fn neutral() -> f32 { + fn map_neutral() -> f32 { f32::MIN } + fn neutral() -> f32 { + 0. + } + fn reduce_two(a: f32, b: f32) -> f32 { a + b } @@ -32,14 +36,60 @@ impl MapReduceKer for SSoftMaxL2 { debug_assert!(x.len() % Self::nr() == 0); debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); let mut sum = 0.; - for x in x.iter_mut() { - *x = (*x - max).exp(); - sum += *x; + for v in x.iter_mut() { + let y = *v - max; + let y = y.exp(); + *v = y; + sum += y; } sum } } +// ported from https://github.com/gnuradio/volk/blob/master/kernels/volk/volk_32f_expfast_32f.h +#[inline] +#[allow(dead_code)] +fn very_fast_exp(v: f32) -> f32 { + const MLN2: f32 = 0.6931471805f32; + const A: f32 = 8388608.0f32; + const B: f32 = 1065353216.0f32; + const C: f32 = 60801.0f32; + const SLOPE: f32 = A / MLN2; + const OFFSET: f32 = B - C; + f32::from_bits(((SLOPE * v) + OFFSET) as u32) +} + +#[inline] +#[allow(dead_code)] +fn exp2_p7(v: f32) -> f32 { + const EXP2P: [f32; 7] = [ + 1.535336188319500e-4, + 1.339887440266574e-3, + 9.618437357674640e-3, + 5.550332471162809e-2, + 2.402264791363012e-1, + 6.931472028550421e-1, + 1.000000000000000, + ]; + let v = v.min(127f32).max(-127f32); + + let ipart = (v + 0.5).floor(); + let fpart = v - ipart; + + // 2^ipart + let two_pow_ipart = f32::from_bits((((ipart as i32) + 127) as u32) << 23); + + let mut v = EXP2P[0]; + v = v * fpart + EXP2P[1]; + v = v * fpart + EXP2P[2]; + v = v * fpart + EXP2P[3]; + v = v * fpart + EXP2P[4]; + v = v * fpart + EXP2P[5]; + v = v * fpart + EXP2P[6]; + v = v * two_pow_ipart; + v +} + #[cfg(test)] #[macro_use] pub mod s { diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index d89bb7cd2d..b783a82b75 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -17,7 +17,7 @@ include!(concat!(env!("OUT_DIR"), "/extern_kernel_macro.rs")); pub mod frame; pub mod generic; use frame::element_wise::ElementWiseKer; -use frame::reduce::ReduceKer; +use frame::reduce::{MapReduceKer, ReduceKer}; use frame::{reduce, MatMatMul}; pub use generic::{ScaleShiftAndRound, Scaler}; #[cfg(target_arch = "x86_64")] @@ -57,6 +57,8 @@ pub struct Ops { pub leaky_relu_f16: Box Box> + Send + Sync>, pub leaky_relu_f32: Box Box> + Send + Sync>, + pub mul_by_scalar_f32: + Box Box> + Send + Sync>, pub sigmoid_f16: Box Box> + Send + Sync>, pub sigmoid_f32: Box Box> + Send + Sync>, @@ -66,6 +68,8 @@ pub struct Ops { pub lut_u8: Box Box + Send + Sync>, pub max_f32: Box Box> + Send + Sync>, + + pub softmax_loop2_f32: Box Box> + Send + Sync>, } impl Ops { @@ -117,6 +121,7 @@ pub fn generic() -> Ops { qmmv_i32: Box::new(|_, _| generic::GenericMmm4x1::::mmm()), leaky_relu_f16: Box::new(|| generic::HLeakyRelu8::ew()), leaky_relu_f32: Box::new(|| generic::SLeakyRelu4::ew()), + mul_by_scalar_f32: Box::new(|| generic::SMulByScalar4::ew()), sigmoid_f16: Box::new(|| generic::HSigmoid8::ew()), sigmoid_f32: Box::new(|| generic::SSigmoid4::ew()), tanh_f16: Box::new(|| generic::HTanh8::ew()), @@ -127,6 +132,7 @@ pub fn generic() -> Ops { /* activation_f32: Box::new(|microcode| generic::SActivation::new(microcode)) */ + softmax_loop2_f32: Box::new(|| generic::softmax::SSoftMaxL2::red()), } } From bc8ce24e692be7f81f61c1b9996aa169500b7216 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 18 Jan 2024 17:54:52 +0100 Subject: [PATCH 12/26] forgotten failes --- linalg/src/frame/by_scalar.rs | 39 +++++++++++++++++++++++++++++++++ linalg/src/generic/by_scalar.rs | 35 +++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 linalg/src/frame/by_scalar.rs create mode 100644 linalg/src/generic/by_scalar.rs diff --git a/linalg/src/frame/by_scalar.rs b/linalg/src/frame/by_scalar.rs new file mode 100644 index 0000000000..528c3c3585 --- /dev/null +++ b/linalg/src/frame/by_scalar.rs @@ -0,0 +1,39 @@ +#[cfg(test)] +#[macro_use] +pub mod test { + use crate::frame::element_wise::ElementWiseKer; + use crate::LADatum; + use num_traits::{AsPrimitive, Float}; + use proptest::test_runner::TestCaseResult; + + #[macro_export] + macro_rules! mul_by_scalar_frame_tests { + ($cond:expr, $t: ty, $ker:ty) => { + proptest::proptest! { + #[test] + fn prop(xs in proptest::collection::vec(-25f32..25.0, 0..100), scalar in -25f32..25f32) { + if $cond { + $crate::frame::by_scalar::test::test_mul_by_scalar::<$ker, $t>(&*xs, scalar).unwrap() + } + } + } + }; + } + + pub fn test_mul_by_scalar, T: LADatum + Float>( + values: &[f32], + scalar: f32, + ) -> TestCaseResult + where + f32: AsPrimitive, + T: AsPrimitive, + { + crate::setup_test_logger(); + let values: Vec = values.iter().copied().map(|x| x.as_()).collect(); + crate::frame::element_wise::test::test_element_wise_params::( + &values, + |a| a * scalar.as_(), + scalar.as_(), + ) + } +} diff --git a/linalg/src/generic/by_scalar.rs b/linalg/src/generic/by_scalar.rs new file mode 100644 index 0000000000..d19fbc5a4e --- /dev/null +++ b/linalg/src/generic/by_scalar.rs @@ -0,0 +1,35 @@ +use crate::element_wise::ElementWiseKer; + + +#[derive(Clone, Debug)] +pub struct SMulByScalar4; + +impl ElementWiseKer for SMulByScalar4 { + fn name() -> &'static str { + "generic" + } + + fn alignment_items() -> usize { + 16 + } + + fn alignment_bytes() -> usize { + 16 + } + + fn nr() -> usize { + 4 + } + + fn run(x: &mut [f32], s: f32) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + x.iter_mut().for_each(|px| *px *= s) + } +} + +#[cfg(test)] +#[macro_use] +pub mod s { + mul_by_scalar_frame_tests!(true, f32, crate::generic::by_scalar::SMulByScalar4); +} From f1930c004a6674dec14b76507721ee2b89dbfab4 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 18 Jan 2024 17:55:14 +0100 Subject: [PATCH 13/26] also forgotten --- core/src/ops/nn/softmax/mod.rs | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/core/src/ops/nn/softmax/mod.rs b/core/src/ops/nn/softmax/mod.rs index c33d070de1..34e2724baf 100644 --- a/core/src/ops/nn/softmax/mod.rs +++ b/core/src/ops/nn/softmax/mod.rs @@ -131,7 +131,14 @@ impl Softmax { view.collapse_axis(Axis(ix), it_coords[ix]); } } - softmax_inner(view); + if let Some(slice) = + view.as_slice_mut().filter(|_| T::datum_type() == f32::datum_type()) + { + let slice: &mut [f32] = unsafe { std::mem::transmute(slice) }; + softmax_inner_slice_f32(slice)?; + } else { + softmax_inner(view); + } } Ok(tvec!(output.into_tvalue())) @@ -171,6 +178,31 @@ impl Softmax { } } +fn softmax_inner_slice_f32(slice: &mut [f32]) -> TractResult<()> { + eprintln!("input: {slice:?}"); + /* + let max = + *slice.iter().max_by(|i, j| i.partial_cmp(j).unwrap_or(std::cmp::Ordering::Less)).unwrap(); + */ + let max = (tract_linalg::ops().max_f32)().run(slice)?; + eprintln!("max: {max}"); + /* + slice.iter_mut().for_each(|x| *x = (*x - max).exp()); + let exp_sum = slice.iter().copied().sum::(); + slice.iter_mut().for_each(|x| *x = *x / exp_sum); + */ + /* + let max = (tract_linalg::ops().max_f32)().run(slice)?; + */ + let sum = (tract_linalg::ops().softmax_loop2_f32)().run_with_params(slice, max)?; + eprintln!("loop 2: {slice:?}"); + eprintln!("sum: {sum:?}"); + let rsum = sum.recip(); + let rsum = slice.iter().copied().sum::().recip(); + (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?; + Ok(()) +} + fn softmax_inner(mut view: ArrayViewMut) { let max = *view.iter().max_by(|i, j| i.partial_cmp(j).unwrap_or(std::cmp::Ordering::Less)).unwrap(); From 040655b4c7271731aaf935e3a459c3da5597f831 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Fri, 19 Jan 2024 13:35:15 +0100 Subject: [PATCH 14/26] play with exp impls --- Cargo.toml | 1 + linalg/Cargo.toml | 1 + linalg/src/generic/softmax.rs | 76 ++++++++++++++++++++++++++++++++++- 3 files changed, 77 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index fcc4340bc5..7ba5089bc4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,6 +79,7 @@ flate2 = "1.0.20" fs2 = "0.4.3" getrandom = "0.2" half = { version="2.2.1", features = [ "std", "num-traits" ] } +hexf = "0.2.1" image = "0.24.1" itertools = "0.10.1" home = "0.5.5" diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index 7f7b3d0f7e..d5c04f0f5e 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -17,6 +17,7 @@ maintenance = { status = "actively-developed" } derive-new.workspace = true downcast-rs.workspace = true dyn-clone.workspace = true +hexf.workspace = true lazy_static.workspace = true log.workspace = true num-traits.workspace = true diff --git a/linalg/src/generic/softmax.rs b/linalg/src/generic/softmax.rs index e83b2e43d5..33e1807f13 100644 --- a/linalg/src/generic/softmax.rs +++ b/linalg/src/generic/softmax.rs @@ -1,3 +1,5 @@ +use tract_data::internal::{tensor0, Approximation}; + use crate::frame::reduce::MapReduceKer; #[derive(Clone, Debug)] @@ -38,7 +40,8 @@ impl MapReduceKer for SSoftMaxL2 { let mut sum = 0.; for v in x.iter_mut() { let y = *v - max; - let y = y.exp(); + // let y = y.exp(); + let y = xnnpack_loop2_exp(y); *v = y; sum += y; } @@ -46,6 +49,77 @@ impl MapReduceKer for SSoftMaxL2 { } } +// https://github.com/google/XNNPACK/blob/3bc4ef01bbdf488556c54584fc2419dd77c39c85/src/f32-raddstoreexpminusmax/scalar-rr2-p5.c.in#L131 +// https://github.com/google/XNNPACK/blob/8951decff5114f70bae7cc2e23b732812e73acc7/src/microparams-init.c#L4121 +#[inline] +#[allow(dead_code, non_upper_case_globals)] +fn xnnpack_loop2_exp(vx: f32) -> f32 { + debug_assert!(vx <= 0f32); + const log2e: f32 = hexf::hexf32!("0x1.715476p+0"); + const magic_bias: f32 = hexf::hexf32!("0x1.8000FEp23"); + const minus_ln2_hi: f32 = hexf::hexf32!("-0x1.62E400p-1"); + const minus_ln2_lo: f32 = hexf::hexf32!("-0x1.7F7D1Cp-20"); + const c5: f32 = hexf::hexf32!("0x1.0F9F9Cp-7"); + const c4: f32 = hexf::hexf32!("0x1.573A1Ap-5"); + const c3: f32 = hexf::hexf32!("0x1.555A80p-3"); + const c2: f32 = hexf::hexf32!("0x1.FFFDC6p-2"); + const c1: f32 = hexf::hexf32!("0x1.FFFFF6p-1"); + const denorm_cutoff: f32 = hexf::hexf32!("-0x1.5D589Ep6"); + + // Compute reduced argument n := round(x / log(2)). + // We do it by adding a large number (magic bias) to the product x * (1/log(2)), which cause rounding of the result + // to an integer, then subtracing the large number back. The trick with adding large number is valid only within + // certain bounds (|x| <= 2**22), but that's ok, because inputs outside of [-87.336540, 0.0] underflow expf(x) + // anyway. We fixup the result for such inputs at the very end of the algorithm. + // float vn = vx * vlog2e + vmagic_bias; + let mut vn: f32 = vx * log2e + magic_bias; + + // Create a floating-point number s (scale) such that s == 2**n for inputs which don't cause underflow, i.e. + // -87.33642 <= x <= 0.0, and -126 <= n <= 0 accordingly. + // const float vs = uint32_as_float(float_as_uint32(vn) << 23); + let vs = f32::from_bits(vn.to_bits() << 23); + + // Subtract the large number back to get final n := round(x / log(2)). + // vn -= vmagic_bias; + vn -= magic_bias; + + // Compute reduced argument t := x - n * log(2). + // Use Cody-Waite range reduction method (note two constants to represent log(2)) to improve accuracy. + // float vt = vn * vminus_ln2_hi + vx; + let mut vt = vn * minus_ln2_hi + vx; + //vt = vn * vminus_ln2_lo + vt; + vt = vn * minus_ln2_lo + vt; + + // Compute degree-5 polynomial approximation for exp(t) on [-log(2)/2, log(2)/2]. + let mut vp = c5 * vt + c4; + vp = vp * vt + c2; + vp = vp * vt + c2; + vp = vp * vt + c1; + + // Reconstruct the final f value: + // f = s * (1 + t * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5))))) + // = s + (t * s) * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5)))) + // = s + (t * s) * p + vt *= vs; + // float vf = vt * vp + vs; + let mut vf = vt * vp + vs; + + // For inputs below denormal cutoff, replace output with +0.0f. + // Note that for NaN inputs, comparison result is false, and outputs are left unchanged. + if vx < denorm_cutoff { + vf = 0.0; + } + vf +} + +#[cfg(test)] +proptest::proptest! { + #[test] + fn t_xnnpack(x in -100f32..0.) { + tensor0(xnnpack_loop2_exp(x)).close_enough(&tensor0(x.exp()), Approximation::Approximate).unwrap(); + } +} + // ported from https://github.com/gnuradio/volk/blob/master/kernels/volk/volk_32f_expfast_32f.h #[inline] #[allow(dead_code)] From 2df51b2fd39dd54b875be8591f1c9f10d658e4be Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 22 Jan 2024 09:17:18 +0100 Subject: [PATCH 15/26] clarify --- linalg/src/frame/reduce.rs | 4 ++-- linalg/src/generic/softmax.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/linalg/src/frame/reduce.rs b/linalg/src/frame/reduce.rs index daf8dfe82c..0c6235261e 100644 --- a/linalg/src/frame/reduce.rs +++ b/linalg/src/frame/reduce.rs @@ -172,7 +172,7 @@ where K::nr(), K::alignment_bytes(), K::map_neutral(), - K::neutral(), + K::reduce_neutral(), K::reduce_two, ) } @@ -189,7 +189,7 @@ where fn alignment_items() -> usize; fn nr() -> usize; fn map_neutral() -> T; - fn neutral() -> T; + fn reduce_neutral() -> T; fn reduce_two(a: T, b: T) -> T; fn run(vec: &mut [T], params: Params) -> T; fn red() -> Box> { diff --git a/linalg/src/generic/softmax.rs b/linalg/src/generic/softmax.rs index 33e1807f13..05ec38e41b 100644 --- a/linalg/src/generic/softmax.rs +++ b/linalg/src/generic/softmax.rs @@ -26,7 +26,7 @@ impl MapReduceKer for SSoftMaxL2 { f32::MIN } - fn neutral() -> f32 { + fn reduce_neutral() -> f32 { 0. } From 90ce496f0e73fb697966c9530da6c36118657df2 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 23 Jan 2024 08:59:16 +0100 Subject: [PATCH 16/26] plug softmax loop2 on intel --- core/src/ops/nn/softmax/mod.rs | 17 ---- linalg/src/frame/reduce.rs | 39 ++++++++ linalg/src/frame/softmax.rs | 18 +++- linalg/src/generic/softmax.rs | 161 ++++++++++++++++++++++++++++++++- linalg/src/x86_64_fma.rs | 5 + 5 files changed, 216 insertions(+), 24 deletions(-) diff --git a/core/src/ops/nn/softmax/mod.rs b/core/src/ops/nn/softmax/mod.rs index 34e2724baf..d3ecc21707 100644 --- a/core/src/ops/nn/softmax/mod.rs +++ b/core/src/ops/nn/softmax/mod.rs @@ -179,26 +179,9 @@ impl Softmax { } fn softmax_inner_slice_f32(slice: &mut [f32]) -> TractResult<()> { - eprintln!("input: {slice:?}"); - /* - let max = - *slice.iter().max_by(|i, j| i.partial_cmp(j).unwrap_or(std::cmp::Ordering::Less)).unwrap(); - */ - let max = (tract_linalg::ops().max_f32)().run(slice)?; - eprintln!("max: {max}"); - /* - slice.iter_mut().for_each(|x| *x = (*x - max).exp()); - let exp_sum = slice.iter().copied().sum::(); - slice.iter_mut().for_each(|x| *x = *x / exp_sum); - */ - /* let max = (tract_linalg::ops().max_f32)().run(slice)?; - */ let sum = (tract_linalg::ops().softmax_loop2_f32)().run_with_params(slice, max)?; - eprintln!("loop 2: {slice:?}"); - eprintln!("sum: {sum:?}"); let rsum = sum.recip(); - let rsum = slice.iter().copied().sum::().recip(); (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?; Ok(()) } diff --git a/linalg/src/frame/reduce.rs b/linalg/src/frame/reduce.rs index 0c6235261e..edd2369963 100644 --- a/linalg/src/frame/reduce.rs +++ b/linalg/src/frame/reduce.rs @@ -136,6 +136,45 @@ where } } +macro_rules! map_reduce_impl_wrap { + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $map_neutral: expr, $reduce_neutral: expr, $run: item, $reduce_two: item) => { + paste! { + #[derive(Copy, Clone, Debug)] + #[allow(non_camel_case_types)] + pub struct $func; + + impl crate::frame::reduce::MapReduceKer<$ti, $params> for $func { + #[inline(always)] + fn name() -> &'static str { + stringify!($func) + } + #[inline(always)] + fn nr() -> usize { + $nr + } + #[inline(always)] + fn alignment_items() -> usize { + $alignment_items + } + #[inline(always)] + fn alignment_bytes() -> usize { + $alignment_items * std::mem::size_of::<$ti>() + } + #[inline(always)] + fn map_neutral() -> $ti { + $map_neutral + } + #[inline(always)] + fn reduce_neutral() -> $ti { + $reduce_neutral + } + $run + $reduce_two + } + } + }; +} + pub trait MapReduce: Send + Sync + Debug + dyn_clone::DynClone where Params: Copy + Send + Sync + Debug + 'static + Default, diff --git a/linalg/src/frame/softmax.rs b/linalg/src/frame/softmax.rs index ebe5e1fe57..d4a8166c0b 100644 --- a/linalg/src/frame/softmax.rs +++ b/linalg/src/frame/softmax.rs @@ -26,7 +26,19 @@ pub mod test { #[test] fn two() { if $cond { - $crate::frame::softmax::test::test_softmax_l2::<$ker, $t>(&[16.62555, 21.950674]).unwrap() + $crate::frame::softmax::test::test_softmax_l2::<$ker, $t>(&[ + 16.62555, 21.950674, + ]) + .unwrap() + } + } + #[test] + fn two_weird() { + if $cond { + $crate::frame::softmax::test::test_softmax_l2::<$ker, $t>(&[ + -46.15512, 42.875168 + ]) + .unwrap() } } }; @@ -39,6 +51,7 @@ pub mod test { f32: AsPrimitive, T: AsPrimitive, { + use crate::generic::softmax::very_fast_exp; crate::setup_test_logger(); let max = values.iter().max_by(|a, b| a.total_cmp(b)).unwrap(); let values: Vec = values.iter().copied().map(|x| x.as_()).collect(); @@ -46,7 +59,8 @@ pub mod test { &values, ::min_value(), T::zero(), - |x| (x - max.as_()).exp(), + // |x| (x - max.as_()).exp(), + |x| very_fast_exp(dbg!(x).as_() - max).as_(), |a, b| a + b, max.as_(), ) diff --git a/linalg/src/generic/softmax.rs b/linalg/src/generic/softmax.rs index 05ec38e41b..632d4f8bbe 100644 --- a/linalg/src/generic/softmax.rs +++ b/linalg/src/generic/softmax.rs @@ -1,5 +1,3 @@ -use tract_data::internal::{tensor0, Approximation}; - use crate::frame::reduce::MapReduceKer; #[derive(Clone, Debug)] @@ -41,7 +39,9 @@ impl MapReduceKer for SSoftMaxL2 { for v in x.iter_mut() { let y = *v - max; // let y = y.exp(); - let y = xnnpack_loop2_exp(y); + // let y = xnnpack_loop2_exp(y); + // let y = expf(y); + let y = very_fast_exp(y); *v = y; sum += y; } @@ -116,14 +116,15 @@ fn xnnpack_loop2_exp(vx: f32) -> f32 { proptest::proptest! { #[test] fn t_xnnpack(x in -100f32..0.) { + use tract_data::internal::{tensor0, Approximation}; tensor0(xnnpack_loop2_exp(x)).close_enough(&tensor0(x.exp()), Approximation::Approximate).unwrap(); } } // ported from https://github.com/gnuradio/volk/blob/master/kernels/volk/volk_32f_expfast_32f.h -#[inline] +#[inline(never)] #[allow(dead_code)] -fn very_fast_exp(v: f32) -> f32 { +pub fn very_fast_exp(v: f32) -> f32 { const MLN2: f32 = 0.6931471805f32; const A: f32 = 8388608.0f32; const B: f32 = 1065353216.0f32; @@ -164,6 +165,156 @@ fn exp2_p7(v: f32) -> f32 { v } +#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] +pub fn scalbnf(mut x: f32, mut n: i32) -> f32 { + let x1p127 = f32::from_bits(0x7f000000); // 0x1p127f === 2 ^ 127 + let x1p_126 = f32::from_bits(0x800000); // 0x1p-126f === 2 ^ -126 + let x1p24 = f32::from_bits(0x4b800000); // 0x1p24f === 2 ^ 24 + + if n > 127 { + x *= x1p127; + n -= 127; + if n > 127 { + x *= x1p127; + n -= 127; + if n > 127 { + n = 127; + } + } + } else if n < -126 { + x *= x1p_126 * x1p24; + n += 126 - 24; + if n < -126 { + x *= x1p_126 * x1p24; + n += 126 - 24; + if n < -126 { + n = -126; + } + } + } + x * f32::from_bits(((0x7f + n) as u32) << 23) +} + +macro_rules! force_eval { + ($e:expr) => { + unsafe { ::core::ptr::read_volatile(&$e) } + }; +} + +macro_rules! i { + ($array:expr, $index:expr) => { + unsafe { *$array.get_unchecked($index) } + }; + ($array:expr, $index:expr, = , $rhs:expr) => { + unsafe { + *$array.get_unchecked_mut($index) = $rhs; + } + }; + ($array:expr, $index:expr, += , $rhs:expr) => { + unsafe { + *$array.get_unchecked_mut($index) += $rhs; + } + }; + ($array:expr, $index:expr, -= , $rhs:expr) => { + unsafe { + *$array.get_unchecked_mut($index) -= $rhs; + } + }; + ($array:expr, $index:expr, &= , $rhs:expr) => { + unsafe { + *$array.get_unchecked_mut($index) &= $rhs; + } + }; + ($array:expr, $index:expr, == , $rhs:expr) => { + unsafe { *$array.get_unchecked_mut($index) == $rhs } + }; +} + +/// Exponential, base *e* (f32) +/// +/// Calculate the exponential of `x`, that is, *e* raised to the power `x` +/// (where *e* is the base of the natural system of logarithms, approximately 2.71828). +#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] +pub fn expf(mut x: f32) -> f32 { + const HALF: [f32; 2] = [0.5, -0.5]; + const LN2_HI: f32 = 6.9314575195e-01; /* 0x3f317200 */ + const LN2_LO: f32 = 1.4286067653e-06; /* 0x35bfbe8e */ + const INV_LN2: f32 = 1.4426950216e+00; /* 0x3fb8aa3b */ + /* + * Domain [-0.34568, 0.34568], range ~[-4.278e-9, 4.447e-9]: + * |x*(exp(x)+1)/(exp(x)-1) - p(x)| < 2**-27.74 + */ + const P1: f32 = 1.6666625440e-1; /* 0xaaaa8f.0p-26 */ + const P2: f32 = -2.7667332906e-3; /* -0xb55215.0p-32 */ + + let x1p127 = f32::from_bits(0x7f000000); // 0x1p127f === 2 ^ 127 + let x1p_126 = f32::from_bits(0x800000); // 0x1p-126f === 2 ^ -126 /*original 0x1p-149f ??????????? */ + let mut hx = x.to_bits(); + let sign = (hx >> 31) as i32; /* sign bit of x */ + let signb: bool = sign != 0; + hx &= 0x7fffffff; /* high word of |x| */ + + /* special cases */ + if hx >= 0x42aeac50 { + /* if |x| >= -87.33655f or NaN */ + if hx > 0x7f800000 { + /* NaN */ + return x; + } + if (hx >= 0x42b17218) && (!signb) { + /* x >= 88.722839f */ + /* overflow */ + x *= x1p127; + return x; + } + if signb { + /* underflow */ + force_eval!(-x1p_126 / x); + if hx >= 0x42cff1b5 { + /* x <= -103.972084f */ + return 0.; + } + } + } + + /* argument reduction */ + let k: i32; + let hi: f32; + let lo: f32; + if hx > 0x3eb17218 { + /* if |x| > 0.5 ln2 */ + if hx > 0x3f851592 { + /* if |x| > 1.5 ln2 */ + k = (INV_LN2 * x + i!(HALF, sign as usize)) as i32; + } else { + k = 1 - sign - sign; + } + let kf = k as f32; + hi = x - kf * LN2_HI; /* k*ln2hi is exact here */ + lo = kf * LN2_LO; + x = hi - lo; + } else if hx > 0x39000000 { + /* |x| > 2**-14 */ + k = 0; + hi = x; + lo = 0.; + } else { + /* raise inexact */ + force_eval!(x1p127 + x); + return 1. + x; + } + + /* x is now in primary range */ + let xx = x * x; + let c = x - xx * (P1 + xx * P2); + let y = 1. + (x * c / (2. - c) - lo + hi); + if k == 0 { + y + } else { + scalbnf(y, k) + } +} + #[cfg(test)] #[macro_use] pub mod s { diff --git a/linalg/src/x86_64_fma.rs b/linalg/src/x86_64_fma.rs index 907d9b6ef0..99331bb5a5 100644 --- a/linalg/src/x86_64_fma.rs +++ b/linalg/src/x86_64_fma.rs @@ -1,11 +1,14 @@ use crate::frame::element_wise::ElementWiseKer; use crate::frame::mmm::kernel::MatMatMulKer; +use crate::frame::reduce::MapReduceKer; +use crate::x86_64_fma::softmax::x86_64_fma_softmaxl2_f32_32n; use crate::Ops; pub mod mmm; mod intel; pub mod max; +pub mod softmax; tanh_impl!(f32, fma_tanh_f32, 8, 8, is_x86_feature_detected!("fma")); sigmoid_impl!(f32, fma_sigmoid_f32, 8, 8, is_x86_feature_detected!("fma")); @@ -91,6 +94,8 @@ fn plug_fma(ops: &mut Ops) { ops.sigmoid_f32 = Box::new(|| fma_sigmoid_f32::ew()); ops.tanh_f32 = Box::new(|| fma_tanh_f32::ew()); + + ops.softmax_loop2_f32 = Box::new(|| x86_64_fma_softmaxl2_f32_32n::red()); log::info!("mmm_f32, mmv_f32, sigmoid_f32, tanh_f32: x86_64/fma activated"); } From 92e86b05d5976dab4f5a445ef5e3d294eb88444a Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 23 Jan 2024 09:05:39 +0100 Subject: [PATCH 17/26] forgot key file --- linalg/src/x86_64_fma/softmax.rs | 115 +++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 linalg/src/x86_64_fma/softmax.rs diff --git a/linalg/src/x86_64_fma/softmax.rs b/linalg/src/x86_64_fma/softmax.rs new file mode 100644 index 0000000000..5f1de841dd --- /dev/null +++ b/linalg/src/x86_64_fma/softmax.rs @@ -0,0 +1,115 @@ +map_reduce_impl_wrap!( + f32, + x86_64_fma_softmaxl2_f32_32n, + 32, + 8, + f32, + f32::MIN, + 0f32, + #[inline(never)] + fn run(buf: &mut [f32], max: f32) -> f32 { + assert!(buf.len() % 32 == 0); + assert!(buf.len() > 0); + unsafe { x86_64_fma_softmaxl2_f32_32n_run(buf, max) } + }, + #[inline(never)] + fn reduce_two(a: f32, b: f32) -> f32 { + a + b + } +); + +#[target_feature(enable = "avx,fma")] +unsafe fn x86_64_fma_softmaxl2_f32_32n_run(buf: &mut [f32], max: f32) -> f32 { + let len = buf.len(); + let ptr = buf.as_ptr(); + let mut acc = 0f32; + const MLN2: f32 = 0.6931471805f32; + const A: f32 = 8388608.0f32; + const B: f32 = 1065353216.0f32; + const C: f32 = 60801.0f32; + const SLOPE: f32 = A / MLN2; + const OFFSET: f32 = B - C; + std::arch::asm!(" + vbroadcastss ymm0, xmm0 + vmovaps ymm1, ymm0 + vmovaps ymm2, ymm0 + vmovaps ymm3, ymm0 + + vpxor ymm12, ymm12, ymm12 + vbroadcastss ymm13, xmm13 + vbroadcastss ymm14, xmm14 + vbroadcastss ymm15, xmm15 + 2: + vmovaps ymm4, [{ptr}] + vmovaps ymm5, [{ptr} + 32] + vmovaps ymm6, [{ptr} + 64] + vmovaps ymm7, [{ptr} + 96] + + vsubps ymm4, ymm4, ymm13 + vsubps ymm5, ymm5, ymm13 + vsubps ymm6, ymm6, ymm13 + vsubps ymm7, ymm7, ymm13 + + vmovaps ymm8, ymm15 + vmovaps ymm9, ymm15 + vmovaps ymm10, ymm15 + vmovaps ymm11, ymm15 + + vfmadd231ps ymm8, ymm4, ymm14 + vfmadd231ps ymm9, ymm5, ymm14 + vfmadd231ps ymm10, ymm6, ymm14 + vfmadd231ps ymm11, ymm7, ymm14 + + vmaxps ymm8, ymm8, ymm12 + vmaxps ymm9, ymm9, ymm12 + vmaxps ymm10, ymm10, ymm12 + vmaxps ymm11, ymm11, ymm12 + + vcvttps2dq ymm8, ymm8 + vcvttps2dq ymm9, ymm9 + vcvttps2dq ymm10, ymm10 + vcvttps2dq ymm11, ymm11 + + vmovaps [{ptr}] , ymm8 + vmovaps [{ptr} + 32], ymm9 + vmovaps [{ptr} + 64], ymm10 + vmovaps [{ptr} + 96], ymm11 + + vaddps ymm0, ymm0, ymm8 + vaddps ymm1, ymm1, ymm9 + vaddps ymm2, ymm2, ymm10 + vaddps ymm3, ymm3, ymm11 + + add {ptr}, 128 + sub {len}, 32 + jnz 2b + + vaddps ymm0, ymm0, ymm1 + vaddps ymm2, ymm2, ymm3 + vaddps ymm0, ymm0, ymm2 + vperm2f128 ymm1, ymm0, ymm0, 1 + vaddps xmm0, xmm0, xmm1 + vpermilps xmm1, xmm0, 2 + (3 << 2) + vaddps xmm0, xmm0, xmm1 + vpermilps xmm1, xmm0, 1 + vaddps xmm0, xmm0, xmm1 + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("ymm0") acc, + out("ymm1") _, out("ymm2") _, out("ymm3") _, + out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _, + out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _, + out("ymm12") _, + inout("ymm13") max => _, + inout("ymm14") SLOPE => _, + inout("ymm15") OFFSET => _, + ); + acc +} + +#[cfg(test)] +mod test_x86_64_fma_softmaxl2_f32_32n { + use super::*; + softmax_l2_frame_tests!(is_x86_feature_detected!("fma"), f32, x86_64_fma_softmaxl2_f32_32n); +} From 487aadccf4c9f325aeac3fb94fe2add16318b9a4 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 23 Jan 2024 09:56:21 +0100 Subject: [PATCH 18/26] better test name --- linalg/src/frame/softmax.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/linalg/src/frame/softmax.rs b/linalg/src/frame/softmax.rs index d4a8166c0b..51a18f1a9d 100644 --- a/linalg/src/frame/softmax.rs +++ b/linalg/src/frame/softmax.rs @@ -32,8 +32,9 @@ pub mod test { .unwrap() } } + #[test] - fn two_weird() { + fn two_missing_max() { if $cond { $crate::frame::softmax::test::test_softmax_l2::<$ker, $t>(&[ -46.15512, 42.875168 From 548dc25f033c7d555334eb186d73c450bb9c04a5 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 23 Jan 2024 11:36:06 +0100 Subject: [PATCH 19/26] softmax tweak and transform --- cli/src/main.rs | 1 + cli/src/params.rs | 14 ++++++++--- core/src/floats.rs | 13 ++++++++++ core/src/lib.rs | 1 + core/src/model/typed.rs | 6 +++++ core/src/ops/nn/mod.rs | 2 +- core/src/ops/nn/softmax/mod.rs | 42 ++++++++++++++++++++++++-------- core/src/transform.rs | 44 ++++++++++++++++++++++++++++++++++ hir/src/ops/nn/layer_max.rs | 15 ++++-------- hir/src/ops/nn/softmax.rs | 9 +++---- linalg/src/lib.rs | 4 ++-- linalg/src/x86_64_fma.rs | 2 +- nnef/src/ops/core.rs | 2 ++ nnef/src/ops/nnef/deser.rs | 4 ++-- nnef/src/ops/nnef/ser.rs | 4 ++++ tflite/src/ops/nn.rs | 2 +- 16 files changed, 131 insertions(+), 34 deletions(-) create mode 100644 core/src/transform.rs diff --git a/cli/src/main.rs b/cli/src/main.rs index 0b25fcf23b..6361758c89 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -127,6 +127,7 @@ fn main() -> tract_core::anyhow::Result<()> { .arg(Arg::new("f32-to-f16").long("f32-to-f16").alias("half-floats").long_help("Convert the decluttered network from f32 to f16")) .arg(arg!(--"f16-to-f32" "Convert the decluttered network from f16 to f32")) + .arg(Arg::new("transform").long("transform").multiple_occurrences(true).takes_value(true).help("Apply a built-in transformation to the model")) .arg(Arg::new("set").long("set").multiple_occurrences(true).takes_value(true) .long_help("Set a symbol to a concrete value after decluttering")) diff --git a/cli/src/params.rs b/cli/src/params.rs index 22f3d465ef..9bbb863200 100644 --- a/cli/src/params.rs +++ b/cli/src/params.rs @@ -617,13 +617,13 @@ impl Parameters { macro_rules! stage { ($name:expr, $from:ident -> $to:ident, $block:expr) => { if let Some(from) = $from.take() { - info!(concat!("Running '", $name, "'")); + info!("Running {:?}", $name); let mut last_model: Option> = if keep_last { Some(Box::new(from.as_ref().clone())) } else { None }; let block: &dyn Fn(_) -> TractResult<_> = &$block; let owned_model = Arc::try_unwrap(from).unwrap_or_else(|from| from.as_ref().clone()); - match block(owned_model).context(concat!("Error at stage ", $name)) { + match block(owned_model).with_context(|| format!("Error at stage {:?}", $name)) { Ok(it) => { $to = Some(Arc::new(it)); } @@ -637,7 +637,7 @@ impl Parameters { } } } - info_usage(concat!("after ", $name), probe); + info_usage(&format!("after {:?}", $name), probe); if reference_stage.as_deref() == Some($name) { reference_model = Some($to.as_ref().unwrap().clone()); } @@ -724,6 +724,14 @@ impl Parameters { tract_core::floats::FloatPrecisionTranslator::::default().translate_model(&m) }); } + if let Some(transform) = matches.values_of("transform") { + for transform in transform { + stage!(transform, typed_model -> typed_model, |m:TypedModel| { + let transformer = tract_core::transform::get_transformer(transform).with_context(|| format!("Could not find transformer named {}", transform))?; + transformer.transform_into(&m) + }); + } + } if let Some(set) = matches.values_of("set") { let mut values = SymbolValues::default(); for set in set { diff --git a/core/src/floats.rs b/core/src/floats.rs index 31550b4840..c4228eb351 100644 --- a/core/src/floats.rs +++ b/core/src/floats.rs @@ -7,10 +7,23 @@ use crate::ops::einsum::EinSum; use crate::ops::konst::Const; use crate::ops::scan::Scan; use crate::ops::source::TypedSource; +use crate::transform::ModelTransformer; #[derive(Debug, Default)] pub struct FloatPrecisionTranslator(PhantomData<(T1, T2)>); +impl ModelTransformer for FloatPrecisionTranslator { + fn name(&self) -> Cow { + format!("{:?}-to-{:?}", T1::datum_type(), T2::datum_type()).into() + } + + fn transform(&self, model: &mut TypedModel) -> TractResult<()> { + let new = self.translate_model(model)?; + *model = new; + Ok(()) + } +} + impl Translate, TypedFact, Box> for FloatPrecisionTranslator diff --git a/core/src/lib.rs b/core/src/lib.rs index 54b7be035c..f7960fe1d3 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -79,6 +79,7 @@ pub mod model; pub mod optim; pub mod plan; pub mod runtime; +pub mod transform; pub mod value; pub use dyn_clone; diff --git a/core/src/model/typed.rs b/core/src/model/typed.rs index cd57e3973e..d03fcf79df 100644 --- a/core/src/model/typed.rs +++ b/core/src/model/typed.rs @@ -3,6 +3,7 @@ use crate::model::*; use crate::ops; use crate::optim::OptimizerSession; use crate::plan::{FrozenSimpleState, SimplePlan, SimpleState}; +use crate::transform::ModelTransformer; /// A model with completely determined types and shapes. pub type TypedModel = Graph>; @@ -149,6 +150,11 @@ impl TypedModel { Ok(self) } + /// Perform declutter passes on the network. + pub fn transform(&mut self, transformer: &dyn ModelTransformer) -> TractResult<()> { + transformer.transform(self) + } + /// Perform declutter passes on the network. pub fn declutter(&mut self) -> TractResult<()> { crate::optim::Optimizer::declutter().session().optimize(self) diff --git a/core/src/ops/nn/mod.rs b/core/src/ops/nn/mod.rs index 1c53a8cfed..28593be282 100644 --- a/core/src/ops/nn/mod.rs +++ b/core/src/ops/nn/mod.rs @@ -4,7 +4,7 @@ mod softmax; pub use self::data_formats::{BaseDataShape, DataFormat, DataShape, SymDataShape}; pub use self::reduce::{Reduce, Reducer}; -pub use self::softmax::Softmax; +pub use self::softmax::{Softmax, SoftmaxExp}; pub use crate::internal::*; diff --git a/core/src/ops/nn/softmax/mod.rs b/core/src/ops/nn/softmax/mod.rs index d3ecc21707..d30654c63b 100644 --- a/core/src/ops/nn/softmax/mod.rs +++ b/core/src/ops/nn/softmax/mod.rs @@ -12,10 +12,19 @@ use std::fmt::Debug; use crate::internal::*; use ndarray::prelude::*; -#[derive(Debug, Clone, new, Hash)] +#[derive(Debug, Copy, Clone, Hash, Default, PartialEq)] +pub enum SoftmaxExp { + #[default] + Libc, + // https://nic.schraudolph.org/pubs/Schraudolph99.pdf + FastCompact, +} + +#[derive(Debug, Clone, new, Hash, Default)] pub struct Softmax { pub axes: TVec, pub quant_output_dt: Option, + pub exp: SoftmaxExp, } impl Op for Softmax { @@ -24,7 +33,7 @@ impl Op for Softmax { } fn info(&self) -> TractResult> { - Ok(vec![format!("Axis: {:?}", self.axes)]) + Ok(vec![format!("Axis: {:?}", self.axes), format!("Exp impl: {:?}", self.exp)]) } op_as_typed_op!(); @@ -135,7 +144,7 @@ impl Softmax { view.as_slice_mut().filter(|_| T::datum_type() == f32::datum_type()) { let slice: &mut [f32] = unsafe { std::mem::transmute(slice) }; - softmax_inner_slice_f32(slice)?; + self.softmax_inner_slice_f32(slice)?; } else { softmax_inner(view); } @@ -176,14 +185,27 @@ impl Softmax { unsafe { output_tensor.set_datum_type(output_dt) }; Ok(tvec!(output_tensor.into_tvalue())) } -} -fn softmax_inner_slice_f32(slice: &mut [f32]) -> TractResult<()> { - let max = (tract_linalg::ops().max_f32)().run(slice)?; - let sum = (tract_linalg::ops().softmax_loop2_f32)().run_with_params(slice, max)?; - let rsum = sum.recip(); - (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?; - Ok(()) + fn softmax_inner_slice_f32(&self, slice: &mut [f32]) -> TractResult<()> { + let max = (tract_linalg::ops().max_f32)().run(slice)?; + let sum = match self.exp { + SoftmaxExp::Libc => { + let mut s = 0f32; + for x in slice.iter_mut() { + let y = (*x - max).exp(); + s += y; + *x = y; + } + s + } + SoftmaxExp::FastCompact => { + (tract_linalg::ops().softmax2_fastcompact_f32)().run_with_params(slice, max)? + } + }; + let rsum = sum.recip(); + (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?; + Ok(()) + } } fn softmax_inner(mut view: ArrayViewMut) { diff --git a/core/src/transform.rs b/core/src/transform.rs new file mode 100644 index 0000000000..d89c08f071 --- /dev/null +++ b/core/src/transform.rs @@ -0,0 +1,44 @@ +use crate::internal::*; +use std::borrow::Cow; +use std::fmt::Debug; + +use tract_data::TractResult; + +use crate::floats::FloatPrecisionTranslator; +use crate::ops::nn::{Softmax, SoftmaxExp, TypedModel}; + +pub fn get_transformer(name: &str) -> Option> { + match name { + "f32-to-f16" => Some(Box::new(FloatPrecisionTranslator::::default())), + "softmax-fast-compact" => Some(Box::new(SoftmaxFastCompact)), + _ => None, + } +} + +pub trait ModelTransformer: Debug { + fn name(&self) -> Cow; + fn transform(&self, model: &mut TypedModel) -> TractResult<()>; + fn transform_into(&self, model: &TypedModel) -> TractResult { + let mut model = model.clone(); + self.transform(&mut model)?; + Ok(model) + } +} + +#[derive(Debug)] +struct SoftmaxFastCompact; + +impl ModelTransformer for SoftmaxFastCompact { + fn name(&self) -> Cow { + "softmax-fast-compact".into() + } + + fn transform(&self, model: &mut TypedModel) -> TractResult<()> { + for node in &mut model.nodes { + if let Some(softmax) = node.op_as_mut::() { + softmax.exp = SoftmaxExp::FastCompact; + } + } + Ok(()) + } +} diff --git a/hir/src/ops/nn/layer_max.rs b/hir/src/ops/nn/layer_max.rs index 9fc5161926..a4e5ae6c81 100644 --- a/hir/src/ops/nn/layer_max.rs +++ b/hir/src/ops/nn/layer_max.rs @@ -1,3 +1,5 @@ +use tract_core::ops::nn::Softmax; + use crate::infer::*; use crate::internal::*; @@ -9,8 +11,6 @@ pub struct LayerHardmax { coerce_to_2d: bool, } - - impl Expansion for LayerHardmax { fn name(&self) -> Cow { "LayerHardmax".into() @@ -83,14 +83,11 @@ pub struct LayerLogSoftmax { pub coerce_to_2d: bool, } - - impl Expansion for LayerLogSoftmax { fn name(&self) -> Cow { "LayerLogSoftmax".into() } - fn rules<'r, 'p: 'r, 's: 'r>( &'s self, solver: &mut Solver<'r>, @@ -118,8 +115,6 @@ pub struct LayerSoftmax { coerce_to_2d: bool, } - - impl Expansion for LayerSoftmax { fn name(&self) -> Cow { "LayerSoftmax".into() @@ -144,10 +139,10 @@ impl Expansion for LayerSoftmax { let rank = target.outlet_fact(input)?.rank(); let dt = target.outlet_fact(input)?.datum_type; let axis = if self.axis < 0 { rank as isize + self.axis } else { self.axis } as usize; - let reducing_axes = + let axes = if self.coerce_to_2d { (axis..rank).collect::>() } else { tvec!(axis) }; - let dt = if dt.is_float() { None } else { Some(dt) }; - target.wire_node(name, tract_core::ops::nn::Softmax::new(reducing_axes, dt), inputs) + let quant_output_dt = if dt.is_float() { None } else { Some(dt) }; + target.wire_node(name, Softmax { axes, quant_output_dt, ..Softmax::default() }, inputs) } } diff --git a/hir/src/ops/nn/softmax.rs b/hir/src/ops/nn/softmax.rs index eaf64b5eca..9ba36e1304 100644 --- a/hir/src/ops/nn/softmax.rs +++ b/hir/src/ops/nn/softmax.rs @@ -1,4 +1,3 @@ -//use tract_core::ops::nn::Softmax; use crate::internal::*; #[derive(Debug, Clone, new, Hash)] @@ -6,8 +5,6 @@ pub struct Softmax { axis: isize, } - - impl Expansion for Softmax { fn name(&self) -> Cow { "Softmax".into() @@ -54,7 +51,11 @@ impl Expansion for Softmax { target.wire_node( name, - tract_core::ops::nn::Softmax { axes: tvec![axis], quant_output_dt }, + tract_core::ops::nn::Softmax { + axes: tvec![axis], + quant_output_dt, + ..tract_core::ops::nn::Softmax::default() + }, inputs, ) } diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index b783a82b75..000f615661 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -69,7 +69,7 @@ pub struct Ops { pub max_f32: Box Box> + Send + Sync>, - pub softmax_loop2_f32: Box Box> + Send + Sync>, + pub softmax2_fastcompact_f32: Box Box> + Send + Sync>, } impl Ops { @@ -132,7 +132,7 @@ pub fn generic() -> Ops { /* activation_f32: Box::new(|microcode| generic::SActivation::new(microcode)) */ - softmax_loop2_f32: Box::new(|| generic::softmax::SSoftMaxL2::red()), + softmax2_fastcompact_f32: Box::new(|| generic::softmax::SSoftMaxL2::red()), } } diff --git a/linalg/src/x86_64_fma.rs b/linalg/src/x86_64_fma.rs index 99331bb5a5..e2d61ca568 100644 --- a/linalg/src/x86_64_fma.rs +++ b/linalg/src/x86_64_fma.rs @@ -95,7 +95,7 @@ fn plug_fma(ops: &mut Ops) { ops.sigmoid_f32 = Box::new(|| fma_sigmoid_f32::ew()); ops.tanh_f32 = Box::new(|| fma_tanh_f32::ew()); - ops.softmax_loop2_f32 = Box::new(|| x86_64_fma_softmaxl2_f32_32n::red()); + ops.softmax2_fastcompact_f32 = Box::new(|| x86_64_fma_softmaxl2_f32_32n::red()); log::info!("mmm_f32, mmv_f32, sigmoid_f32, tanh_f32: x86_64/fma activated"); } diff --git a/nnef/src/ops/core.rs b/nnef/src/ops/core.rs index c322ef9559..dec4aeec94 100644 --- a/nnef/src/ops/core.rs +++ b/nnef/src/ops/core.rs @@ -21,6 +21,7 @@ mod reduce; mod scan; mod scatter; mod shape_of; +mod softmax; mod source; mod store; mod submodel; @@ -59,6 +60,7 @@ pub fn register(registry: &mut Registry) { scan::register(registry); scatter::register(registry); shape_of::register(registry); + softmax::register(registry); source::register(registry); store::register(registry); submodel::register(registry); diff --git a/nnef/src/ops/nnef/deser.rs b/nnef/src/ops/nnef/deser.rs index ca2fe82a63..a16094e65f 100644 --- a/nnef/src/ops/nnef/deser.rs +++ b/nnef/src/ops/nnef/deser.rs @@ -7,7 +7,7 @@ use tract_core::ops::array::PadMode; use tract_core::ops::cnn::deconv::adjustments; use tract_core::ops::cnn::PaddingSpec; use tract_core::ops::cnn::PoolSpec; -use tract_core::ops::nn::DataFormat; +use tract_core::ops::nn::{DataFormat, Softmax, SoftmaxExp}; use tract_itertools::izip; use tract_itertools::Itertools; @@ -686,5 +686,5 @@ pub fn softmax(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> T invocation.dt_from_quant_file.first().cloned().flatten() }; - builder.wire(ops::nn::Softmax { axes, quant_output_dt }, &[x]) + builder.wire(Softmax { axes, quant_output_dt, exp: SoftmaxExp::default() }, &[x]) } diff --git a/nnef/src/ops/nnef/ser.rs b/nnef/src/ops/nnef/ser.rs index 17b980e36f..89a575fbea 100644 --- a/nnef/src/ops/nnef/ser.rs +++ b/nnef/src/ops/nnef/ser.rs @@ -11,6 +11,7 @@ use tract_core::ops::cnn::KernelFormat; use tract_core::ops::cnn::PoolSpec; use tract_core::ops::einsum::BasicMatMul; use tract_core::ops::nn::DataFormat; +use tract_core::ops::nn::SoftmaxExp; use tract_core::tract_data::itertools::Itertools; pub fn source( @@ -453,6 +454,9 @@ pub fn softmax( node: &TypedNode, op: &ops::nn::Softmax, ) -> TractResult>> { + if op.exp != SoftmaxExp::default() { + return Ok(None) + } let litteral_axes: Vec<_> = op.axes.iter().map(|&it| (it as i64).into()).collect(); Ok(Some(invocation( "softmax", diff --git a/tflite/src/ops/nn.rs b/tflite/src/ops/nn.rs index 2d9d4ccb1d..c612475c08 100644 --- a/tflite/src/ops/nn.rs +++ b/tflite/src/ops/nn.rs @@ -140,7 +140,7 @@ fn de_softmax(op: &mut DeserOp) -> TractResult> { let options = builtin!(op, builtin_options_as_softmax_options); ensure!(options.beta() == 1.0); let quant_output_dt = Some(input.datum_type).filter(|dt| !dt.is_float()); - let softmax = core::nn::Softmax { axes: tvec!(input.rank() - 1), quant_output_dt }; + let softmax = Softmax { axes: tvec!(input.rank() - 1), quant_output_dt, ..Softmax::default() }; op.ctx.target.wire_node(op.prefix, softmax, op.inputs) } From 1c6f1306aced695136fe74ad03f4eba1eefb7fa3 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 23 Jan 2024 13:44:50 +0100 Subject: [PATCH 20/26] fix approx test --- core/src/ops/nn/softmax/mod.rs | 5 +++-- linalg/src/generic/softmax.rs | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/ops/nn/softmax/mod.rs b/core/src/ops/nn/softmax/mod.rs index d30654c63b..8a9f7f8c03 100644 --- a/core/src/ops/nn/softmax/mod.rs +++ b/core/src/ops/nn/softmax/mod.rs @@ -365,7 +365,8 @@ mod test { fn check(&self) -> Result<()> { let inputs = tvec!(self.data.clone().into_tvalue()); let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float()); - let softmax = Softmax { axes: self.axes.clone(), quant_output_dt }; + let softmax = + Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() }; // Compute quantized output let result = softmax.eval(inputs)?; @@ -375,7 +376,7 @@ mod test { // Compute reference output let input_float = self.data.cast_to::()?; let inputs_float = tvec!(input_float.into_owned().into_tvalue()); - let softmax_float = Softmax { axes: self.axes.clone(), quant_output_dt: None }; + let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() }; let reference_float = softmax_float.eval(inputs_float)?; let reference_array = args_1!(reference_float); let reference = reference_array.to_array_view::()?; diff --git a/linalg/src/generic/softmax.rs b/linalg/src/generic/softmax.rs index 632d4f8bbe..cdbda86216 100644 --- a/linalg/src/generic/softmax.rs +++ b/linalg/src/generic/softmax.rs @@ -117,7 +117,7 @@ proptest::proptest! { #[test] fn t_xnnpack(x in -100f32..0.) { use tract_data::internal::{tensor0, Approximation}; - tensor0(xnnpack_loop2_exp(x)).close_enough(&tensor0(x.exp()), Approximation::Approximate).unwrap(); + tensor0(xnnpack_loop2_exp(x)).close_enough(&tensor0(x.exp()), Approximation::SuperApproximate).unwrap(); } } From 9705beafffa7bbd24e24dc17c05bf92c3b2453f8 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 23 Jan 2024 13:50:06 +0100 Subject: [PATCH 21/26] forgot file --- nnef/src/ops/core/softmax.rs | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 nnef/src/ops/core/softmax.rs diff --git a/nnef/src/ops/core/softmax.rs b/nnef/src/ops/core/softmax.rs new file mode 100644 index 0000000000..8e1d9a0ac2 --- /dev/null +++ b/nnef/src/ops/core/softmax.rs @@ -0,0 +1,36 @@ +use tract_core::ops::nn::{Softmax, SoftmaxExp}; + +use crate::internal::*; + +pub fn register(registry: &mut Registry) { + registry.register_primitive( + "tract_core_softmax", + &[ + TypeName::Scalar.tensor().named("x"), + TypeName::Integer.tensor().named("axes"), + TypeName::String.named("exp"), + ], + &[("output", TypeName::Scalar.tensor())], + softmax, + ); +} + +pub fn softmax(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult { + let x = invocation.named_arg_as(builder, "x")?; + let axes: TVec = invocation.named_arg_as(builder, "axes")?; + + let input_fact = builder.model.outlet_fact(x)?.clone(); + let quant_output_dt = if input_fact.datum_type.is_float() { + None + } else { + invocation.dt_from_quant_file.first().cloned().flatten() + }; + + let exp: Option = invocation.get_named_arg_as(builder, "exp")?; + let exp = match exp.as_deref() { + Some("fast_compact") => SoftmaxExp::FastCompact, + _ => SoftmaxExp::Libc + }; + + builder.wire(Softmax { axes, quant_output_dt, exp }, &[x]) +} From 52633739816f67a8f9203a2d46332fdbf6ddc1e0 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 23 Jan 2024 14:29:33 +0100 Subject: [PATCH 22/26] cleanup --- linalg/Cargo.toml | 1 - linalg/src/frame/softmax.rs | 4 +- linalg/src/generic/softmax.rs | 264 +--------------------------------- 3 files changed, 7 insertions(+), 262 deletions(-) diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index d5c04f0f5e..7f7b3d0f7e 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -17,7 +17,6 @@ maintenance = { status = "actively-developed" } derive-new.workspace = true downcast-rs.workspace = true dyn-clone.workspace = true -hexf.workspace = true lazy_static.workspace = true log.workspace = true num-traits.workspace = true diff --git a/linalg/src/frame/softmax.rs b/linalg/src/frame/softmax.rs index 51a18f1a9d..d6f613a04e 100644 --- a/linalg/src/frame/softmax.rs +++ b/linalg/src/frame/softmax.rs @@ -52,7 +52,7 @@ pub mod test { f32: AsPrimitive, T: AsPrimitive, { - use crate::generic::softmax::very_fast_exp; + use crate::generic::softmax::fast_compact_exp_f32; crate::setup_test_logger(); let max = values.iter().max_by(|a, b| a.total_cmp(b)).unwrap(); let values: Vec = values.iter().copied().map(|x| x.as_()).collect(); @@ -61,7 +61,7 @@ pub mod test { ::min_value(), T::zero(), // |x| (x - max.as_()).exp(), - |x| very_fast_exp(dbg!(x).as_() - max).as_(), + |x| fast_compact_exp_f32(dbg!(x).as_() - max).as_(), |a, b| a + b, max.as_(), ) diff --git a/linalg/src/generic/softmax.rs b/linalg/src/generic/softmax.rs index cdbda86216..56411f1f61 100644 --- a/linalg/src/generic/softmax.rs +++ b/linalg/src/generic/softmax.rs @@ -38,10 +38,7 @@ impl MapReduceKer for SSoftMaxL2 { let mut sum = 0.; for v in x.iter_mut() { let y = *v - max; - // let y = y.exp(); - // let y = xnnpack_loop2_exp(y); - // let y = expf(y); - let y = very_fast_exp(y); + let y = fast_compact_exp_f32(y); *v = y; sum += y; } @@ -49,82 +46,11 @@ impl MapReduceKer for SSoftMaxL2 { } } -// https://github.com/google/XNNPACK/blob/3bc4ef01bbdf488556c54584fc2419dd77c39c85/src/f32-raddstoreexpminusmax/scalar-rr2-p5.c.in#L131 -// https://github.com/google/XNNPACK/blob/8951decff5114f70bae7cc2e23b732812e73acc7/src/microparams-init.c#L4121 -#[inline] -#[allow(dead_code, non_upper_case_globals)] -fn xnnpack_loop2_exp(vx: f32) -> f32 { - debug_assert!(vx <= 0f32); - const log2e: f32 = hexf::hexf32!("0x1.715476p+0"); - const magic_bias: f32 = hexf::hexf32!("0x1.8000FEp23"); - const minus_ln2_hi: f32 = hexf::hexf32!("-0x1.62E400p-1"); - const minus_ln2_lo: f32 = hexf::hexf32!("-0x1.7F7D1Cp-20"); - const c5: f32 = hexf::hexf32!("0x1.0F9F9Cp-7"); - const c4: f32 = hexf::hexf32!("0x1.573A1Ap-5"); - const c3: f32 = hexf::hexf32!("0x1.555A80p-3"); - const c2: f32 = hexf::hexf32!("0x1.FFFDC6p-2"); - const c1: f32 = hexf::hexf32!("0x1.FFFFF6p-1"); - const denorm_cutoff: f32 = hexf::hexf32!("-0x1.5D589Ep6"); - - // Compute reduced argument n := round(x / log(2)). - // We do it by adding a large number (magic bias) to the product x * (1/log(2)), which cause rounding of the result - // to an integer, then subtracing the large number back. The trick with adding large number is valid only within - // certain bounds (|x| <= 2**22), but that's ok, because inputs outside of [-87.336540, 0.0] underflow expf(x) - // anyway. We fixup the result for such inputs at the very end of the algorithm. - // float vn = vx * vlog2e + vmagic_bias; - let mut vn: f32 = vx * log2e + magic_bias; - - // Create a floating-point number s (scale) such that s == 2**n for inputs which don't cause underflow, i.e. - // -87.33642 <= x <= 0.0, and -126 <= n <= 0 accordingly. - // const float vs = uint32_as_float(float_as_uint32(vn) << 23); - let vs = f32::from_bits(vn.to_bits() << 23); - - // Subtract the large number back to get final n := round(x / log(2)). - // vn -= vmagic_bias; - vn -= magic_bias; - - // Compute reduced argument t := x - n * log(2). - // Use Cody-Waite range reduction method (note two constants to represent log(2)) to improve accuracy. - // float vt = vn * vminus_ln2_hi + vx; - let mut vt = vn * minus_ln2_hi + vx; - //vt = vn * vminus_ln2_lo + vt; - vt = vn * minus_ln2_lo + vt; - - // Compute degree-5 polynomial approximation for exp(t) on [-log(2)/2, log(2)/2]. - let mut vp = c5 * vt + c4; - vp = vp * vt + c2; - vp = vp * vt + c2; - vp = vp * vt + c1; - - // Reconstruct the final f value: - // f = s * (1 + t * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5))))) - // = s + (t * s) * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5)))) - // = s + (t * s) * p - vt *= vs; - // float vf = vt * vp + vs; - let mut vf = vt * vp + vs; - - // For inputs below denormal cutoff, replace output with +0.0f. - // Note that for NaN inputs, comparison result is false, and outputs are left unchanged. - if vx < denorm_cutoff { - vf = 0.0; - } - vf -} - -#[cfg(test)] -proptest::proptest! { - #[test] - fn t_xnnpack(x in -100f32..0.) { - use tract_data::internal::{tensor0, Approximation}; - tensor0(xnnpack_loop2_exp(x)).close_enough(&tensor0(x.exp()), Approximation::SuperApproximate).unwrap(); - } -} - // ported from https://github.com/gnuradio/volk/blob/master/kernels/volk/volk_32f_expfast_32f.h -#[inline(never)] -#[allow(dead_code)] -pub fn very_fast_exp(v: f32) -> f32 { +// probably inspired from https://nic.schraudolph.org/pubs/Schraudolph99.pdf +// not that the cast to u32 deals with negative right, while implem in volk code are wrong in some +// corner cases (need a max(0,x) before the u32 conversion) +pub fn fast_compact_exp_f32(v: f32) -> f32 { const MLN2: f32 = 0.6931471805f32; const A: f32 = 8388608.0f32; const B: f32 = 1065353216.0f32; @@ -134,186 +60,6 @@ pub fn very_fast_exp(v: f32) -> f32 { f32::from_bits(((SLOPE * v) + OFFSET) as u32) } -#[inline] -#[allow(dead_code)] -fn exp2_p7(v: f32) -> f32 { - const EXP2P: [f32; 7] = [ - 1.535336188319500e-4, - 1.339887440266574e-3, - 9.618437357674640e-3, - 5.550332471162809e-2, - 2.402264791363012e-1, - 6.931472028550421e-1, - 1.000000000000000, - ]; - let v = v.min(127f32).max(-127f32); - - let ipart = (v + 0.5).floor(); - let fpart = v - ipart; - - // 2^ipart - let two_pow_ipart = f32::from_bits((((ipart as i32) + 127) as u32) << 23); - - let mut v = EXP2P[0]; - v = v * fpart + EXP2P[1]; - v = v * fpart + EXP2P[2]; - v = v * fpart + EXP2P[3]; - v = v * fpart + EXP2P[4]; - v = v * fpart + EXP2P[5]; - v = v * fpart + EXP2P[6]; - v = v * two_pow_ipart; - v -} - -#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] -pub fn scalbnf(mut x: f32, mut n: i32) -> f32 { - let x1p127 = f32::from_bits(0x7f000000); // 0x1p127f === 2 ^ 127 - let x1p_126 = f32::from_bits(0x800000); // 0x1p-126f === 2 ^ -126 - let x1p24 = f32::from_bits(0x4b800000); // 0x1p24f === 2 ^ 24 - - if n > 127 { - x *= x1p127; - n -= 127; - if n > 127 { - x *= x1p127; - n -= 127; - if n > 127 { - n = 127; - } - } - } else if n < -126 { - x *= x1p_126 * x1p24; - n += 126 - 24; - if n < -126 { - x *= x1p_126 * x1p24; - n += 126 - 24; - if n < -126 { - n = -126; - } - } - } - x * f32::from_bits(((0x7f + n) as u32) << 23) -} - -macro_rules! force_eval { - ($e:expr) => { - unsafe { ::core::ptr::read_volatile(&$e) } - }; -} - -macro_rules! i { - ($array:expr, $index:expr) => { - unsafe { *$array.get_unchecked($index) } - }; - ($array:expr, $index:expr, = , $rhs:expr) => { - unsafe { - *$array.get_unchecked_mut($index) = $rhs; - } - }; - ($array:expr, $index:expr, += , $rhs:expr) => { - unsafe { - *$array.get_unchecked_mut($index) += $rhs; - } - }; - ($array:expr, $index:expr, -= , $rhs:expr) => { - unsafe { - *$array.get_unchecked_mut($index) -= $rhs; - } - }; - ($array:expr, $index:expr, &= , $rhs:expr) => { - unsafe { - *$array.get_unchecked_mut($index) &= $rhs; - } - }; - ($array:expr, $index:expr, == , $rhs:expr) => { - unsafe { *$array.get_unchecked_mut($index) == $rhs } - }; -} - -/// Exponential, base *e* (f32) -/// -/// Calculate the exponential of `x`, that is, *e* raised to the power `x` -/// (where *e* is the base of the natural system of logarithms, approximately 2.71828). -#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] -pub fn expf(mut x: f32) -> f32 { - const HALF: [f32; 2] = [0.5, -0.5]; - const LN2_HI: f32 = 6.9314575195e-01; /* 0x3f317200 */ - const LN2_LO: f32 = 1.4286067653e-06; /* 0x35bfbe8e */ - const INV_LN2: f32 = 1.4426950216e+00; /* 0x3fb8aa3b */ - /* - * Domain [-0.34568, 0.34568], range ~[-4.278e-9, 4.447e-9]: - * |x*(exp(x)+1)/(exp(x)-1) - p(x)| < 2**-27.74 - */ - const P1: f32 = 1.6666625440e-1; /* 0xaaaa8f.0p-26 */ - const P2: f32 = -2.7667332906e-3; /* -0xb55215.0p-32 */ - - let x1p127 = f32::from_bits(0x7f000000); // 0x1p127f === 2 ^ 127 - let x1p_126 = f32::from_bits(0x800000); // 0x1p-126f === 2 ^ -126 /*original 0x1p-149f ??????????? */ - let mut hx = x.to_bits(); - let sign = (hx >> 31) as i32; /* sign bit of x */ - let signb: bool = sign != 0; - hx &= 0x7fffffff; /* high word of |x| */ - - /* special cases */ - if hx >= 0x42aeac50 { - /* if |x| >= -87.33655f or NaN */ - if hx > 0x7f800000 { - /* NaN */ - return x; - } - if (hx >= 0x42b17218) && (!signb) { - /* x >= 88.722839f */ - /* overflow */ - x *= x1p127; - return x; - } - if signb { - /* underflow */ - force_eval!(-x1p_126 / x); - if hx >= 0x42cff1b5 { - /* x <= -103.972084f */ - return 0.; - } - } - } - - /* argument reduction */ - let k: i32; - let hi: f32; - let lo: f32; - if hx > 0x3eb17218 { - /* if |x| > 0.5 ln2 */ - if hx > 0x3f851592 { - /* if |x| > 1.5 ln2 */ - k = (INV_LN2 * x + i!(HALF, sign as usize)) as i32; - } else { - k = 1 - sign - sign; - } - let kf = k as f32; - hi = x - kf * LN2_HI; /* k*ln2hi is exact here */ - lo = kf * LN2_LO; - x = hi - lo; - } else if hx > 0x39000000 { - /* |x| > 2**-14 */ - k = 0; - hi = x; - lo = 0.; - } else { - /* raise inexact */ - force_eval!(x1p127 + x); - return 1. + x; - } - - /* x is now in primary range */ - let xx = x * x; - let c = x - xx * (P1 + xx * P2); - let y = 1. + (x * c / (2. - c) - lo + hi); - if k == 0 { - y - } else { - scalbnf(y, k) - } -} #[cfg(test)] #[macro_use] From 91522e5a7a69046afae7b8d98bd4ecb9313fd47d Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 23 Jan 2024 14:33:52 +0100 Subject: [PATCH 23/26] more cleanup --- Cargo.toml | 1 - core/src/transform.rs | 3 ++- linalg/src/lib.rs | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7ba5089bc4..fcc4340bc5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,7 +79,6 @@ flate2 = "1.0.20" fs2 = "0.4.3" getrandom = "0.2" half = { version="2.2.1", features = [ "std", "num-traits" ] } -hexf = "0.2.1" image = "0.24.1" itertools = "0.10.1" home = "0.5.5" diff --git a/core/src/transform.rs b/core/src/transform.rs index d89c08f071..5a3d082112 100644 --- a/core/src/transform.rs +++ b/core/src/transform.rs @@ -9,7 +9,8 @@ use crate::ops::nn::{Softmax, SoftmaxExp, TypedModel}; pub fn get_transformer(name: &str) -> Option> { match name { - "f32-to-f16" => Some(Box::new(FloatPrecisionTranslator::::default())), + "f32-to-f16" => Some(Box::>::default()), + "f16-to-f32" => Some(Box::>::default()), "softmax-fast-compact" => Some(Box::new(SoftmaxFastCompact)), _ => None, } diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index 000f615661..83de457c38 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -1,6 +1,8 @@ #![allow(clippy::missing_safety_doc)] #![allow(clippy::redundant_closure_call)] #![allow(clippy::len_zero)] +#![allow(clippy::excessive_precision)] +#![allow(clippy::approx_constant)] #[macro_use] extern crate derive_new; extern crate lazy_static; From f489f6d46476c0c8e1cacac6542457f1fbf60fe2 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 23 Jan 2024 17:17:08 +0100 Subject: [PATCH 24/26] mul_by_scalar intel --- linalg/src/generic/by_scalar.rs | 2 +- linalg/src/x86_64_fma.rs | 9 ++++-- linalg/src/x86_64_fma/by_scalar.rs | 52 ++++++++++++++++++++++++++++++ linalg/src/x86_64_fma/max.rs | 3 +- linalg/src/x86_64_fma/softmax.rs | 10 +++--- 5 files changed, 65 insertions(+), 11 deletions(-) create mode 100644 linalg/src/x86_64_fma/by_scalar.rs diff --git a/linalg/src/generic/by_scalar.rs b/linalg/src/generic/by_scalar.rs index d19fbc5a4e..a780c9e69c 100644 --- a/linalg/src/generic/by_scalar.rs +++ b/linalg/src/generic/by_scalar.rs @@ -30,6 +30,6 @@ impl ElementWiseKer for SMulByScalar4 { #[cfg(test)] #[macro_use] -pub mod s { +pub mod mul_by_scalar_f32 { mul_by_scalar_frame_tests!(true, f32, crate::generic::by_scalar::SMulByScalar4); } diff --git a/linalg/src/x86_64_fma.rs b/linalg/src/x86_64_fma.rs index e2d61ca568..d5f6261661 100644 --- a/linalg/src/x86_64_fma.rs +++ b/linalg/src/x86_64_fma.rs @@ -1,11 +1,12 @@ use crate::frame::element_wise::ElementWiseKer; use crate::frame::mmm::kernel::MatMatMulKer; -use crate::frame::reduce::MapReduceKer; -use crate::x86_64_fma::softmax::x86_64_fma_softmaxl2_f32_32n; +use crate::frame::reduce::{MapReduceKer, ReduceKer}; +use crate::x86_64_fma::softmax::x86_64_fma_softmax2_fastcompact_f32_32n; use crate::Ops; pub mod mmm; +pub mod by_scalar; mod intel; pub mod max; pub mod softmax; @@ -95,7 +96,9 @@ fn plug_fma(ops: &mut Ops) { ops.sigmoid_f32 = Box::new(|| fma_sigmoid_f32::ew()); ops.tanh_f32 = Box::new(|| fma_tanh_f32::ew()); - ops.softmax2_fastcompact_f32 = Box::new(|| x86_64_fma_softmaxl2_f32_32n::red()); + ops.mul_by_scalar_f32 = Box::new(|| by_scalar::x86_64_avx_f32_mul_by_scalar_32n::ew()); + ops.max_f32 = Box::new(|| max::x86_64_fma_max_f32_32n::red()); + ops.softmax2_fastcompact_f32 = Box::new(|| x86_64_fma_softmax2_fastcompact_f32_32n::red()); log::info!("mmm_f32, mmv_f32, sigmoid_f32, tanh_f32: x86_64/fma activated"); } diff --git a/linalg/src/x86_64_fma/by_scalar.rs b/linalg/src/x86_64_fma/by_scalar.rs new file mode 100644 index 0000000000..c2e7c9abda --- /dev/null +++ b/linalg/src/x86_64_fma/by_scalar.rs @@ -0,0 +1,52 @@ +ew_impl_wrap!( + f32, + x86_64_avx_f32_mul_by_scalar_32n, + 32, + 8, + f32, + fn run(x: &mut [f32], s: f32) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + unsafe { x86_64_avx_f32_mul_by_scalar_32n_run(x, s) } + } +); + +#[target_feature(enable = "avx")] +unsafe fn x86_64_avx_f32_mul_by_scalar_32n_run(buf: &mut [f32], scalar: f32) { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + vbroadcastss ymm0, xmm0 + 2: + vmovaps ymm4, [{ptr}] + vmovaps ymm5, [{ptr} + 32] + vmovaps ymm6, [{ptr} + 64] + vmovaps ymm7, [{ptr} + 96] + vmulps ymm4, ymm4, ymm0 + vmulps ymm5, ymm5, ymm0 + vmulps ymm6, ymm6, ymm0 + vmulps ymm7, ymm7, ymm0 + vmovaps [{ptr}], ymm4 + vmovaps [{ptr} + 32], ymm5 + vmovaps [{ptr} + 64], ymm6 + vmovaps [{ptr} + 96], ymm7 + add {ptr}, 128 + sub {len}, 32 + jnz 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("xmm0") scalar, + out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _ + ); +} + +#[cfg(test)] +#[macro_use] +pub mod test_x86_64_avx_f32_mul_by_scalar_32n { + mul_by_scalar_frame_tests!( + is_x86_feature_detected!("avx2"), + f32, + crate::x86_64_fma::by_scalar::x86_64_avx_f32_mul_by_scalar_32n + ); +} diff --git a/linalg/src/x86_64_fma/max.rs b/linalg/src/x86_64_fma/max.rs index b273427d75..beec1deaa3 100644 --- a/linalg/src/x86_64_fma/max.rs +++ b/linalg/src/x86_64_fma/max.rs @@ -2,7 +2,7 @@ reduce_impl_wrap!( f32, x86_64_fma_max_f32_32n, 32, - 32, + 8, (), f32::MIN, #[inline(never)] @@ -56,7 +56,6 @@ unsafe fn x86_64_fma_max_f32_32n_run(buf: &[f32]) -> f32 { out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _ ); acc - } #[cfg(test)] diff --git a/linalg/src/x86_64_fma/softmax.rs b/linalg/src/x86_64_fma/softmax.rs index 5f1de841dd..e4f8aeb94d 100644 --- a/linalg/src/x86_64_fma/softmax.rs +++ b/linalg/src/x86_64_fma/softmax.rs @@ -1,6 +1,6 @@ map_reduce_impl_wrap!( f32, - x86_64_fma_softmaxl2_f32_32n, + x86_64_fma_softmax2_fastcompact_f32_32n, 32, 8, f32, @@ -10,7 +10,7 @@ map_reduce_impl_wrap!( fn run(buf: &mut [f32], max: f32) -> f32 { assert!(buf.len() % 32 == 0); assert!(buf.len() > 0); - unsafe { x86_64_fma_softmaxl2_f32_32n_run(buf, max) } + unsafe { x86_64_fma_softmax2_fastcompact_f32_32n_run(buf, max) } }, #[inline(never)] fn reduce_two(a: f32, b: f32) -> f32 { @@ -19,7 +19,7 @@ map_reduce_impl_wrap!( ); #[target_feature(enable = "avx,fma")] -unsafe fn x86_64_fma_softmaxl2_f32_32n_run(buf: &mut [f32], max: f32) -> f32 { +unsafe fn x86_64_fma_softmax2_fastcompact_f32_32n_run(buf: &mut [f32], max: f32) -> f32 { let len = buf.len(); let ptr = buf.as_ptr(); let mut acc = 0f32; @@ -109,7 +109,7 @@ unsafe fn x86_64_fma_softmaxl2_f32_32n_run(buf: &mut [f32], max: f32) -> f32 { } #[cfg(test)] -mod test_x86_64_fma_softmaxl2_f32_32n { +mod test_x86_64_fma_softmax2_fastcompact_f32_32n { use super::*; - softmax_l2_frame_tests!(is_x86_feature_detected!("fma"), f32, x86_64_fma_softmaxl2_f32_32n); + softmax_l2_frame_tests!(is_x86_feature_detected!("fma"), f32, x86_64_fma_softmax2_fastcompact_f32_32n); } From 7d9775ddb628c8d9c95b7ce1c15b7753b34ce758 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 23 Jan 2024 17:17:19 +0100 Subject: [PATCH 25/26] mul_by_scalar intel --- core/src/ops/nn/softmax/mod.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/ops/nn/softmax/mod.rs b/core/src/ops/nn/softmax/mod.rs index 8a9f7f8c03..c0f46d6653 100644 --- a/core/src/ops/nn/softmax/mod.rs +++ b/core/src/ops/nn/softmax/mod.rs @@ -131,10 +131,11 @@ impl Softmax { } } - let mut output = input.into_tensor().into_array::()?; + let mut output = input.into_tensor(); + let mut view = output.to_array_view_mut::()?; for it_coords in tract_ndarray::indices(&*iterating_shape) { - let mut view = output.view_mut(); + let mut view = view.view_mut(); for ix in 0..iterating_shape.len() { if !self.axes.contains(&ix) { view.collapse_axis(Axis(ix), it_coords[ix]); From 048a64128d0aff23ae50ac3a0b449c49da8c4b1a Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 23 Jan 2024 17:22:53 +0100 Subject: [PATCH 26/26] complete softmax bench --- linalg/benches/softmax.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/linalg/benches/softmax.rs b/linalg/benches/softmax.rs index c1d735170d..b26efcb647 100644 --- a/linalg/benches/softmax.rs +++ b/linalg/benches/softmax.rs @@ -65,12 +65,28 @@ fn softmax_f32(c: &mut Criterion) { group.bench_function("loop2/generic", |b| { b.iter(|| tract_linalg::generic::softmax::SSoftMaxL2::red().run_with_params(input, 10.)) }); + #[cfg(target_arch = "x86_64")] + group.bench_function("loop2/iasm", |b| { + b.iter(|| { + tract_linalg::x86_64_fma::softmax::x86_64_fma_softmax2_fastcompact_f32_32n::red() + .run_with_params(input, 10.) + .unwrap() + }); + }); group.bench_function("loop3/naive", |b| b.iter(|| loop3_f32(input, 0.21))); group.bench_function("loop3/generic", |b| { b.iter(|| { tract_linalg::generic::by_scalar::SMulByScalar4::ew().run_with_params(input, 0.21) }) }); + #[cfg(target_arch = "x86_64")] + group.bench_function("loop3/iasm", |b| { + b.iter(|| { + tract_linalg::x86_64_fma::by_scalar::x86_64_avx_f32_mul_by_scalar_32n::ew() + .run_with_params(input, 0.21) + .unwrap() + }); + }); } criterion_group!(benches, softmax_f32);