From 05f79a8e4a75ae8e2886161fc4346337fe59eaf9 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Wed, 15 Jan 2025 11:20:10 +0100 Subject: [PATCH] Repsect KV cahe precision from IR --- src/cpp/src/continuous_batching_impl.cpp | 14 ++++++++++- src/cpp/src/continuous_batching_impl.hpp | 2 ++ src/cpp/src/device_config.hpp | 25 +++++++++++-------- .../speculative_decoding_impl.cpp | 4 +-- 4 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index e778e55b93..5f872db0d3 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -11,6 +11,18 @@ namespace ov::genai { template struct overloaded : Ts... {using Ts::operator()...;}; template overloaded(Ts...) -> overloaded; + +ov::element::Type get_model_kv_cache_precision(std::shared_ptr model) { + const std::vector kv_cache_precision_path = { "runtime_options", ov::hint::kv_cache_precision.name() }; + ov::element::Type ir_kv_cache_precision = ov::element::undefined; + + if (model->has_rt_info(kv_cache_precision_path)) { + ir_kv_cache_precision = model->get_rt_info(kv_cache_precision_path); + } + + return ir_kv_cache_precision; +} + ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl( const std::shared_ptr& model, const Tokenizer& tokenizer, @@ -24,7 +36,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl( m_is_validation_mode_enabled = is_validation_mode_enabled; ov::Core core = utils::singleton_core(); - DeviceConfig device_config(core, scheduler_config, device, properties); + DeviceConfig device_config(core, scheduler_config, device, properties, get_model_kv_cache_precision(model)); bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction; utils::apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control); diff --git a/src/cpp/src/continuous_batching_impl.hpp b/src/cpp/src/continuous_batching_impl.hpp index 78e6638fbc..c1de415cfe 100644 --- a/src/cpp/src/continuous_batching_impl.hpp +++ b/src/cpp/src/continuous_batching_impl.hpp @@ -10,6 +10,8 @@ namespace ov::genai { +ov::element::Type get_model_kv_cache_precision(std::shared_ptr model); + class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatchingPipeline::IContinuousBatchingPipeline { protected: std::shared_ptr m_scheduler; diff --git a/src/cpp/src/device_config.hpp b/src/cpp/src/device_config.hpp index 708360bcc8..4d13c80cd4 100644 --- a/src/cpp/src/device_config.hpp +++ b/src/cpp/src/device_config.hpp @@ -30,27 +30,30 @@ class DeviceConfig { } public: - DeviceConfig(ov::Core& core, const SchedulerConfig& scheduling_config, const std::string& device, const ov::AnyMap& plugin_config = {}) { + DeviceConfig(ov::Core& core, const SchedulerConfig& scheduling_config, const std::string& device, + const ov::AnyMap& plugin_config = {}, ov::element::Type ir_kv_cache_precision = ov::element::undefined) { m_device = device; - // keep information about blocsk + // keep information about block size m_block_size = get_block_size_by_device(device); if (m_device == "CPU") { - // if user sets ov::kv_cache_precision hint + // if user sets properties affecting KV cache precision const auto kv_cache_precision_it = plugin_config.find(ov::hint::kv_cache_precision.name()); + const auto execution_mode_it = plugin_config.find(ov::hint::execution_mode.name()); + if (kv_cache_precision_it != plugin_config.end()) { const auto kv_cache_precision = kv_cache_precision_it->second.as(); m_kv_cache_type = kv_cache_precision; + } else if (execution_mode_it != plugin_config.end() && execution_mode_it->second.as() == ov::hint::ExecutionMode::ACCURACY) { + // ACCURACY mode will use f32 KV cache type + m_kv_cache_type = ov::element::f32; + } else if (ir_kv_cache_precision != ov::element::undefined) { + // check that kv_cache_precision is set in runtime_info section of OpenVINO IR + m_kv_cache_type = ir_kv_cache_precision; } else { - // ACCURACY mode will use f32 kvcache - const auto execution_mode_it = plugin_config.find(ov::hint::execution_mode.name()); - if (execution_mode_it != plugin_config.end() && execution_mode_it->second.as() == ov::hint::ExecutionMode::ACCURACY) { - m_kv_cache_type = ov::element::f32; - } else { - // x86 and arm have different default kv cache type - m_kv_cache_type = core.get_property(device, ov::hint::kv_cache_precision); - } + // x86 and ARM have different default kv cache type, take this information from the plugin + m_kv_cache_type = core.get_property(device, ov::hint::kv_cache_precision); } } else if (m_device.find("GPU") != std::string::npos) { auto inference_precision = core.get_property(device, ov::hint::inference_precision); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 8b2730a62b..35d5e00424 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -62,8 +62,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con ov::AnyMap draft_properties = draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties; ov::Core core = utils::singleton_core(); - DeviceConfig main_device_config(core, main_scheduler_config_updated, main_device, main_model_desc.properties), - draft_device_config(core, draft_scheduler_config, draft_device, draft_properties); + DeviceConfig main_device_config(core, main_scheduler_config_updated, main_device, main_model_desc.properties, get_model_kv_cache_precision(main_model)), + draft_device_config(core, draft_scheduler_config, draft_device, draft_properties, get_model_kv_cache_precision(draft_model)); utils::set_kv_cache_type_and_shape(main_model, main_device_config); utils::set_kv_cache_type_and_shape(draft_model, draft_device_config);