Skip to content

Commit

Permalink
Revert "phi2 contrib ops changes (#19112)"
Browse files Browse the repository at this point in the history
This reverts commit 21034a2.
  • Loading branch information
mszhanyi committed Jan 23, 2024
1 parent 6ca7c1a commit 4b0ad85
Show file tree
Hide file tree
Showing 18 changed files with 81 additions and 280 deletions.
12 changes: 3 additions & 9 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3031,8 +3031,6 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of attention heads</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>unidirectional</tt> : int</dt>
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
</dl>

#### Inputs (1 - 8)
Expand Down Expand Up @@ -5023,10 +5021,6 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>num_heads</tt> : int</dt>
<dd>Number of attention heads. Default value is 0. Must use with rotary_embedding_dim</dd>
<dt><tt>rotary_embedding_dim</tt> : int</dt>
<dd>Rotary embedding dimension. Default value is 0.</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1.0</dd>
</dl>
Expand All @@ -5039,9 +5033,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>position_ids</tt> : M</dt>
<dd>1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)</dd>
<dt><tt>cos_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)</dd>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
<dt><tt>sin_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)</dd>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
</dl>

#### Outputs
Expand All @@ -5054,7 +5048,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>M</tt> : tensor(int64)</dt>
<dd>Constrain input and output types to integer tensors</dd>
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(bfloat16), tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipGroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *in* skip:**T**<br> *in* bias:**T**<br> *out* Y:**T**<br> *out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
6 changes: 0 additions & 6 deletions onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,6 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
relative_position_bias,
&parameters));

if (parameters.do_rotary) {
ORT_NOT_IMPLEMENTED(
"Rotary embedding is not supported in Attention CPU kernel. \
Please fuse the model with MHA + RotaryEmbedding.");
}

const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int input_hidden_size = parameters.input_hidden_size;
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i
num_heads_ = static_cast<int>(num_heads);

mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
}

// Reshape Q/K/V from BxSxD to BxSxNxH
Expand Down Expand Up @@ -284,9 +283,8 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
nullptr,
&parameters,
num_heads_,
mask_filter_value_,
scale,
is_unidirectional_,
mask_filter_value_,
past_present_share_buffer,
false));

Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase {
protected:
int num_heads_; // number of attention heads
float mask_filter_value_;
bool is_unidirectional_;
};

} // namespace contrib
Expand Down
8 changes: 3 additions & 5 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
float scale,
bool is_unidirectional,
bool past_present_share_buffer,
bool dmmha_packing) {
// key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None
Expand Down Expand Up @@ -316,7 +315,7 @@ Status CheckInputs(const T* query,
output_parameters->head_size = hidden_size / num_heads;
output_parameters->v_head_size = v_hidden_size / num_heads;
output_parameters->num_heads = num_heads;
output_parameters->is_unidirectional = is_unidirectional;
output_parameters->is_unidirectional = false;
output_parameters->past_present_share_buffer = past_present_share_buffer;
output_parameters->mask_filter_value = mask_filter_value;
output_parameters->mask_type = mask_type;
Expand All @@ -343,7 +342,6 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
float scale,
bool is_unidirectional,
bool past_present_share_buffer,
bool dmmha_packing,
int max_threads_per_block) {
Expand All @@ -352,8 +350,8 @@ Status CheckInputs(const T* query,
}

return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value,
past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional,
past_present_share_buffer, dmmha_packing);
past_seq_len, parameters, num_heads, mask_filter_value, scale, past_present_share_buffer,
dmmha_packing);
}

} // namespace multihead_attention_helper
Expand Down
47 changes: 17 additions & 30 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,7 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
template <typename T>
RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
scale = info.GetAttrOrDefault<float>("scale", 1.0);
rotary_embedding_dim = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
num_heads = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0));
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);

if (rotary_embedding_dim > 0) {
ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified");
}
}

