Skip to content

Commit

Permalink
cpp lint sucks
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Jun 21, 2024
1 parent 5ee475e commit 9f73c44
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
7 changes: 2 additions & 5 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/alibi.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
#include "utils.h"

#include <cmath>
#include <cute/tensor.hpp>

#include <cutlass/cutlass.h>

Check warning on line 3 in onnxruntime/contrib_ops/cuda/bert/flash_attention/alibi.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after other header. Should be: alibi.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/flash_attention/alibi.h:3: Found C system header after other header. Should be: alibi.h, c system, c++ system, other. [build/include_order] [4]
#include <cutlass/array.h>

Check warning on line 4 in onnxruntime/contrib_ops/cuda/bert/flash_attention/alibi.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after other header. Should be: alibi.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/flash_attention/alibi.h:4: Found C system header after other header. Should be: alibi.h, c system, c++ system, other. [build/include_order] [4]

#include <cmath>
#include "utils.h"

Check warning on line 5 in onnxruntime/contrib_ops/cuda/bert/flash_attention/alibi.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/flash_attention/alibi.h:5: Include the directory when naming header files [build/include_subdir] [4]

namespace onnxruntime {
namespace flash {
Expand Down
16 changes: 9 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ size_t get_softmax_lse_accum_size(size_t num_splits, size_t batch_size, size_t n
return bytes;
}

size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, size_t seqlen_q, size_t head_size_rounded) {
size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads,
size_t seqlen_q, size_t head_size_rounded) {
size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads * head_size_rounded;
return bytes;
}
Expand All @@ -158,8 +159,8 @@ void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split
// splits as that would incur more HBM reads/writes.
// So we find the best efficiency, then find the smallest number of splits that gets 85%
// of the best efficiency.
size_t num_splits_heuristic(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads, size_t head_size, size_t num_SMs,
size_t max_splits) {
size_t num_splits_heuristic(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads,
size_t head_size, size_t num_SMs, size_t max_splits) {
// This needs to match with run_mha_fwd_splitkv_dispatch
const size_t block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
const size_t num_n_blocks = (seqlen_k + block_n - 1) / block_n;
Expand Down Expand Up @@ -209,8 +210,8 @@ size_t num_splits_heuristic(size_t batch_size, size_t seqlen_q, size_t seqlen_k,
}

// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes)
std::tuple<size_t, size_t, size_t> get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads,
size_t head_size, size_t num_SMs) {
std::tuple<size_t, size_t, size_t> get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k,
size_t num_heads, size_t head_size, size_t num_SMs) {
size_t max_splits = 128;
// split kv buffers
size_t num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size,
Expand All @@ -233,7 +234,8 @@ std::tuple<size_t, size_t, size_t> get_num_splits_and_buffer_sizes(size_t batch_
// // TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
// // CHECK_DEVICE(alibi_slopes);
// // TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
// // TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
// // TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads})
// || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
// params.alibi_slopes_ptr = alibi_slopes;
// params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? num_heads : 0; // TODO: flag for bool
// } else {
Expand Down Expand Up @@ -319,7 +321,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
void* out, // half (total_q, num_heads, head_size)
int* cu_seqlens_q, // int (batch_size + 1)
int* cu_seqlens_k, // int (batch_size + 1)
void* seqused_k, // batch_size; If given, only this many elements of each batch element's keys are used.
void* seqused_k, // batch_size; If given, use this many elements of each batch element's keys.
int* block_table, // batch_size x max_num_blocks_per_seq
void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q)
int batch_size,
Expand Down

0 comments on commit 9f73c44

Please sign in to comment.