-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avx512 perf improvements #996
base: main
Are you sure you want to change the base?
Conversation
Awesome! Yes; I'm quite busy rest of the day but I can look tomorrow! I haven't had time to test the AVX512 changes yet; so I'll try to do base->AVX512->this deltas! |
FYI started looking at benchmarks but since we hadn't bumped from version 0.17 I've got a bit of work to do. Symbolication/batch dimensions have changed a bit so it's crashing in concretization. I'll do a review after lunch at least. In case you have a thought @kali this is the issue:
Happens here on this upgrade PR: https://github.com/EmbarkStudios/cervo/pull/44/files#diff-4a7fb40e77a22ccba64ba761a0c31ab388127a6309b79e1c7832602c1755d3dcR39. If you want to try the code,
Will try batch-sizes 1..24 of all the various batching mechanisms cervo has. |
It looks to me like a typical case of overspecification of inputs and output in ONNX. (Prior to 0.19, tract was ignoring them, this was a bug. Now it's trying to make sense of them.) Try to cancel the the output_shapes: |
OK; now I've managed to get this to run. I've compared my current production (0.17) to 0.19.21, main, and this branch. You can see this below. This is a conv-stack of 32x32 followed an MLP [1024,1024]. This branch is definitely an improvement. But there's some weird ones compared to avx256 like getting much worse in the 6-wide case and then much better in the 7-wide case? Maybe I'll see the reason why in the code. Starting review now! Footnotes
|
ah, that's kinda bad |
Just ignore the pre-version tag on main. main will become 0.20. There is a maintenance branch for 0.19 called 0.19.x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! This looks so much cleaner, and only a tiny bit more complex.
As noted, I'm a bit saddened by the performance. I've pointed out a few concerns, but I'm really not seeing anything super-obvious that would lead to this kind of slowdown. I'll see if I can repro your measurements with the script too :-)
@@ -96,14 +100,71 @@ fn plug_fma(ops: &mut Ops) { | |||
fn plug_avx512f(ops: &mut Ops) { | |||
ops.mmv_f32 = Box::new(|m, _k| match m { | |||
Some(m) if m < 31 => mmm::avx512_mmm_f32_16x1::mmm(), | |||
_ => mmm::avx512_mmm_f32_128x1::mmm(), | |||
_ => mmm::avx512_mmm_f32_96x1::mmm(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why 96x1 over 128x1? :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I see it looks better in the table now; which I guess makes sense why this is like it is. It's unexpected to me.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For MMV, we're hitting very low throughput all across the board, regardless of M
My thinking was that according to my benchmarks, lowering 128 to 96 does not cause harm, and it could help with border kernels on matrices that are not multiples of 128
{% for cur_unroll_count in (0..unroll_min_1) %} | ||
|
||
{% for i in (0..prefetches_to_issue_min_1) %} | ||
prefetcht0 [rax + {{i | times:64}} + {{m_total_bytes | times:prefetch_dist}} + {{cur_unroll_count | times:m_total_bytes}}] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you measured that this is worth it? My tests on AVX256 saw perf losses from any type of prefetching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't mesured it, but it did not seem to cause harm according to vtune, and onnxruntime also does it
I'll measure it
{% endfor %} | ||
|
||
{% for i in (0..nr_min_1) %} | ||
vbroadcastss zmm{{col_reg}}, dword ptr [rcx + {{i | times:4}} + {{cur_unroll_count | times:n_total_bytes}}] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tested using two alternating column regs? It'd maybe alleviate some register pressure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try, I thought this couldnt be an issue because of register renaming
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't, but you never know. :-)
|
||
{% for row in (0..mr_arch_min_1) %} | ||
kxnorw k1,k1,k1 // set writemask to ones | ||
vscatterdps [r9 + zmm31]{k1}, zmm{{col | times:mr_arch | plus:row}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm eyeing this line as a potential slowdown. vscatterdps is hugely expensive. Each row here is 43 uops!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Woah, I did not know about that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's only 36 on Cannon Lake and Ice Lake, 43 on SKX; fwiw. https://www.agner.org/optimize/instruction_tables.pdf
Out of curiosity, have you ran this under anything like VTune or Advisor? I'm curious what they'd say and identify as concerns. |
Yes, vtune. Obviously there is something I'm missing here though, so i'll get back on this. Do you have a script so that I can try and reproduce your results? |
Clone cervo on the Then you can run it like this:
Once you've tested a bunch of PRs you can do $ python3 ../python/compare_batchsize.py \
some-path/first.csv,some-path/second.csv,... \
test1label,test2label,... \
10000 \
some-path/output.png The script should handle any number of comparisons, they're all relative to the first file.
|
Okay, I did not have that much time to look at this today, but I managed to replicate the regression locally, and looking at vtune, apparently i'm spending wayy too much time on the MMV kernels it looks like? I was very worried that this was something I wouldnt be able to replicate, seeing that your kernel times are so different. Fortunately it seems that's not the case I'll look further tomorrow, but something is fishy here and it's my fault :) |
How easy would it be to remove loop unrolling? We could be cramping the instruction decoder/cache. |
Very easy, it's just setting the loop unroll variable to 1 and commenting the jumps |
Hey folks, I'm planning on cutting 0.20.x in a week or so. Do we have a path towards making these optimisations a part of it ? |
This PR brings a few things to the avx512 linalg kernels added in #989:
(0..10000000).collect::<Vec<_>>()
gets optimized out completely by llvm now, meaning all my cold cache results were wrong :)Kernel results
Results are in Gelem/s. from a cold start by ruining the cache beforehand and M=1024, K=1000
Intel(R) Xeon(R) Gold 6334 CPU @ 3.60GHz
future work
I think i'm pretty much done on the asm kernel part, I think I won't be able to squeeze out any more perf there
On a slightly higher level, border tile handling is still suboptimal: we can still improve the perf when N is low.
As you can see from the graph:
Gelem/s
Between N=14 and N=15, we have a big drop from 95 Gelem/s with the 32x14 kernel to 63 Gelem/s with the 23x8 kernel
This is because with the x14 kernel, we compute 13 useless elements when N=15 from the border tile of the C matrix.
This gets better as N gets bigger, as the waste ratio is lower.
Gelem/s
On a higher level,
@kali is working on collapsing consecutive dimensions in input matrices in matmuls.
I think this may solve why I'm still getting a run time of 6s instead of 300ms on my transformer models, so I don't think I'll start bothering with the border tiles just yet
Other than that, I don't think we're that far from being on par with onnxruntime perf! :)