Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mlas Gemm 4bit avx2, avx512, and avx512vnni kernels #20163

Merged
merged 67 commits into from
Apr 26, 2024
Merged

Conversation

liqunfu
Copy link
Contributor

@liqunfu liqunfu commented Mar 31, 2024

Description

Perf data from (21a892b)

Avx2:
Int8

               NS(P)	MLAS(P)    MLASGain/Loss(P)	NS(T)	MLAS(T)  MLASGain/Loss(T)
Blklen16:	90.96	96.10		5%		7.65	12.55		64%
Blklen32:	90.73	79.84		-12%		7.86	15.11		92%
Blklen64:	89.49	98.01		9%		8.30	16.04		93%
Blklen128:  	87.38	102.04		16%		7.90	16.21		105%
Blklen256:  	89.45	94.13	 	5%		8.30	16.60		100%

Fp32		
               NS(P)	MLAS(P)    MLASGain/Loss(P)	NS(T)	MLAS(T)  MLASGain/Loss(T)
Blklen16:	91.36	102.60		12%		7.57	9.07		19%
Blklen32:	89.30	83.08		-6%		7.65	10.27		34%
Blklen64:	89.53	102.24		14%		7.97	10.24		28%
Blklen128:	85.23	102.94		20%		7.86	10.41		32%
Blklen256:	88.46	102.62		16%		8.32	10.72		28%

Avx512vnni:
Int8		
               NS(P)	MLAS(P)    MLASGain/Loss(P)	NS(T)	MLAS(T)  MLASGain/Loss(T)
Blklen16:	132.18	105.47		-20%		10.34	13.09		26%
Blklen32:	168.28	106.43		-36%		11.85	16.35		37%
Blklen64:	201.81	104.47		-48%		12.36	17.48		41%
Blklen128:	194.92	104.69		-46%		13.03	17.35		33%
Blklen256:	218.76	112.06		-48%		13.33	16.99		27%

Fp32		
               NS(P)	MLAS(P)    MLASGain/Loss(P)	NS(T)	MLAS(T)  MLASGain/Loss(T)
Blklen16:	102.81	117.29		14%		8.41	12.17		44%
Blklen32:	109.49	112.87		3%		8.83	13.40		51%
Blklen64:	104.13	111.07		6%		9.32	11.17		19%
Blklen128:	108.45	113.08		4%		9.58	11.39		18%
Blklen256:	109.43	113.46		3%		9.19	11.97		30%

(followings are perf results from 1d88398. leave it here for reference. Mlas Prompt compute for Int8 has then been speed up by routing it to fp32. perf results shown above)

Avx2:
Int8

               NS(P)	MLAS(P)    MLASGain/Loss(P)	NS(T)	MLAS(T)  MLASGain/Loss(T)
Blklen16: 	90.96		25.15	-72%			7.65		11.71	53%
Blklen32:	90.73		48.55	-46%			7.86		14.28	81%
Blklen64:	89.49		68.84	-23%			8.30		15.78	90%
Blklen128:  	87.38		78.37	-10%			7.90		16.05	103%
Blklen256:  	89.45		82.36	 -7%			8.30		16.56	99%

Fp32		
               NS(P)	MLAS(P)    MLASGain/Loss(P)	NS(T)	MLAS(T)  MLASGain/Loss(T)
Blklen16:	91.36	105.18		15%		7.57	9.52		25%
Blklen32:	89.30	105.99		18%		7.65	9.68		26%
Blklen64:	89.53	101.41		13%		7.97	9.84		23%
Blklen128:	85.23	99.71		16%		7.86	10.39		32%
Blklen256:	88.46	97.94		10%		8.32	10.23		22%

Avx512vnni:
Int8		
               NS(P)	MLAS(P)    MLASGain/Loss(P)	NS(T)	MLAS(T)  MLASGain/Loss(T)
Blklen16:	132.18	21.56		-83%		10.34	11.48		11%
Blklen32:	168.28	43.69		-74%		11.85	14.73		24%
Blklen64:	201.81	60.29		-70%		12.36	15.47		25%
Blklen128:	194.92	57.04		-71%		13.03	14.67		12%
Blklen256:	218.76	70.20		-68%		13.33	16.31		22%

Fp32		
               NS(P)	MLAS(P)    MLASGain/Loss(P)	NS(T)	MLAS(T)  MLASGain/Loss(T)
Blklen16:	102.81	92.74		-9%		8.41	9.18		9%
Blklen32:	109.49	97.08		-11%		8.83	11.51		30%
Blklen64:	104.13	101.57		-2%		9.32	12.00		28%
Blklen128:	108.45	103.69		-4%		9.58	12.45		29%
Blklen256:	109.43	106.43		-2%		9.19	12.2		32%

@liqunfu liqunfu requested a review from a team as a code owner March 31, 2024 23:05
@liqunfu liqunfu marked this pull request as draft March 31, 2024 23:06
@liqunfu
Copy link
Contributor Author

