diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index dd3051b8b0..a8cf844cb7 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -46,22 +46,6 @@ void print_tensor(const ov::Tensor& tensor) { std::cout << "]" << std::endl; } -int64_t argmax(const ov::Tensor& logits, const size_t batch_idx) { - if (logits.get_shape()[0] <= batch_idx) { - OPENVINO_THROW("logits batch size doesn't match the number of beams"); - } - - size_t vocab_size = logits.get_shape().back(); - size_t batch_offset = batch_idx * logits.get_shape()[1] * vocab_size; - size_t sequence_offset = (logits.get_shape()[1] - 1) * vocab_size; - const float* logits_data = logits.data() + batch_offset + sequence_offset; - - int64_t out_token = std::max_element(logits_data, logits_data + vocab_size) - logits_data; - float max_logit = logits_data[out_token]; - - return out_token; -} - /** * Initializes position ids based on attention mask and starting position */ diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index c25b2c3913..8c56c39a8c 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -47,8 +47,6 @@ Tensor init_attention_mask(const Tensor& position_ids); void print_tensor(const ov::Tensor& tensor); -int64_t argmax(const ov::Tensor& logits, const size_t batch_idx); - void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, int64_t start_pos = 0); ov::Tensor extend_attention(ov::Tensor attention_mask); diff --git a/src/cpp/src/whisper/models/decoder.cpp b/src/cpp/src/whisper/models/decoder.cpp index 1c8df0edd9..c09a84ccdd 100644 --- a/src/cpp/src/whisper/models/decoder.cpp +++ b/src/cpp/src/whisper/models/decoder.cpp @@ -6,7 +6,7 @@ #include #include "statefull_decoder.hpp" -#include "utils.hpp" +#include "whisper/whisper_utils.hpp" #include "with_past_decoder.hpp" namespace ov::genai { diff --git a/src/cpp/src/whisper/whisper_utils.cpp b/src/cpp/src/whisper/whisper_utils.cpp index 6e56a1439d..f41d3d11d8 100644 --- a/src/cpp/src/whisper/whisper_utils.cpp +++ b/src/cpp/src/whisper/whisper_utils.cpp @@ -41,6 +41,22 @@ void filter_non_segment_metrics(ov::genai::RawPerfMetrics& raw_metrics, filter_by_ranges(raw_metrics.m_batch_sizes, offset, ranges); } +int64_t argmax(const ov::Tensor& logits, const size_t batch_idx) { + if (logits.get_shape()[0] <= batch_idx) { + OPENVINO_THROW("logits batch size doesn't match the number of beams"); + } + + size_t vocab_size = logits.get_shape().back(); + size_t batch_offset = batch_idx * logits.get_shape()[1] * vocab_size; + size_t sequence_offset = (logits.get_shape()[1] - 1) * vocab_size; + const float* logits_data = logits.data() + batch_offset + sequence_offset; + + int64_t out_token = std::max_element(logits_data, logits_data + vocab_size) - logits_data; + float max_logit = logits_data[out_token]; + + return out_token; +} + } // namespace utils } // namespace genai } // namespace ov diff --git a/src/cpp/src/whisper/whisper_utils.hpp b/src/cpp/src/whisper/whisper_utils.hpp index 234feed6a8..8fd0a080c6 100644 --- a/src/cpp/src/whisper/whisper_utils.hpp +++ b/src/cpp/src/whisper/whisper_utils.hpp @@ -17,6 +17,8 @@ void filter_non_segment_metrics(ov::genai::RawPerfMetrics& raw_metrics, size_t offset, std::vector>& ranges); +int64_t argmax(const ov::Tensor& logits, const size_t batch_idx); + } // namespace utils } // namespace genai } // namespace ov