template <typename T>
Expand All @@ -48,8 +42,6 @@ Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
position_ids,
cos_cache,
sin_cache,
num_heads,
rotary_embedding_dim,
&parameters));

Tensor* output = context->Output(0, input->Shape());
Expand All @@ -67,66 +59,61 @@ Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {

const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int n_heads = parameters.num_heads;
const int num_heads = parameters.num_heads;
const int head_size = parameters.head_size;
const int position_ids_format = parameters.position_ids_format;
const int rotary_emb_dim = parameters.rotary_embedding_dim;
const int half_rotary_emb_dim = rotary_emb_dim / 2;

const int half_head_size = head_size / 2;
// Default input tensor shape is [batch, seq_len, hidden_size]
int head_stride = head_size;
int seq_stride = n_heads * head_stride;
int seq_stride = num_heads * head_stride;
int batch_stride = sequence_length * seq_stride;
if (parameters.transposed) {
// Transposed input tensor shape is [batch, n_heads, seq_len, head_size]
// Transposed input tensor shape is [batch, num_heads, seq_len, head_size]
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
batch_stride = n_heads * head_stride;
batch_stride = num_heads * head_stride;
}

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto* tp = context->GetOperatorThreadPool();

const int loop_len = batch_size * sequence_length * n_heads;
const double cost = static_cast<double>(rotary_emb_dim);
const int loop_len = batch_size * sequence_length * num_heads;
const double cost = static_cast<double>(head_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / n_heads) / sequence_length);
const int s = static_cast<int>((ptr / n_heads) % sequence_length);
const int n = static_cast<int>(ptr % n_heads);
const int b = static_cast<int>((ptr / num_heads) / sequence_length);
const int s = static_cast<int>((ptr / num_heads) % sequence_length);
const int n = static_cast<int>(ptr % num_heads);

const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;

const T* input_data = input_src + block_offset;
T* output_data = output_dest + block_offset;

// Cache is (M, H/2) or (M, rotary_embedding_dim/2)
// Cache is (M, H/2)
const int position_id = (position_ids_format == 0)
? static_cast<int>(pos_ids_data[0]) + s
: static_cast<int>(pos_ids_data[b * sequence_length + s]);
const int cache_offset = position_id * half_rotary_emb_dim;
const int cache_offset = position_id * half_head_size;
const T* cos_data = cos_cache_data + cache_offset;
const T* sin_data = sin_cache_data + cache_offset;

int cache_idx = 0;
T sign = 0;
int j = 0;
for (int i = 0; i < rotary_emb_dim; i++) {
for (int i = 0; i < head_size; i++) {
if (interleaved) {
cache_idx = (i / 2) % half_rotary_emb_dim;
cache_idx = (i / 2) % half_head_size;
sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
cache_idx = i % half_rotary_emb_dim;
sign = (i < half_rotary_emb_dim) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
cache_idx = i % half_head_size;
sign = (i < half_head_size) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i + half_head_size) % head_size;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
for (int i = rotary_emb_dim; i < head_size; i++) {
output_data[i] = input_data[i];
}
}
});

Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class RotaryEmbedding final : public OpKernel {

protected:
float scale;
int num_heads;
int rotary_embedding_dim;
bool interleaved;
};

