diff --git a/linalg/src/arm64.rs b/linalg/src/arm64.rs index 6c7f4fede4..76782ee005 100644 --- a/linalg/src/arm64.rs +++ b/linalg/src/arm64.rs @@ -241,6 +241,8 @@ pub fn plug(ops: &mut Ops) { ops.leaky_relu_f16 = Box::new(|| arm64fp16_leaky_relu_f16_16n::ew()); ops.tanh_f16 = Box::new(|| arm64fp16_tanh_f16_8n::ew()); ops.sigmoid_f16 = Box::new(|| arm64fp16_sigmoid_f16_8n::ew()); + ops.max_f16 = Box::new(|| arm64fp16_max_f16_32n::red()); + ops.mul_by_scalar_f16 = Box::new(|| arm64fp16_mul_by_scalar_f16_32n::ew()); } else { log::info!("No native fp16 support"); } diff --git a/linalg/src/arm64/arm64fp16.rs b/linalg/src/arm64/arm64fp16.rs index bfdcc8bd21..3873c6404f 100644 --- a/linalg/src/arm64/arm64fp16.rs +++ b/linalg/src/arm64/arm64fp16.rs @@ -1,9 +1,11 @@ use tract_data::half::f16; +mod by_scalar; mod leaky_relu; mod max; +pub use by_scalar::*; pub use leaky_relu::*; -pub use max::arm64fp16_max_f16_32n; +pub use max::*; use crate::frame::mmm::*; MMMKernel!(f16, arm64fp16_mmm_f16_16x8_gen; 16, 8; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16()); diff --git a/linalg/src/arm64/arm64fp16/by_scalar.rs b/linalg/src/arm64/arm64fp16/by_scalar.rs new file mode 100644 index 0000000000..204ce212e1 --- /dev/null +++ b/linalg/src/arm64/arm64fp16/by_scalar.rs @@ -0,0 +1,41 @@ +use crate::f16; + +ew_impl_wrap!( + f16, + arm64fp16_mul_by_scalar_f16_32n, + 32, + 4, + f16, + fn run(buf: &mut [f16], s: f16) { + assert!(buf.len() % 16 == 0); + assert!(buf.len() > 0); + #[target_feature(enable = "fp16")] + unsafe fn run(buf: &mut[f16], s: f16) { + let len = buf.len(); + let ptr = buf.as_ptr(); + std::arch::asm!(" + dup v0.8h, v0.h[0] + 1: + ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}] + fmul v4.8h, v4.8h, v0.8h + fmul v5.8h, v5.8h, v0.8h + fmul v6.8h, v6.8h, v0.8h + fmul v7.8h, v7.8h, v0.8h + st1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}], 64 + subs {len}, {len}, 32 + bne 1b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("v0") s.to_bits(), + out("v4") _, out("v5") _, out("v6") _, out("v7") _,); + } + unsafe { run(buf, s) } + } +); + +#[cfg(test)] +mod test_arm64fp16_mul_by_scalar_f16_32n { + use super::*; + mul_by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_mul_by_scalar_f16_32n); +}