-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GroupQueryAttention with KV-Cache #3425
base: develop
Are you sure you want to change the base?
Conversation
This build is not recommended to merge 🔴 |
🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output |
} | ||
|
||
template <class T, class U> | ||
void apply_attention(T qkv, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need description of all the arguments, including the fact that they're all iterators and therefore the contents are mutable (is that right?) Is output
the only output? What gets populated and what's the purpose of the content i.e. is it a final result or an intermediate step? It appears this is only a helper function designed to keep the compute()
method from getting too large--correct?
args[0] = args[0].reshape(shape{output_shape_0.type(), | ||
{batch_size, | ||
sequence_length, | ||
static_cast<std::size_t>(num_heads + 2 * kv_num_heads), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Description needed of what's being done with this dimension. It looks important.
{ | ||
std::string name() const { return "instructions_tuple"; } | ||
|
||
shape compute_shape(const std::vector<shape>& inputs) const { return shape(inputs); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shape constructor deprecated?
@@ -64,6 +64,7 @@ | |||
#include <migraphx/op/gathernd.hpp> | |||
#include <migraphx/op/get_tuple_elem.hpp> | |||
#include <migraphx/op/greater.hpp> | |||
#include <migraphx/op/group_query_attention.hpp> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is instructions_tuple
a new op as well?
nextafterf(x, numeric_max<T>()) >= y; | ||
} | ||
|
||
template <class T> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
template <class T> | |
/** | |
* Calculate softmax function in-place in array score. | |
*/ | |
template <class T> |
shape compute_shape(std::vector<shape> inputs) const | ||
{ | ||
auto query_lens = inputs.front().lens(); | ||
std::vector<std::size_t> output_lens{query_lens.at(0), num_heads, query_lens.at(2), 4096}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to keep this magic number?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Im curious where it came from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just an artifact of early hard-coding that I happened to miss because for llama2 it always ends up being 4096.
}; | ||
MIGRAPHX_REGISTER_OP(gpu_concat_past_present); | ||
|
||
struct find_group_query_attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the effect of this matcher? Is it here because a group_query_attention
by itself won't work?
dims=tsl_val.shape, | ||
vals=tsl_val.astype(int)) | ||
cc_val = np.ones([4096, 64], dtype=np.float16) | ||
cos_cache = helper.make_tensor(name="cos_cache", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What sort of values would cos_cache and sin_cache hold in a realistic scenario?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partial answer: these are (I think) rotational matrices used in Rotational Position Encoding (ROPE) as seen in ROFORMER: ENHANCED TRANSFORMER WITH ROTARY
POSITION EMBEDDING, one of several possible positional embedding schemes that can be used for attention models. Relative, as opposed to absolute, position embedding is a key feature of GQA.
auto rotary_interleaved = v.at("rotary_interleaved").to<int>(); | ||
assert(v.contains("scale")); | ||
auto scale = v.at("scale").to<float>(); | ||
assert(v.contains("present_kv_seqlen")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need to assert when you already calling at
.
transposed_qkv = mpm.get_module().insert_instruction( | ||
ins, make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), transposed_qkv); | ||
transposed_qkv = | ||
mpm.get_module().insert_instruction(ins, make_op("contiguous"), transposed_qkv); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is a contiguous inserted?
const int kv_num_heads = params.kv_num_heads; | ||
const int packed_batch_stride = (num_heads + 2 * kv_num_heads) * sequence_length * head_size; | ||
const int kv_num_heads_factor = num_heads / kv_num_heads; | ||
const size_t q_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // S x H |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why static_cast to size_t
when you can just declare the original variable as size_t
?
const int batch_size = params.batch_size; | ||
const int sequence_length = params.sequence_length; | ||
const int head_size = params.head_size; | ||
const size_t present_buffer_sequence_length = params.seqlen_present_kv_cache; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use index_int
instead of size_t
.
int rotary_interleaved; | ||
int past_present_share_buffer; | ||
|
||
__host__ __device__ void print() const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use operator<<
instead, similar to how the shape
class is used:
template <class Stream>
friend constexpr const Stream& operator<<(const Stream& ss, const gqa_parameters& gp)
{
ss << "scale: " << scale << "\n";
...
return ss;
}
This way we can print this using our print function.
}; | ||
|
||
template <class S, class... Ts> | ||
__device__ gqa_parameters make_gqa_params(S s, Ts... ts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function doesnt seems necessary as this class doesnt have any template parameters, you can just construct it directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, are these variables known at compile time? If thats the case then we should make them all integral constants.
int kv_num_heads; | ||
int local_window_size; | ||
int rotary_interleaved; | ||
int past_present_share_buffer; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please initialize all these variables.
float beta; | ||
|
||
template <class C, class A, class B> | ||
__device__ void compute(C cmat, const A amat, const B bmat, const std::size_t idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to pass 2d tensor_view's to do the gemms instead?
std::size_t _k; | ||
std::size_t lda; | ||
std::size_t ldb; | ||
std::size_t ldc; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use index_int
instead of std::size_t
.
|
||
apply_map.emplace("gpu::compute_attention_probabilities", [=](instruction_ref ins) { | ||
auto s = ins->get_shape().sub_shapes().front(); | ||
auto output = insert_allocation(ins, s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think this allocation will be handled by memory coloring. Rather than use tuple you can just use identity operators to ensure the order when accessing the buffer.
{ | ||
for(int i = 0; i < d; i++) | ||
{ | ||
y[i] = 1.0f / static_cast<float>(d); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Odd, when can we ever get into this case? The sum of exponents should never become negative and never converge to zero even if x[i] ins negative since e^(x) is never negative and you're doing things inplace
{ | ||
if(max < x[i]) | ||
max = x[i]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to assume we have no way of knowing order here or getting this input sorted to avoid an extra O(d) checks.
template <class T> | ||
__device__ bool float_equal(T x, T y) | ||
{ | ||
return isfinite(x) and isfinite(y) and nextafterf(x, numeric_lowest<T>()) <= y and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need the is_finite here as if you remove the max value, if the max is say infinity, you get e^-inf which is zero in this case.
@@ -4455,6 +4455,113 @@ def group_norm_invalid_bias_shape_test(): | |||
return group_norm_test([1, 4, 3, 3], [2], [3], [1, 4, 3, 3], 2) | |||
|
|||
|
|||
@onnx_test() | |||
def group_query_attention_test(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add another case without some of the parameters to test defaults
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few comments and questions.
Curious about the softmax step and additional tests. Paul and Brian's comments seem to cover a few other things I was curious about
No description provided.