Skip to content

Commit

Permalink
fix warning and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Nov 4, 2023
1 parent 40f6e3b commit e2eadab
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 27 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ struct GroupQueryAttentionParameters {
int head_size;
int kv_hidden_size;
int kv_num_heads;
int num_splits; // number of splits for splitkv
int num_splits; // number of splits for splitkv
bool has_mask;
bool is_unidirectional; // causal
bool kv_share_buffer;
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,9 @@ class TimestampLogitsProcessor : public ILogitsProcessor<T> {
}
}

// #ifdef DEBUG_GENERATION
// DumpScores("TimestampLogitsProcessor", next_token_scores);
// #endif
// #ifdef DEBUG_GENERATION
// DumpScores("TimestampLogitsProcessor", next_token_scores);
// #endif
}

private:
Expand Down
24 changes: 12 additions & 12 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GroupQueryAttention, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GroupQueryAttention, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("M", {DataTypeImpl::GetTensorType<int64_t>()}) \
.MayInplace(3, 1) \
.MayInplace(4, 2), \
.MayInplace(3, 1) \
.MayInplace(4, 2), \
GroupQueryAttention<T>);

// REGISTER_KERNEL_TYPED(float)
Expand Down Expand Up @@ -131,7 +131,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
auto out_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
#endif

ORT_ENFORCE(use_flash_attention);
ORT_ENFORCE(use_flash_attention);

#if USE_MEMORY_EFFICIENT_ATTENTION
int sm = (device_prop.major * 10) + device_prop.minor;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class GroupQueryAttention final : public CudaKernel {
int kv_num_heads_; // different for k and v for group query attention
int past_sequence_length_;
bool is_unidirectional_; // causal
bool kv_share_buffer_; // kv-cache
bool kv_share_buffer_; // kv-cache
bool is_past_bsnh_;
float scale_;
bool disable_flash_attention_;
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ Status CheckInputs(const Tensor* query,
const auto& attention_mask_shape = attention_mask->Shape().GetDims();
if (attention_mask_shape[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"attention_mask dim 0 must be batch_size.");
"attention_mask dim 0 must be batch_size.");
}
if (attention_mask_shape[1] == kv_sequence_length) {
is_prompt = true;
Expand All @@ -197,7 +197,7 @@ Status CheckInputs(const Tensor* query,
if (kv_share_buffer) {
if (attention_mask == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"attention_mask tensor must be present when kv-share buffer is on.");
"attention_mask tensor must be present when kv-share buffer is on.");
}
present_sequence_length = max_sequence_length;
} else {
Expand All @@ -208,11 +208,11 @@ Status CheckInputs(const Tensor* query,
if (parameters != nullptr) {
GroupQueryAttentionParameters* output_parameters = reinterpret_cast<GroupQueryAttentionParameters*>(parameters);
output_parameters->batch_size = batch_size;
output_parameters->sequence_length = sequence_length; // sequence length of Q
output_parameters->past_sequence_length = past_sequence_length; // max sequence length of past kv tensors
output_parameters->kv_sequence_length = kv_sequence_length; // max sequence length of new kv tensors
output_parameters->present_sequence_length = present_sequence_length; // max sequence length of present kv tensors
output_parameters->max_sequence_length = max_sequence_length; // max sequence length of kv buffer tensors TODO(aciddelgado): always same as present, remove
output_parameters->sequence_length = sequence_length; // sequence length of Q
output_parameters->past_sequence_length = past_sequence_length; // max sequence length of past kv tensors
output_parameters->kv_sequence_length = kv_sequence_length; // max sequence length of new kv tensors
output_parameters->present_sequence_length = present_sequence_length; // max sequence length of present kv tensors
output_parameters->max_sequence_length = max_sequence_length; // max sequence length of kv buffer tensors TODO(aciddelgado): always same as present, remove

Check warning on line 215 in onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h#L215

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h:215:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
output_parameters->mask_sequence_length = mask_sequence_length;
output_parameters->hidden_size = q_hidden_size;
output_parameters->num_heads = num_heads;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ Status FlashAttention(
} else {
// Launch kernel to copy seqlen
int thr_per_blk = 256;
int blk_in_grid = ceil(float(batch_size) / thr_per_blk);
int blk_in_grid = int(ceil(float(batch_size) / thr_per_blk));

Check warning on line 586 in onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu#L586

Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu:586:  Using deprecated casting style.  Use static_cast<int>(...) instead  [readability/casting] [4]
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k, parameters.past_sequence_length, batch_size);
}

Expand Down Expand Up @@ -693,7 +693,7 @@ Status EfficientAttention(
if (!parameters.has_mask) {
// Launch kernel to copy seqlen
int thr_per_blk = 256;
int blk_in_grid = ceil(float(batch_size) / thr_per_blk);
int blk_in_grid = int(ceil(float(batch_size) / thr_per_blk));

Check warning on line 696 in onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu#L696

Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu:696:  Using deprecated casting style.  Use static_cast<int>(...) instead  [readability/casting] [4]
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k, parameters.past_sequence_length, batch_size);
} else {
ORT_RETURN_IF_ERROR(LaunchGetCacheSeqlens(parameters, data.attention_mask, data.seqlens_k, parameters.is_prompt, stream, 256));

Check warning on line 699 in onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu#L699

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu:699:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Expand Down

0 comments on commit e2eadab

Please sign in to comment.