Skip to content

Latest commit

 

History

History

mat-transpose

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

Mat Transpose

0x00 说明

包含以下内容:

  • mat_transpose_f32_col2row_kernel
  • mat_transpose_f32_row2col_kernel
  • mat_transpose_f32x4_col2row_kernel(float4向量化版本)
  • mat_transpose_f32x4_row2col_kernel(float4向量化版本)
  • mat_transpose_f32_diagnonal(对角轴应用于S=K)
  • mat_transpose_f32x4_shared_col2row_kernel(float4向量化版本,共享内存)
  • mat_transpose_f32x4_shared_row2col_kernel(float4向量化版本,共享内存)
  • mat_transpose_f32x4_shared_bcf_col2row_kernel(float4向量化版本,共享内存,去bank conflict)
  • mat_transpose_f32x4_shared_bcf_row2col_kernel(float4向量化版本,共享内存,去bank conflict)
  • PyTorch bindings

虽然是基础操作但是很适合练手,比矩阵乘法难度低一点但是可以其中可以用到的优化技巧都可以想办法用到这里来。

测试

# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada 
python3 mat_transpose.py

输出:

------------------------------------------------------------------------------------------------------------------------
                                                  S=1024, K=1024
                  out_original: [0.2706067, 1.89055979, 0.62714416], validate False, time:0.00007796ms
               out_f32_col2row: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.03732634ms
               out_f32_row2col: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.03055906ms
           out_f32_col2row(2d): [0.2706067, 0.62714416, 1.89055979], validate True , time:0.02096868ms
           out_f32_row2col(2d): [0.2706067, 0.62714416, 1.89055979], validate True , time:0.03112197ms
             out_f32_diagnonal: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.02037907ms
             out_f32x4_col2row: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.06107259ms
             out_f32x4_row2col: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.02692676ms
         out_f32x4_col2row(2d): [0.2706067, 0.62714416, 1.89055979], validate True , time:0.03207874ms
         out_f32x4_row2col(2d): [0.2706067, 0.62714416, 1.89055979], validate True , time:0.01719213ms
      out_f32x4_shared_col2row: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.01326251ms
      out_f32x4_shared_row2col: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.02352262ms
  out_f32x4_shared_bcf_col2row: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.01917195ms
  out_f32x4_shared_bcf_row2col: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.01389265ms
                    out_f32_th: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.05057526ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
                                                  S=1024, K=2048
                  out_original: [0.1013972, 0.10635406, 0.45091254], validate False, time:0.00007367ms
               out_f32_col2row: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.11233115ms
               out_f32_row2col: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.05733228ms
           out_f32_col2row(2d): [0.1013972, 0.45091254, 0.10635406], validate True , time:0.04851723ms
           out_f32_row2col(2d): [0.1013972, 0.45091254, 0.10635406], validate True , time:0.05224919ms
             out_f32x4_col2row: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.10379744ms
             out_f32x4_row2col: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.05431175ms
         out_f32x4_col2row(2d): [0.1013972, 0.45091254, 0.10635406], validate True , time:0.05774999ms
         out_f32x4_row2col(2d): [0.1013972, 0.45091254, 0.10635406], validate True , time:0.03115702ms
      out_f32x4_shared_col2row: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.03814983ms
      out_f32x4_shared_row2col: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.03473568ms
  out_f32x4_shared_bcf_col2row: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.03495407ms
  out_f32x4_shared_bcf_row2col: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.03433728ms
                    out_f32_th: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.08867288ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
                                                  S=1024, K=4096
                  out_original: [1.78550363, -1.60489535, -0.16560346], validate False, time:0.00007296ms
               out_f32_col2row: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.19823909ms
               out_f32_row2col: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.11195445ms
           out_f32_col2row(2d): [1.78550363, -0.16560346, -1.60489535], validate True , time:0.09996772ms
           out_f32_row2col(2d): [1.78550363, -0.16560346, -1.60489535], validate True , time:0.09864736ms
             out_f32x4_col2row: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.19718719ms
             out_f32x4_row2col: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.11092091ms
         out_f32x4_col2row(2d): [1.78550363, -0.16560346, -1.60489535], validate True , time:0.10105634ms
         out_f32x4_row2col(2d): [1.78550363, -0.16560346, -1.60489535], validate True , time:0.06530714ms
      out_f32x4_shared_col2row: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.06287837ms
      out_f32x4_shared_row2col: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.07055283ms
  out_f32x4_shared_bcf_col2row: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.06612253ms
  out_f32x4_shared_bcf_row2col: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.06411195ms
                    out_f32_th: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.17973542ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
                                                  S=2048, K=1024
                  out_original: [-0.96589017, -0.53940338, 1.51841831], validate False, time:0.00007153ms
               out_f32_col2row: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.10408664ms
               out_f32_row2col: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.05784106ms
           out_f32_col2row(2d): [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.04911971ms
           out_f32_row2col(2d): [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.04792857ms
             out_f32x4_col2row: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.15571523ms
             out_f32x4_row2col: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.07688594ms
         out_f32x4_col2row(2d): [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.05413485ms
         out_f32x4_row2col(2d): [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.03497577ms
      out_f32x4_shared_col2row: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.04818010ms
      out_f32x4_shared_row2col: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.05148292ms
  out_f32x4_shared_bcf_col2row: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.04849076ms
  out_f32x4_shared_bcf_row2col: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.03030324ms
                    out_f32_th: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.09853792ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
                                                  S=2048, K=2048
                  out_original: [0.66138971, 0.43854904, -1.19618118], validate False, time:0.00007439ms
               out_f32_col2row: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.24223709ms
               out_f32_row2col: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.15707016ms
           out_f32_col2row(2d): [0.66138971, -1.19618118, 0.43854904], validate True , time:0.09814286ms
           out_f32_row2col(2d): [0.66138971, -1.19618118, 0.43854904], validate True , time:0.13747311ms
             out_f32_diagnonal: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.08852434ms
             out_f32x4_col2row: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.26274681ms
             out_f32x4_row2col: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.12002778ms
         out_f32x4_col2row(2d): [0.66138971, -1.19618118, 0.43854904], validate True , time:0.15025878ms
         out_f32x4_row2col(2d): [0.66138971, -1.19618118, 0.43854904], validate True , time:0.07008457ms
      out_f32x4_shared_col2row: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.07605863ms
      out_f32x4_shared_row2col: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.09375811ms
  out_f32x4_shared_bcf_col2row: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.07940960ms
  out_f32x4_shared_bcf_row2col: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.07159257ms
                    out_f32_th: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.25392270ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
                                                  S=2048, K=4096
                  out_original: [0.21140628, 0.86610204, -0.61084032], validate False, time:0.00007534ms
               out_f32_col2row: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.51111245ms
               out_f32_row2col: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.29512668ms
           out_f32_col2row(2d): [0.21140628, -0.61084032, 0.86610204], validate True , time:0.25763965ms
           out_f32_row2col(2d): [0.21140628, -0.61084032, 0.86610204], validate True , time:0.25509524ms
             out_f32x4_col2row: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.47753954ms
             out_f32x4_row2col: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.27053690ms
         out_f32x4_col2row(2d): [0.21140628, -0.61084032, 0.86610204], validate True , time:0.26033616ms
         out_f32x4_row2col(2d): [0.21140628, -0.61084032, 0.86610204], validate True , time:0.16601658ms
      out_f32x4_shared_col2row: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.14935517ms
      out_f32x4_shared_row2col: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.17617536ms
  out_f32x4_shared_bcf_col2row: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.14183927ms
  out_f32x4_shared_bcf_row2col: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.17589092ms
                    out_f32_th: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.43119144ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
                                                  S=4096, K=1024
                  out_original: [-0.33594334, -0.13206008, 0.8452214], validate False, time:0.00007868ms
               out_f32_col2row: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.26727128ms
               out_f32_row2col: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.17777562ms
           out_f32_col2row(2d): [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.09764647ms
           out_f32_row2col(2d): [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.13735604ms
             out_f32x4_col2row: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.25628328ms
             out_f32x4_row2col: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.15057874ms
         out_f32x4_col2row(2d): [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.12607431ms
         out_f32x4_row2col(2d): [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.09281611ms
      out_f32x4_shared_col2row: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.07143378ms
      out_f32x4_shared_row2col: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.08804989ms
  out_f32x4_shared_bcf_col2row: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.09320903ms
  out_f32x4_shared_bcf_row2col: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.07376838ms
                    out_f32_th: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.25272131ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
                                                  S=4096, K=2048
                  out_original: [1.44601941, 1.46612203, -2.00953078], validate False, time:0.00007796ms
               out_f32_col2row: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.51826644ms
               out_f32_row2col: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.31751609ms
           out_f32_col2row(2d): [1.44601941, -2.00953078, 1.46612203], validate True , time:0.26685858ms
           out_f32_row2col(2d): [1.44601941, -2.00953078, 1.46612203], validate True , time:0.18520737ms
             out_f32x4_col2row: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.29121876ms
             out_f32x4_row2col: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.16650081ms
         out_f32x4_col2row(2d): [1.44601941, -2.00953078, 1.46612203], validate True , time:0.14630580ms
         out_f32x4_row2col(2d): [1.44601941, -2.00953078, 1.46612203], validate True , time:0.09408069ms
      out_f32x4_shared_col2row: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.09475493ms
      out_f32x4_shared_row2col: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.09508491ms
  out_f32x4_shared_bcf_col2row: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.09532118ms
  out_f32x4_shared_bcf_row2col: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.09467864ms
                    out_f32_th: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.26716113ms
------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------
                                                  S=4096, K=4096
                  out_original: [-1.07092094, -1.13755226, 0.99070781], validate False, time:0.00007606ms
               out_f32_col2row: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.75331712ms
               out_f32_row2col: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.52119255ms
           out_f32_col2row(2d): [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.36621094ms
           out_f32_row2col(2d): [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.36603284ms
             out_f32_diagnonal: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.37416911ms
             out_f32x4_col2row: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.96249247ms
             out_f32x4_row2col: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.56916833ms
         out_f32x4_col2row(2d): [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.48158646ms
         out_f32x4_row2col(2d): [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.30216074ms
      out_f32x4_shared_col2row: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.32637930ms
      out_f32x4_shared_row2col: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.32455182ms
  out_f32x4_shared_bcf_col2row: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.30707669ms
  out_f32x4_shared_bcf_row2col: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.31853962ms
                    out_f32_th: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.91187215ms
------------------------------------------------------------------------------------------------------------------------