Skip to content

Commit

Permalink
update unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeizh2 committed Jul 8, 2024
1 parent 586654c commit 8716a96
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
1 change: 1 addition & 0 deletions lib/gc/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ void populateCPURuntimePasses(mlir::PassManager &pm) {
}

void populateLoweringToLLVMPasses(mlir::PassManager &pm) {
pm.addPass(createLowerAffinePass());
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addPass(createConvertSCFToCFPass());
pm.addPass(cpuruntime::createCPURuntimeToLLVM());
Expand Down
26 changes: 25 additions & 1 deletion test/gc/Transform/flashAttention.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,31 @@
// RUN: gc-opt --split-input-file --flash-attention-conversion %s
// RUN: gc-opt --split-input-file --flash-attention-conversion --gc-cpu-pipeline %s | gc-cpu-runner -e main -entry-point-result=void

func.func @flash_attention(%arg0: tensor<1x16x384x64xf32>, %arg1: tensor<1x16x384x64xf32>, %arg2: tensor<1x16x384x64xf32>, %arg3: tensor<1x16x384x384xf32>) -> tensor<1x16x384x64xf32> {
%0 = tensor.empty() : tensor<1x16x384x64xf32>
%1 = linalgx.scaled_dot_product_attention ins(%arg0, %arg1, %arg2, %arg3: tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x384xf32>) outs(%0 : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
return %1 : tensor<1x16x384x64xf32>
}

func.func @main() {
%cst = arith.constant 1.000000e+00 : f32

%QKVShape = tensor.empty() : tensor<1x16x384x64xf32>
%maskShape = tensor.empty() : tensor<1x16x384x384xf32>

%Q = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
%K = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
%V = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
%mask = linalg.fill ins(%cst : f32) outs(%maskShape : tensor<1x16x384x384xf32>) -> tensor<1x16x384x384xf32>

%out = func.call @flash_attention(%Q, %K, %V, %mask) :
(tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x384xf32>)
-> (tensor<1x16x384x64xf32>)

%idx = arith.constant 0 : index
%val = tensor.extract %out[%idx, %idx, %idx, %idx] : tensor<1x16x384x64xf32>
cpuruntime.printf "output[0, 0, 0]: %f\n" %val : f32

return
}
// CHECK: output[0, 0, 0]: 1.0

0 comments on commit 8716a96

Please sign in to comment.