From 4bf7a6f5a43515b8425a65ba1bb08083e2e7392b Mon Sep 17 00:00:00 2001 From: Tomasz Krupa Date: Wed, 11 Dec 2024 13:28:20 +0000 Subject: [PATCH] Remove is_copyable() logic and copy manually if needed --- .../src/transformations/convert_precision.cpp | 8 +++++ .../tests/utils/convert_precision.cpp | 36 +++++++++++++++++++ .../rt_info/weightless_caching_attributes.hpp | 5 +-- .../openvino/core/runtime_attribute.hpp | 1 - .../op/util/weightless_caching_attributes.cpp | 11 ------ src/core/src/rt_info.cpp | 2 +- src/core/src/runtime_attribute.cpp | 4 --- .../include/intel_gpu/primitives/data.hpp | 2 +- 8 files changed, 47 insertions(+), 22 deletions(-) diff --git a/src/common/transformations/src/transformations/convert_precision.cpp b/src/common/transformations/src/transformations/convert_precision.cpp index 8a2985a284769a..aa067da4f360fd 100644 --- a/src/common/transformations/src/transformations/convert_precision.cpp +++ b/src/common/transformations/src/transformations/convert_precision.cpp @@ -8,6 +8,7 @@ #include #include "itt.hpp" +#include "openvino/core/rt_info/weightless_caching_attributes.hpp" #include "openvino/op/ops.hpp" #include "openvino/pass/constant_folding.hpp" #include "openvino/pass/manager.hpp" @@ -1405,6 +1406,13 @@ bool fuse_type_to_constant(const std::shared_ptr& node, new_const->validate_and_infer_types(); new_const->set_friendly_name(constant->get_friendly_name()); ov::copy_runtime_info(constant, new_const); + + const auto& rt_info = node->get_rt_info(); + auto weightless_caching_attr = rt_info.find(ov::WeightlessCacheAttribute::get_type_info_static()); + if (weightless_caching_attr != rt_info.end()) { + new_const->get_rt_info()[ov::WeightlessCacheAttribute::get_type_info_static()] = + weightless_caching_attr->second; + } return true; } return false; diff --git a/src/common/transformations/tests/utils/convert_precision.cpp b/src/common/transformations/tests/utils/convert_precision.cpp index 318f15ab1a64dc..c2b7133506aebe 100644 --- a/src/common/transformations/tests/utils/convert_precision.cpp +++ b/src/common/transformations/tests/utils/convert_precision.cpp @@ -13,6 +13,7 @@ #include "common_test_utils/ov_test_utils.hpp" #include "openvino/core/model.hpp" +#include "openvino/core/rt_info/weightless_caching_attributes.hpp" #include "openvino/opsets/opset1.hpp" #include "openvino/opsets/opset10.hpp" #include "openvino/opsets/opset15.hpp" @@ -2702,3 +2703,38 @@ TEST(TransformationTests, ConvertPrecision_assign_read_value_preserve_orig_types FunctionsComparator::Result result = func_comparator(model_ref, model); ASSERT_TRUE(result.valid) << result.message; } + +TEST(TransformationTests, ConvertPrecision_assign_read_value_preserve_weightless_cache_info_as_rt_attribute) { + pass::Manager manager; + + auto some_value = opset10::Constant::create(element::f32, Shape{1}, {2}); + auto& node_rt_info = some_value->get_rt_info(); + ov::WeightlessCacheAttribute attr(element::f32.size(), 0, element::f32); + node_rt_info[ov::WeightlessCacheAttribute::get_type_info_static()] = attr; + + ov::ParameterVector inputParams; + ov::ResultVector results; + results.push_back(std::make_shared(some_value->output(0))); + auto model = std::make_shared(results, inputParams); + + type_to_fuse_map empty_type_to_fuse_map = {}; + bool keep_precision_sensitive_in_fp32 = false; + bool convert_input_output_precision = false; + bool store_original_precision_as_rt_attribute = true; + manager.register_pass(precisions_map{{element::f32, element::f16}}, + empty_type_to_fuse_map, + keep_precision_sensitive_in_fp32, + convert_input_output_precision, + store_original_precision_as_rt_attribute); + manager.run_passes(model); + + const auto& ops = model->get_ops(); + auto it = std::find_if(ops.begin(), ops.end(), [](const std::shared_ptr& node) { + return ov::op::util::is_constant(node); + }); + + ASSERT_TRUE(it != ops.end()); + const auto& new_rt_info = (*it)->get_rt_info(); + auto weightless_caching_attr_it = new_rt_info.find(ov::WeightlessCacheAttribute::get_type_info_static()); + ASSERT_TRUE(weightless_caching_attr_it != new_rt_info.end()); +} diff --git a/src/core/dev_api/openvino/core/rt_info/weightless_caching_attributes.hpp b/src/core/dev_api/openvino/core/rt_info/weightless_caching_attributes.hpp index bfc260eeeda71e..e3cf2609b26c8d 100644 --- a/src/core/dev_api/openvino/core/rt_info/weightless_caching_attributes.hpp +++ b/src/core/dev_api/openvino/core/rt_info/weightless_caching_attributes.hpp @@ -29,16 +29,13 @@ class OPENVINO_API WeightlessCacheAttribute : public RuntimeAttribute { WeightlessCacheAttribute(size_t original_size, size_t bin_offset, ov::element::Type original_dtype) : original_size(original_size), bin_offset(bin_offset), - original_dtype(original_dtype), - curr_dtype(original_dtype) {} + original_dtype(original_dtype) {} bool is_copyable() const override; - bool is_copyable(const std::shared_ptr& from, const std::shared_ptr& to) const override; size_t original_size; size_t bin_offset; ov::element::Type original_dtype; - ov::element::Type curr_dtype; }; } // namespace ov diff --git a/src/core/include/openvino/core/runtime_attribute.hpp b/src/core/include/openvino/core/runtime_attribute.hpp index 0a4ba2560d5969..86d301ddbfc62f 100644 --- a/src/core/include/openvino/core/runtime_attribute.hpp +++ b/src/core/include/openvino/core/runtime_attribute.hpp @@ -31,7 +31,6 @@ class OPENVINO_API RuntimeAttribute { virtual ~RuntimeAttribute() = default; virtual bool is_copyable() const; virtual bool is_copyable(const std::shared_ptr& to) const; - virtual bool is_copyable(const std::shared_ptr& from, const std::shared_ptr& to) const; virtual Any init(const std::shared_ptr& node) const; virtual Any merge(const ov::NodeVector& nodes) const; virtual Any merge(const ov::OutputVector& outputs) const; diff --git a/src/core/src/op/util/weightless_caching_attributes.cpp b/src/core/src/op/util/weightless_caching_attributes.cpp index e02d2383ee1f18..7c540f8a3bef02 100644 --- a/src/core/src/op/util/weightless_caching_attributes.cpp +++ b/src/core/src/op/util/weightless_caching_attributes.cpp @@ -4,17 +4,6 @@ #include "openvino/core/rt_info/weightless_caching_attributes.hpp" -#include "openvino/op/util/op_types.hpp" - bool ov::WeightlessCacheAttribute::is_copyable() const { return false; } - -bool ov::WeightlessCacheAttribute::is_copyable(const std::shared_ptr& from, - const std::shared_ptr& to) const { - if (!ov::op::util::is_constant(from) || !ov::op::util::is_constant(to)) { - return false; - } - - return from->get_element_type() != to->get_element_type(); -} diff --git a/src/core/src/rt_info.cpp b/src/core/src/rt_info.cpp index f82bb69f348a1f..d69790adde7392 100644 --- a/src/core/src/rt_info.cpp +++ b/src/core/src/rt_info.cpp @@ -32,7 +32,7 @@ std::unordered_map> get_copyable_attrs(const o for (const auto& item : node->get_rt_info()) { bool copy = item.first != "opset"; if (item.second.is()) { - copy = copy && item.second.as().is_copyable(node, to); + copy = copy && item.second.as().is_copyable(to); } if (copy) { attrs[item.first].push_back(item.second); diff --git a/src/core/src/runtime_attribute.cpp b/src/core/src/runtime_attribute.cpp index bf047926d8cdc8..8a784f35bb1ea2 100644 --- a/src/core/src/runtime_attribute.cpp +++ b/src/core/src/runtime_attribute.cpp @@ -36,10 +36,6 @@ bool RuntimeAttribute::is_copyable(const std::shared_ptr& to) const { return is_copyable(); } -bool RuntimeAttribute::is_copyable(const std::shared_ptr& from, const std::shared_ptr& to) const { - return is_copyable(to); -} - std::ostream& operator<<(std::ostream& os, const RuntimeAttribute& attrubute) { return os << attrubute.to_string(); } diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/data.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/data.hpp index 3b234c0e39d9ff..8a9a35b1e92fe9 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/data.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/data.hpp @@ -158,7 +158,7 @@ struct weightless_cache_manager { } manager.run_passes(model); - const auto ops = model->get_ops(); + const auto& ops = model->get_ops(); auto it = std::find_if(ops.begin(), ops.end(), [](const std::shared_ptr& node) { return ov::op::util::is_constant(node); });