diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index e778e55b93..928bc72bc7 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -23,8 +23,14 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl( m_generation_config = generation_config; m_is_validation_mode_enabled = is_validation_mode_enabled; + const std::vector kv_cache_precision_path = { "runtime_options", ov::hint::kv_cache_precision.name() }; + ov::element::Type ir_kv_cache_precision = ov::element::undefined; + if (model->has_rt_info(kv_cache_precision_path)) { + ir_kv_cache_precision = model->get_rt_info(kv_cache_precision_path); + } + ov::Core core = utils::singleton_core(); - DeviceConfig device_config(core, scheduler_config, device, properties); + DeviceConfig device_config(core, scheduler_config, device, properties, ir_kv_cache_precision); bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction; utils::apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control); diff --git a/src/cpp/src/device_config.hpp b/src/cpp/src/device_config.hpp index 708360bcc8..3f534eaa95 100644 --- a/src/cpp/src/device_config.hpp +++ b/src/cpp/src/device_config.hpp @@ -30,27 +30,29 @@ class DeviceConfig { } public: - DeviceConfig(ov::Core& core, const SchedulerConfig& scheduling_config, const std::string& device, const ov::AnyMap& plugin_config = {}) { + DeviceConfig(ov::Core& core, const SchedulerConfig& scheduling_config, const std::string& device, const ov::AnyMap& plugin_config, ov::element::Type ir_kv_cache_precision) { m_device = device; - // keep information about blocsk + // keep information about block size m_block_size = get_block_size_by_device(device); if (m_device == "CPU") { - // if user sets ov::kv_cache_precision hint + // if user sets properties affecting KV cache precision const auto kv_cache_precision_it = plugin_config.find(ov::hint::kv_cache_precision.name()); + const auto execution_mode_it = plugin_config.find(ov::hint::execution_mode.name()); + if (kv_cache_precision_it != plugin_config.end()) { const auto kv_cache_precision = kv_cache_precision_it->second.as(); m_kv_cache_type = kv_cache_precision; + } else if (execution_mode_it != plugin_config.end() && execution_mode_it->second.as() == ov::hint::ExecutionMode::ACCURACY) { + // ACCURACY mode will use f32 KV cache type + m_kv_cache_type = ov::element::f32; + } else if (ir_kv_cache_precision != ov::element::undefined) { + // check that kv_cache_precision is set in runtime_info section of OpenVINO IR + m_kv_cache_type = ir_kv_cache_precision; } else { - // ACCURACY mode will use f32 kvcache - const auto execution_mode_it = plugin_config.find(ov::hint::execution_mode.name()); - if (execution_mode_it != plugin_config.end() && execution_mode_it->second.as() == ov::hint::ExecutionMode::ACCURACY) { - m_kv_cache_type = ov::element::f32; - } else { - // x86 and arm have different default kv cache type - m_kv_cache_type = core.get_property(device, ov::hint::kv_cache_precision); - } + // x86 and ARM have different default kv cache type, take this information from the plugin + m_kv_cache_type = core.get_property(device, ov::hint::kv_cache_precision); } } else if (m_device.find("GPU") != std::string::npos) { auto inference_precision = core.get_property(device, ov::hint::inference_precision);