liqunfu commented Apr 10, 2024

2 implementations:
USE_NCOLs = false: process one row from A and one column from B,
USE_NCOLs = true : process one row from A and NCols(4) column from B.

here are the benchmark run results:
start /B /HIGH onnxruntime_mlas_benchmark.exe --benchmark_filter="SQNBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time" --benchmark_repetitions=10

options:name mean_real mean_cpu
USE_NCOLs = false SQNBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time_mean 1857795610ns
USE_NCOLs = true SQNBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4 1731348910ns

start /B /HIGH onnxruntime_mlas_benchmark.exe --benchmark_filter="SQNBITGEMM<4>/BlkLen:128/M:1024/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time" --benchmark_repetitions=10

options:name mean_real mean_cpu
USE_NCOLs = false SQNBITGEMM<4>/BlkLen:128/M:1024/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4 614658820ns
USE_NCOLs = true SQNBITGEMM<4>/BlkLen:128/M:1024/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4 559375000ns

liqunfu added 4 commits April 12, 2024 21:58
…M>1 20% improvement by using implementing simd dequantization. int8 blklen=16 significantly improved

Signed-off-by: liqunfu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
@liqunfu liqunfu marked this pull request as ready for review April 18, 2024 03:30
@liqunfu liqunfu changed the title Liqun/mlas 4bit cpu Liqun/mlas Gemm 4bit avx2, avx512, and avx512vnni kernels Apr 18, 2024
@liqunfu
Copy link
Contributor Author

liqunfu commented Apr 24, 2024

Description

Avx2: Int8
NS(Prompt) MLAS(Prompt) MLAS(Prompt)Gain/Loss NS(TokenGen) MLAS(TokenGen) MLAS(TokenGen)Gain/Loss Blklen16: 90.96 25.15 -72% 7.65 11.71 53% Blklen32: 90.73 48.55 -46% 7.86 14.28 81% Blklen64: 89.49 68.84 -23% 8.30 15.78 90% Blklen128: 87.38 78.37 -10% 7.90 16.05 103% Blklen256: 89.45 82.36 -7% 8.30 16.56 99%
Fp32 NS(Prompt) MLAS(Prompt) MLAS(Prompt)Gain/Loss NS(TokenGen) MLAS(TokenGen) MLAS(TokenGen)Gain/Loss Blklen16: 91.36 105.18 15% 7.57 9.52 25% Blklen32: 89.30 105.99 18% 7.65 9.68 26% Blklen64: 89.53 101.41 13% 7.97 9.84 23% Blklen128: 85.23 99.71 16% 7.86 10.39 32% Blklen256: 88.46 97.94 10% 8.32 10.23 22%
Avx512vnni: Int8 NS(Prompt) MLAS(Prompt) MLAS(Prompt)Gain/Loss NS(TokenGen) MLAS(TokenGen) MLAS(TokenGen)Gain/Loss Blklen16: 132.18 21.56 -83% 10.34 11.48 11% Blklen32: 168.28 43.69 -74% 11.85 14.73 24% Blklen64: 201.81 60.29 -70% 12.36 15.47 25% Blklen128: 194.92 57.04 -71% 13.03 14.67 12% Blklen256: 218.76 70.20 -68% 13.33 16.31 22%
Fp32 NS(Prompt) MLAS(Prompt) MLAS(Prompt)Gain/Loss NS(TokenGen) MLAS(TokenGen) MLAS(TokenGen)Gain/Loss Blklen16: 102.81 92.74 -9% 8.41 9.18 9% Blklen32: 109.49 97.08 -11% 8.83 11.51 30% Blklen64: 104.13 101.57 -2% 9.32 12.00 28% Blklen128: 108.45 103.69 -4% 9.58 12.45 29% Blklen256: 109.43 106.43 -2% 9.19 12.2 32%

The prompt performance for fp32 is much better than the int8. I think we can use the fp32 prompt kernel for the int8 case.

good point! routed M>2 cases int8 compute to fp32. Will role bask to in8 compute after int8 compute for M>1 is improved.

@liqunfu liqunfu merged commit cc26b2d into main Apr 26, 2024
89 of 94 checks passed
@liqunfu liqunfu deleted the liqun/mlas-4bit-cpu branch April 26, 2024 04:30
@sophies927 sophies927 added release:1.18.0 triage:approved Approved for cherrypicks for release labels May 1, 2024
@yihonglyu yihonglyu added the cherry-picked Cherry-picked for a cherrypicks branch label May 4, 2024
yihonglyu pushed a commit that referenced this pull request May 4, 2024
### Description

