Skip to content

Commit

Permalink
Refactor to delta + LUT
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Dec 11, 2024
1 parent 20798ba commit 9540c7b
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class ov::pass::StateManagementPattern : public ov::pass::MatcherPass {
bool use_per_layer_block_indices_inputs,
bool use_score_outputs,
bool allow_cache_rotation,
ParameterVector& rotation_coefficients_inputs_for_each_layer,
ParameterVector& rotated_block_indices_inputs_for_each_layer);
};
ParameterVector& rotated_block_indices_inputs_for_each_layer,
ParameterVector& rotation_deltas_inputs_for_each_layer,
std::shared_ptr<op::v0::Parameter> model_rotation_trig_lut);
};
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "openvino/op/gather.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/paged_attention.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/select.hpp"
Expand Down Expand Up @@ -64,17 +65,6 @@ static node_tuple kv_read_and_concat(ov::Output<ov::Node> kv_current) {
return node_tuple(kv_past_par, kv_current2, kv_current_reshaped, kv_concat);
}

template <class T>
void insert_rotation_inputs_as(OutputVector& pa_arguments, size_t layer_index) {
auto rotation_coefficients = setName(std::make_shared<T>(ov::element::f32, ov::PartialShape{-1}),
"rotation_coefficients." + std::to_string(layer_index - 1));
auto rotated_block_indices = setName(std::make_shared<T>(ov::element::i32, ov::PartialShape{-1}),
"rotated_block_indices." + std::to_string(layer_index - 1));

pa_arguments.insert(pa_arguments.begin() + 13, rotation_coefficients);
pa_arguments.insert(pa_arguments.begin() + 14, rotated_block_indices);
}

ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_parameters,
ParameterVector& model_remaining_params,
const std::shared_ptr<ov::op::v0::Constant>& sliding_window,
Expand All @@ -86,8 +76,9 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
bool use_per_layer_block_indices_inputs,
bool use_score_outputs,
bool allow_cache_rotation,
ParameterVector& rotation_coefficients_inputs_for_each_layer,
ParameterVector& rotated_block_indices_inputs_for_each_layer) {
ParameterVector& rotated_block_indices_inputs_for_each_layer,
ParameterVector& rotation_deltas_inputs_for_each_layer,
std::shared_ptr<op::v0::Parameter> model_rotation_trig_lut) {
MATCHER_SCOPE(StateManagementPattern);

auto k_current = pattern::any_input();
Expand Down Expand Up @@ -193,8 +184,8 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
&block_indices_inputs_for_each_layer,
&score_results,
&layer_index,
&rotation_coefficients_inputs_for_each_layer,
&rotated_block_indices_inputs_for_each_layer](ov::pass::pattern::Matcher& m) {
&rotated_block_indices_inputs_for_each_layer,
&rotation_deltas_inputs_for_each_layer](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto real_q = pattern_map.at(q);

Expand Down Expand Up @@ -400,16 +391,17 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
OPENVINO_ASSERT(pa_arguments.size() == 13);

if (allow_cache_rotation) {
auto rotation_coefficients = setName(std::make_shared<v0::Parameter>(element::f32, PartialShape{-1}),
"rotation_coefficients." + std::to_string(layer_index - 1));
auto rotated_block_indices = setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}),
"rotated_block_indices." + std::to_string(layer_index - 1));
auto rotation_deltas = setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}),
"rotation_deltas." + std::to_string(layer_index - 1));

pa_arguments.insert(pa_arguments.begin() + 13, rotation_coefficients);
pa_arguments.insert(pa_arguments.begin() + 14, rotated_block_indices);
pa_arguments.insert(pa_arguments.begin() + 13, rotated_block_indices);
pa_arguments.insert(pa_arguments.begin() + 14, rotation_deltas);
pa_arguments.insert(pa_arguments.begin() + 15, model_rotation_trig_lut);

rotation_coefficients_inputs_for_each_layer.push_back(rotation_coefficients);
rotated_block_indices_inputs_for_each_layer.push_back(rotated_block_indices);
rotation_deltas_inputs_for_each_layer.push_back(rotation_deltas);
}

