Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Attention Bias Broadcast Support #21710

Merged
merged 14 commits into from
Aug 16, 2024
28 changes: 14 additions & 14 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size) or (3 * batch_size + 2)</dd>
<dt><tt>past</tt> (optional) : T</dt>
<dd>past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dd>additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)</dd>
<dt><tt>attention_bias</tt> (optional) : T</dt>
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
<dd>additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)</dd>
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).</dd>
</dl>
Expand Down Expand Up @@ -1166,7 +1166,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Value with shape (batch_size, 1, v_hidden_size) for self attention or past_value with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention</dd>
<dt><tt>mask_index</tt> (optional) : M</dt>
<dd>Mask values of shape (batch_size, total_sequence_length) or (batch_size, kv_sequence_length)</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dt><tt>attention_bias</tt> (optional) : T</dt>
<dd>additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>past state for key with shape (batch_size, num_heads, past_sequence_length, head_size) for self attentionWhen past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size). The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.</dd>
Expand Down Expand Up @@ -1256,8 +1256,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Mask values of shape (batch_size, total_sequence_length)</dd>
<dt><tt>past</tt> : T</dt>
<dd>past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size). The first `batch_size * num_heads * max_sequence_length * head_size` elements correspond to keys and the next `batch_size * num_heads * max_sequence_length * head_size` elements correspond to values. The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dd>additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)</dd>
<dt><tt>attention_bias</tt> (optional) : T</dt>
<dd>additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)</dd>
<dt><tt>past_sequence_length</tt> : M</dt>
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).</dd>
<dt><tt>beam_width</tt> (optional) : M</dt>
Expand Down Expand Up @@ -3202,8 +3202,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
<dd>Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dd>relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)</dd>
<dt><tt>attention_bias</tt> (optional) : T</dt>
<dd>bias added to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>past state for self attention key with shape (batch_size, num_heads, past_sequence_length, head_size)</dd>
<dt><tt>past_value</tt> (optional) : T</dt>
Expand Down Expand Up @@ -3516,8 +3516,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>In packing mode, it specifies the offset of each token(batch_size, sequence_length).</dd>
<dt><tt>cumulative_sequence_length</tt> : M</dt>
<dd>A tensor with shape (batch_size + 1). It specifies the cumulative sequence length.</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dd>A tensor with shape (batch_size, num_heads, sequence_length, sequence_length)or (1, num_heads, sequence_length, sequence_length).It specifies the additional bias to QxK'</dd>
<dt><tt>attention_bias</tt> (optional) : T</dt>
<dd>A tensor with shape (batch_size or 1, num_heads or 1, sequence_length, sequence_length).It specifies the additional bias to QxK'</dd>
</dl>

#### Outputs
Expand Down Expand Up @@ -3591,8 +3591,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Offset of each token before packing, with shape (batch_size, sequence_length).</dd>
<dt><tt>cumulative_sequence_length</tt> : M</dt>
<dd>A tensor with shape (batch_size + 1). It specifies the cumulative sequence length.</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dd>It specifies the additional bias to QxK'. The shape is (batch_size, num_heads, sequence_length, sequence_length) or (1, num_heads, sequence_length, sequence_length)</dd>
<dt><tt>attention_bias</tt> (optional) : T</dt>
<dd>It specifies the additional bias to QxK'. The shape is (batch_size or 1, num_heads or 1, sequence_length, sequence_length)</dd>
</dl>

#### Outputs
Expand Down Expand Up @@ -4468,7 +4468,7 @@ This version of the operator has been available since version 1 of the 'com.micr
left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by
the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past
and present state are optional. Present state could appear in output even when past state is not in input.
Current version does not support past/present, relative_position_bias and qkv_hidden_sizes.
Current version does not support past/present, attention_bias and qkv_hidden_sizes.
TODO: Support them if needed in the future.

#### Version
Expand Down Expand Up @@ -4533,8 +4533,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).</dd>
<dt><tt>past</tt> (optional) : Q</dt>
<dd>past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).</dd>
<dt><tt>relative_position_bias</tt> (optional) : S</dt>
<dd>additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).</dd>
<dt><tt>attention_bias</tt> (optional) : S</dt>
<dd>additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length).</dd>
</dl>

#### Outputs
Expand Down
Loading
Loading