Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More f16 #1221

Merged
merged 3 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/src/half.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::ops::array::{Pad, PadMode};
use crate::ops::cnn::{ConvUnary, DeconvUnary};
use crate::ops::einsum::EinSum;
use crate::ops::konst::Const;
use crate::ops::scan::Scan;
use crate::ops::source::TypedSource;

#[derive(Debug)]
Expand All @@ -27,6 +28,9 @@ impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for Hal
bias: op.bias.as_ref().map(tensor_f32_to_f16),
..op.clone()
})
} else if let Some(op) = node.op_as::<Scan>() {
let body = HalfTranslator.translate_model(&op.body)?;
Box::new(Scan { body, .. op.clone() })
} else if let Some(op) = node.op_as::<EinSum>() {
Box::new(EinSum { operating_dt: dt_f32_to_f16(op.operating_dt), ..op.clone() })
} else if let Some(op) = node.op_as::<DeconvUnary>() {
Expand Down
5 changes: 4 additions & 1 deletion core/src/ops/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ mod data_formats;
mod reduce;
mod softmax;

use tract_num_traits::{AsPrimitive, Zero};

pub use self::data_formats::{BaseDataShape, DataFormat, DataShape, SymDataShape};
pub use self::reduce::{Reduce, Reducer};
pub use self::softmax::Softmax;
Expand All @@ -19,5 +21,6 @@ element_wise!(hard_swish, HardSwish,
);

element_wise!(leaky_relu, LeakyRelu { alpha: f32 },
[f32] => |op, xs| { xs.iter_mut().for_each(|x| *x *= if *x < 0. { op.alpha } else { 1.0 }); Ok(()) }
[f32] => |op, xs| { xs.iter_mut().for_each(|x| *x *= if *x < 0. { op.alpha } else { 1.0 }); Ok(()) },
[f16] => |op, xs| { xs.iter_mut().for_each(|x| *x *= if *x < f16::zero() { AsPrimitive::<f16>::as_(op.alpha) } else { (1.0).as_() }); Ok(()) }
);
2 changes: 1 addition & 1 deletion core/src/ops/scan/mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub struct Scan {
pub skip: usize,
pub reset_every_turn: bool,
pub body: TypedModel,
decluttered: bool,
pub decluttered: bool,
pub input_mapping: Vec<InputMapping>,
pub output_mapping: Vec<OutputMapping<TDim>>,
}
Expand Down
1 change: 1 addition & 0 deletions linalg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ walkdir.workspace = true

[dev-dependencies]
criterion.workspace = true
env_logger.workspace = true
nu-ansi-term.workspace = true
proptest.workspace = true
core_affinity.workspace = true
Expand Down
38 changes: 33 additions & 5 deletions linalg/src/arm64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ use crate::Ops;
use crate::frame::element_wise::ElementWiseKer;
use crate::frame::mmm::kernel::MatMatMulKer;

lazy_static::lazy_static! {
static ref KIND: Kind = Kind::choose();
}

// https://en.wikipedia.org/wiki/Comparison_of_ARMv8-A_cores
const PART_A53: &str = "0xd03";
const PART_A55: &str = "0xd05";
Expand All @@ -37,9 +33,34 @@ fn max_cpuid() -> std::io::Result<String> {
Ok(max.unwrap_or("").to_string())
}

lazy_static::lazy_static! {
static ref KIND: Kind = Kind::choose();

static ref CPU_FEATURES: Vec<String> = {
#[cfg(test)] crate::setup_test_logger();
let Ok(cpu_info) = std::fs::read_to_string("/proc/cpuinfo") else {
log::warn!("Could not read /proc/cpuinfo. CPU Features detection may be impaired.");
return vec!();
};
if let Some(line) = cpu_info
.lines()
.filter(|line| line.starts_with("Features"))
.next() {
line.split_once(":").unwrap().1.split_whitespace().map(|s| s.to_string()).collect()
} else {
log::warn!("Could not find \"Features :\" lines in /proc/cpuinfo. CPU Features detection may be impaired.");
vec!()
}
};

static ref HAS_FP16: bool = {
CPU_FEATURES.iter().find(|s| &**s == "asimdhp").is_some()
};
}

#[inline]
pub fn has_fp16() -> bool {
cfg!(feature_cpu = "fp16")
cfg!(feature_cpu = "fp16") || *KIND == Kind::CortexA55 || *KIND == Kind::CortexA75 || *HAS_FP16
}

#[derive(Debug, PartialEq, Eq, Copy, Clone)]
Expand All @@ -55,6 +76,8 @@ enum Kind {

impl Kind {
fn choose() -> Kind {
#[cfg(test)]
crate::setup_test_logger();
let kind = if let Ok(kind) = std::env::var("TRACT_CPU_AARCH64_KIND") {
log::info!("CPU kind forced with TRACT_CPU_AARCH64_KIND: {}", kind);
let kind = kind.to_lowercase();
Expand Down Expand Up @@ -150,6 +173,7 @@ pub fn plug(ops: &mut Ops) {
#[cfg(not(feature = "no_fp16"))]
if has_fp16() {
if *KIND == Kind::CortexA55 {
log::info!("Cortex-A55 mmm_f16 and mmv_f16 activated");
ops.mmm_f16 = Box::new(|_, _, n| {
use tract_data::internal::DimLike;
if n.unwrap_or(1024).divceil(4) * 4 < n.unwrap_or(1024).divceil(8) * 8 {
Expand All @@ -160,6 +184,7 @@ pub fn plug(ops: &mut Ops) {
});
ops.mmv_f16 = Box::new(|_, _| arm64fp16_mmm_f16_128x1_a55::mmm());
} else {
log::info!("ARMv8.2 mmm_f16 and mmv_f16 activated");
ops.mmm_f16 = Box::new(|_, _, n| {
use tract_data::internal::DimLike;
if n.unwrap_or(1024).divceil(4) * 4 < n.unwrap_or(1024).divceil(8) * 8 {
Expand All @@ -175,8 +200,11 @@ pub fn plug(ops: &mut Ops) {
ops.tanh_f32 = Box::new(|| arm64simd_tanh_f32_4n::ew());
#[cfg(not(feature = "no_fp16"))]
if has_fp16() {
log::info!("ARMv8.2 tanh_f16 and sigmoid_f16 activated");
ops.tanh_f16 = Box::new(|| arm64fp16_tanh_f16_8n::ew());
ops.sigmoid_f16 = Box::new(|| arm64fp16_sigmoid_f16_8n::ew());
} else {
log::info!("No native fp16 support");
}
#[cfg(target_os = "macos")]
{
Expand Down
1 change: 1 addition & 0 deletions linalg/src/frame/element_wise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ pub mod test {
values: &[T],
reference: F,
) -> TestCaseResult {
crate::setup_test_logger();
let op = ElementWiseImpl::<K, T>::new();
let mut values = values.to_vec();
while values.len() < K::nr() {
Expand Down
3 changes: 3 additions & 0 deletions linalg/src/frame/mmm/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ where
i32: AsPrimitive<TI>,
usize: AsPrimitive<TI>,
{
crate::setup_test_logger();
assert_eq!(a.datum_type(), TA::datum_type());
let op = MatMatMulImpl::<K, TI>::default();
unsafe {
Expand Down Expand Up @@ -281,6 +282,7 @@ where
i32: AsPrimitive<TI>,
usize: AsPrimitive<TI>,
{
crate::setup_test_logger();
unsafe {
let op = MatMatMulImpl::<K, TI>::default();
let mut packed_a =
Expand Down Expand Up @@ -327,6 +329,7 @@ where
i32: AsPrimitive<TI>,
usize: AsPrimitive<TI>,
{
crate::setup_test_logger();
let op = MatMatMulImpl::<K, TI>::default();

let mut found = Tensor::zero::<TC>(&[m, n]).unwrap();
Expand Down
1 change: 1 addition & 0 deletions linalg/src/frame/sigmoid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ pub mod test {
f32: AsPrimitive<T>,
T: AsPrimitive<f32>,
{
crate::setup_test_logger();
let values: Vec<T> = values.iter().copied().map(|x| x.as_()).collect();
crate::frame::element_wise::test::test_element_wise::<K, _, _>(&values, |x| {
(1f32).as_() / (1f32.as_() + (-x).exp())
Expand Down
1 change: 1 addition & 0 deletions linalg/src/frame/tanh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ pub mod test {
f32: AsPrimitive<T>,
T: AsPrimitive<f32>,
{
crate::setup_test_logger();
let values: Vec<T> = values.iter().copied().map(|x| x.as_()).collect();
crate::frame::element_wise::test::test_element_wise::<K, _, _>(&values, |x| x.tanh())
}
Expand Down
6 changes: 6 additions & 0 deletions linalg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,9 @@ impl LADatum for i32 {
any::<i32>().boxed()
}
}

#[cfg(test)]
#[allow(dead_code)]
fn setup_test_logger() {
let _ = env_logger::Builder::from_env("TRACT_LOG").try_init();
}