Skip to content

Commit

Permalink
Move whisper utils
Browse files Browse the repository at this point in the history
  • Loading branch information
as-suvorov committed Jan 23, 2025
1 parent 3348ad5 commit d63e700
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 19 deletions.
16 changes: 0 additions & 16 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,6 @@ void print_tensor(const ov::Tensor& tensor) {
std::cout << "]" << std::endl;
}

int64_t argmax(const ov::Tensor& logits, const size_t batch_idx) {
if (logits.get_shape()[0] <= batch_idx) {
OPENVINO_THROW("logits batch size doesn't match the number of beams");
}

size_t vocab_size = logits.get_shape().back();
size_t batch_offset = batch_idx * logits.get_shape()[1] * vocab_size;
size_t sequence_offset = (logits.get_shape()[1] - 1) * vocab_size;
const float* logits_data = logits.data<const float>() + batch_offset + sequence_offset;

int64_t out_token = std::max_element(logits_data, logits_data + vocab_size) - logits_data;
float max_logit = logits_data[out_token];

return out_token;
}

/**
* Initializes position ids based on attention mask and starting position
*/
Expand Down
2 changes: 0 additions & 2 deletions src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ Tensor init_attention_mask(const Tensor& position_ids);

void print_tensor(const ov::Tensor& tensor);

int64_t argmax(const ov::Tensor& logits, const size_t batch_idx);

void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, int64_t start_pos = 0);

ov::Tensor extend_attention(ov::Tensor attention_mask);
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/whisper/models/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <filesystem>

#include "statefull_decoder.hpp"
#include "utils.hpp"
#include "whisper/whisper_utils.hpp"
#include "with_past_decoder.hpp"

namespace ov::genai {
Expand Down
16 changes: 16 additions & 0 deletions src/cpp/src/whisper/whisper_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ void filter_non_segment_metrics(ov::genai::RawPerfMetrics& raw_metrics,
filter_by_ranges(raw_metrics.m_batch_sizes, offset, ranges);
}

int64_t argmax(const ov::Tensor& logits, const size_t batch_idx) {
if (logits.get_shape()[0] <= batch_idx) {
OPENVINO_THROW("logits batch size doesn't match the number of beams");
}

size_t vocab_size = logits.get_shape().back();
size_t batch_offset = batch_idx * logits.get_shape()[1] * vocab_size;
size_t sequence_offset = (logits.get_shape()[1] - 1) * vocab_size;
const float* logits_data = logits.data<const float>() + batch_offset + sequence_offset;

int64_t out_token = std::max_element(logits_data, logits_data + vocab_size) - logits_data;
float max_logit = logits_data[out_token];

return out_token;
}

} // namespace utils
} // namespace genai
} // namespace ov
2 changes: 2 additions & 0 deletions src/cpp/src/whisper/whisper_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ void filter_non_segment_metrics(ov::genai::RawPerfMetrics& raw_metrics,
size_t offset,
std::vector<std::pair<size_t, size_t>>& ranges);

int64_t argmax(const ov::Tensor& logits, const size_t batch_idx);

} // namespace utils
} // namespace genai
} // namespace ov

0 comments on commit d63e700

Please sign in to comment.