Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avx512 perf improvements #996

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions linalg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ core_affinity.workspace = true
no_fp16 = []
default = []
complex = [ "tract-data/complex" ]
# Internal feature for benchmarking matmul kernels
compile_all_kernels = []

[[bench]]
bench = false
Expand Down Expand Up @@ -99,3 +101,8 @@ harness = false
[[bench]]
name = "intel"
harness = false

[[bench]]
bench = false
name = "kernel_test"
harness = false
4 changes: 2 additions & 2 deletions linalg/benches/intel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tract_linalg::mmm::OutputStoreKer;

fn ruin_cache() {
// return;
let _a = (0..1000000).collect::<Vec<i32>>();
let _a = std::hint::black_box((0..10000000).collect::<Vec<i32>>());
}

pub fn reference<T, K>(mr: usize, k: usize, nr: usize) -> Vec<f32>
Expand Down Expand Up @@ -63,7 +63,7 @@ fn bench_to_nanos<
FusedSpec::AddMatMul {
k,
a: kernel.a_packed(4, k).wrap(&a.view()),
b: kernel.b_packed(4, k).wrap(&b.view()).unwrap(),
b: kernel.b_packed(4, k).wrap(&b.view()),
},
// FusedSpec::AddUnicast(kernel.c_view(1, 0).wrap(&c.view_mut())),
FusedSpec::Store(kernel.c_view(1, 0).wrap(&c.view_mut())),
Expand Down
84 changes: 84 additions & 0 deletions linalg/benches/kernel_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use criterion::*;

mod utils;
use tract_data::prelude::DatumType;
use tract_linalg::mmm::MatMatMul;
use tract_linalg::mmm::MatMatMulKer;
use utils::*;

pub fn mat_mat_mm(
be: &mut Bencher,
&(mm, dt, m, k, n, cold): &(&dyn MatMatMul, DatumType, usize, usize, usize, bool),
) {
mat_mat_with_mm(be, mm, &(dt, m, k, n, cold));
}

fn cold_and_hot(c: &mut Criterion, mm: &dyn MatMatMul, m: usize, k: usize, n: usize) {
let mut group = c.benchmark_group(format!("{}", mm.kernel_name()));
group.throughput(Throughput::Elements((m * k * n) as u64));
let id = format!("{m}x{k}x{n}");
group.bench_with_input(
BenchmarkId::new("f32/cold", &id),
&(mm, DatumType::F32, m, k, n, false),
mat_mat_mm,
);
// group.bench_with_input(
// BenchmarkId::new("f32/hot", &id),
// &(mm, DatumType::F32, m, k, n, true),
// mat_mat_mm,
// );
}

fn mm(be: &mut Criterion, mm: impl AsRef<dyn MatMatMul>, n: usize) {
// for m in (0..1024).step_by(128).skip(1) {
cold_and_hot(be, mm.as_ref(), 1024, 1000, n);
// }
}

fn all(c: &mut Criterion) {
use tract_linalg::x86_64_fma::mmm::*;
macro_rules! benches_for_n {
($c:expr ; $n:expr ; $m:expr) => (
paste::paste! {
mm($c, [<avx512_mmm_f32_ $m x $n>]::mmm(), $n);
}
);
($c:expr ; $x:expr ; $m1:expr, $($y:expr),+) => (
benches_for_n!($c ; $x ; $m1);
benches_for_n!($c ; $x ; $($y),+);
);
}

benches_for_n!(c; 1 ; 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240);
benches_for_n!(c; 2 ; 16, 32, 48, 64, 80, 96, 112, 128, 144, 160);
benches_for_n!(c; 3 ; 16, 32, 48, 64, 80, 96, 112);
benches_for_n!(c; 4 ; 16, 32, 48, 64, 80, 96);
benches_for_n!(c; 5 ; 16, 32, 48, 64, 80);
benches_for_n!(c; 6 ; 16, 32, 48, 64);
benches_for_n!(c; 7 ; 16, 32, 48);
benches_for_n!(c; 8 ; 16, 32, 48);
benches_for_n!(c; 9 ; 16, 32, 48);
benches_for_n!(c; 10 ; 16, 32);
benches_for_n!(c; 11 ; 16, 32);
benches_for_n!(c; 12 ; 16, 32);
benches_for_n!(c; 13 ; 16, 32);
benches_for_n!(c; 14 ; 16, 32);
benches_for_n!(c; 15 ; 16);
benches_for_n!(c; 16 ; 16);
benches_for_n!(c; 17 ; 16);
benches_for_n!(c; 18 ; 16);
benches_for_n!(c; 19 ; 16);
benches_for_n!(c; 20 ; 16);
benches_for_n!(c; 21 ; 16);
benches_for_n!(c; 22 ; 16);
benches_for_n!(c; 23 ; 16);
benches_for_n!(c; 24 ; 16);
benches_for_n!(c; 25 ; 16);
benches_for_n!(c; 26 ; 16);
benches_for_n!(c; 27 ; 16);
benches_for_n!(c; 28 ; 16);
benches_for_n!(c; 29 ; 16);
}

