Skip to content

Latest commit

 

History

History

rms-norm

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

RMSNorm

0x00 说明

包含以下内容:

  • rms_norm_f32_kernel
  • rms_norm_f32x4_kernel
  • rms_norm_f16_f16_kernel
  • rms_norm_f16x2_f16_kernel
  • rms_norm_f16x8_f16_kernel
  • rms_norm_f16x8_f32_kernel
  • rms_norm_f16x8_pack_f16_kernel
  • rms_norm_f16x8_pack_f32_kernel
  • rms_norm_f16_f32_kernel
  • PyTorch bindings

测试

# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada 
python3 rms_norm.py

输出:

-------------------------------------------------------------------------------------
                                        N=4096, K=512
          out_f32: ['0.04078517  ', '0.74503314  ', '0.87149841  '], time:0.01198173ms
        out_f32x4: ['0.04078517  ', '0.74503314  ', '0.87149841  '], time:0.00517488ms
       out_f32_th: ['0.04078539  ', '0.74503714  ', '0.87150306  '], time:0.04351616ms
-------------------------------------------------------------------------------------
       out_f16f16: ['0.040802    ', '0.74511719  ', '0.87158203  '], time:0.01200986ms
       out_f16f32: ['0.040802    ', '0.74511719  ', '0.87109375  '], time:0.01180410ms
     out_f16x2f16: ['0.040802    ', '0.74511719  ', '0.87158203  '], time:0.00670171ms
     out_f16x8f16: ['0.040802    ', '0.74511719  ', '0.87158203  '], time:0.00411820ms
     out_f16x8f32: ['0.040802    ', '0.74511719  ', '0.87158203  '], time:0.00411677ms
 out_f16x8packf16: ['0.040802    ', '0.74511719  ', '0.87158203  '], time:0.00411630ms
 out_f16x8packf32: ['0.040802    ', '0.74511719  ', '0.87109375  '], time:0.00399137ms
       out_f16_th: ['0.040802    ', '0.74511719  ', '0.87158203  '], time:0.04383564ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        N=4096, K=1024
          out_f32: ['-0.76329279 ', '-0.62111992 ', '-1.45531178 '], time:0.03398657ms
        out_f32x4: ['-0.76329279 ', '-0.62111992 ', '-1.45531178 '], time:0.00862885ms
       out_f32_th: ['-0.76329684 ', '-0.62112319 ', '-1.4553194  '], time:0.04355550ms
-------------------------------------------------------------------------------------
       out_f16f16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.03526235ms
       out_f16f32: ['-0.76318359 ', '-0.62109375 ', '-1.45605469 '], time:0.03302288ms
     out_f16x2f16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.01215649ms
     out_f16x8f16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.00632071ms
     out_f16x8f32: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.00631690ms
 out_f16x8packf16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.00528240ms
 out_f16x8packf32: ['-0.76318359 ', '-0.62109375 ', '-1.45605469 '], time:0.00519514ms
       out_f16_th: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.04399920ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        N=4096, K=2048
        out_f32x4: ['-0.17984088 ', '-1.76387513 ', '-0.32782754 '], time:0.01650691ms
       out_f32_th: ['-0.17984176 ', '-1.76388371 ', '-0.32782915 '], time:0.09451318ms
-------------------------------------------------------------------------------------
     out_f16x2f16: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.03497124ms
     out_f16x8f16: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.01254177ms
     out_f16x8f32: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.01253581ms
 out_f16x8packf16: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.00903535ms
 out_f16x8packf32: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.00894380ms
       out_f16_th: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.04889655ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        N=4096, K=4096
        out_f32x4: ['-1.14100003 ', '-0.71529448 ', '2.26544118  '], time:0.18783689ms
       out_f32_th: ['-1.14100587 ', '-0.71529812 ', '2.26545286  '], time:0.52556086ms
-------------------------------------------------------------------------------------
     out_f16x8f16: ['-1.140625   ', '-0.71484375 ', '2.26367188  '], time:0.03605795ms
     out_f16x8f32: ['-1.140625   ', '-0.71484375 ', '2.26367188  '], time:0.03605533ms
 out_f16x8packf16: ['-1.140625   ', '-0.71484375 ', '2.26367188  '], time:0.01718473ms
 out_f16x8packf32: ['-1.140625   ', '-0.71533203 ', '2.26367188  '], time:0.01735568ms
       out_f16_th: ['-1.140625   ', '-0.71484375 ', '2.26367188  '], time:0.11150384ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        N=4096, K=8192
     out_f16x8f16: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.19292974ms
     out_f16x8f32: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.19298863ms
 out_f16x8packf16: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.18497562ms
 out_f16x8packf32: ['-0.40844727 ', '-0.14294434 ', '-0.93310547 '], time:0.18479729ms
       out_f16_th: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.59557104ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        N=8192, K=8192
     out_f16x8f16: ['-0.35253906 ', '-1.04101562 ', '0.17358398  '], time:0.38169765ms
     out_f16x8f32: ['-0.35253906 ', '-1.04101562 ', '0.17358398  '], time:0.38264203ms
 out_f16x8packf16: ['-0.35253906 ', '-1.04101562 ', '0.17358398  '], time:0.40794849ms
 out_f16x8packf32: ['-0.35229492 ', '-1.04003906 ', '0.17346191  '], time:0.40747380ms
       out_f16_th: ['-0.35229492 ', '-1.04003906 ', '0.17346191  '], time:1.35807014ms
-------------------------------------------------------------------------------------