auto paged_attention = std::make_shared<ov::op::PagedAttentionExtension>(pa_arguments);
Expand Down
29 changes: 20 additions & 9 deletions src/core/src/op/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void PagedAttentionExtension::validate_and_infer_types() {
OV_OP_SCOPE(PagedAttentionExtension_validate_and_infer_types);

NODE_VALIDATION_CHECK(this,
get_input_size() == 13 || get_input_size() == 15,
get_input_size() == 13 || get_input_size() == 16,
"PagedAttensionExtension expects 13 or 15 inputs, but it has ",
get_input_size());

Expand Down Expand Up @@ -147,30 +147,41 @@ void PagedAttentionExtension::validate_and_infer_types() {
get_input_element_type(12),
".");

if (get_input_size() == 15) {
if (get_input_size() == 16) {
NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(13).rank().is_dynamic() || get_input_partial_shape(13).rank().get_length() == 1,
"Input `rotation_coefficients` should either have rank 1 or be omitted, but it has rank ",
get_input_partial_shape(13).rank().is_dynamic() || get_input_partial_shape(14).rank().get_length() == 1,
"Input `rotated_block_indices` should either have rank 1 or be omitted, but it has rank ",
get_input_partial_shape(13).rank().get_length(),
".");
NODE_VALIDATION_CHECK(this,
get_input_element_type(13).is_dynamic() || get_input_element_type(13) == element::f32,
"Element type of `rotation_coefficients` input should be f32, but it is ",
get_input_element_type(13).is_dynamic() || get_input_element_type(13) == element::i32,
"Element type of `rotated_block_indices` input should be i32, but it is ",
get_input_element_type(13),
".");

NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(14).rank().is_dynamic() || get_input_partial_shape(14).rank().get_length() == 1,
"Input `rotated_block_indices` should either have rank 1 or be omitted, but it has rank ",
"Input `rotation_deltas` should either have rank 1 or be omitted, but it has rank ",
get_input_partial_shape(14).rank().get_length(),
".");
NODE_VALIDATION_CHECK(this,
get_input_element_type(14).is_dynamic() || get_input_element_type(14) == element::i32,
"Element type of `rotated_block_indices` input should be i32, but it is ",
"Element type of `rotation_deltas` input should be i32, but it is ",
get_input_element_type(14),
".");
NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(15).rank().is_dynamic() || get_input_partial_shape(15).rank().get_length() == 2,
"Input `rotation_trig_lut` should either have rank 2 or be omitted, but it has rank ",
get_input_partial_shape(15).rank().get_length(),
".");
NODE_VALIDATION_CHECK(this,
get_input_element_type(15).is_dynamic() || get_input_element_type(15) == element::f32,
"Element type of `rotation_trig_lut` input should be f32, but it is ",
get_input_element_type(15),
".");

}

// value head_size may be not same with key
Expand Down
17 changes: 13 additions & 4 deletions src/core/src/pass/sdpa_to_paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
model_remaining_params.insert(model_remaining_params.begin() + 2, block_indices);
}

std::shared_ptr<v0::Parameter> model_rotation_trig_lut;

if (m_allow_cache_rotation) {
model_rotation_trig_lut = setName(std::make_shared<v0::Parameter>(element::f32, PartialShape{-1, -1}), "rotation_trig_lut");
}

auto sliding_window = v0::Constant::create(element::i32, Shape{}, {0}); // sliding_window

auto get_parameter = [=](const std::shared_ptr<ov::Model>& model,
Expand Down Expand Up @@ -98,8 +104,9 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
ParameterVector parameters_to_remove;
ResultVector results_to_remove; // # used, but cannot really track all Results in stateless model
ParameterVector block_indices_inputs_for_each_layer;
ParameterVector rotation_coefficients_inputs_for_each_layer;
ParameterVector rotated_block_indices_inputs_for_each_layer;
ParameterVector rotation_deltas_inputs_for_each_layer;

ResultVector score_results;

std::shared_ptr<v0::Parameter> position_ids;
Expand Down Expand Up @@ -133,8 +140,9 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
m_use_per_layer_block_indices_inputs,
m_use_score_outputs,
m_allow_cache_rotation,
rotation_coefficients_inputs_for_each_layer,
rotated_block_indices_inputs_for_each_layer);
rotated_block_indices_inputs_for_each_layer,
rotation_deltas_inputs_for_each_layer,
model_rotation_trig_lut);
manager.register_pass<PrevSequenceLengthPattern>(prev_max_seq_len, batch_dim);
manager.register_pass<TotalSequenceLengthPattern>(max_context_len);
manager.register_pass<PositionIDsReplacer>(unsqueezed_position_ids->output(0));
Expand Down Expand Up @@ -191,8 +199,9 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
}

if (m_allow_cache_rotation) {
model->add_parameters(rotation_coefficients_inputs_for_each_layer);
model->add_parameters(rotated_block_indices_inputs_for_each_layer);
model->add_parameters(rotation_deltas_inputs_for_each_layer);
model->add_parameters({ model_rotation_trig_lut });
}

model->add_parameters(kv_parameters);
Expand Down
10 changes: 6 additions & 4 deletions src/core/tests/type_prop/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ TEST(type_prop, paged_attention_static_13_inputs) {
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{3, 4}));
}

TEST(type_prop, paged_attention_static_15_inputs) {
TEST(type_prop, paged_attention_static_16_inputs) {
const auto query = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{3, 4});
const auto key = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{3, 4});
const auto value = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{3, 4});
Expand All @@ -59,8 +59,9 @@ TEST(type_prop, paged_attention_static_15_inputs) {
const auto alibi_slopes = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{9});
const auto max_context_len = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});

const auto rotation_coefficients = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{12});
const auto rotated_block_indices = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{3});
const auto rotation_deltas = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{12});
const auto rotation_trig_lut = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{5, 256});

ov::OutputVector args = {query,
key,
Expand All @@ -75,8 +76,9 @@ TEST(type_prop, paged_attention_static_15_inputs) {
sliding_window,
alibi_slopes,
max_context_len,
rotation_coefficients,
rotated_block_indices};
rotation_deltas,
rotated_block_indices,
rotation_trig_lut};

const auto op = std::make_shared<op::PagedAttentionExtension>(args);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
Expand Down
Loading

0 comments on commit 9540c7b

Please sign in to comment.