Skip to content

Commit

Permalink
Repsect KV cahe precision from IR
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Jan 15, 2025
1 parent 79ed0eb commit 9f5e415
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
8 changes: 7 additions & 1 deletion src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
m_generation_config = generation_config;
m_is_validation_mode_enabled = is_validation_mode_enabled;

const std::vector<std::string> kv_cache_precision_path = { "runtime_options", ov::hint::kv_cache_precision.name() };
ov::element::Type ir_kv_cache_precision = ov::element::undefined;
if (model->has_rt_info(kv_cache_precision_path)) {
ir_kv_cache_precision = model->get_rt_info<ov::element::Type>(kv_cache_precision_path);
}

ov::Core core = utils::singleton_core();
DeviceConfig device_config(core, scheduler_config, device, properties);
DeviceConfig device_config(core, scheduler_config, device, properties, ir_kv_cache_precision);

bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction;
utils::apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control);
Expand Down
24 changes: 13 additions & 11 deletions src/cpp/src/device_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,29 @@ class DeviceConfig {
}

public:
DeviceConfig(ov::Core& core, const SchedulerConfig& scheduling_config, const std::string& device, const ov::AnyMap& plugin_config = {}) {
DeviceConfig(ov::Core& core, const SchedulerConfig& scheduling_config, const std::string& device, const ov::AnyMap& plugin_config, ov::element::Type ir_kv_cache_precision) {
m_device = device;

// keep information about blocsk
// keep information about block size
m_block_size = get_block_size_by_device(device);

if (m_device == "CPU") {
// if user sets ov::kv_cache_precision hint
// if user sets properties affecting KV cache precision
const auto kv_cache_precision_it = plugin_config.find(ov::hint::kv_cache_precision.name());
const auto execution_mode_it = plugin_config.find(ov::hint::execution_mode.name());

if (kv_cache_precision_it != plugin_config.end()) {
const auto kv_cache_precision = kv_cache_precision_it->second.as<ov::element::Type>();
m_kv_cache_type = kv_cache_precision;
} else if (execution_mode_it != plugin_config.end() && execution_mode_it->second.as<ov::hint::ExecutionMode>() == ov::hint::ExecutionMode::ACCURACY) {
// ACCURACY mode will use f32 KV cache type
m_kv_cache_type = ov::element::f32;
} else if (ir_kv_cache_precision != ov::element::undefined) {
// check that kv_cache_precision is set in runtime_info section of OpenVINO IR
m_kv_cache_type = ir_kv_cache_precision;
} else {
// ACCURACY mode will use f32 kvcache
const auto execution_mode_it = plugin_config.find(ov::hint::execution_mode.name());
if (execution_mode_it != plugin_config.end() && execution_mode_it->second.as<ov::hint::ExecutionMode>() == ov::hint::ExecutionMode::ACCURACY) {
m_kv_cache_type = ov::element::f32;
} else {
// x86 and arm have different default kv cache type
m_kv_cache_type = core.get_property(device, ov::hint::kv_cache_precision);
}
// x86 and ARM have different default kv cache type, take this information from the plugin
m_kv_cache_type = core.get_property(device, ov::hint::kv_cache_precision);
}
} else if (m_device.find("GPU") != std::string::npos) {
auto inference_precision = core.get_property(device, ov::hint::inference_precision);
Expand Down

0 comments on commit 9f5e415

Please sign in to comment.