Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GroupQueryAttention with KV-Cache #3425

Open
wants to merge 44 commits into
base: develop
Choose a base branch
from
Open

Add GroupQueryAttention with KV-Cache #3425

wants to merge 44 commits into from

Conversation

turneram
Copy link
Contributor

@turneram turneram commented Sep 6, 2024

No description provided.

@turneram turneram marked this pull request as ready for review September 11, 2024 19:10
@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
54ff0e
Rate old
e230c0
Diff Compare
torchvision-resnet50 64 3,249.59 3,249.77 -0.01%
torchvision-resnet50_fp16 64 6,985.02 6,987.71 -0.04%
torchvision-densenet121 32 2,429.05 2,431.48 -0.10%
torchvision-densenet121_fp16 32 4,102.80 4,103.92 -0.03%
torchvision-inceptionv3 32 1,639.90 1,637.67 0.14%
torchvision-inceptionv3_fp16 32 2,745.89 2,744.19 0.06%
cadene-inceptionv4 16 779.12 779.19 -0.01%
cadene-resnext64x4 16 809.10 808.74 0.04%
slim-mobilenet 64 7,457.79 7,462.54 -0.06%
slim-nasnetalarge 64 208.17 208.50 -0.16%
slim-resnet50v2 64 3,435.34 3,435.17 0.00%
bert-mrpc-onnx 8 1,147.62 1,150.08 -0.21%
bert-mrpc-tf 1 308.71 314.23 -1.76%
pytorch-examples-wlang-gru 1 396.79 420.51 -5.64% 🔴
pytorch-examples-wlang-lstm 1 381.23 495.59 -23.08% 🔴
torchvision-resnet50_1 1 812.82 770.67 5.47% 🔆
cadene-dpn92_1 1 398.06 402.30 -1.05%
cadene-resnext101_1 1 381.02 381.59 -0.15%
onnx-taau-downsample 1 343.77 343.63 0.04%
dlrm-criteoterabyte 1 35.04 35.05 -0.05%
dlrm-criteoterabyte_fp16 1 58.00 58.08 -0.13%
agentmodel 1 8,074.94 8,076.83 -0.02%
unet_fp16 2 58.08 57.92 0.27%
resnet50v1_fp16 1 931.57 935.45 -0.42%
resnet50v1_int8 1 944.13 956.44 -1.29%
bert_base_cased_fp16 64 1,154.61 1,153.21 0.12%
bert_large_uncased_fp16 32 356.03 355.68 0.10%
bert_large_fp16 1 211.76 211.87 -0.05%
distilgpt2_fp16 16 2,154.18 2,159.18 -0.23%
yolov5s 1 531.07 533.70 -0.49%
tinyllama 1 43.42 43.69 -0.61%
vicuna-fastchat 1 177.91 172.04 3.41% 🔆
whisper-tiny-encoder 1 418.10 417.90 0.05%
whisper-tiny-decoder 1 431.26 424.90 1.50%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

}

