From 8716a96bd4ed030e9abe7889f1503b40123d0baf Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Sun, 7 Jul 2024 19:39:25 -0700 Subject: [PATCH] update unit test --- lib/gc/Transforms/Pipeline.cpp | 1 + test/gc/Transform/flashAttention.mlir | 26 +++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 72003224a..559c7ff00 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -94,6 +94,7 @@ void populateCPURuntimePasses(mlir::PassManager &pm) { } void populateLoweringToLLVMPasses(mlir::PassManager &pm) { + pm.addPass(createLowerAffinePass()); pm.addPass(createFinalizeMemRefToLLVMConversionPass()); pm.addPass(createConvertSCFToCFPass()); pm.addPass(cpuruntime::createCPURuntimeToLLVM()); diff --git a/test/gc/Transform/flashAttention.mlir b/test/gc/Transform/flashAttention.mlir index 8b1456ee4..ca6907cc8 100644 --- a/test/gc/Transform/flashAttention.mlir +++ b/test/gc/Transform/flashAttention.mlir @@ -1,7 +1,31 @@ -// RUN: gc-opt --split-input-file --flash-attention-conversion %s +// RUN: gc-opt --split-input-file --flash-attention-conversion --gc-cpu-pipeline %s | gc-cpu-runner -e main -entry-point-result=void func.func @flash_attention(%arg0: tensor<1x16x384x64xf32>, %arg1: tensor<1x16x384x64xf32>, %arg2: tensor<1x16x384x64xf32>, %arg3: tensor<1x16x384x384xf32>) -> tensor<1x16x384x64xf32> { %0 = tensor.empty() : tensor<1x16x384x64xf32> %1 = linalgx.scaled_dot_product_attention ins(%arg0, %arg1, %arg2, %arg3: tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x384xf32>) outs(%0 : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32> return %1 : tensor<1x16x384x64xf32> } + +func.func @main() { + %cst = arith.constant 1.000000e+00 : f32 + + %QKVShape = tensor.empty() : tensor<1x16x384x64xf32> + %maskShape = tensor.empty() : tensor<1x16x384x384xf32> + + %Q = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32> + %K = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32> + %V = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32> + %mask = linalg.fill ins(%cst : f32) outs(%maskShape : tensor<1x16x384x384xf32>) -> tensor<1x16x384x384xf32> + + %out = func.call @flash_attention(%Q, %K, %V, %mask) : + (tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x384xf32>) + -> (tensor<1x16x384x64xf32>) + + %idx = arith.constant 0 : index + %val = tensor.extract %out[%idx, %idx, %idx, %idx] : tensor<1x16x384x64xf32> + cpuruntime.printf "output[0, 0, 0]: %f\n" %val : f32 + + return +} +// CHECK: output[0, 0, 0]: 1.0 +