Expand Down
55 changes: 24 additions & 31 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,26 @@ namespace rotary_embedding_helper {

// Parameters deduced from node attributes and inputs/outputs.
struct RotaryParameters {
int batch_size; // Batch size used by input
int sequence_length; // Sequence length used by input
int hidden_size; // Hidden size used by input
int head_size; // Head size
int rotary_embedding_dim; // Rotary embedding dimension.
int num_heads; // num_heads = hidden_size / head_size
int max_sequence_length; // Sequence length used by cos/sin cache
int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
int batch_size; // Batch size used by input
int sequence_length; // Sequence length used by input
int hidden_size; // Hidden size used by input
int head_size; // Head size used by cos/sin cache * 2
int num_heads; // num_heads = hidden_size / head_size
int max_sequence_length; // Sequence length used by cos/sin cache
int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
};

template <typename T>
Status CheckInputs(const T* input,
const T* position_ids,
const T* cos_cache,
const T* sin_cache,
int num_heads,
int rotary_embedding_dim,
void* parameters) {
// input : (batch_size, sequence_length, hidden_size)
// position ids : (1) or (batch_size, sequence_length)
// cos cache : (max_sequence_length, rotary_embedding_dim / 2)
// sin cache : (max_sequence_length, rotary_embedding_dim / 2)
// cos cache : (max_sequence_length, head_size / 2)
// sin cache : (max_sequence_length, head_size / 2)

// Check input
const auto& input_dims = input->Shape().GetDims();
Expand Down Expand Up @@ -63,12 +60,6 @@ Status CheckInputs(const T* input,
"the same shape");
}

// Check num_heads and rotary_embedding_dim
if (rotary_embedding_dim > 0 && num_heads == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ",
"specified");
}

// Get attributes from inputs
int batch_size = static_cast<int>(input_dims[0]);
int sequence_length = static_cast<int>(input_dims[1]);
Expand All @@ -82,13 +73,8 @@ Status CheckInputs(const T* input,
transposed = true;
}
int max_sequence_length = static_cast<int>(cos_cache_dims[0]);
int head_size = rotary_embedding_dim == 0 ? static_cast<int>(cos_cache_dims[1]) * 2
: static_cast<int>(hidden_size / num_heads);
if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ",
"head_size");
}

int head_size = static_cast<int>(cos_cache_dims[1]) * 2;
int num_heads = hidden_size / head_size;
int position_ids_format = -1;

// Check position_ids input shapes
Expand All @@ -105,15 +91,23 @@ Status CheckInputs(const T* input,
} else {
position_ids_format = 0;
}

// Check cos_cache input shapes
if (max_sequence_length != static_cast<int>(cos_cache_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ",
"max_sequence_length, got ", cos_cache_dims[0]);
}
if ((head_size / 2) != static_cast<int>(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast<int>(cos_cache_dims[1]))) {
if ((head_size / 2) != static_cast<int>(cos_cache_dims[1])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ",
"head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]);
"head_size / 2, got ", cos_cache_dims[1]);
}
// Check sin_cache input shapes
if (max_sequence_length != static_cast<int>(sin_cache_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ",
"max_sequence_length, got ", sin_cache_dims[0]);
}
if ((head_size / 2) != static_cast<int>(sin_cache_dims[1])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ",
"head_size / 2, got ", sin_cache_dims[1]);
}

// Set rotary parameters
Expand All @@ -123,11 +117,10 @@ Status CheckInputs(const T* input,
output_parameters->sequence_length = sequence_length;
output_parameters->hidden_size = hidden_size;
output_parameters->head_size = head_size;
output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast<int>(hidden_size / head_size);
output_parameters->num_heads = num_heads;
output_parameters->max_sequence_length = max_sequence_length;
output_parameters->position_ids_format = position_ids_format;
output_parameters->transposed = transposed;
output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size;
}

return Status::OK();
Expand Down
3 changes: 0 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);

scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead.");

disable_fused_self_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
Expand Down Expand Up @@ -107,7 +105,6 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
num_heads_,
mask_filter_value_,
scale_,
is_unidirectional_,
false, // past_present_share_buffer
false, // dmmha_packing
device_prop.maxThreadsPerBlock));
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class MultiHeadAttention final : public CudaKernel {
int num_heads_; // number of attention heads
float mask_filter_value_;
float scale_;
bool is_unidirectional_;
bool disable_fused_self_attention_;
bool enable_trt_flash_attention_;
bool disable_fused_cross_attention_;
Expand Down
Loading

0 comments on commit 4b0ad85

Please sign in to comment.