Skip to content

Commit

Permalink
Remove is_copyable() logic and copy manually if needed
Browse files Browse the repository at this point in the history
  • Loading branch information
tkrupa-intel committed Dec 11, 2024
1 parent 0148525 commit 4bf7a6f
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>

#include "itt.hpp"
#include "openvino/core/rt_info/weightless_caching_attributes.hpp"
#include "openvino/op/ops.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "openvino/pass/manager.hpp"
Expand Down Expand Up @@ -1405,6 +1406,13 @@ bool fuse_type_to_constant(const std::shared_ptr<ov::Node>& node,
new_const->validate_and_infer_types();
new_const->set_friendly_name(constant->get_friendly_name());
ov::copy_runtime_info(constant, new_const);

const auto& rt_info = node->get_rt_info();
auto weightless_caching_attr = rt_info.find(ov::WeightlessCacheAttribute::get_type_info_static());
if (weightless_caching_attr != rt_info.end()) {
new_const->get_rt_info()[ov::WeightlessCacheAttribute::get_type_info_static()] =
weightless_caching_attr->second;
}
return true;
}
return false;
Expand Down
36 changes: 36 additions & 0 deletions src/common/transformations/tests/utils/convert_precision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/core/model.hpp"
#include "openvino/core/rt_info/weightless_caching_attributes.hpp"
#include "openvino/opsets/opset1.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/opsets/opset15.hpp"
Expand Down Expand Up @@ -2702,3 +2703,38 @@ TEST(TransformationTests, ConvertPrecision_assign_read_value_preserve_orig_types
FunctionsComparator::Result result = func_comparator(model_ref, model);
ASSERT_TRUE(result.valid) << result.message;
}

TEST(TransformationTests, ConvertPrecision_assign_read_value_preserve_weightless_cache_info_as_rt_attribute) {
pass::Manager manager;

auto some_value = opset10::Constant::create(element::f32, Shape{1}, {2});
auto& node_rt_info = some_value->get_rt_info();
ov::WeightlessCacheAttribute attr(element::f32.size(), 0, element::f32);
node_rt_info[ov::WeightlessCacheAttribute::get_type_info_static()] = attr;

ov::ParameterVector inputParams;
ov::ResultVector results;
results.push_back(std::make_shared<ov::op::v0::Result>(some_value->output(0)));
auto model = std::make_shared<ov::Model>(results, inputParams);

type_to_fuse_map empty_type_to_fuse_map = {};
bool keep_precision_sensitive_in_fp32 = false;
bool convert_input_output_precision = false;
bool store_original_precision_as_rt_attribute = true;
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
empty_type_to_fuse_map,
keep_precision_sensitive_in_fp32,
convert_input_output_precision,
store_original_precision_as_rt_attribute);
manager.run_passes(model);

const auto& ops = model->get_ops();
auto it = std::find_if(ops.begin(), ops.end(), [](const std::shared_ptr<Node>& node) {
return ov::op::util::is_constant(node);
});

ASSERT_TRUE(it != ops.end());
const auto& new_rt_info = (*it)->get_rt_info();
auto weightless_caching_attr_it = new_rt_info.find(ov::WeightlessCacheAttribute::get_type_info_static());
ASSERT_TRUE(weightless_caching_attr_it != new_rt_info.end());
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,13 @@ class OPENVINO_API WeightlessCacheAttribute : public RuntimeAttribute {
WeightlessCacheAttribute(size_t original_size, size_t bin_offset, ov::element::Type original_dtype)
: original_size(original_size),
bin_offset(bin_offset),
original_dtype(original_dtype),
curr_dtype(original_dtype) {}
original_dtype(original_dtype) {}

bool is_copyable() const override;
bool is_copyable(const std::shared_ptr<Node>& from, const std::shared_ptr<Node>& to) const override;

size_t original_size;
size_t bin_offset;
ov::element::Type original_dtype;
ov::element::Type curr_dtype;
};

} // namespace ov
1 change: 0 additions & 1 deletion src/core/include/openvino/core/runtime_attribute.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class OPENVINO_API RuntimeAttribute {
virtual ~RuntimeAttribute() = default;
virtual bool is_copyable() const;
virtual bool is_copyable(const std::shared_ptr<Node>& to) const;
virtual bool is_copyable(const std::shared_ptr<Node>& from, const std::shared_ptr<Node>& to) const;
virtual Any init(const std::shared_ptr<Node>& node) const;
virtual Any merge(const ov::NodeVector& nodes) const;
virtual Any merge(const ov::OutputVector& outputs) const;
Expand Down
11 changes: 0 additions & 11 deletions src/core/src/op/util/weightless_caching_attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,6 @@

#include "openvino/core/rt_info/weightless_caching_attributes.hpp"

#include "openvino/op/util/op_types.hpp"

bool ov::WeightlessCacheAttribute::is_copyable() const {
return false;
}

bool ov::WeightlessCacheAttribute::is_copyable(const std::shared_ptr<ov::Node>& from,
const std::shared_ptr<ov::Node>& to) const {
if (!ov::op::util::is_constant(from) || !ov::op::util::is_constant(to)) {
return false;
}

return from->get_element_type() != to->get_element_type();
}
2 changes: 1 addition & 1 deletion src/core/src/rt_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::unordered_map<std::string, std::vector<ov::Any>> get_copyable_attrs(const o
for (const auto& item : node->get_rt_info()) {
bool copy = item.first != "opset";
if (item.second.is<ov::RuntimeAttribute>()) {
copy = copy && item.second.as<ov::RuntimeAttribute>().is_copyable(node, to);
copy = copy && item.second.as<ov::RuntimeAttribute>().is_copyable(to);
}
if (copy) {
attrs[item.first].push_back(item.second);
Expand Down
4 changes: 0 additions & 4 deletions src/core/src/runtime_attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ bool RuntimeAttribute::is_copyable(const std::shared_ptr<Node>& to) const {
return is_copyable();
}

bool RuntimeAttribute::is_copyable(const std::shared_ptr<Node>& from, const std::shared_ptr<Node>& to) const {
return is_copyable(to);
}

std::ostream& operator<<(std::ostream& os, const RuntimeAttribute& attrubute) {
return os << attrubute.to_string();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ struct weightless_cache_manager {
}

manager.run_passes(model);
const auto ops = model->get_ops();
const auto& ops = model->get_ops();
auto it = std::find_if(ops.begin(), ops.end(), [](const std::shared_ptr<ov::Node>& node) {
return ov::op::util::is_constant(node);
});
Expand Down

0 comments on commit 4bf7a6f

Please sign in to comment.