diff --git a/cli/src/main.rs b/cli/src/main.rs index 15b3f0a977..48c8306da2 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -587,7 +587,7 @@ fn handle(matches: clap::ArgMatches, probe: Option<&Probe>) -> TractResult<()> { for m in tract_linalg::ops().mmm_impls() { println!("{}", Green.paint(format!(" * {}", m.name()))); for packings in m.packings() { - println!(" - {} • {}", packings.0, packings.1); + println!(" - {:?} • {:?}", packings.0, packings.1); } } println!("{}", White.bold().paint("# MatMul kits")); diff --git a/core/src/ops/einsum/kernel_selection.rs b/core/src/ops/einsum/kernel_selection.rs index 99732acd14..649b5486f7 100644 --- a/core/src/ops/einsum/kernel_selection.rs +++ b/core/src/ops/einsum/kernel_selection.rs @@ -110,7 +110,7 @@ pub fn wire_linear( kit.weight == weight && kit.accumulator == accumulator && kit.activation == activation }) .min_by_key(|kit| kit.generic_fallback as usize) - .with_context(|| format!("No kit found for matmul {a:?} • {b_fact:?}"))?; + .with_context(|| format!("No kit found for matmul {weight:?} {accumulator:?} {activation:?}"))?; let configs = [kit.item_for_mv(), kit.item_for_squarish()]; let packed: Box = if let Some(a_payload) = a_as_bqv { let packed = kit diff --git a/data/src/datum.rs b/data/src/datum.rs index ef8d781a6c..768deccb04 100644 --- a/data/src/datum.rs +++ b/data/src/datum.rs @@ -492,6 +492,7 @@ pub trait Datum: { fn name() -> &'static str; fn datum_type() -> DatumType; + fn is() -> bool; } macro_rules! datum { @@ -510,6 +511,10 @@ macro_rules! datum { fn datum_type() -> DatumType { DatumType::$v } + + fn is() -> bool { + Self::datum_type() == D::datum_type() + } } }; } diff --git a/linalg/src/arm64/arm64fp16.rs b/linalg/src/arm64/arm64fp16.rs index e69d64d9f5..286f67f3bb 100644 --- a/linalg/src/arm64/arm64fp16.rs +++ b/linalg/src/arm64/arm64fp16.rs @@ -50,3 +50,13 @@ pub fn plug(ops: &mut Ops) { tanh_impl!(f16, arm64fp16_tanh_f16_8n, 8, 8, crate::arm64::has_fp16()); sigmoid_impl!(f16, arm64fp16_sigmoid_f16_8n, 8, 8, crate::arm64::has_fp16()); + +#[cfg(test)] +mod test { + + #[test] + fn kits() { + let mut ops = crate::generic(); + super::plug(&mut ops); + } +} diff --git a/linalg/src/arm64/arm64fp16/panel_extract.rs b/linalg/src/arm64/arm64fp16/panel_extract.rs index 21145f3fbe..bea69a9968 100644 --- a/linalg/src/arm64/arm64fp16/panel_extract.rs +++ b/linalg/src/arm64/arm64fp16/panel_extract.rs @@ -1,8 +1,8 @@ +use super::FP16; use crate::frame::block_quant::{PackedBlockQuantFormat, Q4_0}; -use crate::frame::PackedFormat; +use crate::mmm::Packing; use crate::Ops; use tract_data::internal::*; -use super::FP16; pub fn plug(ops: &mut Ops) { ops.panel_extractors.push(packed_64_q40_to_f16.clone()); @@ -10,7 +10,7 @@ pub fn plug(ops: &mut Ops) { panel_extractor!(kernel_packed_64_q40_to_f16 as packed_64_q40_to_f16( Box::new(PackedBlockQuantFormat::new(&Q4_0, 64, 16, true)), - PackedFormat::new(f16::datum_type(), 64, 16) + f16::packing(64).align(16) ) where(FP16)); #[target_feature(enable = "fp16")] diff --git a/linalg/src/frame/mmm/kernel.rs b/linalg/src/frame/mmm/kernel.rs index 1018483887..eeb671c7db 100644 --- a/linalg/src/frame/mmm/kernel.rs +++ b/linalg/src/frame/mmm/kernel.rs @@ -34,7 +34,6 @@ type Kernel = unsafe fn(&[FusedKerSpec]) -> isize; pub struct DynKernel { pub name: String, pub kernel: Kernel, - pub default_packing_alignments: (usize, usize), pub packings: Vec<(Box, Box)>, pub stores: Vec, pub supported_predicate: fn() -> bool, @@ -45,7 +44,8 @@ impl DynKernel { pub fn new( name: &str, kernel: Kernel, - default_packing_alignments: (usize, usize), + packing_a: PackedFormat, + packing_b: PackedFormat, ) -> Self { let kernel = DynKernel { name: name.to_string(), @@ -53,12 +53,9 @@ impl DynKernel { packings: vec![], stores: vec![Acc::datum_type()], supported_predicate: || true, - default_packing_alignments, can_fuse: |_| true, }; - let a = kernel.regular_pack_a(); - let b = kernel.regular_pack_b(); - kernel.with_packing(a, b) + kernel.with_packing(packing_a, packing_b) } pub fn with_platform_condition(mut self, f: fn() -> bool) -> Self { @@ -77,11 +74,11 @@ impl DynKernel { } pub fn regular_pack_a(&self) -> PackedFormat { - PackedFormat::new(Acc::datum_type(), MR, self.default_packing_alignments.0) + *self.packings[0].0.clone().downcast::().unwrap() } pub fn regular_pack_b(&self) -> PackedFormat { - PackedFormat::new(Acc::datum_type(), NR, self.default_packing_alignments.1) + *self.packings[0].1.clone().downcast::().unwrap() } pub fn with_can_fuse(self, can_fuse: fn(&FusedSpec) -> bool) -> Self { diff --git a/linalg/src/frame/mmm/kit.rs b/linalg/src/frame/mmm/kit.rs index 3739cf7943..9c92566ce8 100644 --- a/linalg/src/frame/mmm/kit.rs +++ b/linalg/src/frame/mmm/kit.rs @@ -187,7 +187,7 @@ impl MMMKit { .1 .downcast_ref::() .is_some_and(|pf| KitDatumType::from(pf.dt) == self.activation), - "Activation packecd mismatch {self:?} {:?}", + "Activation packed dt mismatch {self:?} {:?}", mmm.packings()[packing].1 ); self.items.push(MMMKitItem { mmm, packing, weight_panel_extractor }); diff --git a/linalg/src/frame/mmm/macros.rs b/linalg/src/frame/mmm/macros.rs index 3e47ab1135..8714bed583 100644 --- a/linalg/src/frame/mmm/macros.rs +++ b/linalg/src/frame/mmm/macros.rs @@ -1,6 +1,7 @@ macro_rules! MMMExternKernel { ( - $func:ident<$ti:ident>($mr: expr, $nr: expr)@($align_a:expr, $align_b:expr) + $func:ident<$ti:ident>($mr: expr, $nr: expr) + $(@($align_a:expr, $align_b:expr))? $(where($where:expr))? $(can_fuse($can_fuse:expr))? $(packing[$pnum:literal] = $pid:ident => $packing:expr;)* @@ -21,7 +22,8 @@ macro_rules! MMMExternKernel { } } - MMMKernel!([]::rusty as $func<$ti>($mr, $nr)@($align_a, $align_b) + MMMKernel!([]::rusty as $func<$ti>($mr, $nr) + $(@($align_a, $align_b))? $(where($where))? $(can_fuse($can_fuse))? $(packing[$pnum] = $pid => $packing;)* @@ -32,7 +34,8 @@ macro_rules! MMMExternKernel { } macro_rules! MMMRustKernel { ( $func: path => - $id:ident<$ti:ident>($mr: expr, $nr: expr)@($align_a:expr, $align_b:expr) + $id:ident<$ti:ident>($mr: expr, $nr: expr) + $(@($align_a:expr, $align_b:expr))? $(where($where:expr))? $(can_fuse($can_fuse:expr))? $(packing[$pnum:literal] = $pid:ident => $packing:expr;)* @@ -49,7 +52,8 @@ macro_rules! MMMRustKernel { $func(op.as_ptr()) } } - MMMKernel!([]::rusty as $id<$ti>($mr, $nr)@($align_a, $align_b) + MMMKernel!([]::rusty as $id<$ti>($mr, $nr) + $(@($align_a, $align_b))? $(where($where))? $(can_fuse($can_fuse))? $(packing[$pnum] = $pid => $packing;)* @@ -62,7 +66,8 @@ macro_rules! MMMRustKernel { macro_rules! MMMKernel { ( $func: path as - $id:ident<$ti:ident>($mr: expr, $nr: expr)@($align_a:expr, $align_b:expr) + $id:ident<$ti:ident>($mr: expr, $nr: expr) + $(@($align_a:expr, $align_b:expr))? $(where($where:expr))? $(can_fuse($can_fuse:expr))? $(packing[$pnum:literal] = $pid:ident => $packing:expr;)* @@ -75,8 +80,15 @@ macro_rules! MMMKernel { use $crate::mmm::DynKernel; #[allow(unused_imports)] use tract_data::prelude::*; + use $crate::frame::mmm::Packing; #[allow(unused_mut)] - let mut k = DynKernel::<$mr, $nr, $ti>::new(stringify!($id), $func, ($align_a, $align_b)); + let (mut packing_a, mut packing_b) = ($ti::packing($mr), $ti::packing($nr)); + $( + packing_a = packing_a.align($align_a); + packing_b = packing_b.align($align_b); + )? + #[allow(unused_mut)] + let mut k = DynKernel::<$mr, $nr, $ti>::new(stringify!($id), $func, packing_a, packing_b); $(k = k.with_platform_condition($where);)? $( assert!(k.packings.len() == $pnum); @@ -102,4 +114,3 @@ macro_rules! MMMKernel { } }; } - diff --git a/linalg/src/frame/mmm/mod.rs b/linalg/src/frame/mmm/mod.rs index 9c63aa915a..797dbfda1e 100644 --- a/linalg/src/frame/mmm/mod.rs +++ b/linalg/src/frame/mmm/mod.rs @@ -30,6 +30,8 @@ pub use kit::*; pub use scratch::*; pub use storage::*; +pub use pack::Packing; + pub fn no_prefetch(_ptr: *const u8, _len: usize) {} pub trait MatMatMul: Debug + dyn_clone::DynClone + Send + Sync + std::any::Any { diff --git a/linalg/src/frame/mmm/pack.rs b/linalg/src/frame/mmm/pack.rs index 770bf38b58..73c59b34e9 100644 --- a/linalg/src/frame/mmm/pack.rs +++ b/linalg/src/frame/mmm/pack.rs @@ -12,7 +12,7 @@ use super::MMMInputFormat; pub struct PackedFormat { pub dt: DatumType, pub r: usize, - pub alignment: usize, + pub alignment_bytes: usize, pub end_padding_record: usize, } @@ -47,22 +47,31 @@ impl Display for PackedFormat { impl Debug for PackedFormat { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - ::fmt(self, f) + write!( + f, + "Packed{:?}[{}]@{}+{}", + self.dt, self.r, self.alignment_bytes, self.end_padding_record + ) } } impl PackedFormat { - pub const fn new(dt: DatumType, nr: usize, alignment: usize) -> PackedFormat { - PackedFormat { dt, r: nr, alignment, end_padding_record: 1 } + pub const fn new(dt: DatumType, nr: usize, alignment_bytes: usize) -> PackedFormat { + PackedFormat { dt, r: nr, alignment_bytes, end_padding_record: 1 } } pub const fn with_end_padding_record(self, end_padding_record: usize) -> Self { PackedFormat { end_padding_record, ..self } } + #[inline] + pub fn align(self, alignment: usize) -> Self { + Self { alignment_bytes: alignment, ..self } + } + #[inline] pub fn alignment(&self) -> usize { - self.alignment + self.alignment_bytes } #[inline] @@ -101,8 +110,10 @@ impl PackedFormat { let panel_bytes = panel_len * t.datum_type().size_of(); let strides = t.strides(); unsafe { - let mut packed = - Blob::new_for_size_and_align(t.datum_type().size_of() * packed_len, self.alignment); + let mut packed = Blob::new_for_size_and_align( + t.datum_type().size_of() * packed_len, + self.alignment_bytes, + ); if cfg!(debug_assertions) { packed.as_bytes_mut().fill(0u8); } @@ -140,8 +151,10 @@ impl PackedFormat { let panel_bytes = panel_len * t.datum_type().size_of(); let strides = t.strides(); unsafe { - let mut packed = - Blob::new_for_size_and_align(t.datum_type().size_of() * packed_len, self.alignment); + let mut packed = Blob::new_for_size_and_align( + t.datum_type().size_of() * packed_len, + self.alignment_bytes, + ); if cfg!(debug_assertions) { packed.as_bytes_mut().fill(0u8); } @@ -509,6 +522,16 @@ unsafe fn pack_mn_major( } } +pub trait Packing { + fn packing(r: usize) -> PackedFormat; +} + +impl Packing for D { + fn packing(r: usize) -> PackedFormat { + PackedFormat::new(Self::datum_type(), r, Self::datum_type().alignment()) + } +} + #[cfg(test)] mod test { use std::ops::Range; diff --git a/linalg/src/generic/mmm.rs b/linalg/src/generic/mmm.rs index cc984b2d68..e14a995a03 100644 --- a/linalg/src/generic/mmm.rs +++ b/linalg/src/generic/mmm.rs @@ -1,8 +1,8 @@ #![allow(clippy::needless_range_loop)] use num_traits::AsPrimitive; -use pack::PackedFormat; use tract_data::prelude::f16; +use tract_data::prelude::DatumType::*; use tract_data::prelude::*; use super::*; @@ -280,14 +280,20 @@ where FusedKerSpec::AddMatMul { k, pa, pb, packing } => { use std::mem::transmute; if TI::datum_type().is_float() { - if packing == 0 { - add_mat_mul::(pa, pb, k, &mut ab); - } else if packing == 1 { - add_mat_mul_pq40::(pa, pb, k, &mut ab); - } else if packing == 2 { - add_mat_mul_pq40_scales_at_end::(pa, pb, k, &mut ab) - } else if packing == 3 { - add_mat_mul_pq40::(pa, pb, k, &mut ab); + match packing { + 0 => add_mat_mul::(pa, pb, k, &mut ab), + 1 if TI::is::() => { + add_mat_mul::(pa, pb, k, &mut ab) + } + 1 if TI::is::() => { + add_mat_mul::(pa, pb, k, &mut ab) + } + 2 => add_mat_mul_pq40::(pa, pb, k, &mut ab), + 3 => add_mat_mul_pq40_scales_at_end::( + pa, pb, k, &mut ab, + ), + 4 => add_mat_mul_pq40::(pa, pb, k, &mut ab), + _ => unreachable!(), } } else if TI::datum_type() == i32::datum_type() { // transmute to allow using explicitly i3 in add_mat_mul generic params @@ -331,58 +337,79 @@ where fn pq40_r4() -> PackedBlockQuantFormat { PackedBlockQuantFormat::new(&Q4_0, 4, 0, false) } + fn pq40_r4_se() -> PackedBlockQuantFormat { PackedBlockQuantFormat::new(&Q4_0, 4, 0, true) } // f16 kernels -MMMRustKernel!(kernel:: => generic_f16_4x4(4,4)@(4,4) store(f32, f64)); -MMMRustKernel! {kernel:: => generic_f16_4x1(4,1)@(4,1) - packing[1] = q40f16 => |k| k.with_packing_a(pq40_r4()); - packing[2] = q40f16se => |k| k.with_packing_a(pq40_r4_se()); - packing[3] = q40f32 => |k| k.with_packing(pq40_r4(), PackedFormat::new(DatumType::F32, 1, 4)); +MMMRustKernel!(kernel:: => generic_f16_4x4(4,4) store(f32, f64)); +MMMRustKernel! {kernel:: => generic_f16_4x1(4,1) + packing[1] = f32f32 => |k| k.with_packing(f32::packing(4), f32::packing(1)); + packing[2] = q40f16 => |k| k.with_packing_a(pq40_r4()); + packing[3] = q40f16se => |k| k.with_packing_a(pq40_r4_se()); + packing[4] = q40f32 => |k| k.with_packing(pq40_r4(), f32::packing(1)); store(f32, f64) } // f32 kernels -MMMRustKernel!(kernel:: => generic_f32_4x4(4,4)@(4,4) store(f16, f64)); -MMMRustKernel! {kernel:: => generic_f32_4x1(4,1)@(4,1) - packing[1] = q40f16 => |k| k.with_packing(pq40_r4(), PackedFormat::new(DatumType::F16, 1, 4)); - packing[2] = q40f16se => |k| k.with_packing(pq40_r4_se(), PackedFormat::new(DatumType::F16, 1, 4)); - packing[3] = q40f32 => |k| k.with_packing_a(pq40_r4()); +MMMRustKernel!(kernel:: => generic_f32_4x4(4,4) + packing[1] = f16f16 => |k| k.with_packing(f16::packing(4), f16::packing(4)); + store(f16, f64) +); +MMMRustKernel! {kernel:: => generic_f32_4x1(4,1) + packing[1] = f16f16 => |k| k.with_packing(f16::packing(4), f16::packing(1)); + packing[2] = q40f16 => |k| k.with_packing(pq40_r4(), f16::packing(1)); + packing[3] = q40f16se => |k| k.with_packing(pq40_r4_se(), f16::packing(1)); + packing[4] = q40f32 => |k| k.with_packing_a(pq40_r4()); store(f16, f64) } // f64 kernels -MMMRustKernel!(kernel:: => generic_f64_4x4(4,4)@(4,4) store(f16, f32)); -MMMRustKernel!(kernel:: => generic_f64_4x1(4,1)@(4,1) store(f16, f32)); +MMMRustKernel!(kernel:: => generic_f64_4x4(4,4) store(f16, f32)); +MMMRustKernel!(kernel:: => generic_f64_4x1(4,1) store(f16, f32)); // I32 kernels -MMMRustKernel! {kernel:: => generic_i32_4x4(4,4)@(4,4) - packing[1] = i8i8 => |k| k.with_packing(PackedFormat::new(DatumType::I8, 4, 4), PackedFormat::new(DatumType::I8, 4, 4)); +MMMRustKernel! {kernel:: => generic_i32_4x4(4,4) + packing[1] = i8i8 => |k| k.with_packing(i8::packing(4), i8::packing(4)); store(i8) } -MMMRustKernel! {kernel:: => generic_i32_4x1(4,1)@(4,4) - packing[1] = i8i8 => |k| k.with_packing(PackedFormat::new(DatumType::I8, 4, 4), PackedFormat::new(DatumType::I8, 1, 4)); +MMMRustKernel! {kernel:: => generic_i32_4x1(4,1) + packing[1] = i8i8 => |k| k.with_packing(i8::packing(4), i8::packing(1)); store(i8) } // extra tests kernels - #[cfg(test)] -MMMRustKernel!(kernel:: => generic_f32_3x2(3,2)@(4,4) store(f16, f64)); +MMMRustKernel!(kernel:: => generic_f32_3x2(3,2) store(f16, f64)); #[cfg(test)] -MMMRustKernel! {kernel:: => generic_i32_3x2(3,2)@(4,4) - packing[1] = i8i8 => |k| k.with_packing(PackedFormat::new(DatumType::I8, 3, 4), PackedFormat::new(DatumType::I8, 2, 4)); +MMMRustKernel! {kernel:: => generic_i32_3x2(3,2) + packing[1] = i8i8 => |k| k.with_packing(i8::packing(3), i8::packing(2)); store(i8) } pub fn plug(ops: &mut Ops) { ops.mmm_kits.push( - MMMKit::new(Q4_0, f32::datum_type(), f32::datum_type(), &pq40_r4()) - .with_native(generic_f32_4x1.mmm(), 3) + MMMKit::new(Q4_0, F32, F32, &pq40_r4()) + .with_native(generic_f32_4x1.mmm(), 4) .with_generic_fallback(true), ); + ops.mmm_kits.push( + MMMKit::new(F16, F32, F16, &f16::packing(4)) + .with_native(generic_f32_4x1.mmm(), 1) + .with_native(generic_f32_4x4.mmm(), 1) + .with_generic_fallback(true), + ); +} + +#[cfg(test)] +mod test { + + #[test] + fn kits() { + let mut ops = crate::generic(); + super::plug(&mut ops); + } } diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index 89fc56f41b..6de89bb2f1 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -27,7 +27,7 @@ MMMExternKernel! {fma_mmm_f32_32x1(32,1)@(32,4) where(FMA) store(f16) } MMMExternKernel!(fma_mmm_f32_32x3(32,3)@(32,4) where(FMA) - packing[1] = f32f16 => |k| k.with_packing(PackedFormat::new(F32, 32, 32), PackedFormat::new(F16, 3, 2)); + packing[1] = f32f16 => |k| k.with_packing(f32::packing(32).align(32), f16::packing(3)); store(f16) ); diff --git a/linalg/src/x86_64_fma/panel_extract.rs b/linalg/src/x86_64_fma/panel_extract.rs index be976844c8..ca91059050 100644 --- a/linalg/src/x86_64_fma/panel_extract.rs +++ b/linalg/src/x86_64_fma/panel_extract.rs @@ -1,5 +1,6 @@ use super::*; use crate::frame::PackedFormat; +use crate::frame::mmm::Packing; use crate::Ops; use tract_data::internal::*; @@ -9,12 +10,12 @@ pub fn plug(ops: &mut Ops) { panel_extractor!(kernel_packed_32_q40_to_f32 as packed_32_q40_to_f32( Box::new(super::mmm::pq40_r32()), - PackedFormat::new(f32::datum_type(), 32, 32) + f32::packing(32).align(32) ) where(AVX2)); panel_extractor!(kernel_packed_32_f16_to_f32 as packed_32_f16_to_f32( Box::new(PackedFormat::new(f16::datum_type(), 32, 32)), - PackedFormat::new(f32::datum_type(), 32, 32) + f32::packing(32).align(32) ) where(AVX2)); #[target_feature(enable = "avx2")]