criterion_group!(benches, all);
criterion_main!(benches);
3 changes: 2 additions & 1 deletion linalg/benches/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ pub fn packed_vec(c: &mut Criterion, name: &str, m: usize, k: usize, n: usize) {
}

pub fn ruin_cache() {
let _a = (0..1000000).collect::<Vec<i32>>();
// the collect gets optimized out by llvm without black_box
let _a = std::hint::black_box((0..10000000).collect::<Vec<i32>>());
}

#[allow(clippy::too_many_arguments)]
Expand Down
82 changes: 74 additions & 8 deletions linalg/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ impl ConfigForHalf {
}
}

struct GenerateKernelsSpec {
sizes: Vec<(usize, usize)>,
file: path::PathBuf,
}

fn main() {
let target = var("TARGET");
let arch = var("CARGO_CFG_TARGET_ARCH");
Expand All @@ -83,8 +88,51 @@ fn main() {

match arch.as_ref() {
"x86_64" => {
let mut files = preprocess_files("x86_64/fma", &[], &suffix, false);
files.extend(preprocess_files("x86_64/avx512", &[], &suffix, false));
let mut files = preprocess_files("x86_64/fma", &[], &suffix, false, None);

let avx512_kernels: Vec<_> = if cfg!(feature = "compile_all_kernels") {
// limits of the max M size of the kernels in avx512; index is n-1
let avx512_kernels_max = [
240, 160, 112, 96, 80, 64, 48, 48, 48, 32, 32, 32, 32, 32, 16, 16, 16, 16, 16,
16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
];
avx512_kernels_max
.iter()
.enumerate()
.flat_map(|(n_min_1, &max)| {
(16..=max).step_by(16).map(move |m| (m, n_min_1 + 1))
})
.collect()
} else {
vec![
(16, 1),
(96, 1),
(96, 2),
(80, 3),
(64, 4),
(32, 5),
(32, 6),
(32, 7),
(32, 8),
(32, 9),
(32, 10),
(32, 11),
(32, 12),
(32, 13),
(32, 14),
]
};

files.extend(preprocess_files(
"x86_64/avx512",
&[],
&suffix,
false,
Some(GenerateKernelsSpec {
sizes: avx512_kernels,
file: "x86_64/avx512/avx512_mmm_f32.tmpliq".into(),
}),
));

if os == "windows" {
if use_masm() {
Expand Down Expand Up @@ -136,7 +184,7 @@ fn main() {
}
}
"arm" | "armv7" => {
let files = preprocess_files("arm32/armvfpv2", &[], &suffix, false);
let files = preprocess_files("arm32/armvfpv2", &[], &suffix, false, None);
cc::Build::new()
.files(files)
.flag("-marm")
Expand All @@ -148,6 +196,7 @@ fn main() {
&[("core", vec!["cortexa7", "cortexa9", "generic"])],
&suffix,
false,
None,
);
cc::Build::new()
.files(files)
Expand All @@ -162,11 +211,12 @@ fn main() {
&[("core", vec!["a53", "a55", "gen"])],
&suffix,
false,
None,
);
cc::Build::new().files(files).static_flag(true).compile("arm64simd");
if os == "macos" {
// aarch64 darwin => M1
let files = preprocess_files("arm64/apple_amx", &[], &suffix, false);
let files = preprocess_files("arm64/apple_amx", &[], &suffix, false, None);
cc::Build::new().files(files).static_flag(true).compile("appleamx");
}
if std::env::var("CARGO_FEATURE_NO_FP16").is_err() {
Expand All @@ -177,6 +227,7 @@ fn main() {
&[("core", vec!["a55", "gen"])],
&suffix,
config.needs_pragma,
None,
);
config.cc().files(files).static_flag(true).compile("arm64fp16")
}
Expand All @@ -192,9 +243,24 @@ fn preprocess_files(
variants: &[Variant],
suffix: &str,
needs_pragma: bool,
generate_kernels_spec: Option<GenerateKernelsSpec>,
) -> Vec<path::PathBuf> {
let out_dir = path::PathBuf::from(var("OUT_DIR"));
let mut files = vec![];

if let Some(spec) = generate_kernels_spec {
let tmpl_file = spec.file.file_stem().unwrap().to_str().unwrap();
for (m, n) in spec.sizes {
let globals = vec![
("mr", liquid::model::Value::scalar(format!("{m}"))),
("nr", liquid::model::Value::scalar(format!("{n}"))),
];
let file = out_dir.join(format!("{tmpl_file}_{m}x{n}.S"));
println!("{}", file.display());
preprocess_file(&spec.file, &file, &globals, suffix, needs_pragma);
files.push(file);
}
}
let dir_entries = {
let mut dir_entries: Vec<fs::DirEntry> =
input.as_ref().read_dir().unwrap().map(|f| f.unwrap()).collect();
Expand All @@ -214,7 +280,7 @@ fn preprocess_files(
for variable in variants {
let key = variable.0;
let value = variable.1[id % variable.1.len()];
globals.push((key, value));
globals.push((key, liquid::model::Value::scalar(value)));
tmpl_file = tmpl_file.replace(key, value);
id /= variable.1.len();
}
Expand All @@ -239,7 +305,7 @@ fn strip_comments(s: String, msvc: bool) -> String {
fn preprocess_file(
template: impl AsRef<path::Path>,
output: impl AsRef<path::Path>,
variants: &[(&'static str, &'static str)],
added_globals: &[(&'static str, liquid::model::Value)],
suffix: &str,
needs_pragma: bool,
) {
Expand Down Expand Up @@ -277,8 +343,8 @@ fn preprocess_file(
"jump_table": jump_table(),
"align": align,
});
for (k, v) in variants {
globals.insert(k.to_string().into(), liquid::model::Value::scalar(*v));
for (k, v) in added_globals {
globals.insert(k.to_string().into(), v.clone());
}
let partials = load_partials(template.as_ref().parent().unwrap(), msvc);
let mut parser = liquid::ParserBuilder::with_stdlib()
Expand Down
73 changes: 67 additions & 6 deletions linalg/src/x86_64_fma.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use std::cmp::Ordering;

use tract_data::internal::num_integer::Integer;

use crate::frame::element_wise::ElementWiseKer;
use crate::frame::mmm::kernel::MatMatMulKer;
use crate::Ops;
Expand Down Expand Up @@ -96,14 +100,71 @@ fn plug_fma(ops: &mut Ops) {
fn plug_avx512f(ops: &mut Ops) {
ops.mmv_f32 = Box::new(|m, _k| match m {
Some(m) if m < 31 => mmm::avx512_mmm_f32_16x1::mmm(),
_ => mmm::avx512_mmm_f32_128x1::mmm(),
_ => mmm::avx512_mmm_f32_96x1::mmm(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 96x1 over 128x1? :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I see it looks better in the table now; which I guess makes sense why this is like it is. It's unexpected to me.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For MMV, we're hitting very low throughput all across the board, regardless of M
My thinking was that according to my benchmarks, lowering 128 to 96 does not cause harm, and it could help with border kernels on matrices that are not multiples of 128

});

ops.mmm_f32 = Box::new(|_, _, n| match n {
Some(1) => unreachable!("should've been mmv"),
Some(2) => mmm::avx512_mmm_f32_80x2::mmm(),
Some(n) if n % 4 == 0 && n % 3 != 0 => mmm::avx512_mmm_f32_48x4::mmm(),
_ => mmm::avx512_mmm_f32_64x3::mmm(),
ops.mmm_f32 = Box::new(|_, _, n| {
if n.is_none() {
return mmm::avx512_mmm_f32_32x12::mmm();
}
let mut n = n.unwrap();

if n > 14 {
// throughputs are mesured using the kernel_throughput.py script
let scaling_baseline = 98.0;
let kernel_throughputs = [
(2, 18.0 / scaling_baseline),
(3, 28.0 / scaling_baseline),
(4, 36.5 / scaling_baseline),
(5, 44.0 / scaling_baseline),
(6, 49.0 / scaling_baseline),
(7, 58.0 / scaling_baseline),
(8, 65.0 / scaling_baseline),
(9, 72.5 / scaling_baseline),
(10, 82.0 / scaling_baseline),
(11, 84.0 / scaling_baseline),
(12, 88.5 / scaling_baseline),
(13, 95.0 / scaling_baseline),
(14, 98.0 / scaling_baseline),
];

let throughputs = kernel_throughputs.map(|(kernel_width, thrpt): (usize, f32)| {
let n_tiles = Integer::div_ceil(&n, &kernel_width);

let n_elem_total = n_tiles * kernel_width;
let n_elem_on_border_tile = n_elem_total - n;
let wasted_ratio = n_elem_on_border_tile as f32 / n_elem_total as f32;

let final_thrpt = thrpt * (1.0 - wasted_ratio);

(kernel_width, final_thrpt)
});

let best_ker = *throughputs
.iter()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
.map(|(ker_width, _)| ker_width)
.unwrap();

n = best_ker;
}

match n {
2 => mmm::avx512_mmm_f32_96x2::mmm(),
3 => mmm::avx512_mmm_f32_80x3::mmm(),
4 => mmm::avx512_mmm_f32_64x4::mmm(),
5 => mmm::avx512_mmm_f32_32x5::mmm(),
6 => mmm::avx512_mmm_f32_32x6::mmm(),
7 => mmm::avx512_mmm_f32_32x7::mmm(),
8 => mmm::avx512_mmm_f32_32x8::mmm(),
9 => mmm::avx512_mmm_f32_32x9::mmm(),
10 => mmm::avx512_mmm_f32_32x10::mmm(),
11 => mmm::avx512_mmm_f32_32x11::mmm(),
12 => mmm::avx512_mmm_f32_32x12::mmm(),
13 => mmm::avx512_mmm_f32_32x13::mmm(),
14 => mmm::avx512_mmm_f32_32x14::mmm(),
_ => unreachable!("not a valid index"),
}
});
log::info!("mmm_f32, mmv_f32: x86_64/avx512f activated");
}
Expand Down
Loading