Skip to content

Commit

Permalink
[Common] MarkDequantizationSubgraph: avoid modification of RT Info fo…
Browse files Browse the repository at this point in the history
…r fold_multiply_const=true option
  • Loading branch information
dmitry-gorokhov committed Aug 28, 2024
1 parent 2e5189b commit 6cd53b5
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TRANSFORMATIONS_API MarkDequantizationSubgraph : public MatcherPass {
OPENVINO_RTTI("MarkDequantizationSubgraph", "0");
MarkDequantizationSubgraph(const element::TypeVector& precisions,
const bool fold_subtract_const = false,
const bool fold_multiply_const = true);
const bool disable_fold_multiply_const = false);
};
} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::TypeVector& precisions,
const bool fold_subtract_const,
const bool fold_multiply_const) {
const bool disable_fold_multiply_const) {
// Dequantization subgraph may have two forms: with and without Subtract
//
// Input Input
Expand Down Expand Up @@ -103,13 +103,10 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
auto scale = multiply->get_input_node_shared_ptr(1);
if (ov::is_type<ov::op::v0::Convert>(scale) &&
ov::is_type<ov::op::v0::Constant>(scale->get_input_node_ptr(0))) {
if (!fold_multiply_const) {
if (disable_fold_multiply_const) {
ov::disable_constant_folding(scale);
ov::unmark_as_decompression(scale);
ov::enable_keep_const_precision(scale->get_input_node_shared_ptr(0));
} else {
ov::enable_constant_folding(scale);
ov::disable_keep_const_precision(scale->get_input_node_shared_ptr(0));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
ov::element::i4,
ov::element::nf4,
ov::element::f4e2m1};
CPU_REGISTER_PASS_X64(decompression_handling_manager, ov::pass::MarkDequantizationSubgraph, decompression_precisions, false, false);
CPU_REGISTER_PASS_X64(decompression_handling_manager, ov::pass::MarkDequantizationSubgraph, decompression_precisions, false, true);
CPU_SET_CALLBACK_X64(decompression_handling_manager, [&](const_node_ptr &node) -> bool {
return !is_decompression_multiply(node);
}, ov::pass::MarkDequantizationSubgraph);
Expand Down

0 comments on commit 6cd53b5

Please sign in to comment.