Skip to content

Commit

Permalink
[GPU] Fix weightless caching with int4 models
Browse files Browse the repository at this point in the history
  • Loading branch information
tkrupa-intel committed Jan 23, 2025
1 parent b65a324 commit 5700e0e
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1417,13 +1417,7 @@ bool fuse_type_to_constant(const std::shared_ptr<ov::Node>& node,
new_const->validate_and_infer_types();
new_const->set_friendly_name(constant->get_friendly_name());
ov::copy_runtime_info(constant, new_const);

const auto& rt_info = node->get_rt_info();
auto weightless_caching_attr = rt_info.find(ov::WeightlessCacheAttribute::get_type_info_static());
if (weightless_caching_attr != rt_info.end()) {
new_const->get_rt_info()[ov::WeightlessCacheAttribute::get_type_info_static()] =
weightless_caching_attr->second;
}
ov::copy_weightless_cache_attr(constant, new_const);
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
#include "openvino/core/core_visibility.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/runtime_attribute.hpp"
#include "transformations_visibility.hpp"

namespace ov {

TRANSFORMATIONS_API void copy_weightless_cache_attr(const std::shared_ptr<Node>& from, const std::shared_ptr<Node>& to);

/**
* @brief Holds weightless caching attributes of a single constant.
*
Expand Down
11 changes: 11 additions & 0 deletions src/core/src/op/util/weightless_caching_attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,14 @@
bool ov::WeightlessCacheAttribute::is_copyable() const {
return false;
}

TRANSFORMATIONS_API void ov::copy_weightless_cache_attr(const std::shared_ptr<ov::Node>& from,
const std::shared_ptr<ov::Node>& to) {
const auto& rt_info = from->get_rt_info();
auto weightless_caching_attr = rt_info.find(ov::WeightlessCacheAttribute::get_type_info_static());

if (weightless_caching_attr != rt_info.end()) {
to->get_rt_info()[ov::WeightlessCacheAttribute::get_type_info_static()] =
weightless_caching_attr->second;
}
}
2 changes: 2 additions & 0 deletions src/core/src/pass/constant_folding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "openvino/cc/pass/itt.hpp"
#include "openvino/core/constant_fold_utils.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/core/rt_info/weightless_caching_attributes.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/util/op_types.hpp"
Expand Down Expand Up @@ -153,6 +154,7 @@ bool ov::pass::ConstantFolding::run_on_model(const std::shared_ptr<ov::Model>& m
copy_runtime_info_from_input_values(original_node);
// Propagate runtime info attributes to replacement
copy_runtime_info(original_node, replacement_ptr);
ov::copy_weightless_cache_attr(original_node, replacement_ptr);

rewritten = true;
}
Expand Down
Loading

0 comments on commit 5700e0e

Please sign in to comment.