From e13ec1061fb12ba58ad27e94f1472c731282833a Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Sun, 14 Jul 2024 18:28:53 -0700 Subject: [PATCH] fix sdpa op definition --- include/gc/Dialect/Linalgx/LinalgxOps.td | 25 ++++++++++++++++++ .../Dialect/Linalgx/LinalgxStructuredOps.td | 26 ------------------- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/include/gc/Dialect/Linalgx/LinalgxOps.td b/include/gc/Dialect/Linalgx/LinalgxOps.td index 4491967c3..9bd28c4c2 100644 --- a/include/gc/Dialect/Linalgx/LinalgxOps.td +++ b/include/gc/Dialect/Linalgx/LinalgxOps.td @@ -11,8 +11,33 @@ include "LinalgxDialect.td" +include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" + // Base class for Linalg dialect ops that do not correspond to library calls. class Linalgx_Op traits = []> : Op; +def Linalgx_ScaledDotProductAttentionOp + : Linalgx_Op<"scaled_dot_product_attention", + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods]> { + let summary = "Attention structure."; + let description = [{ + Q, K, V, attention_mask. + Output = SoftMax(Q @ K.transpose(-2, -1) + attention_mask) @ V. + }]; + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs); + let results = (outs Variadic:$results); + + let hasVerifier = 1; + let assemblyFormat = [{ + attr-dict + `ins` `(` $inputs `:` type($inputs) `)` + `outs` `(` $outputs `:` type($outputs) `)` + (`->` type($results)^)? + }]; +} #endif // LINALGX_OPS \ No newline at end of file diff --git a/include/gc/Dialect/Linalgx/LinalgxStructuredOps.td b/include/gc/Dialect/Linalgx/LinalgxStructuredOps.td index cf149d573..dee5eef74 100644 --- a/include/gc/Dialect/Linalgx/LinalgxStructuredOps.td +++ b/include/gc/Dialect/Linalgx/LinalgxStructuredOps.td @@ -23,9 +23,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" -class Linalgx_Op traits = []> : - Op; - // Base Tablegen class for Linalg ops. // Linalg ops that correspond to library calls operate on ShapedType as their // first operands. These may be optionally followed by non-view operands @@ -315,27 +312,4 @@ def Linalgx_MultiBatchMatmulOp : LinalgxStructuredBase_Op<"multi_batch_matmul", }]; } -def Linalgx_ScaledDotProductAttentionOp - : Linalgx_Op<"scaled_dot_product_attention", - [AttrSizedOperandSegments, - DeclareOpInterfaceMethods]> { - let summary = "Attention structure."; - let description = [{ - Q, K, V, attention_mask. - Output = SoftMax(Q @ K.transpose(-2, -1) + attention_mask) @ V. - }]; - let arguments = (ins - Variadic:$inputs, - Variadic:$outputs); - let results = (outs Variadic:$results); - - let hasVerifier = 1; - let assemblyFormat = [{ - attr-dict - `ins` `(` $inputs `:` type($inputs) `)` - `outs` `(` $outputs `:` type($outputs) `)` - (`->` type($results)^)? - }]; -} - #endif // LINALGX_STRUCTURED_OPS \ No newline at end of file