-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement 2d tiled matmulnbits specialized for prefill (#23058)
### Description This change implements matmul4bits with tiling both for A and B. This is beneficial for prefill scenarios on Intel integrated GPUs, because each row of A has to run through the same set of shared rows of B. This change should improve core occupancy and model_benchmark does indicate improvements for prefill. The same shader is not used for generation because when A has just a single row, the other threads in the workgroup get unused and that hurts performance. ``` -- Baseline run on an Alderlake GPU -- C:\onnxruntime>C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500 Batch size: 1, prompt tokens: 501, tokens to generate: 128 Prompt processing (time to first token): avg (us): 1.72338e+07 avg (tokens/s): 29.0707 << p50 (us): 1.72548e+07 stddev (us): 57012.8 n: 5 * 501 token(s) Token generation: avg (us): 79227.5 avg (tokens/s): 12.6219 p50 (us): 79284.4 stddev (us): 2109.72 n: 635 * 1 token(s) Token sampling: avg (us): 15.8198 avg (tokens/s): 63211.8 p50 (us): 14.3 stddev (us): 8.67178 n: 640 * 1 token(s) E2E generation (entire generation loop): avg (ms): 27297.8 p50 (ms): 27269.8 stddev (ms): 89.4322 n: 5 Peak working set size (bytes): 5490987008 WebGPU device lost (2): Device was destroyed. ----------------------------------- With Prefill Optimization ---- C:\onnxruntime>C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500 Batch size: 1, prompt tokens: 501, tokens to generate: 128 Prompt processing (time to first token): avg (us): 1.2135e+07 avg (tokens/s): 41.2856 << p50 (us): 1.21288e+07 stddev (us): 21282.1 n: 5 * 501 token(s) Token generation: avg (us): 78945.3 avg (tokens/s): 12.667 p50 (us): 78900.7 stddev (us): 2232.43 n: 635 * 1 token(s) Token sampling: avg (us): 20.5608 avg (tokens/s): 48636.3 p50 (us): 18.7 stddev (us): 19.0409 n: 640 * 1 token(s) E2E generation (entire generation loop): avg (ms): 22163.8 p50 (ms): 22160.1 stddev (ms): 31.3122 n: 5 Peak working set size (bytes): 5478862848 WebGPU device lost (2): Device was destroyed. ```
- Loading branch information
1 parent
d8de3c4
commit 8800830
Showing
2 changed files
with
186 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters