Skip to content

Commit

Permalink
+optional reshape for second input to matmul
Browse files Browse the repository at this point in the history
+bathless params dims produced by partitioner while replacing const by params, by squeesing uneffective dims, and the inserting reshape
  • Loading branch information
esmirno committed Dec 11, 2024
1 parent f4ee57b commit dc55719
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "openvino/util/xml_parse_utils.hpp"
#include "patterns/dcoff.hpp"
#include "patterns/opt.hpp"
#include "openvino/op/ops.hpp"

namespace ov {
namespace npuw {
Expand Down Expand Up @@ -1516,9 +1517,39 @@ void Partitioner::createFunction(FunctionPipeline& func_ggg) {
LOG_DEBUG("Handling a Constant input " << prod_output);
LOG_BLOCK();

auto new_param = std::make_shared<ov::op::v0::Parameter>(prod_output.get_element_type(),
prod_output.get_partial_shape());
input_desc.replace_source_output(new_param); // (n)/1/i/a
// TODO: tricky part when const of 4d became a parameter it is no long batch friendly
// lets squeese this shape
auto partial_sh = prod_output.get_partial_shape();
std::shared_ptr<ov::op::v0::Parameter> new_param;
std::shared_ptr<ov::Node> new_param_or_reshape;

if (partial_sh.all_non_negative()) {
auto static_shape = prod_output.get_shape();
std::vector<size_t> dims;
bool needReshape = false;
for (auto s : static_shape) {
if (s != 1) {
dims.push_back(s);
needReshape = true;
}
}
new_param = std::make_shared<ov::op::v0::Parameter>(prod_output.get_element_type(), ov::Shape{dims});
// dont need 2 reshapes
if (needReshape && !ov::as_type<ov::op::v1::Reshape>(input_desc.get_node())) {
auto new_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{static_shape.size()}, static_shape);
new_param_or_reshape = std::make_shared<ov::op::v1::Reshape>(new_param, new_const, false);
} else {
new_param_or_reshape = new_param;
}
LOG_DEBUG("PARTITIONER: a new Constant shape: input " << new_param);
LOG_DEBUG("PARTITIONER: a new reshape inserted: " << new_param_or_reshape);

} else {
new_param = std::make_shared<ov::op::v0::Parameter>(prod_output.get_element_type(),
prod_output.get_partial_shape());
new_param_or_reshape = new_param;
}
input_desc.replace_source_output(new_param_or_reshape); // (n)/1/i/a
function._model->add_parameters({std::move(new_param)});
LOG_DEBUG("Register Parameter[" << new_param_idx << "] as input to " << iport.first << " / "
<< iport.second);
Expand Down
102 changes: 86 additions & 16 deletions src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1590,15 +1590,18 @@ SliceLastMatmulMultiply::SliceLastMatmulMultiply() {

ConvToMatmul::ConvToMatmul(Context::Ref ctx) {
auto param = opp::wrap_type<ov::op::v0::Parameter>();
auto convert = opp::wrap_type<ov::op::v0::Convert>({param->output(0)});
auto param1_reshape = opp::optional<ov::op::v1::Reshape>({param, opp::any_input()});
auto convert = opp::wrap_type<ov::op::v0::Convert>({param1_reshape->output(0)});
auto param2 = opp::any_input();
auto param2_reshape = opp::optional<ov::op::v1::Reshape>({param2, opp::any_input()});
auto convert2 = opp::optional<ov::op::v0::Convert>({param2_reshape->output(0)});
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({convert, convert2});
auto tr_input = opp::any_input();
auto transpose_in = opp::wrap_type<ov::op::v1::Transpose>({tr_input, opp::any_input()});
auto transpose_in = opp::optional<ov::op::v1::Transpose>({tr_input, opp::any_input()});
auto conv = opp::wrap_type<ov::op::v1::Convolution>({transpose_in, multiply});
auto transpose_out = opp::wrap_type<ov::op::v1::Transpose>({conv, opp::any_input()});

// since this transpose is optional, we might start with convolution and fully match the case where we are working without transpose.
auto transpose_out = opp::optional<ov::op::v1::Transpose>({conv, opp::any_input()});

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
Expand All @@ -1608,45 +1611,94 @@ ConvToMatmul::ConvToMatmul(Context::Ref ctx) {
auto matched_node_param2 = node_to_output.at(param2).get_node_shared_ptr();
auto matched_node_convert = node_to_output.at(convert).get_node_shared_ptr();
auto matched_node_tr_input = node_to_output.at(tr_input);
auto matched_node_transpose_in = node_to_output.at(transpose_in).get_node_shared_ptr();
auto matched_node_transpose_out = node_to_output.at(transpose_out).get_node_shared_ptr();

const auto has_input_tr = node_to_output.count(transpose_in) != 0;
auto has_output_tr = node_to_output.count(transpose_out) != 0;

// check extension mode where matcher started from optional transpose layer and not captured it for some reason.
auto conv_node = node_to_output.at(conv).get_node_shared_ptr();
std::shared_ptr<const Node> transpose_out_node;
if (!has_output_tr) {
for (auto n : conv_node->output(0).get_target_inputs()) {
transpose_out_node = ov::as_type_ptr<ov::op::v1::Transpose>(
n.get_node()->shared_from_this());

if (transpose_out_node) {
LOG_DEBUG("ConvToMatmull: output transpose matched used expanding algorithm "
<< transpose_out_node->get_friendly_name());
has_output_tr = true;
break;
} else {
LOG_DEBUG("ConvToMatmull: output of conv: "
<< n.get_node()->get_friendly_name());
}
}
} else {
transpose_out_node = node_to_output.at(transpose_out).get_node_shared_ptr();;
}

// in case of transpose missed need to check tensor dimensions, in some cases reshapes are enough or transposes pair required
const auto& matched_node_transpose_in = uat::_(node_to_output).at_or_at(transpose_in, tr_input).get_node_shared_ptr();
const auto& matched_node_transpose_out = has_output_tr ? transpose_out_node : conv_node;

auto matched_node_multiply = node_to_output.at(multiply).get_node_shared_ptr();
const auto& cvt2_or_multiply = uat::_(node_to_output).at_or_at(convert2, multiply);

const auto& shape = matched_node_param->get_shape();
const auto shape = uat::_(node_to_output).at_or_at(param1_reshape, param).get_shape();
const auto shape2 = uat::_(node_to_output).at_or_at(param2_reshape, param2).get_shape();

const auto& shape2 =
node_to_output.count(param2_reshape) ? node_to_output.at(param2_reshape).get_shape() :
node_to_output.at(param2).get_shape();
const auto& tr_in_shape = has_input_tr ?
matched_node_transpose_in->input(0).get_shape() :
matched_node_transpose_in->output(0).get_shape();

const auto& tr_in_shape = matched_node_transpose_in->input(0).get_shape();
const auto& tr_out_shape = matched_node_transpose_out->output(0).get_shape();

auto check_shape = [](const ov::Shape& shape) {
// last 2 dims are 1
return shape.size() == 4 && shape[2] == 1 && shape[3] == 1;
};

auto check_transpose_shape = [](const ov::Shape& shape) {
// first 2 dims are 1
return shape.size() == 4 && shape[0] == 1 && shape[1] == 1;
};
auto check_conv_shape_1D = [](const ov::Shape& shape) {
// in case of missed transpose also check the reshape possibility
// first 2 dims and 3rd dim are 1
return shape.size() == 4 && shape[0] == 1 && shape[2] == 1 && shape[3] == 1;
};

bool conv_in_shape = has_input_tr ? check_transpose_shape(tr_in_shape) : check_conv_shape_1D(tr_in_shape);
bool conv_out_shape = has_output_tr ? check_transpose_shape(tr_out_shape) : check_conv_shape_1D(tr_out_shape);


LOG_DEBUG("ConvToMatmull: conv_input_shape " << conv_in_shape);
LOG_DEBUG("ConvToMatmull: conv_out_shape " << conv_out_shape);
LOG_DEBUG("ConvToMatmull: matched_node_transpose_in: " << matched_node_transpose_in->get_friendly_name());
LOG_DEBUG("ConvToMatmull: matched_node_transpose_out: " << matched_node_transpose_out->get_friendly_name());

LOG_DEBUG("ConvToMatmull: matched_node_transpose_in shape: " << matched_node_transpose_in->get_shape());
LOG_DEBUG("ConvToMatmull: matched_node_transpose_out shape: " << matched_node_transpose_out->get_shape());

LOG_DEBUG("ConvToMatmull: matched_node_param->get_element_type(): " << matched_node_param->get_element_type());
LOG_DEBUG("ConvToMatmull: matched_node_param2->get_element_type(): " << matched_node_param2->get_element_type());
LOG_DEBUG("ConvToMatmull: matched_node_param2: " << matched_node_param2->get_friendly_name());
LOG_DEBUG("ConvToMatmull: check_shape(shape): " << check_shape(shape));
LOG_DEBUG("ConvToMatmull: check_shape(shape2): " << check_shape(shape2));

// if there is no transpose input - convolution input is fine, but for matmul substitution we might need to add reshape
LOG_DEBUG("ConvToMatmull: check_transpose_shape(tr_in_shape): " << check_transpose_shape(tr_in_shape));

// if there is no transpose input - convolution input is fine, but for matmul substitution we might need to add reshape
LOG_DEBUG("ConvToMatmull: check_transpose_shape(tr_out_shape): " << check_transpose_shape(tr_out_shape));

if ((matched_node_param->get_element_type() == ov::element::i4 ||
matched_node_param->get_element_type() == ov::element::i8) &&
(matched_node_param2->get_element_type() == ov::element::f32 ||
matched_node_param2->get_element_type() == ov::element::f16) &&
(ov::op::util::is_parameter(matched_node_param2) || ov::op::util::is_constant(matched_node_param2)) &&
check_shape(shape) && check_shape(shape2) && check_transpose_shape(tr_in_shape) &&
check_transpose_shape(tr_out_shape)) {
check_shape(shape) && check_shape(shape2) &&
// check_transpose_shape(tr_in_shape) && check_transpose_shape(tr_out_shape) &&
conv_in_shape && conv_out_shape) {
// Add Reshape before Params/Const
auto new_dims = std::vector<std::size_t>{shape[0], shape[1]};
auto new_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, new_dims);
Expand All @@ -1668,13 +1720,31 @@ ConvToMatmul::ConvToMatmul(Context::Ref ctx) {
matched_node_multiply->validate_and_infer_types();
}

// Get rid of Transposes
// Get rid of input Transpose
// TODO: change transpose in case of 2D prefill
auto new_traspose_input = matched_node_tr_input;
if (!has_input_tr) {
std::vector<size_t> perm = {0, 3, 2, 1};
auto shPattern = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, perm);
new_traspose_input = std::make_shared<ov::op::v1::Transpose>(matched_node_tr_input, shPattern);
}
auto matmul =
std::make_shared<ov::op::v0::MatMul>(matched_node_tr_input, matched_node_multiply, false, true);
std::make_shared<ov::op::v0::MatMul>(new_traspose_input, matched_node_multiply, false, true);

// Get rid of output Transpose
// TODO: change reshape to transpose in case of prefill
std::shared_ptr<Node> new_traspose_output = matmul;
if (!has_output_tr) {
std::vector<size_t> perm = {0, 3, 2, 1};
auto shPattern = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, perm);
new_traspose_output = std::make_shared<ov::op::v1::Transpose>(matmul, shPattern);
}

for (auto&& r : matched_node_transpose_out->output(0).get_target_inputs()) {
r.replace_source_output(matmul);
r.replace_source_output(new_traspose_output);
}
matmul->validate_and_infer_types();

return true; // root has changed
}
return false; // root hasn't changed
Expand Down

0 comments on commit dc55719

Please sign in to comment.