Skip to content

Commit

Permalink
[GPU] Fix beam search accuracy/functional errors (#27835)
Browse files Browse the repository at this point in the history
### Details:
- Fixed "error set arg" error caused by not initialized internal memory
buffers required for indirect sdpa_opt kernel in case of micro_sdpa +
sdpa_opt
 - Fixed beam_table offset in shape_info
- Fixed beam table index calculation in rearrange_cache() call
considering concat_axis
 - Backport of #27833

---------

Co-authored-by: Pavel Durandin <[email protected]>
  • Loading branch information
sshlyapn and p-durandin authored Dec 4, 2024
1 parent 07f6c9a commit e042531
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,34 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod

protected:
std::vector<layout> get_internal_buffer_layouts_impl() const override {
// TODO: current implementation is supposed to have the same kernel version for both indirect/default paths,
// considering this, we may assume that both indirect/default kernels have absolutely the same intermediate
// buffers number and its' sizes (since update_dispatch_data is called for both kernels too), and
// do not double memory allocations during reallocate_if_needed() function call
// Look for the first sdpa_opt kernel entry. Currently, it can be used as default sdpa, indirect sdpa, or for both default
// and indirect cases. All of sdpa_opt kernels use the same internal buffers, so we can find the first sdpa_opt and
// use its` internal buffers configuration. The following scenarios are possible:
// 1) _kernels_data[0] - micro_sdpa (default)
// => internal buffers are not needed
// 2) _kernels_data[0] - sdpa_opt (default)
// => use internal buffers from [0] kernel
// 2) _kernels_data[0] - sdpa_opt (default)
// _kernels_data[1] - sdpa_opt (indirect)
// => use internal buffers from [0] kernel
// 3) _kernels_data[0] - micro_sdpa (default)
// _kernels_data[1] - sdpa_opt (indirect)
// => use internal buffers from [1] kernel
size_t kernel_idx = _kernels_data.size();
if (_kernels_data.size() >= 1 && !_kernels_data[0].internalBufferSizes.empty()) {
kernel_idx = 0;
} else if (_kernels_data.size() >= 2 && !_kernels_data[1].internalBufferSizes.empty()) {
kernel_idx = 1;
}

std::vector<layout> layouts;
for (size_t i = 0; i < _kernels_data.size(); i++) {
if (!_kernels_data[i].internalBufferSizes.empty()) {
auto dtype = from_data_type(_kernels_data[i].internalBufferDataType);
const auto bpp = data_type_traits::size_of(dtype);
for (auto size : _kernels_data[i].internalBufferSizes) {
layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
{1, 1, 1, (tensor::value_type)(size / bpp)}};
layouts.push_back(inbuf_layout);
}
if (kernel_idx < _kernels_data.size()) {
auto dtype = from_data_type(_kernels_data[kernel_idx].internalBufferDataType);
const auto bpp = data_type_traits::size_of(dtype);
for (auto size : _kernels_data[kernel_idx].internalBufferSizes) {
layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
{1, 1, 1, (tensor::value_type)(size / bpp)}};
layouts.push_back(inbuf_layout);
}
}

Expand Down Expand Up @@ -332,7 +346,7 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
}

if (indirect && has_indirect_inputs(impl_param)) {
params.beam_table.SetDynamicShapeOffset(get_beam_table_id(desc));
params.beam_table.SetDynamicShapeOffset(in_offsets_map.at(get_beam_table_id(desc)));
}

return params;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ static void rearrange_cache(cldnn::memory::ptr kv_in_mem, cldnn::memory::ptr bt_
for (size_t f = 0; f < kv_shape[1]; f++) {
for (size_t y = 0; y < kv_shape[2]; y++) {
for (size_t x = 0; x < kv_shape[3]; x++) {
size_t b_kv = bt_in_ptr[b* kv_shape[concat_axis] + y];
auto out_idx = std::vector<int>{static_cast<int>(b), static_cast<int>(f), static_cast<int>(y), static_cast<int>(x)};

size_t b_kv = bt_in_ptr[b * kv_shape[concat_axis] + out_idx[concat_axis]]; // bt_idx = b * total_seq_len + seq_len_idx
auto in_idx = std::vector<int>{static_cast<int>(b_kv), static_cast<int>(f), static_cast<int>(y), static_cast<int>(x)};
auto out_idx = std::vector<int>{static_cast<int>(b), static_cast<int>(f), static_cast<int>(y), static_cast<int>(x)};

cldnn::tensor in(cldnn::format::bfyx, in_idx, 0);
cldnn::tensor out(cldnn::format::bfyx, out_idx, 0);
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_gpu/src/runtime/ocl/ocl_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void set_arguments_impl(ocl_kernel_type& kernel,
switch (scalar.t) {
case scalar_t::UINT8:
status = kernel.setArg(i, scalar.v.u8);
GPU_DEBUG_TRACE_DETAIL << "kernel: " << kernel.get() << " set scalar " << i << " (u8): " << scalar.v.u8 << "\n";
GPU_DEBUG_TRACE_DETAIL << "kernel: " << kernel.get() << " set scalar " << i << " (u8): " << static_cast<int>(scalar.v.u8) << "\n";
break;
case scalar_t::UINT16:
status = kernel.setArg(i, scalar.v.u16);
Expand All @@ -140,7 +140,7 @@ void set_arguments_impl(ocl_kernel_type& kernel,
break;
case scalar_t::INT8:
status = kernel.setArg(i, scalar.v.s8);
GPU_DEBUG_TRACE_DETAIL << "kernel: " << kernel.get() << " set scalar " << i << " (s8): " << scalar.v.s8 << "\n";
GPU_DEBUG_TRACE_DETAIL << "kernel: " << kernel.get() << " set scalar " << i << " (s8): " << static_cast<int>(scalar.v.s8) << "\n";
break;
case scalar_t::INT16:
status = kernel.setArg(i, scalar.v.s16);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ std::vector<Params> get_test_params() {
p.push_back({with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
p.push_back({!with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});

// Beam search
p.push_back({with_rearrange, !with_mask, !with_scale, !causal, !compressed, 2, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, !with_mask, !with_scale, !causal, !compressed, 4, ov::element::Type_t::f16, 5, 16, 1, {0, 2, 1, 3}});

// Compressed
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
Expand Down

0 comments on commit e042531

Please sign in to comment.