diff --git a/cli/src/main.rs b/cli/src/main.rs index 0b25fcf23b..6361758c89 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -127,6 +127,7 @@ fn main() -> tract_core::anyhow::Result<()> { .arg(Arg::new("f32-to-f16").long("f32-to-f16").alias("half-floats").long_help("Convert the decluttered network from f32 to f16")) .arg(arg!(--"f16-to-f32" "Convert the decluttered network from f16 to f32")) + .arg(Arg::new("transform").long("transform").multiple_occurrences(true).takes_value(true).help("Apply a built-in transformation to the model")) .arg(Arg::new("set").long("set").multiple_occurrences(true).takes_value(true) .long_help("Set a symbol to a concrete value after decluttering")) diff --git a/cli/src/params.rs b/cli/src/params.rs index 22f3d465ef..9bbb863200 100644 --- a/cli/src/params.rs +++ b/cli/src/params.rs @@ -617,13 +617,13 @@ impl Parameters { macro_rules! stage { ($name:expr, $from:ident -> $to:ident, $block:expr) => { if let Some(from) = $from.take() { - info!(concat!("Running '", $name, "'")); + info!("Running {:?}", $name); let mut last_model: Option> = if keep_last { Some(Box::new(from.as_ref().clone())) } else { None }; let block: &dyn Fn(_) -> TractResult<_> = &$block; let owned_model = Arc::try_unwrap(from).unwrap_or_else(|from| from.as_ref().clone()); - match block(owned_model).context(concat!("Error at stage ", $name)) { + match block(owned_model).with_context(|| format!("Error at stage {:?}", $name)) { Ok(it) => { $to = Some(Arc::new(it)); } @@ -637,7 +637,7 @@ impl Parameters { } } } - info_usage(concat!("after ", $name), probe); + info_usage(&format!("after {:?}", $name), probe); if reference_stage.as_deref() == Some($name) { reference_model = Some($to.as_ref().unwrap().clone()); } @@ -724,6 +724,14 @@ impl Parameters { tract_core::floats::FloatPrecisionTranslator::::default().translate_model(&m) }); } + if let Some(transform) = matches.values_of("transform") { + for transform in transform { + stage!(transform, typed_model -> typed_model, |m:TypedModel| { + let transformer = tract_core::transform::get_transformer(transform).with_context(|| format!("Could not find transformer named {}", transform))?; + transformer.transform_into(&m) + }); + } + } if let Some(set) = matches.values_of("set") { let mut values = SymbolValues::default(); for set in set { diff --git a/core/src/floats.rs b/core/src/floats.rs index 31550b4840..c4228eb351 100644 --- a/core/src/floats.rs +++ b/core/src/floats.rs @@ -7,10 +7,23 @@ use crate::ops::einsum::EinSum; use crate::ops::konst::Const; use crate::ops::scan::Scan; use crate::ops::source::TypedSource; +use crate::transform::ModelTransformer; #[derive(Debug, Default)] pub struct FloatPrecisionTranslator(PhantomData<(T1, T2)>); +impl ModelTransformer for FloatPrecisionTranslator { + fn name(&self) -> Cow { + format!("{:?}-to-{:?}", T1::datum_type(), T2::datum_type()).into() + } + + fn transform(&self, model: &mut TypedModel) -> TractResult<()> { + let new = self.translate_model(model)?; + *model = new; + Ok(()) + } +} + impl Translate, TypedFact, Box> for FloatPrecisionTranslator diff --git a/core/src/lib.rs b/core/src/lib.rs index 54b7be035c..f7960fe1d3 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -79,6 +79,7 @@ pub mod model; pub mod optim; pub mod plan; pub mod runtime; +pub mod transform; pub mod value; pub use dyn_clone; diff --git a/core/src/model/typed.rs b/core/src/model/typed.rs index cd57e3973e..d03fcf79df 100644 --- a/core/src/model/typed.rs +++ b/core/src/model/typed.rs @@ -3,6 +3,7 @@ use crate::model::*; use crate::ops; use crate::optim::OptimizerSession; use crate::plan::{FrozenSimpleState, SimplePlan, SimpleState}; +use crate::transform::ModelTransformer; /// A model with completely determined types and shapes. pub type TypedModel = Graph>; @@ -149,6 +150,11 @@ impl TypedModel { Ok(self) } + /// Perform declutter passes on the network. + pub fn transform(&mut self, transformer: &dyn ModelTransformer) -> TractResult<()> { + transformer.transform(self) + } + /// Perform declutter passes on the network. pub fn declutter(&mut self) -> TractResult<()> { crate::optim::Optimizer::declutter().session().optimize(self) diff --git a/core/src/ops/nn/mod.rs b/core/src/ops/nn/mod.rs index 1c53a8cfed..28593be282 100644 --- a/core/src/ops/nn/mod.rs +++ b/core/src/ops/nn/mod.rs @@ -4,7 +4,7 @@ mod softmax; pub use self::data_formats::{BaseDataShape, DataFormat, DataShape, SymDataShape}; pub use self::reduce::{Reduce, Reducer}; -pub use self::softmax::Softmax; +pub use self::softmax::{Softmax, SoftmaxExp}; pub use crate::internal::*; diff --git a/core/src/ops/nn/reduce.rs b/core/src/ops/nn/reduce.rs index dc082da7e4..f26cba0f5a 100644 --- a/core/src/ops/nn/reduce.rs +++ b/core/src/ops/nn/reduce.rs @@ -1,6 +1,7 @@ use crate::internal::Axis; use crate::internal::*; use std::convert::TryFrom; +use std::mem::transmute; use tract_data::internal::ClampCast; use tract_data::itertools::Itertools; use tract_ndarray::prelude::*; @@ -210,6 +211,12 @@ fn max_t(v: ArrayViewD, _: ()) -> T where T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd, { + if T::datum_type() == f32::datum_type() { + if let Some(slice) = v.as_slice() { + let slice = unsafe { transmute(slice) }; + (tract_linalg::ops().max_f32)().run(slice).unwrap(); + } + } v.fold(T::min_value(), |acc, &v| if acc > v { acc } else { v }) } @@ -297,19 +304,23 @@ impl TypedOp for Reduce { outputs: &[&TypedFact], ) -> TractResult { let mut letters = 'a'..; - let axes = (0..inputs[0].rank()).flat_map(|ix| { - if self.axes.contains(&ix) { - tvec!( - Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()).input(0, ix), - Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()).output(0, ix), - ) - } else { - tvec!(Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()) - .input(0, ix) - .output(0, ix)) - } - .into_iter() - }).collect_vec(); + let axes = (0..inputs[0].rank()) + .flat_map(|ix| { + if self.axes.contains(&ix) { + tvec!( + Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()) + .input(0, ix), + Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()) + .output(0, ix), + ) + } else { + tvec!(Axis::new(letters.next().unwrap(), inputs.len(), outputs.len()) + .input(0, ix) + .output(0, ix)) + } + .into_iter() + }) + .collect_vec(); AxesMapping::new(1, 1, axes) } diff --git a/core/src/ops/nn/softmax/mod.rs b/core/src/ops/nn/softmax/mod.rs index c33d070de1..c0f46d6653 100644 --- a/core/src/ops/nn/softmax/mod.rs +++ b/core/src/ops/nn/softmax/mod.rs @@ -12,10 +12,19 @@ use std::fmt::Debug; use crate::internal::*; use ndarray::prelude::*; -#[derive(Debug, Clone, new, Hash)] +#[derive(Debug, Copy, Clone, Hash, Default, PartialEq)] +pub enum SoftmaxExp { + #[default] + Libc, + // https://nic.schraudolph.org/pubs/Schraudolph99.pdf + FastCompact, +} + +#[derive(Debug, Clone, new, Hash, Default)] pub struct Softmax { pub axes: TVec, pub quant_output_dt: Option, + pub exp: SoftmaxExp, } impl Op for Softmax { @@ -24,7 +33,7 @@ impl Op for Softmax { } fn info(&self) -> TractResult> { - Ok(vec![format!("Axis: {:?}", self.axes)]) + Ok(vec![format!("Axis: {:?}", self.axes), format!("Exp impl: {:?}", self.exp)]) } op_as_typed_op!(); @@ -122,16 +131,24 @@ impl Softmax { } } - let mut output = input.into_tensor().into_array::()?; + let mut output = input.into_tensor(); + let mut view = output.to_array_view_mut::()?; for it_coords in tract_ndarray::indices(&*iterating_shape) { - let mut view = output.view_mut(); + let mut view = view.view_mut(); for ix in 0..iterating_shape.len() { if !self.axes.contains(&ix) { view.collapse_axis(Axis(ix), it_coords[ix]); } } - softmax_inner(view); + if let Some(slice) = + view.as_slice_mut().filter(|_| T::datum_type() == f32::datum_type()) + { + let slice: &mut [f32] = unsafe { std::mem::transmute(slice) }; + self.softmax_inner_slice_f32(slice)?; + } else { + softmax_inner(view); + } } Ok(tvec!(output.into_tvalue())) @@ -169,6 +186,27 @@ impl Softmax { unsafe { output_tensor.set_datum_type(output_dt) }; Ok(tvec!(output_tensor.into_tvalue())) } + + fn softmax_inner_slice_f32(&self, slice: &mut [f32]) -> TractResult<()> { + let max = (tract_linalg::ops().max_f32)().run(slice)?; + let sum = match self.exp { + SoftmaxExp::Libc => { + let mut s = 0f32; + for x in slice.iter_mut() { + let y = (*x - max).exp(); + s += y; + *x = y; + } + s + } + SoftmaxExp::FastCompact => { + (tract_linalg::ops().softmax2_fastcompact_f32)().run_with_params(slice, max)? + } + }; + let rsum = sum.recip(); + (tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, rsum)?; + Ok(()) + } } fn softmax_inner(mut view: ArrayViewMut) { @@ -328,7 +366,8 @@ mod test { fn check(&self) -> Result<()> { let inputs = tvec!(self.data.clone().into_tvalue()); let quant_output_dt = Some(self.output_dt).filter(|dt| !dt.is_float()); - let softmax = Softmax { axes: self.axes.clone(), quant_output_dt }; + let softmax = + Softmax { axes: self.axes.clone(), quant_output_dt, ..Softmax::default() }; // Compute quantized output let result = softmax.eval(inputs)?; @@ -338,7 +377,7 @@ mod test { // Compute reference output let input_float = self.data.cast_to::()?; let inputs_float = tvec!(input_float.into_owned().into_tvalue()); - let softmax_float = Softmax { axes: self.axes.clone(), quant_output_dt: None }; + let softmax_float = Softmax { axes: self.axes.clone(), ..Softmax::default() }; let reference_float = softmax_float.eval(inputs_float)?; let reference_array = args_1!(reference_float); let reference = reference_array.to_array_view::()?; diff --git a/core/src/transform.rs b/core/src/transform.rs new file mode 100644 index 0000000000..5a3d082112 --- /dev/null +++ b/core/src/transform.rs @@ -0,0 +1,45 @@ +use crate::internal::*; +use std::borrow::Cow; +use std::fmt::Debug; + +use tract_data::TractResult; + +use crate::floats::FloatPrecisionTranslator; +use crate::ops::nn::{Softmax, SoftmaxExp, TypedModel}; + +pub fn get_transformer(name: &str) -> Option> { + match name { + "f32-to-f16" => Some(Box::>::default()), + "f16-to-f32" => Some(Box::>::default()), + "softmax-fast-compact" => Some(Box::new(SoftmaxFastCompact)), + _ => None, + } +} + +pub trait ModelTransformer: Debug { + fn name(&self) -> Cow; + fn transform(&self, model: &mut TypedModel) -> TractResult<()>; + fn transform_into(&self, model: &TypedModel) -> TractResult { + let mut model = model.clone(); + self.transform(&mut model)?; + Ok(model) + } +} + +#[derive(Debug)] +struct SoftmaxFastCompact; + +impl ModelTransformer for SoftmaxFastCompact { + fn name(&self) -> Cow { + "softmax-fast-compact".into() + } + + fn transform(&self, model: &mut TypedModel) -> TractResult<()> { + for node in &mut model.nodes { + if let Some(softmax) = node.op_as_mut::() { + softmax.exp = SoftmaxExp::FastCompact; + } + } + Ok(()) + } +} diff --git a/hir/src/ops/nn/layer_max.rs b/hir/src/ops/nn/layer_max.rs index 9fc5161926..a4e5ae6c81 100644 --- a/hir/src/ops/nn/layer_max.rs +++ b/hir/src/ops/nn/layer_max.rs @@ -1,3 +1,5 @@ +use tract_core::ops::nn::Softmax; + use crate::infer::*; use crate::internal::*; @@ -9,8 +11,6 @@ pub struct LayerHardmax { coerce_to_2d: bool, } - - impl Expansion for LayerHardmax { fn name(&self) -> Cow { "LayerHardmax".into() @@ -83,14 +83,11 @@ pub struct LayerLogSoftmax { pub coerce_to_2d: bool, } - - impl Expansion for LayerLogSoftmax { fn name(&self) -> Cow { "LayerLogSoftmax".into() } - fn rules<'r, 'p: 'r, 's: 'r>( &'s self, solver: &mut Solver<'r>, @@ -118,8 +115,6 @@ pub struct LayerSoftmax { coerce_to_2d: bool, } - - impl Expansion for LayerSoftmax { fn name(&self) -> Cow { "LayerSoftmax".into() @@ -144,10 +139,10 @@ impl Expansion for LayerSoftmax { let rank = target.outlet_fact(input)?.rank(); let dt = target.outlet_fact(input)?.datum_type; let axis = if self.axis < 0 { rank as isize + self.axis } else { self.axis } as usize; - let reducing_axes = + let axes = if self.coerce_to_2d { (axis..rank).collect::>() } else { tvec!(axis) }; - let dt = if dt.is_float() { None } else { Some(dt) }; - target.wire_node(name, tract_core::ops::nn::Softmax::new(reducing_axes, dt), inputs) + let quant_output_dt = if dt.is_float() { None } else { Some(dt) }; + target.wire_node(name, Softmax { axes, quant_output_dt, ..Softmax::default() }, inputs) } } diff --git a/hir/src/ops/nn/softmax.rs b/hir/src/ops/nn/softmax.rs index eaf64b5eca..9ba36e1304 100644 --- a/hir/src/ops/nn/softmax.rs +++ b/hir/src/ops/nn/softmax.rs @@ -1,4 +1,3 @@ -//use tract_core::ops::nn::Softmax; use crate::internal::*; #[derive(Debug, Clone, new, Hash)] @@ -6,8 +5,6 @@ pub struct Softmax { axis: isize, } - - impl Expansion for Softmax { fn name(&self) -> Cow { "Softmax".into() @@ -54,7 +51,11 @@ impl Expansion for Softmax { target.wire_node( name, - tract_core::ops::nn::Softmax { axes: tvec![axis], quant_output_dt }, + tract_core::ops::nn::Softmax { + axes: tvec![axis], + quant_output_dt, + ..tract_core::ops::nn::Softmax::default() + }, inputs, ) } diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index 60342b7ca7..7f7b3d0f7e 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -79,6 +79,10 @@ harness = false name = "sigmoid" harness = false +[[bench]] +name = "softmax" +harness = false + [[bench]] bench = false name = "arm64simd" diff --git a/linalg/benches/arm64simd.rs b/linalg/benches/arm64simd.rs index 0fb458718f..1dd244ff19 100644 --- a/linalg/benches/arm64simd.rs +++ b/linalg/benches/arm64simd.rs @@ -1,5 +1,4 @@ #![allow(dead_code, non_upper_case_globals, unused_macros, non_snake_case, unused_assignments)] -#![feature(aarch64_target_feature)] use std::arch::asm; @@ -57,6 +56,16 @@ pub unsafe fn armv8(filter: Option<&str>) { out("v4") _, out("v5") _, out("v6") _, out("v7") _, ) }); + s128!("fmax", 4, { + asm!(" fmax v0.4s, v1.4s, v1.4s + fmax v2.4s, v3.4s, v3.4s + fmax v4.4s, v5.4s, v5.4s + fmax v6.4s, v7.4s, v7.4s ", + out("v0") _, out("v1") _, out("v2") _, out("v3") _, + out("v4") _, out("v5") _, out("v6") _, out("v7") _, + ) + }); + s128!("fmax_with_dep", 1, { asm!("fmax v0.4s, v0.4s, v0.4s", out("v0") _) }); s128!("fmla", 16, { asm!(" fmla v0.4s, v0.4s, v0.4s fmla v1.4s, v1.4s, v1.4s diff --git a/linalg/benches/softmax.rs b/linalg/benches/softmax.rs new file mode 100644 index 0000000000..b26efcb647 --- /dev/null +++ b/linalg/benches/softmax.rs @@ -0,0 +1,93 @@ +use criterion::*; +use tract_data::prelude::*; +use tract_linalg::element_wise::ElementWiseKer; +use tract_linalg::frame::reduce::{MapReduceKer, ReduceKer}; + +#[inline(never)] +fn loop1_f32_naive(slice: &mut [f32]) -> f32 { + let mut max = std::f32::MIN; + for x in &*slice { + if *x > max { + max = *x; + } + } + max +} + +#[inline(never)] +fn loop2_f32(slice: &mut [f32], max: f32) -> f32 { + let mut sum = 0.; + for x in slice.iter_mut() { + *x = (*x - max).exp(); + sum = sum + *x; + } + sum +} + +#[inline(never)] +fn loop3_f32(slice: &mut [f32], sum: f32) { + let recip = sum.recip(); + for x in slice { + *x = *x * recip; + } +} + +#[inline(never)] +fn rust_f32(slice: &mut [f32]) { + let max = loop1_f32_naive(slice); + let sum = loop2_f32(slice, max); + loop3_f32(slice, sum); +} + +fn softmax_f32(c: &mut Criterion) { + let mut group = c.benchmark_group("softmax_f32"); + group.throughput(Throughput::Elements(1500)); + let mut input = unsafe { Tensor::uninitialized_aligned::(&[1500], 16).unwrap() }; + let input = input.as_slice_mut::().unwrap(); + group.bench_function("rust", |b| b.iter(|| rust_f32(input))); + group.bench_function("loop1/naive", |b| b.iter(|| loop1_f32_naive(input))); + group.bench_function("loop1/generic", |b| { + b.iter(|| tract_linalg::generic::max::SMax4::red().run(&input)) + }); + #[cfg(target_arch = "x86_64")] + group.bench_function("loop1/iasm", |b| { + b.iter(|| { + tract_linalg::x86_64_fma::max::x86_64_fma_max_f32_32n::red().run(input).unwrap(); + }) + }); + #[cfg(target_arch = "aarch64")] + group.bench_function("loop1/intr", |b| { + b.iter(|| { + tract_linalg::arm64::arm64simd_max_f32_16n::red().run(input).unwrap(); + }) + }); + group.bench_function("loop2/naive", |b| b.iter(|| loop2_f32(input, 1.0))); + group.bench_function("loop2/generic", |b| { + b.iter(|| tract_linalg::generic::softmax::SSoftMaxL2::red().run_with_params(input, 10.)) + }); + #[cfg(target_arch = "x86_64")] + group.bench_function("loop2/iasm", |b| { + b.iter(|| { + tract_linalg::x86_64_fma::softmax::x86_64_fma_softmax2_fastcompact_f32_32n::red() + .run_with_params(input, 10.) + .unwrap() + }); + }); + group.bench_function("loop3/naive", |b| b.iter(|| loop3_f32(input, 0.21))); + group.bench_function("loop3/generic", |b| { + b.iter(|| { + tract_linalg::generic::by_scalar::SMulByScalar4::ew().run_with_params(input, 0.21) + }) + }); + #[cfg(target_arch = "x86_64")] + group.bench_function("loop3/iasm", |b| { + b.iter(|| { + tract_linalg::x86_64_fma::by_scalar::x86_64_avx_f32_mul_by_scalar_32n::ew() + .run_with_params(input, 0.21) + .unwrap() + }); + }); +} + +criterion_group!(benches, softmax_f32); +criterion_main!(benches); diff --git a/linalg/proptest-regressions/generic/mmm.txt b/linalg/proptest-regressions/generic/mmm.txt deleted file mode 100644 index a6a851136a..0000000000 --- a/linalg/proptest-regressions/generic/mmm.txt +++ /dev/null @@ -1,19 +0,0 @@ -# Seeds for failure cases proptest has generated in the past. It is -# automatically read and these particular cases re-run before any -# novel cases are generated. -# -# It is recommended to check this file in to source control so that -# everyone who runs the test benefits from these saved cases. -cc 121b28bde52462d8007d6aa3b57effba3c4f3ce05a57803cbbcaa4b7096d083f # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 1, 0, 0], mult: 0, boo: PhantomData } -cc d54877b98a8e851f187820853649131a4a0ad62e0967e5c5881577ab2ebc56f8 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 1], mult: 0, boo: PhantomData } -cc a80e1dd4a5bb0c6858ccdaf91a5aa4a8b9b71594473a91aa82583e0db4ec6ef5 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 1], mult: 0, boo: PhantomData } -cc 5aac55722aca2ca476544313178602ef71fbdda4960d3c711605cc70c6e93c06 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, -1], mult: 0, boo: PhantomData } -cc 462968c12ee0fcc1e6a78b7fc4f412d0c7f0766299893886b8ec49cf978eff49 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, -1], mult: 0, boo: PhantomData } -cc 8ac3a83562b28a2ee6fa3896f929b7210ee7d4f559d9e9c2c06c9dd7e7fa93b5 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, 1], mult: 0, boo: PhantomData } -cc d0369cc2a5339619ebabdd3dbdac60ea9ebf542b23ef536db03fba2de3651c21 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, -2], mult: 0, boo: PhantomData } -cc 9c92a9f97a6d24ffaa21370b1eb5aec9880b508a384d037287df8de7f8d4c608 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, 1], mult: 0, boo: PhantomData } -cc f3574fa61a7e297de8cf76c8268dee6c6d062dfc5f98738618bdeb69d5b9b595 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], mult: 0, boo: PhantomData } -cc bb5421af3bb30e0ac316287dc41563ce69eca8ca74abf5f5a7f8e82d7cab6628 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], mult: 0, boo: PhantomData } -cc 45859e240aae3fa0e7667d46dd475d9f3de42e87ace8b3678db322b4a18509af # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255], mult: 0, boo: PhantomData } -cc 65d3a24331e3919a11d220118f97fe21caf769d1c590737cad9aef927959b4e9 # shrinks to pb = QWrappingMulHighDoublingProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], mult: 0, boo: PhantomData } -cc 378b328c0e46a5038f5876614a1ac16dc796482676f2ac81a670f5dbabb490ec # shrinks to pb = ConvProblem { ci: 1, co: 19, kt: 1, stride: 1, dilation: 4, filters: 19,1,F32 0, 0, 0, 0, 0, 0, 0, 0.026, -0.286, -0.627, -0.248, -0.291..., data: 1,5,F32 0.873, -0.377, -0.818, -0.415, 0.746, phantom: PhantomData } diff --git a/linalg/proptest-regressions/x86_64_fma/mmm.txt b/linalg/proptest-regressions/x86_64_fma/mmm.txt deleted file mode 100644 index 57df875ef5..0000000000 --- a/linalg/proptest-regressions/x86_64_fma/mmm.txt +++ /dev/null @@ -1,17 +0,0 @@ -# Seeds for failure cases proptest has generated in the past. It is -# automatically read and these particular cases re-run before any -# novel cases are generated. -# -# It is recommended to check this file in to source control so that -# everyone who runs the test benefits from these saved cases. -cc 058e808c67cb65d0a9013ff2cc0d9f726f19bef9b21e205396db15871be2586d # shrinks to pb = ReturnCProblem { c: [0.0, 0.0, -0.0, 0.0, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, 0.0, 0.0, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1000702600000000000000.0], boo: PhantomData } -cc 65e8110dc380469eec961f5d4ef07bd31b7f44defe08af00e6dcc812228bcc7f # shrinks to pb = QRightShiftProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], shift: 0, policy: Away, boo: PhantomData } -cc 1f7e80660ef4800d7bbb6acafe9bbff1e396f0f78b296d20deb51b439a40736c # shrinks to pb = QScaleProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], mult: 536870912, shift: 0, policy: Odd, boo: PhantomData } -cc 57c3e684c05eb9cc964eaaf22ccb620a2d9ef5ea5bd38c6dff6c5344b8192619 # shrinks to pb = QScaleProblem { c: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], mult: 536870912, shift: 0, policy: Odd, boo: PhantomData } -cc 9fce04b53146b7615e54728101dd8ad24ffd157d63c536eb1ca888070c8b1719 # shrinks to pb = PackedOffsetsProblem { a: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], b: [0.0, 0.0, 0.0, 0.0, 1.0, 9.0, 7.0, 7.0, 0.0, 5.0, 2.0, 3.0, 0.0, 3.0, 7.0, 2.0, 2.0, 6.0, 4.0, 4.0, 1.0, 5.0, 5.0, 1.0, 3.0, 7.0, 3.0, 4.0, 7.0, 5.0, 8.0, 8.0, 1.0, 5.0, 6.0, 8.0, 7.0, 6.0], cols_offsets: [28], rows_offsets: [9], add_one: false, _phantom: PhantomData } -cc 708172a1bb5110aa22d15b4a0daa777fae18dda1ab481328ab200f54d72357b6 # shrinks to pb = ConvProblem { ci: 1, co: 17, kt: 1, stride: 3, dilation: 4, filters: 17,1,F32 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.044..., data: 1,3,F32 -0.416, -0.932, -0.21, phantom: PhantomData } -cc e62d087b655259ebda7622f40e0049a49fb803777a3a98343c6855e69944ccfc # shrinks to pb = ConvProblem { ci: 1, co: 17, kt: 1, stride: 4, dilation: 4, filters: 17,1,F32 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0..., data: 1,4,F32 0.251, 0, 0, 0, phantom: PhantomData } -cc 15bfa5078b5479e12e74820888d477cbdd628faa03aa1bef26738dfa8cbff014 # shrinks to (m, k, n, ref a, ref b) = (1, 1, 2, 1,1,F32 0, 1,2,F32 0, 0) -cc dde4ea4eafe9beb60bffe8307cdb62c6143849f247ad9edbb64acd586f57d722 # shrinks to (m, k, n, ref a, ref b) = (1, 1, 2, 1,1,F32 0, 1,2,F32 0, 0) -cc 066cecc4f004061abab40982576e9f10c56dd090751a699e2cffeb3f7260b0e1 # shrinks to (m, k, n, ref a, ref b) = (1, 1, 2, 1,1,F32 0, 1,2,F32 0, 0) -cc 2ba4fcc74dd26a2e26c3c589d03961e5cb753a245a2ce5ec2ee7b9c53b30d8a2 # shrinks to pb = PackedPackedProblem { k: 16, a: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], b: [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], trans_c: false, add_one: false, _phantom: PhantomData } diff --git a/linalg/src/arm64.rs b/linalg/src/arm64.rs index c0435d6c65..96a974e914 100644 --- a/linalg/src/arm64.rs +++ b/linalg/src/arm64.rs @@ -8,14 +8,17 @@ mod cortex_a55; //mod cortex_a73; pub use arm64simd::*; -mod leaky_relu; -pub use leaky_relu::*; +#[cfg(not(feature = "no_fp16"))] +mod arm64fp16; +#[cfg(not(feature = "no_fp16"))] +pub use arm64fp16::*; use crate::Ops; use crate::f16; use crate::frame::element_wise::ElementWiseKer; use crate::frame::mmm::kernel::MatMatMulKer; +use crate::frame::reduce::ReduceKer; // https://en.wikipedia.org/wiki/Comparison_of_ARMv8-A_cores const PART_A53: &str = "0xd03"; @@ -229,6 +232,7 @@ pub fn plug(ops: &mut Ops) { ops.leaky_relu_f32 = Box::new(|| arm64simd_leaky_relu_f32_8n::ew()); ops.sigmoid_f32 = Box::new(|| arm64simd_sigmoid_f32_4n::ew()); ops.tanh_f32 = Box::new(|| arm64simd_tanh_f32_4n::ew()); + ops.max_f32 = Box::new(|| arm64simd_max_f32_16n::red()); #[cfg(not(feature = "no_fp16"))] if has_fp16() { log::info!("ARMv8.2 tanh_f16 and sigmoid_f16 activated"); diff --git a/linalg/src/arm64/arm64fp16.rs b/linalg/src/arm64/arm64fp16.rs new file mode 100644 index 0000000000..82e010d6c9 --- /dev/null +++ b/linalg/src/arm64/arm64fp16.rs @@ -0,0 +1,15 @@ +use tract_data::half::f16; + +mod leaky_relu; +pub use leaky_relu::*; + +use crate::frame::mmm::*; +MMMKernel!(f16, arm64fp16_mmm_f16_16x8_gen; 16, 8; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); +MMMKernel!(f16, arm64fp16_mmm_f16_16x8_a55; 16, 8; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); +MMMKernel!(f16, arm64fp16_mmm_f16_32x4_gen; 32, 4; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); +MMMKernel!(f16, arm64fp16_mmm_f16_32x4_a55; 32, 4; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); +MMMKernel!(f16, arm64fp16_mmm_f16_128x1_gen; 128, 1; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); +MMMKernel!(f16, arm64fp16_mmm_f16_128x1_a55; 128, 1; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); + +tanh_impl!(f16, arm64fp16_tanh_f16_8n, 8, 8, crate::arm64::has_fp16()); +sigmoid_impl!(f16, arm64fp16_sigmoid_f16_8n, 8, 8, crate::arm64::has_fp16()); diff --git a/linalg/src/arm64/leaky_relu.rs b/linalg/src/arm64/arm64fp16/leaky_relu.rs similarity index 54% rename from linalg/src/arm64/leaky_relu.rs rename to linalg/src/arm64/arm64fp16/leaky_relu.rs index 1ff968d96c..79b7f95174 100644 --- a/linalg/src/arm64/leaky_relu.rs +++ b/linalg/src/arm64/arm64fp16/leaky_relu.rs @@ -1,50 +1,5 @@ use tract_data::internal::f16; -ew_impl_wrap!( - f32, - arm64simd_leaky_relu_f32_8n, - 8, - 4, - f32, - #[inline(never)] - fn run(buf: &mut [f32], alpha: f32) { - assert!(buf.len() % 8 == 0); - assert!(buf.len() > 0); - unsafe { - let len = buf.len(); - let ptr = buf.as_ptr(); - std::arch::asm!(" - dup v0.4s, {alpha:v}.s[0] - dup v1.4s, {one:v}.s[0] - 1: - ldp q3, q4, [{ptr}] - - fcmgt v5.4s, v3.4s, #0.0 - fcmgt v6.4s, v4.4s, #0.0 - bsl v5.16b, v1.16b, v0.16b - bsl v6.16b, v1.16b, v0.16b - fmul v3.4s, v3.4s, v5.4s - fmul v4.4s, v4.4s, v6.4s - - stp q3, q4, [{ptr}], #32 - subs {len}, {len}, 8 - bne 1b - ", - one = in(vreg) 1.0f32, - alpha = in(vreg) alpha, - len = inout(reg) len => _, - ptr = inout(reg) ptr => _, - out("v0") _, - out("v1") _, - out("q3") _, - out("q4") _, - out("q5") _, - out("q6") _, - ); - } - } -); - ew_impl_wrap!( f16, arm64fp16_leaky_relu_f16_16n, @@ -92,14 +47,9 @@ ew_impl_wrap!( } ); -#[cfg(test)] -pub mod test_arm64simd_leaky_relu_f32_8n { - use super::*; - leaky_relu_frame_tests!(true, f32, arm64simd_leaky_relu_f32_8n); -} - #[cfg(test)] pub mod test_arm64simd_leaky_relu_f16_16n { use super::*; leaky_relu_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_leaky_relu_f16_16n); } + diff --git a/linalg/src/arm64/arm64simd.rs b/linalg/src/arm64/arm64simd.rs index 5ac30d731a..fd7694f235 100644 --- a/linalg/src/arm64/arm64simd.rs +++ b/linalg/src/arm64/arm64simd.rs @@ -1,6 +1,10 @@ +mod leaky_relu; +mod max; + use crate::frame::mmm::*; -#[cfg(not(feature = "no_fp16"))] -use tract_data::half::f16; + +pub use leaky_relu::arm64simd_leaky_relu_f32_8n; +pub use max::arm64simd_max_f32_16n; MMMKernel!(f32, arm64simd_mmm_f32_8x8_a55; 8, 8; 16, 16; 1, 1; no_prefetch, true); MMMKernel!(f32, arm64simd_mmm_f32_12x8_a55; 12, 8; 16, 16; 1, 1; no_prefetch, true); @@ -23,23 +27,6 @@ MMMKernel!(f32, arm64simd_mmm_f32_64x1_gen; 64, 1; 16, 16; 1, 1; no_prefetch, tr MMMKernel!(i32, arm64simd_mmm_i32_8x8; 8, 8; 16, 16; 0,0; no_prefetch, true); MMMKernel!(i32, arm64simd_mmm_i32_64x1; 64, 1; 16, 1; 0,0; no_prefetch, true); -#[cfg(not(feature = "no_fp16"))] -MMMKernel!(f16, arm64fp16_mmm_f16_16x8_gen; 16, 8; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); -#[cfg(not(feature = "no_fp16"))] -MMMKernel!(f16, arm64fp16_mmm_f16_16x8_a55; 16, 8; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); -#[cfg(not(feature = "no_fp16"))] -MMMKernel!(f16, arm64fp16_mmm_f16_32x4_gen; 32, 4; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); -#[cfg(not(feature = "no_fp16"))] -MMMKernel!(f16, arm64fp16_mmm_f16_32x4_a55; 32, 4; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); -#[cfg(not(feature = "no_fp16"))] -MMMKernel!(f16, arm64fp16_mmm_f16_128x1_gen; 128, 1; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); -#[cfg(not(feature = "no_fp16"))] -MMMKernel!(f16, arm64fp16_mmm_f16_128x1_a55; 128, 1; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); - tanh_impl!(f32, arm64simd_tanh_f32_4n, 4, 4, true); sigmoid_impl!(f32, arm64simd_sigmoid_f32_4n, 4, 4, true); -#[cfg(not(feature = "no_fp16"))] -tanh_impl!(f16, arm64fp16_tanh_f16_8n, 8, 8, crate::arm64::has_fp16()); -#[cfg(not(feature = "no_fp16"))] -sigmoid_impl!(f16, arm64fp16_sigmoid_f16_8n, 8, 8, crate::arm64::has_fp16()); diff --git a/linalg/src/arm64/arm64simd/leaky_relu.rs b/linalg/src/arm64/arm64simd/leaky_relu.rs new file mode 100644 index 0000000000..83901c87a6 --- /dev/null +++ b/linalg/src/arm64/arm64simd/leaky_relu.rs @@ -0,0 +1,50 @@ +ew_impl_wrap!( + f32, + arm64simd_leaky_relu_f32_8n, + 8, + 4, + f32, + #[inline(never)] + fn run(buf: &mut [f32], alpha: f32) { + assert!(buf.len() % 8 == 0); + assert!(buf.len() > 0); + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + dup v0.4s, {alpha:v}.s[0] + dup v1.4s, {one:v}.s[0] + 1: + ldp q3, q4, [{ptr}] + + fcmgt v5.4s, v3.4s, #0.0 + fcmgt v6.4s, v4.4s, #0.0 + bsl v5.16b, v1.16b, v0.16b + bsl v6.16b, v1.16b, v0.16b + fmul v3.4s, v3.4s, v5.4s + fmul v4.4s, v4.4s, v6.4s + + stp q3, q4, [{ptr}], #32 + subs {len}, {len}, 8 + bne 1b + ", + one = in(vreg) 1.0f32, + alpha = in(vreg) alpha, + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + out("v0") _, + out("v1") _, + out("q3") _, + out("q4") _, + out("q5") _, + out("q6") _, + ); + } + } +); + +#[cfg(test)] +pub mod test_arm64simd_leaky_relu_f32_8n { + use super::*; + leaky_relu_frame_tests!(true, f32, arm64simd_leaky_relu_f32_8n); +} diff --git a/linalg/src/arm64/arm64simd/max.rs b/linalg/src/arm64/arm64simd/max.rs new file mode 100644 index 0000000000..86cc1a2673 --- /dev/null +++ b/linalg/src/arm64/arm64simd/max.rs @@ -0,0 +1,52 @@ +use std::arch::aarch64::{float32x4_t, vdupq_n_f32, vgetq_lane_f32}; + +reduce_impl_wrap!( + f32, + arm64simd_max_f32_16n, + 16, + 4, + (), + f32::MIN, + #[inline(never)] + fn run(buf: &[f32], _: ()) -> f32 { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + unsafe { + let len = buf.len(); + let ptr = buf.as_ptr(); + let mut out: float32x4_t = vdupq_n_f32(f32::MIN); + std::arch::asm!(" + and v1.16b, v0.16b, v0.16b + and v2.16b, v0.16b, v0.16b + and v3.16b, v0.16b, v0.16b + 1: + ld1 {{v4.4s, v5.4s, v6.4s, v7.4s}}, [{ptr}], 64 + fmax v0.4s, v0.4s, v4.4s + fmax v1.4s, v1.4s, v5.4s + fmax v2.4s, v2.4s, v6.4s + fmax v3.4s, v3.4s, v7.4s + subs {len}, {len}, 16 + bne 1b + fmax v0.4s, v0.4s, v1.4s + fmax v2.4s, v2.4s, v3.4s + fmax v0.4s, v0.4s, v2.4s + fmaxv s0, v0.4s + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("v0") out, out("v1") _, out("v2") _, out("v3") _, + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + vgetq_lane_f32(out, 0) + } + }, + #[inline(never)] + fn reduce_two(a: f32, b: f32) -> f32 { + a.max(b) + } +); + +#[cfg(test)] +mod test_arm64simd_max_f32_16n { + use super::*; + max_frame_tests!(true, f32, arm64simd_max_f32_16n); +} diff --git a/linalg/src/frame.rs b/linalg/src/frame.rs index 5b20912abb..9810835e0d 100644 --- a/linalg/src/frame.rs +++ b/linalg/src/frame.rs @@ -1,15 +1,24 @@ #[macro_use] pub mod element_wise; + +#[macro_use] +pub mod by_scalar; #[macro_use] pub mod lut; #[macro_use] +pub mod max; +#[macro_use] pub mod mmm; pub mod pack; #[macro_use] pub mod leaky_relu; #[macro_use] +pub mod reduce; +#[macro_use] pub mod sigmoid; #[macro_use] +pub mod softmax; +#[macro_use] pub mod tanh; pub mod element_wise_helper; diff --git a/linalg/src/frame/by_scalar.rs b/linalg/src/frame/by_scalar.rs new file mode 100644 index 0000000000..528c3c3585 --- /dev/null +++ b/linalg/src/frame/by_scalar.rs @@ -0,0 +1,39 @@ +#[cfg(test)] +#[macro_use] +pub mod test { + use crate::frame::element_wise::ElementWiseKer; + use crate::LADatum; + use num_traits::{AsPrimitive, Float}; + use proptest::test_runner::TestCaseResult; + + #[macro_export] + macro_rules! mul_by_scalar_frame_tests { + ($cond:expr, $t: ty, $ker:ty) => { + proptest::proptest! { + #[test] + fn prop(xs in proptest::collection::vec(-25f32..25.0, 0..100), scalar in -25f32..25f32) { + if $cond { + $crate::frame::by_scalar::test::test_mul_by_scalar::<$ker, $t>(&*xs, scalar).unwrap() + } + } + } + }; + } + + pub fn test_mul_by_scalar, T: LADatum + Float>( + values: &[f32], + scalar: f32, + ) -> TestCaseResult + where + f32: AsPrimitive, + T: AsPrimitive, + { + crate::setup_test_logger(); + let values: Vec = values.iter().copied().map(|x| x.as_()).collect(); + crate::frame::element_wise::test::test_element_wise_params::( + &values, + |a| a * scalar.as_(), + scalar.as_(), + ) + } +} diff --git a/linalg/src/frame/element_wise.rs b/linalg/src/frame/element_wise.rs index 089769d064..d5bae4a0e9 100644 --- a/linalg/src/frame/element_wise.rs +++ b/linalg/src/frame/element_wise.rs @@ -5,7 +5,7 @@ use tract_data::TractResult; use crate::LADatum; -use super::element_wise_helper::run_over_slice_with_alignment; +use super::element_wise_helper::map_slice_with_alignment; macro_rules! ew_impl_wrap { ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $run: item) => { @@ -100,7 +100,7 @@ where K: ElementWiseKer + Clone, { fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult<()> { - run_over_slice_with_alignment( + map_slice_with_alignment( vec, |data| K::run(data, params), K::nr(), diff --git a/linalg/src/frame/element_wise_helper.rs b/linalg/src/frame/element_wise_helper.rs index 3116217b11..f4b3084642 100644 --- a/linalg/src/frame/element_wise_helper.rs +++ b/linalg/src/frame/element_wise_helper.rs @@ -2,7 +2,7 @@ use crate::LADatum; use std::alloc::*; use tract_data::TractResult; -pub(crate) fn run_over_slice_with_alignment( +pub(crate) fn map_slice_with_alignment( vec: &mut [T], f: impl Fn(&mut [T]), nr: usize, @@ -40,6 +40,92 @@ where Ok(()) } +pub(crate) fn reduce_slice_with_alignment( + vec: &[T], + f: impl Fn(&[T]) -> T, + nr: usize, + alignment_bytes: usize, + neutral: T, + reduce: impl Fn(T, T) -> T, +) -> TractResult +where + T: LADatum, +{ + if vec.is_empty() { + return Ok(neutral); + } + let mut red = neutral; + unsafe { + TMP.with(|buffer| { + let mut buffer = buffer.borrow_mut(); + buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes); + let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr); + let mut compute_via_temp_buffer = |slice: &[T], red: &mut T| { + tmp[..slice.len()].copy_from_slice(slice); + tmp[slice.len()..].fill(neutral); + *red = reduce(*red, f(tmp)); + }; + let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len()); + if prefix_len > 0 { + compute_via_temp_buffer(&vec[..prefix_len], &mut red); + } + let aligned_len = (vec.len() - prefix_len) / nr * nr; + if aligned_len > 0 { + let t = f(&vec[prefix_len..][..aligned_len]); + red = reduce(red, t); + } + if prefix_len + aligned_len < vec.len() { + compute_via_temp_buffer(&vec[prefix_len + aligned_len..], &mut red); + } + }) + } + Ok(red) +} + +pub(crate) fn map_reduce_slice_with_alignment( + vec: &mut [T], + f: impl Fn(&mut [T]) -> T, + nr: usize, + alignment_bytes: usize, + map_neutral: T, + neutral: T, + reduce: impl Fn(T, T) -> T, +) -> TractResult +where + T: LADatum, +{ + if vec.is_empty() { + return Ok(neutral); + } + let mut red = neutral; + unsafe { + TMP.with(|buffer| { + let mut buffer = buffer.borrow_mut(); + buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes); + let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr); + let mut compute_via_temp_buffer = |slice: &mut [T], red: &mut T| { + tmp[..slice.len()].copy_from_slice(slice); + tmp[slice.len()..].fill(map_neutral); + *red = reduce(*red, f(tmp)); + slice.copy_from_slice(&tmp[..slice.len()]); + }; + let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len()); + if prefix_len > 0 { + compute_via_temp_buffer(&mut vec[..prefix_len], &mut red); + } + let aligned_len = (vec.len() - prefix_len) / nr * nr; + if aligned_len > 0 { + let t = f(&mut vec[prefix_len..][..aligned_len]); + red = reduce(red, t); + } + if prefix_len + aligned_len < vec.len() { + compute_via_temp_buffer(&mut vec[prefix_len + aligned_len..], &mut red); + } + }) + } + Ok(red) +} + std::thread_local! { static TMP: std::cell::RefCell = std::cell::RefCell::new(TempBuffer::default()); } diff --git a/linalg/src/frame/max.rs b/linalg/src/frame/max.rs new file mode 100644 index 0000000000..12a451487c --- /dev/null +++ b/linalg/src/frame/max.rs @@ -0,0 +1,58 @@ + +/*macro_rules! max_impl { + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $cond: expr) => { + reduce_impl!($ti, $func, $nr, $alignment_items); + #[cfg(test)] + paste! { + mod [] { + use super::*; + max_frame_tests!($cond, $ti, $func); + } + } + }; +} +*/ + +#[cfg(test)] +#[macro_use] +pub mod test { + use crate::frame::reduce::ReduceKer; + use crate::LADatum; + use num_traits::{AsPrimitive, Float}; + use proptest::test_runner::TestCaseResult; + + #[macro_export] + macro_rules! max_frame_tests { + ($cond:expr, $t: ty, $ker:ty) => { + proptest::proptest! { + #[test] + fn prop(xs in proptest::collection::vec(-25f32..25.0, 0..100)) { + if $cond { + $crate::frame::max::test::test_max::<$ker, $t>(&*xs).unwrap() + } + } + } + + #[test] + fn empty() { + if $cond { + $crate::frame::max::test::test_max::<$ker, $t>(&[]).unwrap() + } + } + }; + } + + pub fn test_max, T: LADatum + Float>(values: &[f32]) -> TestCaseResult + where + f32: AsPrimitive, + T: AsPrimitive, + { + crate::setup_test_logger(); + let values: Vec = values.iter().copied().map(|x| x.as_()).collect(); + crate::frame::reduce::test::test_reduce::( + &values, + ::min_value(), + |a, b| a.max(b), + ) + } +} diff --git a/linalg/src/frame/mmm.rs b/linalg/src/frame/mmm.rs index 7060ba4d1b..df3bf18e30 100644 --- a/linalg/src/frame/mmm.rs +++ b/linalg/src/frame/mmm.rs @@ -37,7 +37,7 @@ macro_rules! MMMKernel { #[derive(Copy, Clone, Debug, new)] pub struct $func; - impl MatMatMulKer<$ti> for $func { + impl $crate::frame::mmm::MatMatMulKer<$ti> for $func { #[inline(always)] fn name() -> &'static str { stringify!($func) @@ -67,9 +67,9 @@ macro_rules! MMMKernel { $end_padding_packed_b } #[inline(always)] - fn kernel(spec: &[FusedKerSpec<$ti>]) -> isize { + fn kernel(spec: &[$crate::frame::mmm::FusedKerSpec<$ti>]) -> isize { debug_assert!(spec.len() > 0); - debug_assert!(matches!(spec[spec.len() - 1], FusedKerSpec::Done)); + debug_assert!(matches!(spec[spec.len() - 1], $crate::frame::mmm::FusedKerSpec::Done)); unsafe { []::$func(spec.as_ptr()) } } #[inline(always)] diff --git a/linalg/src/frame/reduce.rs b/linalg/src/frame/reduce.rs new file mode 100644 index 0000000000..edd2369963 --- /dev/null +++ b/linalg/src/frame/reduce.rs @@ -0,0 +1,309 @@ +use std::fmt::Debug; +use std::marker::PhantomData; + +use tract_data::TractResult; + +use crate::LADatum; + +use super::element_wise_helper::{reduce_slice_with_alignment, map_reduce_slice_with_alignment}; + +macro_rules! reduce_impl_wrap { + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $neutral: expr, $run: item, $reduce_two: item) => { + paste! { + #[derive(Copy, Clone, Debug)] + #[allow(non_camel_case_types)] + pub struct $func; + + impl crate::frame::reduce::ReduceKer<$ti, $params> for $func { + #[inline(always)] + fn name() -> &'static str { + stringify!($func) + } + #[inline(always)] + fn nr() -> usize { + $nr + } + #[inline(always)] + fn alignment_items() -> usize { + $alignment_items + } + #[inline(always)] + fn alignment_bytes() -> usize { + $alignment_items * std::mem::size_of::<$ti>() + } + #[inline(always)] + fn neutral() -> $ti { + $neutral + } + $run + $reduce_two + } + } + }; +} + +/* +macro_rules! reduce_impl { + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr) => { + paste! { + mod [] { + #[allow(unused_imports)] + use tract_data::prelude::f16; + extern_kernel!(fn $func(ptr: *mut $ti, count: usize) -> ()); + } + reduce_impl_wrap!($ti, $func, $nr, $alignment_items, (), + #[inline(never)] + fn run(buf: &mut [$ti], _params: ()) { + unsafe { []::$func(buf.as_mut_ptr(), buf.len()) } + } + ); + } + }; + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty) => { + paste! { + mod [] { + #[allow(unused_imports)] + use tract_data::prelude::f16; + extern_kernel!(fn $func(ptr: *mut $ti, count: usize, params: $params) -> ()); + } + ew_impl_wrap!($ti, $func, $nr, $alignment_items, $params, + #[inline(never)] + fn run(buf: &mut [$ti], params: $params) { + unsafe { []::$func(buf.as_mut_ptr(), buf.len(), params) } + } + ); + } + }; +} +*/ + +pub trait Reduce: Send + Sync + Debug + dyn_clone::DynClone +where + Params: Copy + Send + Sync + Debug + 'static + Default, + T: Copy + Debug + PartialEq + Send + Sync, +{ + fn run(&self, vec: &[T]) -> TractResult { + self.run_with_params(vec, Params::default()) + } + fn run_with_params(&self, vec: &[T], params: Params) -> TractResult; +} + +dyn_clone::clone_trait_object!( Reduce where T: Copy, Params: Copy); + +#[derive(Debug, Clone, new)] +pub struct ReduceImpl +where + T: LADatum, + Params: Copy + Send + Sync + Debug + 'static + Default, + K: ReduceKer + Clone, +{ + phantom: PhantomData<(K, T, Params)>, +} + +impl Reduce for ReduceImpl +where + T: LADatum, + Params: Copy + Send + Sync + Debug + 'static + Default, + K: ReduceKer + Clone, +{ + fn run_with_params(&self, vec: &[T], params: Params) -> TractResult { + reduce_slice_with_alignment( + vec, + |data| K::run(data, params), + K::nr(), + K::alignment_bytes(), + K::neutral(), + K::reduce_two, + ) + } +} + +pub trait ReduceKer: + Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static +where + Params: Copy + Send + Sync + Debug + 'static + Default, + T: LADatum, +{ + fn name() -> &'static str; + fn alignment_bytes() -> usize; + fn alignment_items() -> usize; + fn nr() -> usize; + fn neutral() -> T; + fn reduce_two(a: T, b: T) -> T; + fn run(vec: &[T], params: Params) -> T; + fn red() -> Box> { + Box::new(ReduceImpl::::new()) + } +} + +macro_rules! map_reduce_impl_wrap { + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $map_neutral: expr, $reduce_neutral: expr, $run: item, $reduce_two: item) => { + paste! { + #[derive(Copy, Clone, Debug)] + #[allow(non_camel_case_types)] + pub struct $func; + + impl crate::frame::reduce::MapReduceKer<$ti, $params> for $func { + #[inline(always)] + fn name() -> &'static str { + stringify!($func) + } + #[inline(always)] + fn nr() -> usize { + $nr + } + #[inline(always)] + fn alignment_items() -> usize { + $alignment_items + } + #[inline(always)] + fn alignment_bytes() -> usize { + $alignment_items * std::mem::size_of::<$ti>() + } + #[inline(always)] + fn map_neutral() -> $ti { + $map_neutral + } + #[inline(always)] + fn reduce_neutral() -> $ti { + $reduce_neutral + } + $run + $reduce_two + } + } + }; +} + +pub trait MapReduce: Send + Sync + Debug + dyn_clone::DynClone +where + Params: Copy + Send + Sync + Debug + 'static + Default, + T: Copy + Debug + PartialEq + Send + Sync, +{ + fn run(&self, vec: &mut [T]) -> TractResult { + self.run_with_params(vec, Params::default()) + } + fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult; +} + +dyn_clone::clone_trait_object!( MapReduce where T: Copy, Params: Copy); + +#[derive(Debug, Clone, new)] +pub struct MapReduceImpl +where + T: LADatum, + Params: Copy + Send + Sync + Debug + 'static + Default, + K: MapReduceKer + Clone, +{ + phantom: PhantomData<(K, T, Params)>, +} + +impl MapReduce for MapReduceImpl +where + T: LADatum, + Params: Copy + Send + Sync + Debug + 'static + Default, + K: MapReduceKer + Clone, +{ + fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult { + map_reduce_slice_with_alignment( + vec, + |data| K::run(data, params), + K::nr(), + K::alignment_bytes(), + K::map_neutral(), + K::reduce_neutral(), + K::reduce_two, + ) + } +} + +pub trait MapReduceKer: + Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static +where + Params: Copy + Send + Sync + Debug + 'static + Default, + T: LADatum, +{ + fn name() -> &'static str; + fn alignment_bytes() -> usize; + fn alignment_items() -> usize; + fn nr() -> usize; + fn map_neutral() -> T; + fn reduce_neutral() -> T; + fn reduce_two(a: T, b: T) -> T; + fn run(vec: &mut [T], params: Params) -> T; + fn red() -> Box> { + Box::new(MapReduceImpl::::new()) + } +} + +#[cfg(test)] +pub mod test { + use super::*; + use proptest::test_runner::{TestCaseError, TestCaseResult}; + use tract_data::internal::*; + use tract_data::itertools::Itertools; + + pub fn test_reduce, T: LADatum>( + values: &[T], + neutral: T, + reference_reduce: impl Fn(T, T) -> T, + ) -> TestCaseResult { + test_reduce_params::(values, neutral, reference_reduce, ()) + } + + pub fn test_reduce_params, T: LADatum, Params>( + values: &[T], + neutral: T, + reference_reducer: impl Fn(T, T) -> T, + params: Params, + ) -> TestCaseResult + where + Params: Copy + Send + Sync + Debug + 'static + Default, + { + crate::setup_test_logger(); + let op = K::red(); + let expected = values.iter().fold(neutral, |acc, i| reference_reducer(acc, *i)); + let mut found = values; + let red = op.run_with_params(&mut found, params).unwrap(); + tensor0(red) + .close_enough(&tensor0(expected), true) + .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?; + Ok(()) + } + + pub fn test_map_reduce, T: LADatum>( + values: &[T], + map_neutral: T, + neutral: T, + reference_map: impl Fn(T) -> T, + reference_reduce: impl Fn(T, T) -> T, + ) -> TestCaseResult { + test_map_reduce_params::(values, map_neutral, neutral, reference_map, reference_reduce, ()) + } + + pub fn test_map_reduce_params, T: LADatum, Params>( + values: &[T], + _neutral: T, + map_neutral: T, + reference_map: impl Fn(T) -> T, + reference_reducer: impl Fn(T, T) -> T, + params: Params, + ) -> TestCaseResult + where + Params: Copy + Send + Sync + Debug + 'static + Default, + { + crate::setup_test_logger(); + let op = K::red(); + let mut found = values.to_vec(); + let expected_values = values.iter().copied().map(reference_map).collect_vec(); + let expected_reduced = expected_values.iter().fold(map_neutral, |acc, i| reference_reducer(acc, *i)); + let red = op.run_with_params(&mut found, params).unwrap(); + tensor1(&found) + .close_enough(&tensor1(&expected_values), Approximation::SuperApproximate) + .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?; + tensor0(red) + .close_enough(&tensor0(expected_reduced), Approximation::SuperApproximate) + .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?; + Ok(()) + } +} diff --git a/linalg/src/frame/softmax.rs b/linalg/src/frame/softmax.rs new file mode 100644 index 0000000000..d6f613a04e --- /dev/null +++ b/linalg/src/frame/softmax.rs @@ -0,0 +1,69 @@ +#[cfg(test)] +#[macro_use] +pub mod test { + use crate::frame::reduce::MapReduceKer; + use crate::LADatum; + use num_traits::{AsPrimitive, Float}; + use proptest::test_runner::TestCaseResult; + + #[macro_export] + macro_rules! softmax_l2_frame_tests { + ($cond:expr, $t: ty, $ker:ty) => { + proptest::proptest! { + #[test] + fn prop(xs in proptest::collection::vec(-25f32..25.0, 1..100)) { + if $cond { + $crate::frame::softmax::test::test_softmax_l2::<$ker, $t>(&*xs).unwrap() + } + } + } + #[test] + fn single() { + if $cond { + $crate::frame::softmax::test::test_softmax_l2::<$ker, $t>(&[0.0]).unwrap() + } + } + #[test] + fn two() { + if $cond { + $crate::frame::softmax::test::test_softmax_l2::<$ker, $t>(&[ + 16.62555, 21.950674, + ]) + .unwrap() + } + } + + #[test] + fn two_missing_max() { + if $cond { + $crate::frame::softmax::test::test_softmax_l2::<$ker, $t>(&[ + -46.15512, 42.875168 + ]) + .unwrap() + } + } + }; + } + + pub fn test_softmax_l2, T: LADatum + Float>( + values: &[f32], + ) -> TestCaseResult + where + f32: AsPrimitive, + T: AsPrimitive, + { + use crate::generic::softmax::fast_compact_exp_f32; + crate::setup_test_logger(); + let max = values.iter().max_by(|a, b| a.total_cmp(b)).unwrap(); + let values: Vec = values.iter().copied().map(|x| x.as_()).collect(); + crate::frame::reduce::test::test_map_reduce_params::( + &values, + ::min_value(), + T::zero(), + // |x| (x - max.as_()).exp(), + |x| fast_compact_exp_f32(dbg!(x).as_() - max).as_(), + |a, b| a + b, + max.as_(), + ) + } +} diff --git a/linalg/src/generic.rs b/linalg/src/generic.rs index ddbc23f6d4..ba28449b47 100644 --- a/linalg/src/generic.rs +++ b/linalg/src/generic.rs @@ -1,11 +1,15 @@ +pub mod by_scalar; pub mod erf; pub mod leaky_relu; pub mod lut; +pub mod max; pub mod mmm; pub mod rounding; pub mod sigmoid; +pub mod softmax; pub mod tanh; +pub use self::by_scalar::SMulByScalar4; pub use self::erf::SErf4; pub use self::leaky_relu::{HLeakyRelu8, SLeakyRelu4}; pub use self::lut::GenericLut8; @@ -13,4 +17,5 @@ pub use self::mmm::GenericMmm4x1; pub use self::mmm::GenericMmm4x4; pub use self::rounding::{ScaleShiftAndRound, Scaler}; pub use self::sigmoid::{HSigmoid8, SSigmoid4}; +pub use self::softmax::SSoftMaxL2; pub use self::tanh::{HTanh8, STanh4}; diff --git a/linalg/src/generic/by_scalar.rs b/linalg/src/generic/by_scalar.rs new file mode 100644 index 0000000000..a780c9e69c --- /dev/null +++ b/linalg/src/generic/by_scalar.rs @@ -0,0 +1,35 @@ +use crate::element_wise::ElementWiseKer; + + +#[derive(Clone, Debug)] +pub struct SMulByScalar4; + +impl ElementWiseKer for SMulByScalar4 { + fn name() -> &'static str { + "generic" + } + + fn alignment_items() -> usize { + 16 + } + + fn alignment_bytes() -> usize { + 16 + } + + fn nr() -> usize { + 4 + } + + fn run(x: &mut [f32], s: f32) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + x.iter_mut().for_each(|px| *px *= s) + } +} + +#[cfg(test)] +#[macro_use] +pub mod mul_by_scalar_f32 { + mul_by_scalar_frame_tests!(true, f32, crate::generic::by_scalar::SMulByScalar4); +} diff --git a/linalg/src/generic/max.rs b/linalg/src/generic/max.rs new file mode 100644 index 0000000000..5d21badc86 --- /dev/null +++ b/linalg/src/generic/max.rs @@ -0,0 +1,42 @@ +use crate::frame::reduce::ReduceKer; + +#[derive(Clone, Debug)] +pub struct SMax4; + +impl ReduceKer for SMax4 { + fn name() -> &'static str { + "generic" + } + + fn alignment_bytes() -> usize { + 16 + } + + fn alignment_items() -> usize { + 4 + } + + fn nr() -> usize { + 4 + } + + fn neutral() -> f32 { + f32::MIN + } + + fn reduce_two(a: f32, b: f32) -> f32 { + a.max(b) + } + + fn run(x: &[f32], _: ()) -> f32 { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap() + } +} + +#[cfg(test)] +#[macro_use] +pub mod s { + max_frame_tests!(true, f32, crate::generic::max::SMax4); +} diff --git a/linalg/src/generic/softmax.rs b/linalg/src/generic/softmax.rs new file mode 100644 index 0000000000..56411f1f61 --- /dev/null +++ b/linalg/src/generic/softmax.rs @@ -0,0 +1,68 @@ +use crate::frame::reduce::MapReduceKer; + +#[derive(Clone, Debug)] +pub struct SSoftMaxL2; + +impl MapReduceKer for SSoftMaxL2 { + fn name() -> &'static str { + "generic" + } + + fn alignment_bytes() -> usize { + 16 + } + + fn alignment_items() -> usize { + 4 + } + + fn nr() -> usize { + 4 + } + + fn map_neutral() -> f32 { + f32::MIN + } + + fn reduce_neutral() -> f32 { + 0. + } + + fn reduce_two(a: f32, b: f32) -> f32 { + a + b + } + + fn run(x: &mut [f32], max: f32) -> f32 { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + let mut sum = 0.; + for v in x.iter_mut() { + let y = *v - max; + let y = fast_compact_exp_f32(y); + *v = y; + sum += y; + } + sum + } +} + +// ported from https://github.com/gnuradio/volk/blob/master/kernels/volk/volk_32f_expfast_32f.h +// probably inspired from https://nic.schraudolph.org/pubs/Schraudolph99.pdf +// not that the cast to u32 deals with negative right, while implem in volk code are wrong in some +// corner cases (need a max(0,x) before the u32 conversion) +pub fn fast_compact_exp_f32(v: f32) -> f32 { + const MLN2: f32 = 0.6931471805f32; + const A: f32 = 8388608.0f32; + const B: f32 = 1065353216.0f32; + const C: f32 = 60801.0f32; + const SLOPE: f32 = A / MLN2; + const OFFSET: f32 = B - C; + f32::from_bits(((SLOPE * v) + OFFSET) as u32) +} + + +#[cfg(test)] +#[macro_use] +pub mod s { + softmax_l2_frame_tests!(true, f32, crate::generic::softmax::SSoftMaxL2); +} diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index 3dd60ff7d0..83de457c38 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -1,5 +1,8 @@ #![allow(clippy::missing_safety_doc)] #![allow(clippy::redundant_closure_call)] +#![allow(clippy::len_zero)] +#![allow(clippy::excessive_precision)] +#![allow(clippy::approx_constant)] #[macro_use] extern crate derive_new; extern crate lazy_static; @@ -16,7 +19,8 @@ include!(concat!(env!("OUT_DIR"), "/extern_kernel_macro.rs")); pub mod frame; pub mod generic; use frame::element_wise::ElementWiseKer; -use frame::MatMatMul; +use frame::reduce::{MapReduceKer, ReduceKer}; +use frame::{reduce, MatMatMul}; pub use generic::{ScaleShiftAndRound, Scaler}; #[cfg(target_arch = "x86_64")] pub mod x86_64_fma; @@ -55,6 +59,8 @@ pub struct Ops { pub leaky_relu_f16: Box Box> + Send + Sync>, pub leaky_relu_f32: Box Box> + Send + Sync>, + pub mul_by_scalar_f32: + Box Box> + Send + Sync>, pub sigmoid_f16: Box Box> + Send + Sync>, pub sigmoid_f32: Box Box> + Send + Sync>, @@ -62,6 +68,10 @@ pub struct Ops { pub tanh_f32: Box Box> + Send + Sync>, pub erf_f32: Box Box> + Send + Sync>, pub lut_u8: Box Box + Send + Sync>, + + pub max_f32: Box Box> + Send + Sync>, + + pub softmax2_fastcompact_f32: Box Box> + Send + Sync>, } impl Ops { @@ -113,15 +123,18 @@ pub fn generic() -> Ops { qmmv_i32: Box::new(|_, _| generic::GenericMmm4x1::::mmm()), leaky_relu_f16: Box::new(|| generic::HLeakyRelu8::ew()), leaky_relu_f32: Box::new(|| generic::SLeakyRelu4::ew()), + mul_by_scalar_f32: Box::new(|| generic::SMulByScalar4::ew()), sigmoid_f16: Box::new(|| generic::HSigmoid8::ew()), sigmoid_f32: Box::new(|| generic::SSigmoid4::ew()), tanh_f16: Box::new(|| generic::HTanh8::ew()), tanh_f32: Box::new(|| generic::STanh4::ew()), erf_f32: Box::new(|| generic::SErf4::ew()), lut_u8: Box::new(|table: &[u8]| Box::new(lut::LutImpl::::new(table))), + max_f32: Box::new(|| generic::max::SMax4::red()), /* activation_f32: Box::new(|microcode| generic::SActivation::new(microcode)) */ + softmax2_fastcompact_f32: Box::new(|| generic::softmax::SSoftMaxL2::red()), } } diff --git a/linalg/src/x86_64_fma.rs b/linalg/src/x86_64_fma.rs index 12664bdcca..d5f6261661 100644 --- a/linalg/src/x86_64_fma.rs +++ b/linalg/src/x86_64_fma.rs @@ -1,10 +1,15 @@ use crate::frame::element_wise::ElementWiseKer; use crate::frame::mmm::kernel::MatMatMulKer; +use crate::frame::reduce::{MapReduceKer, ReduceKer}; +use crate::x86_64_fma::softmax::x86_64_fma_softmax2_fastcompact_f32_32n; use crate::Ops; pub mod mmm; +pub mod by_scalar; mod intel; +pub mod max; +pub mod softmax; tanh_impl!(f32, fma_tanh_f32, 8, 8, is_x86_feature_detected!("fma")); sigmoid_impl!(f32, fma_sigmoid_f32, 8, 8, is_x86_feature_detected!("fma")); @@ -90,6 +95,10 @@ fn plug_fma(ops: &mut Ops) { ops.sigmoid_f32 = Box::new(|| fma_sigmoid_f32::ew()); ops.tanh_f32 = Box::new(|| fma_tanh_f32::ew()); + + ops.mul_by_scalar_f32 = Box::new(|| by_scalar::x86_64_avx_f32_mul_by_scalar_32n::ew()); + ops.max_f32 = Box::new(|| max::x86_64_fma_max_f32_32n::red()); + ops.softmax2_fastcompact_f32 = Box::new(|| x86_64_fma_softmax2_fastcompact_f32_32n::red()); log::info!("mmm_f32, mmv_f32, sigmoid_f32, tanh_f32: x86_64/fma activated"); } diff --git a/linalg/src/x86_64_fma/by_scalar.rs b/linalg/src/x86_64_fma/by_scalar.rs new file mode 100644 index 0000000000..c2e7c9abda --- /dev/null +++ b/linalg/src/x86_64_fma/by_scalar.rs @@ -0,0 +1,52 @@ +ew_impl_wrap!( + f32, + x86_64_avx_f32_mul_by_scalar_32n, + 32, + 8, + f32, + fn run(x: &mut [f32], s: f32) { + debug_assert!(x.len() % Self::nr() == 0); + debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0); + unsafe { x86_64_avx_f32_mul_by_scalar_32n_run(x, s) } + } +); + +#[target_feature(enable = "avx")] +unsafe fn x86_64_avx_f32_mul_by_scalar_32n_run(buf: &mut [f32], scalar: f32) { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + vbroadcastss ymm0, xmm0 + 2: + vmovaps ymm4, [{ptr}] + vmovaps ymm5, [{ptr} + 32] + vmovaps ymm6, [{ptr} + 64] + vmovaps ymm7, [{ptr} + 96] + vmulps ymm4, ymm4, ymm0 + vmulps ymm5, ymm5, ymm0 + vmulps ymm6, ymm6, ymm0 + vmulps ymm7, ymm7, ymm0 + vmovaps [{ptr}], ymm4 + vmovaps [{ptr} + 32], ymm5 + vmovaps [{ptr} + 64], ymm6 + vmovaps [{ptr} + 96], ymm7 + add {ptr}, 128 + sub {len}, 32 + jnz 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("xmm0") scalar, + out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _ + ); +} + +#[cfg(test)] +#[macro_use] +pub mod test_x86_64_avx_f32_mul_by_scalar_32n { + mul_by_scalar_frame_tests!( + is_x86_feature_detected!("avx2"), + f32, + crate::x86_64_fma::by_scalar::x86_64_avx_f32_mul_by_scalar_32n + ); +} diff --git a/linalg/src/x86_64_fma/max.rs b/linalg/src/x86_64_fma/max.rs new file mode 100644 index 0000000000..beec1deaa3 --- /dev/null +++ b/linalg/src/x86_64_fma/max.rs @@ -0,0 +1,65 @@ +reduce_impl_wrap!( + f32, + x86_64_fma_max_f32_32n, + 32, + 8, + (), + f32::MIN, + #[inline(never)] + fn run(buf: &[f32], _: ()) -> f32 { + assert!(buf.len() % 32 == 0); + assert!(buf.len() > 0); + unsafe { x86_64_fma_max_f32_32n_run(buf) } + }, + #[inline(never)] + fn reduce_two(a: f32, b: f32) -> f32 { + a.max(b) + } +); + +#[target_feature(enable = "avx")] +unsafe fn x86_64_fma_max_f32_32n_run(buf: &[f32]) -> f32 { + let len = buf.len(); + let ptr = buf.as_ptr(); + let mut acc = f32::MIN; + std::arch::asm!(" + vbroadcastss ymm0, xmm0 + vmovaps ymm1, ymm0 + vmovaps ymm2, ymm0 + vmovaps ymm3, ymm0 + 2: + vmovaps ymm4, [{ptr}] + vmovaps ymm5, [{ptr} + 32] + vmovaps ymm6, [{ptr} + 64] + vmovaps ymm7, [{ptr} + 96] + vmaxps ymm0, ymm0, ymm4 + vmaxps ymm1, ymm1, ymm5 + vmaxps ymm2, ymm2, ymm6 + vmaxps ymm3, ymm3, ymm7 + add {ptr}, 128 + sub {len}, 32 + jnz 2b + vmaxps ymm0, ymm0, ymm1 + vmaxps ymm2, ymm2, ymm3 + vmaxps ymm0, ymm0, ymm2 + vperm2f128 ymm1, ymm0, ymm0, 1 // copy second half (4xf32) of ymm0 to ymm1 + vmaxps xmm0, xmm0, xmm1 // xmm0 contains 4 values to max + vpermilps xmm1, xmm0, 2 + (3 << 2) // second 2x32 bit half moved to top + vmaxps xmm0, xmm0, xmm1 // xmm0 containes 2 values + vpermilps xmm1, xmm0, 1 // second f32 to top + vmaxps xmm0, xmm0, xmm1 + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("ymm0") acc, + out("ymm1") _, out("ymm2") _, out("ymm3") _, + out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _ + ); + acc +} + +#[cfg(test)] +mod test_x86_64_fma_max_f32_32n { + use super::*; + max_frame_tests!(is_x86_feature_detected!("avx2"), f32, x86_64_fma_max_f32_32n); +} diff --git a/linalg/src/x86_64_fma/softmax.rs b/linalg/src/x86_64_fma/softmax.rs new file mode 100644 index 0000000000..e4f8aeb94d --- /dev/null +++ b/linalg/src/x86_64_fma/softmax.rs @@ -0,0 +1,115 @@ +map_reduce_impl_wrap!( + f32, + x86_64_fma_softmax2_fastcompact_f32_32n, + 32, + 8, + f32, + f32::MIN, + 0f32, + #[inline(never)] + fn run(buf: &mut [f32], max: f32) -> f32 { + assert!(buf.len() % 32 == 0); + assert!(buf.len() > 0); + unsafe { x86_64_fma_softmax2_fastcompact_f32_32n_run(buf, max) } + }, + #[inline(never)] + fn reduce_two(a: f32, b: f32) -> f32 { + a + b + } +); + +#[target_feature(enable = "avx,fma")] +unsafe fn x86_64_fma_softmax2_fastcompact_f32_32n_run(buf: &mut [f32], max: f32) -> f32 { + let len = buf.len(); + let ptr = buf.as_ptr(); + let mut acc = 0f32; + const MLN2: f32 = 0.6931471805f32; + const A: f32 = 8388608.0f32; + const B: f32 = 1065353216.0f32; + const C: f32 = 60801.0f32; + const SLOPE: f32 = A / MLN2; + const OFFSET: f32 = B - C; + std::arch::asm!(" + vbroadcastss ymm0, xmm0 + vmovaps ymm1, ymm0 + vmovaps ymm2, ymm0 + vmovaps ymm3, ymm0 + + vpxor ymm12, ymm12, ymm12 + vbroadcastss ymm13, xmm13 + vbroadcastss ymm14, xmm14 + vbroadcastss ymm15, xmm15 + 2: + vmovaps ymm4, [{ptr}] + vmovaps ymm5, [{ptr} + 32] + vmovaps ymm6, [{ptr} + 64] + vmovaps ymm7, [{ptr} + 96] + + vsubps ymm4, ymm4, ymm13 + vsubps ymm5, ymm5, ymm13 + vsubps ymm6, ymm6, ymm13 + vsubps ymm7, ymm7, ymm13 + + vmovaps ymm8, ymm15 + vmovaps ymm9, ymm15 + vmovaps ymm10, ymm15 + vmovaps ymm11, ymm15 + + vfmadd231ps ymm8, ymm4, ymm14 + vfmadd231ps ymm9, ymm5, ymm14 + vfmadd231ps ymm10, ymm6, ymm14 + vfmadd231ps ymm11, ymm7, ymm14 + + vmaxps ymm8, ymm8, ymm12 + vmaxps ymm9, ymm9, ymm12 + vmaxps ymm10, ymm10, ymm12 + vmaxps ymm11, ymm11, ymm12 + + vcvttps2dq ymm8, ymm8 + vcvttps2dq ymm9, ymm9 + vcvttps2dq ymm10, ymm10 + vcvttps2dq ymm11, ymm11 + + vmovaps [{ptr}] , ymm8 + vmovaps [{ptr} + 32], ymm9 + vmovaps [{ptr} + 64], ymm10 + vmovaps [{ptr} + 96], ymm11 + + vaddps ymm0, ymm0, ymm8 + vaddps ymm1, ymm1, ymm9 + vaddps ymm2, ymm2, ymm10 + vaddps ymm3, ymm3, ymm11 + + add {ptr}, 128 + sub {len}, 32 + jnz 2b + + vaddps ymm0, ymm0, ymm1 + vaddps ymm2, ymm2, ymm3 + vaddps ymm0, ymm0, ymm2 + vperm2f128 ymm1, ymm0, ymm0, 1 + vaddps xmm0, xmm0, xmm1 + vpermilps xmm1, xmm0, 2 + (3 << 2) + vaddps xmm0, xmm0, xmm1 + vpermilps xmm1, xmm0, 1 + vaddps xmm0, xmm0, xmm1 + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + inout("ymm0") acc, + out("ymm1") _, out("ymm2") _, out("ymm3") _, + out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _, + out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _, + out("ymm12") _, + inout("ymm13") max => _, + inout("ymm14") SLOPE => _, + inout("ymm15") OFFSET => _, + ); + acc +} + +#[cfg(test)] +mod test_x86_64_fma_softmax2_fastcompact_f32_32n { + use super::*; + softmax_l2_frame_tests!(is_x86_feature_detected!("fma"), f32, x86_64_fma_softmax2_fastcompact_f32_32n); +} diff --git a/nnef/src/ops/core.rs b/nnef/src/ops/core.rs index c322ef9559..dec4aeec94 100644 --- a/nnef/src/ops/core.rs +++ b/nnef/src/ops/core.rs @@ -21,6 +21,7 @@ mod reduce; mod scan; mod scatter; mod shape_of; +mod softmax; mod source; mod store; mod submodel; @@ -59,6 +60,7 @@ pub fn register(registry: &mut Registry) { scan::register(registry); scatter::register(registry); shape_of::register(registry); + softmax::register(registry); source::register(registry); store::register(registry); submodel::register(registry); diff --git a/nnef/src/ops/core/softmax.rs b/nnef/src/ops/core/softmax.rs new file mode 100644 index 0000000000..8e1d9a0ac2 --- /dev/null +++ b/nnef/src/ops/core/softmax.rs @@ -0,0 +1,36 @@ +use tract_core::ops::nn::{Softmax, SoftmaxExp}; + +use crate::internal::*; + +pub fn register(registry: &mut Registry) { + registry.register_primitive( + "tract_core_softmax", + &[ + TypeName::Scalar.tensor().named("x"), + TypeName::Integer.tensor().named("axes"), + TypeName::String.named("exp"), + ], + &[("output", TypeName::Scalar.tensor())], + softmax, + ); +} + +pub fn softmax(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult { + let x = invocation.named_arg_as(builder, "x")?; + let axes: TVec = invocation.named_arg_as(builder, "axes")?; + + let input_fact = builder.model.outlet_fact(x)?.clone(); + let quant_output_dt = if input_fact.datum_type.is_float() { + None + } else { + invocation.dt_from_quant_file.first().cloned().flatten() + }; + + let exp: Option = invocation.get_named_arg_as(builder, "exp")?; + let exp = match exp.as_deref() { + Some("fast_compact") => SoftmaxExp::FastCompact, + _ => SoftmaxExp::Libc + }; + + builder.wire(Softmax { axes, quant_output_dt, exp }, &[x]) +} diff --git a/nnef/src/ops/nnef/deser.rs b/nnef/src/ops/nnef/deser.rs index ca2fe82a63..a16094e65f 100644 --- a/nnef/src/ops/nnef/deser.rs +++ b/nnef/src/ops/nnef/deser.rs @@ -7,7 +7,7 @@ use tract_core::ops::array::PadMode; use tract_core::ops::cnn::deconv::adjustments; use tract_core::ops::cnn::PaddingSpec; use tract_core::ops::cnn::PoolSpec; -use tract_core::ops::nn::DataFormat; +use tract_core::ops::nn::{DataFormat, Softmax, SoftmaxExp}; use tract_itertools::izip; use tract_itertools::Itertools; @@ -686,5 +686,5 @@ pub fn softmax(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> T invocation.dt_from_quant_file.first().cloned().flatten() }; - builder.wire(ops::nn::Softmax { axes, quant_output_dt }, &[x]) + builder.wire(Softmax { axes, quant_output_dt, exp: SoftmaxExp::default() }, &[x]) } diff --git a/nnef/src/ops/nnef/ser.rs b/nnef/src/ops/nnef/ser.rs index 17b980e36f..89a575fbea 100644 --- a/nnef/src/ops/nnef/ser.rs +++ b/nnef/src/ops/nnef/ser.rs @@ -11,6 +11,7 @@ use tract_core::ops::cnn::KernelFormat; use tract_core::ops::cnn::PoolSpec; use tract_core::ops::einsum::BasicMatMul; use tract_core::ops::nn::DataFormat; +use tract_core::ops::nn::SoftmaxExp; use tract_core::tract_data::itertools::Itertools; pub fn source( @@ -453,6 +454,9 @@ pub fn softmax( node: &TypedNode, op: &ops::nn::Softmax, ) -> TractResult>> { + if op.exp != SoftmaxExp::default() { + return Ok(None) + } let litteral_axes: Vec<_> = op.axes.iter().map(|&it| (it as i64).into()).collect(); Ok(Some(invocation( "softmax", diff --git a/tflite/src/ops/nn.rs b/tflite/src/ops/nn.rs index 2d9d4ccb1d..c612475c08 100644 --- a/tflite/src/ops/nn.rs +++ b/tflite/src/ops/nn.rs @@ -140,7 +140,7 @@ fn de_softmax(op: &mut DeserOp) -> TractResult> { let options = builtin!(op, builtin_options_as_softmax_options); ensure!(options.beta() == 1.0); let quant_output_dt = Some(input.datum_type).filter(|dt| !dt.is_float()); - let softmax = core::nn::Softmax { axes: tvec!(input.rank() - 1), quant_output_dt }; + let softmax = Softmax { axes: tvec!(input.rank() - 1), quant_output_dt, ..Softmax::default() }; op.ctx.target.wire_node(op.prefix, softmax, op.inputs) }