Skip to content

Commit

Permalink
Repsect KV cahe precision from IR
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Jan 15, 2025
1 parent 79ed0eb commit 05f79a8
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 14 deletions.
14 changes: 13 additions & 1 deletion src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ namespace ov::genai {
template<class... Ts> struct overloaded : Ts... {using Ts::operator()...;};
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;


ov::element::Type get_model_kv_cache_precision(std::shared_ptr<ov::Model> model) {
const std::vector<std::string> kv_cache_precision_path = { "runtime_options", ov::hint::kv_cache_precision.name() };
ov::element::Type ir_kv_cache_precision = ov::element::undefined;

if (model->has_rt_info(kv_cache_precision_path)) {
ir_kv_cache_precision = model->get_rt_info<ov::element::Type>(kv_cache_precision_path);
}

return ir_kv_cache_precision;
}

ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
const std::shared_ptr<ov::Model>& model,
const Tokenizer& tokenizer,
Expand All @@ -24,7 +36,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
m_is_validation_mode_enabled = is_validation_mode_enabled;

ov::Core core = utils::singleton_core();
DeviceConfig device_config(core, scheduler_config, device, properties);
DeviceConfig device_config(core, scheduler_config, device, properties, get_model_kv_cache_precision(model));

bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction;
utils::apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control);
Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/continuous_batching_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

namespace ov::genai {

ov::element::Type get_model_kv_cache_precision(std::shared_ptr<ov::Model> model);

class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatchingPipeline::IContinuousBatchingPipeline {
protected:
std::shared_ptr<Scheduler> m_scheduler;
Expand Down
25 changes: 14 additions & 11 deletions src/cpp/src/device_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,30 @@ class DeviceConfig {
}

public:
DeviceConfig(ov::Core& core, const SchedulerConfig& scheduling_config, const std::string& device, const ov::AnyMap& plugin_config = {}) {
DeviceConfig(ov::Core& core, const SchedulerConfig& scheduling_config, const std::string& device,
const ov::AnyMap& plugin_config = {}, ov::element::Type ir_kv_cache_precision = ov::element::undefined) {
m_device = device;

// keep information about blocsk
// keep information about block size
m_block_size = get_block_size_by_device(device);

if (m_device == "CPU") {
// if user sets ov::kv_cache_precision hint
// if user sets properties affecting KV cache precision
const auto kv_cache_precision_it = plugin_config.find(ov::hint::kv_cache_precision.name());
const auto execution_mode_it = plugin_config.find(ov::hint::execution_mode.name());

if (kv_cache_precision_it != plugin_config.end()) {
const auto kv_cache_precision = kv_cache_precision_it->second.as<ov::element::Type>();
m_kv_cache_type = kv_cache_precision;
} else if (execution_mode_it != plugin_config.end() && execution_mode_it->second.as<ov::hint::ExecutionMode>() == ov::hint::ExecutionMode::ACCURACY) {
// ACCURACY mode will use f32 KV cache type
m_kv_cache_type = ov::element::f32;
} else if (ir_kv_cache_precision != ov::element::undefined) {
// check that kv_cache_precision is set in runtime_info section of OpenVINO IR
m_kv_cache_type = ir_kv_cache_precision;
} else {
// ACCURACY mode will use f32 kvcache
const auto execution_mode_it = plugin_config.find(ov::hint::execution_mode.name());
if (execution_mode_it != plugin_config.end() && execution_mode_it->second.as<ov::hint::ExecutionMode>() == ov::hint::ExecutionMode::ACCURACY) {
m_kv_cache_type = ov::element::f32;
} else {
// x86 and arm have different default kv cache type
m_kv_cache_type = core.get_property(device, ov::hint::kv_cache_precision);
}
// x86 and ARM have different default kv cache type, take this information from the plugin
m_kv_cache_type = core.get_property(device, ov::hint::kv_cache_precision);
}
} else if (m_device.find("GPU") != std::string::npos) {
auto inference_precision = core.get_property(device, ov::hint::inference_precision);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con
ov::AnyMap draft_properties = draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties;

ov::Core core = utils::singleton_core();
DeviceConfig main_device_config(core, main_scheduler_config_updated, main_device, main_model_desc.properties),
draft_device_config(core, draft_scheduler_config, draft_device, draft_properties);
DeviceConfig main_device_config(core, main_scheduler_config_updated, main_device, main_model_desc.properties, get_model_kv_cache_precision(main_model)),
draft_device_config(core, draft_scheduler_config, draft_device, draft_properties, get_model_kv_cache_precision(draft_model));

utils::set_kv_cache_type_and_shape(main_model, main_device_config);
utils::set_kv_cache_type_and_shape(draft_model, draft_device_config);
Expand Down

0 comments on commit 05f79a8

Please sign in to comment.