template <class T, class U>
void apply_attention(T qkv,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need description of all the arguments, including the fact that they're all iterators and therefore the contents are mutable (is that right?) Is output the only output? What gets populated and what's the purpose of the content i.e. is it a final result or an intermediate step? It appears this is only a helper function designed to keep the compute() method from getting too large--correct?

args[0] = args[0].reshape(shape{output_shape_0.type(),
{batch_size,
sequence_length,
static_cast<std::size_t>(num_heads + 2 * kv_num_heads),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Description needed of what's being done with this dimension. It looks important.

{
std::string name() const { return "instructions_tuple"; }

shape compute_shape(const std::vector<shape>& inputs) const { return shape(inputs); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shape constructor deprecated?

@@ -64,6 +64,7 @@
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/group_query_attention.hpp>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is instructions_tuple a new op as well?

nextafterf(x, numeric_max<T>()) >= y;
}

template <class T>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
template <class T>
/**
* Calculate softmax function in-place in array score.
*/
template <class T>

shape compute_shape(std::vector<shape> inputs) const
{
auto query_lens = inputs.front().lens();
std::vector<std::size_t> output_lens{query_lens.at(0), num_heads, query_lens.at(2), 4096};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to keep this magic number?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im curious where it came from.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an artifact of early hard-coding that I happened to miss because for llama2 it always ends up being 4096.

};
MIGRAPHX_REGISTER_OP(gpu_concat_past_present);

struct find_group_query_attention
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the effect of this matcher? Is it here because a group_query_attention by itself won't work?

dims=tsl_val.shape,
vals=tsl_val.astype(int))
cc_val = np.ones([4096, 64], dtype=np.float16)
cos_cache = helper.make_tensor(name="cos_cache",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What sort of values would cos_cache and sin_cache hold in a realistic scenario?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partial answer: these are (I think) rotational matrices used in Rotational Position Encoding (ROPE) as seen in ROFORMER: ENHANCED TRANSFORMER WITH ROTARY
POSITION EMBEDDING
, one of several possible positional embedding schemes that can be used for attention models. Relative, as opposed to absolute, position embedding is a key feature of GQA.

auto rotary_interleaved = v.at("rotary_interleaved").to<int>();
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
assert(v.contains("present_kv_seqlen"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to assert when you already calling at.

transposed_qkv = mpm.get_module().insert_instruction(
ins, make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), transposed_qkv);
transposed_qkv =
mpm.get_module().insert_instruction(ins, make_op("contiguous"), transposed_qkv);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a contiguous inserted?

const int kv_num_heads = params.kv_num_heads;
const int packed_batch_stride = (num_heads + 2 * kv_num_heads) * sequence_length * head_size;
const int kv_num_heads_factor = num_heads / kv_num_heads;
const size_t q_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // S x H
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why static_cast to size_t when you can just declare the original variable as size_t?

const int batch_size = params.batch_size;
const int sequence_length = params.sequence_length;
const int head_size = params.head_size;
const size_t present_buffer_sequence_length = params.seqlen_present_kv_cache;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use index_int instead of size_t.

int rotary_interleaved;
int past_present_share_buffer;

__host__ __device__ void print() const
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use operator<< instead, similar to how the shape class is used:

template <class Stream>
friend constexpr const Stream& operator<<(const Stream& ss, const gqa_parameters& gp)
{
    ss << "scale: " << scale << "\n";
    ...
    return ss;
}

This way we can print this using our print function.

};

template <class S, class... Ts>
__device__ gqa_parameters make_gqa_params(S s, Ts... ts)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function doesnt seems necessary as this class doesnt have any template parameters, you can just construct it directly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, are these variables known at compile time? If thats the case then we should make them all integral constants.

int kv_num_heads;
int local_window_size;
int rotary_interleaved;
int past_present_share_buffer;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please initialize all these variables.

float beta;

template <class C, class A, class B>
__device__ void compute(C cmat, const A amat, const B bmat, const std::size_t idx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to pass 2d tensor_view's to do the gemms instead?

std::size_t _k;
std::size_t lda;
std::size_t ldb;
std::size_t ldc;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use index_int instead of std::size_t.


apply_map.emplace("gpu::compute_attention_probabilities", [=](instruction_ref ins) {
auto s = ins->get_shape().sub_shapes().front();
auto output = insert_allocation(ins, s);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this allocation will be handled by memory coloring. Rather than use tuple you can just use identity operators to ensure the order when accessing the buffer.

{
for(int i = 0; i < d; i++)
{
y[i] = 1.0f / static_cast<float>(d);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Odd, when can we ever get into this case? The sum of exponents should never become negative and never converge to zero even if x[i] ins negative since e^(x) is never negative and you're doing things inplace

{
if(max < x[i])
max = x[i];
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to assume we have no way of knowing order here or getting this input sorted to avoid an extra O(d) checks.

template <class T>
__device__ bool float_equal(T x, T y)
{
return isfinite(x) and isfinite(y) and nextafterf(x, numeric_lowest<T>()) <= y and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need the is_finite here as if you remove the max value, if the max is say infinity, you get e^-inf which is zero in this case.

@@ -4455,6 +4455,113 @@ def group_norm_invalid_bias_shape_test():
return group_norm_test([1, 4, 3, 3], [2], [3], [1, 4, 3, 3], 2)


@onnx_test()
def group_query_attention_test():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add another case without some of the parameters to test defaults

Copy link
Collaborator

@TedThemistokleous TedThemistokleous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few comments and questions.

Curious about the softmax step and additional tests. Paul and Brian's comments seem to cover a few other things I was curious about

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Perf Improve roadmap Tasks to finish for a release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants