Skip to content

Commit

Permalink
start work rotary
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Nov 20, 2023
1 parent 93cb019 commit 3c332f0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* past_value = context->Input<Tensor>(4);
const Tensor* seqlens_k = context->Input<Tensor>(5);
const Tensor* total_seqlen = context->Input<Tensor>(6);
const Tensor* cos_cache = context->Input<Tensor>(7);
const Tensor* sin_cache = context->Input<Tensor>(8);

auto& device_prop = GetDeviceProp();
GroupQueryAttentionParameters parameters;
Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"left_window_size for local attention (like Mistral). Default value is -1 meaning unused.",
AttributeProto::INT,
static_cast<int64_t>(-1))
.Attr("do_rotary",
"Whether to use rotary position embedding. Default value is 0.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("rotary_interleaved",
"Rotate using interleaved pattern. Default value is 0 (False).",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0,
"query",
"Query with shape (batch_size, sequence_length, hidden_size)",
Expand Down Expand Up @@ -1040,6 +1048,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"total_sequence_length",
"Scalar tensor of total sequence length (past + new).",
"M")
.Input(7,
"cos_cache",
"2D tensor with shape (max_sequence_length, head_size / 2).",
"T")
.Input(8,
"sin_cache",
"2D tensor with shape (max_sequence_length, head_size / 2).",
"T")
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, hidden_size)",
Expand Down

0 comments on commit 3c332f0

Please sign in to comment.