```
Avx2:
Int8

NS(Prompt)		MLAS(Prompt)  	MLAS(Prompt)Gain/Loss		NS(TokenGen)		MLAS(TokenGen)  	MLAS(TokenGen)Gain/Loss
Blklen16: 	90.96			25.15			-72%					7.65				11.71			53%
Blklen32:	90.73			48.55			-46%					7.86				14.28			81%
Blklen64:	89.49			68.84			-23%					8.30				15.78			90%
Blklen128:	87.38			78.37			-10%					7.90				16.05			103%
Blklen256:	89.45			82.36			-7%					8.30				16.56			99%

Fp32		
NS(Prompt)		MLAS(Prompt)  	MLAS(Prompt)Gain/Loss		NS(TokenGen)		MLAS(TokenGen)  	MLAS(TokenGen)Gain/Loss
Blklen16:	91.36			105.18		15%				7.57			9.52		25%
Blklen32:	89.30			105.99			18%					7.65				9.68			26%
Blklen64:	89.53			101.41			13%					7.97				9.84			23%
Blklen128:	85.23			99.71			16%					7.86				10.39			32%
Blklen256:	88.46			97.94			10%					8.32				10.23			22%

Avx512vnni:
Int8		
NS(Prompt)		MLAS(Prompt)  	MLAS(Prompt)Gain/Loss		NS(TokenGen)		MLAS(TokenGen)  	MLAS(TokenGen)Gain/Loss
Blklen16:	132.18			21.56			-83%					10.34				11.48			11%
Blklen32:	168.28			43.69			-74%					11.85				14.73			24%
Blklen64:	201.81			60.29			-70%					12.36				15.47			25%
Blklen128:	194.92			57.04			-71%					13.03				14.67			12%
Blklen256:	218.76			70.20			-68%					13.33				16.31			22%

Fp32		
NS(Prompt)		MLAS(Prompt)  	MLAS(Prompt)Gain/Loss		NS(TokenGen)		MLAS(TokenGen)  	MLAS(TokenGen)Gain/Loss
Blklen16:	102.81			92.74			-9%					8.41				9.18			9%
Blklen32:	109.49			97.08			-11%					8.83				11.51			30%
Blklen64:	104.13			101.57			-2%					9.32				12.00			28%
Blklen128:	108.45			103.69			-4%					9.58				12.45			29%
Blklen256:	109.43			106.43			-2%					9.19				12.2			32%

```

---------

Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
Co-authored-by: edgchen1 <[email protected]>
TedThemistokleous pushed a commit to TedThemistokleous/onnxruntime that referenced this pull request May 7, 2024
### Description

```
Avx2:
Int8

NS(Prompt)		MLAS(Prompt)  	MLAS(Prompt)Gain/Loss		NS(TokenGen)		MLAS(TokenGen)  	MLAS(TokenGen)Gain/Loss
Blklen16: 	90.96			25.15			-72%					7.65				11.71			53%
Blklen32:	90.73			48.55			-46%					7.86				14.28			81%
Blklen64:	89.49			68.84			-23%					8.30				15.78			90%
Blklen128:	87.38			78.37			-10%					7.90				16.05			103%
Blklen256:	89.45			82.36			-7%					8.30				16.56			99%

Fp32		
NS(Prompt)		MLAS(Prompt)  	MLAS(Prompt)Gain/Loss		NS(TokenGen)		MLAS(TokenGen)  	MLAS(TokenGen)Gain/Loss
Blklen16:	91.36			105.18		15%				7.57			9.52		25%
Blklen32:	89.30			105.99			18%					7.65				9.68			26%
Blklen64:	89.53			101.41			13%					7.97				9.84			23%
Blklen128:	85.23			99.71			16%					7.86				10.39			32%
Blklen256:	88.46			97.94			10%					8.32				10.23			22%

Avx512vnni:
Int8		
NS(Prompt)		MLAS(Prompt)  	MLAS(Prompt)Gain/Loss		NS(TokenGen)		MLAS(TokenGen)  	MLAS(TokenGen)Gain/Loss
Blklen16:	132.18			21.56			-83%					10.34				11.48			11%
Blklen32:	168.28			43.69			-74%					11.85				14.73			24%
Blklen64:	201.81			60.29			-70%					12.36				15.47			25%
Blklen128:	194.92			57.04			-71%					13.03				14.67			12%
Blklen256:	218.76			70.20			-68%					13.33				16.31			22%

Fp32		
NS(Prompt)		MLAS(Prompt)  	MLAS(Prompt)Gain/Loss		NS(TokenGen)		MLAS(TokenGen)  	MLAS(TokenGen)Gain/Loss
Blklen16:	102.81			92.74			-9%					8.41				9.18			9%
Blklen32:	109.49			97.08			-11%					8.83				11.51			30%
Blklen64:	104.13			101.57			-2%					9.32				12.00			28%
Blklen128:	108.45			103.69			-4%					9.58				12.45			29%
Blklen256:	109.43			106.43			-2%					9.19				12.2			32%

```

---------

Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
Co-authored-by: edgchen1 <[email protected]>
@yihonglyu yihonglyu added the rel-merged Cherrypicks merged into release label May 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cherry-picked Cherry-picked for a cherrypicks branch rel-merged Cherrypicks merged into release release:1.18.0 triage:approved Approved for cherrypicks for release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants