Skip to content

Commit

Permalink
Add wrapper for Tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasantony committed Mar 20, 2023
1 parent e835aab commit a8acebd
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/PyLlama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,25 @@

namespace py = pybind11;

class Tokenizer {
gpt_vocab vocab{};
public:
Tokenizer(const gpt_vocab& vocabIn): vocab(vocabIn) {}
std::vector<gpt_vocab::id> tokenize(const std::string & text, bool bos) {
return llama_tokenize(vocab, text, bos);
}
std::string detokenize(const std::vector<gpt_vocab::id>& ids) {
std::string output = "";
for (auto id: ids) {
output += vocab.id_to_token.at(id);
}
return output;
}
std::string detokenize(const gpt_vocab::id& id) {
return vocab.id_to_token.at(id);
}
};

class PyLLAMA {
llama_context* ctx_ptr = nullptr;
gpt_params params{};
Expand All @@ -27,6 +46,9 @@ class PyLLAMA {
llama_free_context(ctx_ptr);
}
}
Tokenizer get_tokenizer() {
return Tokenizer(vocab());
}
// Pass through all functions for llama_context
const gpt_vocab & vocab() const {
return llama_context_get_vocab(*ctx_ptr);
Expand Down Expand Up @@ -145,8 +167,10 @@ PYBIND11_MODULE(llamacpp, m) {
.def_readwrite("repeat_last_n", &gpt_params::repeat_last_n)
.def_readwrite("n_batch", &gpt_params::n_batch);

/* Wrapper for PyLLAMA */
py::class_<PyLLAMA>(m, "PyLLAMA")
.def(py::init<gpt_params>())
.def("get_tokenizer", &PyLLAMA::get_tokenizer, "Get the tokenizer")
.def("tokenize", &PyLLAMA::tokenize, "Tokenize text")
.def("prepare_context", &PyLLAMA::prepare_context, "Prepare the LLaMA context")
.def("add_bos", &PyLLAMA::add_bos, "Add a BOS token to the input")
Expand Down Expand Up @@ -186,5 +210,12 @@ PYBIND11_MODULE(llamacpp, m) {
.def_property("antiprompt", &PyLLAMA::get_antiprompt, &PyLLAMA::set_antiprompt, "Antiprompt")
;

/* Wrapper for Tokenizer */
py::class_<Tokenizer>(m, "Tokenizer")
.def("tokenize", &Tokenizer::tokenize, "Tokenize text", py::arg("text"), py::arg("add_bos") = false)
.def("detokenize", py::overload_cast<const std::vector<gpt_vocab::id>&>(&Tokenizer::detokenize), "Detokenize text")
.def("detokenize", py::overload_cast<const gpt_vocab::id&>(&Tokenizer::detokenize), "Detokenize single token");

/* Wrapper for llama_model_quantize */
m.def("llama_model_quantize", &llama_model_quantize, "Quantize the LLaMA model");
}

0 comments on commit a8acebd

Please sign in to comment.