Skip to content

Commit

Permalink
Chmck beam search batch > 1 in ci
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Jan 20, 2025
1 parent 8aeb714 commit 1057eb8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,23 @@ jobs:
import transformers
with open('pred.txt', 'r', errors='ignore') as file:
predictions = file.read()
print('\n\n')
print(predictions)
print('\n\n')
tokenizer = transformers.AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0')
prompts = [
'Alan Turing was a',
'return 0',
'你好! 你好嗎?'
]
for prompt in prompts:
tokenized = tokenizer(prompt, return_tensors='pt')
if tokenizer.chat_template:
prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True)
tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
print(tokenized)
for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False):
ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True)
print(ref)
idx = predictions.find(ref.replace('�', ''))
if -1 == idx:
raise RuntimeError(f'Missing "{ref=}" from predictions')
Expand Down
10 changes: 8 additions & 2 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,14 @@ DecodedResults StatefulLLMPipeline::generate(
TokenizedInputs encoded_input;

if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts");
encoded_input = m_tokenizer.encode(*input_vector);
std::vector<std::string> templated_input_vector;
for (auto& input : *input_vector) {
ChatHistory history({{{"role", "user"}, {"content", input}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
templated_input_vector.push_back(templated_prompt);
}
encoded_input = m_tokenizer.encode(templated_input_vector, ov::genai::add_special_tokens(false));
} else if (auto input_prompt = std::get_if<std::string>(&inputs)) {
std::string& prompt = *input_prompt;

Expand Down

0 comments on commit 1057eb8

Please sign in to comment.