Skip to content

Commit

Permalink
mul_by_scalar_f16
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jan 31, 2024
1 parent c057192 commit 310ffc8
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
2 changes: 2 additions & 0 deletions linalg/src/arm64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ pub fn plug(ops: &mut Ops) {
ops.leaky_relu_f16 = Box::new(|| arm64fp16_leaky_relu_f16_16n::ew());
ops.tanh_f16 = Box::new(|| arm64fp16_tanh_f16_8n::ew());
ops.sigmoid_f16 = Box::new(|| arm64fp16_sigmoid_f16_8n::ew());
ops.max_f16 = Box::new(|| arm64fp16_max_f16_32n::red());
ops.mul_by_scalar_f16 = Box::new(|| arm64fp16_mul_by_scalar_f16_32n::ew());
} else {
log::info!("No native fp16 support");
}
Expand Down
4 changes: 3 additions & 1 deletion linalg/src/arm64/arm64fp16.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use tract_data::half::f16;

mod by_scalar;
mod leaky_relu;
mod max;
pub use by_scalar::*;
pub use leaky_relu::*;
pub use max::arm64fp16_max_f16_32n;
pub use max::*;

use crate::frame::mmm::*;
MMMKernel!(f16, arm64fp16_mmm_f16_16x8_gen; 16, 8; 16, 16; 1, 1; no_prefetch, crate::arm64::has_fp16());
Expand Down
41 changes: 41 additions & 0 deletions linalg/src/arm64/arm64fp16/by_scalar.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use crate::f16;

ew_impl_wrap!(
f16,
arm64fp16_mul_by_scalar_f16_32n,
32,
4,
f16,
fn run(buf: &mut [f16], s: f16) {
assert!(buf.len() % 16 == 0);
assert!(buf.len() > 0);
#[target_feature(enable = "fp16")]
unsafe fn run(buf: &mut[f16], s: f16) {
let len = buf.len();
let ptr = buf.as_ptr();
std::arch::asm!("
dup v0.8h, v0.h[0]
1:
ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}]
fmul v4.8h, v4.8h, v0.8h
fmul v5.8h, v5.8h, v0.8h
fmul v6.8h, v6.8h, v0.8h
fmul v7.8h, v7.8h, v0.8h
st1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{ptr}], 64
subs {len}, {len}, 32
bne 1b
",
len = inout(reg) len => _,
ptr = inout(reg) ptr => _,
in("v0") s.to_bits(),
out("v4") _, out("v5") _, out("v6") _, out("v7") _,);
}
unsafe { run(buf, s) }
}
);

#[cfg(test)]
mod test_arm64fp16_mul_by_scalar_f16_32n {
use super::*;
mul_by_scalar_frame_tests!(crate::arm64::has_fp16(), f16, arm64fp16_mul_by_scalar_f16_32n);
}

0 comments on commit 310ffc8

Please sign in to comment.