Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Nov 5, 2023
1 parent 2fab38e commit 791621b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
16 changes: 6 additions & 10 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2393,19 +2393,15 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes

<dl>
<dt><tt>is_past_bsnh</tt> : int</dt>
<dd>Whether past kv uses BSNH, otherwise BNSH. Default value is 1 (BSNH).</dd>
<dt><tt>kv_num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for k and v</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for q</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>unidirectional</tt> : int</dt>
<dd>Whether every token can only attend to previous tokens. Default value is 1.</dd>
</dl>

#### Inputs (3 - 6)
#### Inputs

<dl>
<dt><tt>query</tt> : T</dt>
Expand All @@ -2418,8 +2414,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>past state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dt><tt>past_value</tt> (optional) : T</dt>
<dd>past state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
<dd>When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.</dd>
<dt><tt>attention_mask</tt> : M</dt>
<dd>2d Tensor of shape (batch_size, past_sequence_length + sequence_length). Must be a right padding mask.</dd>
</dl>

#### Outputs
Expand All @@ -2438,8 +2434,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>T</tt> : tensor(float16)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>M</tt> : tensor(int32), tensor(int64)</dt>
<dd>Constrain past sequence length to int tensor.</dd>
<dt><tt>M</tt> : tensor(int64)</dt>
<dd>Constrain mask to int tensor.</dd>
</dl>


Expand Down Expand Up @@ -4950,7 +4946,7 @@ This version of the operator has been available since version 1 of the 'com.micr

### <a name="com.microsoft.RotaryEmbedding"></a><a name="com.microsoft.rotaryembedding">**com.microsoft.RotaryEmbedding**</a>

RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices
RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices
that are multiplied to query and key before the inner product of query and key is taken.

#### Version
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32), tensor(int64)<br/> **T** = tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* attention_mask:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down

0 comments on commit 791621b

Please sign in to comment.