Skip to content

Commit

Permalink
fix sdpa op definition
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeizh2 committed Jul 15, 2024
1 parent 3567769 commit e13ec10
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 26 deletions.
25 changes: 25 additions & 0 deletions include/gc/Dialect/Linalgx/LinalgxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,33 @@

include "LinalgxDialect.td"

include "mlir/Dialect/Linalg/IR/LinalgBase.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"

// Base class for Linalg dialect ops that do not correspond to library calls.
class Linalgx_Op<string mnemonic, list<Trait> traits = []> :
Op<LinalgxDialect, mnemonic, traits>;

def Linalgx_ScaledDotProductAttentionOp
: Linalgx_Op<"scaled_dot_product_attention",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>]> {
let summary = "Attention structure.";
let description = [{
Q, K, V, attention_mask.
Output = SoftMax(Q @ K.transpose(-2, -1) + attention_mask) @ V.
}];
let arguments = (ins
Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs);
let results = (outs Variadic<AnyRankedTensor>:$results);

let hasVerifier = 1;
let assemblyFormat = [{
attr-dict
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($results)^)?
}];
}
#endif // LINALGX_OPS
26 changes: 0 additions & 26 deletions include/gc/Dialect/Linalgx/LinalgxStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"

class Linalgx_Op<string mnemonic, list<Trait> traits = []> :
Op<LinalgxDialect, mnemonic, traits>;

// Base Tablegen class for Linalg ops.
// Linalg ops that correspond to library calls operate on ShapedType as their
// first operands. These may be optionally followed by non-view operands
Expand Down Expand Up @@ -315,27 +312,4 @@ def Linalgx_MultiBatchMatmulOp : LinalgxStructuredBase_Op<"multi_batch_matmul",
}];
}

def Linalgx_ScaledDotProductAttentionOp
: Linalgx_Op<"scaled_dot_product_attention",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>]> {
let summary = "Attention structure.";
let description = [{
Q, K, V, attention_mask.
Output = SoftMax(Q @ K.transpose(-2, -1) + attention_mask) @ V.
}];
let arguments = (ins
Variadic<TensorOrMemref>:$inputs,
Variadic<TensorOrMemref>:$outputs);
let results = (outs Variadic<TensorOrMemref>:$results);

let hasVerifier = 1;
let assemblyFormat = [{
attr-dict
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($results)^)?
}];
}

#endif // LINALGX_STRUCTURED_OPS

0 comments on commit e13ec10

Please sign in to comment.