Skip to content

Commit

Permalink
Merge pull request #1 from reinfer/dev/prashanth/return_logits
Browse files Browse the repository at this point in the history
Return logits and add instructions to build wheel locally
  • Loading branch information
annesylvie authored Jun 25, 2024
2 parents 174e4ee + 0a49019 commit 8bce69f
Show file tree
Hide file tree
Showing 32 changed files with 505 additions and 107 deletions.
25 changes: 25 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,31 @@

### Fixes and improvements

## [v3.24.0](https://github.com/OpenNMT/CTranslate2/releases/tag/v3.23.0) (2024-01-08)

### New features
* Support of new option offset to ignore token score of special tokens

## [v3.23.0](https://github.com/OpenNMT/CTranslate2/releases/tag/v3.23.0) (2023-12-05)

### New features
* Support Phi model

### Fixes and improvements
* Fix the conversion for whisper without the "alignment_heads" in the "generation_config.json"
* Fix forward batch

## [v3.22.0](https://github.com/OpenNMT/CTranslate2/releases/tag/v3.22.0) (2023-11-22)

### New features
* Support "sliding window" and "chunking input" for Mistral

### Fixes and improvements
* Take into account the "generation_config.json" and fix "lang_ids" getter for Whisper converter
* Accept callback even on "generate_tokens" method
* Fix iomp5 linking with latest Intel OpenAPI on Ubuntu
* Fixed "decoder_start_token_id" for T5

## [v3.21.0](https://github.com/OpenNMT/CTranslate2/releases/tag/v3.21.0) (2023-11-09)

### New features
Expand Down
6 changes: 4 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ set(SOURCES
src/ops/bias_add.cc
src/ops/bias_add_cpu.cc
src/ops/concat.cc
src/ops/concat_split_cpu.cc
src/ops/concat_split_slide_cpu.cc
src/ops/conv1d.cc
src/ops/conv1d_cpu.cc
src/ops/cos.cc
Expand Down Expand Up @@ -168,6 +168,7 @@ set(SOURCES
src/ops/softmax.cc
src/ops/softmax_cpu.cc
src/ops/split.cc
src/ops/slide.cc
src/ops/sub.cc
src/ops/swish.cc
src/ops/tanh.cc
Expand Down Expand Up @@ -263,6 +264,7 @@ if(NOT OPENMP_RUNTIME STREQUAL "NONE")
${INTEL_ROOT}/oneAPI/compiler/latest/windows/compiler/lib/intel64_win
${INTEL_ROOT}/oneapi/compiler/latest/linux/compiler/lib/intel64_lin
${INTEL_ROOT}/oneapi/compiler/latest/mac/compiler/lib
${INTEL_ROOT}/oneapi/compiler/latest/lib
)
if(IOMP5_LIBRARY)
list(APPEND LIBRARIES ${IOMP5_LIBRARY})
Expand Down Expand Up @@ -505,7 +507,7 @@ if (WITH_CUDA)
src/cuda/utils.cc
src/ops/alibi_add_gpu.cu
src/ops/bias_add_gpu.cu
src/ops/concat_split_gpu.cu
src/ops/concat_split_slide_gpu.cu
src/ops/conv1d_gpu.cu
src/ops/dequantize_gpu.cu
src/ops/gather_gpu.cu
Expand Down
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,20 @@ Executed with CUDA 11 on a [*g5.xlarge*](https://aws.amazon.com/ec2/instance-typ
* [Documentation](https://opennmt.net/CTranslate2)
* [Forum](https://forum.opennmt.net)
* [Gitter](https://gitter.im/OpenNMT/CTranslate2)


# To build locally

cd CTranslate2
mkdir build
sudo cmake -DCMAKE_INSTALL_PREFIX=/usr/local -DWITH_CUDA=ON -DWITH_CUDNN=ON -DWITH_MKL=ON -DOPENMP_RUNTIME=COMP -DCMAKE_BUILD_TYPE=Release ..
sudo make -j4
sudo make install
sudo ldconfig

# LD_LIBRARY_PATH should contain the ctranslate install path

# Build python wheel
cd python
python setup.py bdist_wheel --dist-dir <path_to_wheel_dir>
auditwheel repair --plat manylinux_2_34_x86_64 <path_to_wheel>
10 changes: 5 additions & 5 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as builder
FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu20.04 as builder

RUN apt-get update && \
apt-get install -y --no-install-recommends \
Expand Down Expand Up @@ -66,18 +66,18 @@ RUN cd python && \
python3 -m pip --no-cache-dir install -r install_requirements.txt && \
python3 setup.py bdist_wheel --dist-dir $CTRANSLATE2_ROOT

FROM nvidia/cuda:11.2.2-base-ubuntu20.04
FROM nvidia/cuda:12.2.2-base-ubuntu20.04

# We remove the cuda-compat package because it conflicts with the CUDA Enhanced Compatibility.
# See e.g. https://github.com/NVIDIA/nvidia-docker/issues/1515
RUN apt-get update && \
apt-get install -y --no-install-recommends \
libcublas-11-2 \
libcudnn8=8.1.1.33-1+cuda11.2 \
libcublas-12-2 \
libcudnn8=8.9.7.29-1+cuda12.2 \
libgomp1 \
python3-pip \
&& \
apt-get purge -y cuda-compat-11-2 && \
apt-get purge -y cuda-compat-12-2 && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

Expand Down
2 changes: 1 addition & 1 deletion docker/build_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ build()
fi
}

build Dockerfile ubuntu20.04-cuda11.2
build Dockerfile ubuntu20.04-cuda12.2
2 changes: 2 additions & 0 deletions include/ctranslate2/decoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace ctranslate2 {
std::vector<std::vector<size_t>> hypotheses;
std::vector<float> scores;
std::vector<std::vector<std::vector<float>>> attention;
std::vector<float> logits;
// (max_decoding_steps)
};

struct DecodingStepResult {
Expand Down
3 changes: 2 additions & 1 deletion include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ namespace ctranslate2 {
const Padder* queries_padder = nullptr,
const Padder* values_padder = nullptr,
bool return_normalized_attention = true,
StorageView* position_bias = nullptr) const;
StorageView* position_bias = nullptr,
dim_t offset = 0) const;

bool has_positional_embeddings() const {
return _relative_position_keys || _relative_attention_bias || _rotary_embeddings || _alibi;
Expand Down
4 changes: 3 additions & 1 deletion include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ namespace ctranslate2 {
const Padder* input_padder = nullptr,
const Padder* memory_padder = nullptr,
bool return_normalized_attention = true,
StorageView* position_bias = nullptr) const;
StorageView* position_bias = nullptr,
dim_t offset = 0) const;

DataType output_type() const override {
return _ff.output_type();
Expand Down Expand Up @@ -209,6 +210,7 @@ namespace ctranslate2 {
std::vector<std::vector<dim_t>> _alignment_heads;
bool _average_alignment_heads;
Dense _proj;
const dim_t _sliding_window;
};

}
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@
#include "median_filter.h"
#include "rotary.h"
#include "alibi_add.h"
#include "slide.h"
26 changes: 26 additions & 0 deletions include/ctranslate2/ops/slide.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "op.h"

namespace ctranslate2 {
namespace ops {

class Slide : public Op {
public:
Slide(dim_t axis, const dim_t& index, const dim_t& size, bool no_copy = false);

void operator()(const StorageView& input, StorageView& output) const;
private:
dim_t _axis;
dim_t _index;
dim_t _size;
bool _no_copy;

void check_arguments() const;

template <Device D, typename T>
void compute(const StorageView& input, StorageView& output, const dim_t& index) const;
};

}
}
4 changes: 3 additions & 1 deletion include/ctranslate2/scoring.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace ctranslate2 {
struct ScoringOptions {
// Truncate the inputs after this many tokens (set 0 to disable truncation).
size_t max_input_length = 1024;
dim_t offset = 0;
};

struct ScoringResult {
Expand All @@ -38,6 +39,7 @@ namespace ctranslate2 {
layers::DecoderState& state,
const std::vector<std::vector<size_t>>& sequences,
const Vocabulary& vocabulary,
const dim_t preferred_size_multiple = 1);
const dim_t preferred_size_multiple = 1,
const dim_t offset=0);

}
6 changes: 5 additions & 1 deletion include/ctranslate2/translation.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ namespace ctranslate2 {
std::vector<std::vector<std::string>> hypotheses;
std::vector<float> scores;
std::vector<std::vector<std::vector<float>>> attention;
std::vector<float> logits;

TranslationResult(std::vector<std::vector<std::string>> hypotheses_)
: hypotheses(std::move(hypotheses_))
Expand All @@ -95,10 +96,12 @@ namespace ctranslate2 {

TranslationResult(std::vector<std::vector<std::string>> hypotheses_,
std::vector<float> scores_,
std::vector<std::vector<std::vector<float>>> attention_)
std::vector<std::vector<std::vector<float>>> attention_,
std::vector<float> logits_)
: hypotheses(std::move(hypotheses_))
, scores(std::move(scores_))
, attention(std::move(attention_))
, logits(std::move(logits_))
{
}

Expand All @@ -109,6 +112,7 @@ namespace ctranslate2 {
: hypotheses(num_hypotheses)
, scores(with_score ? num_hypotheses : 0, static_cast<float>(0))
, attention(with_attention ? num_hypotheses : 0)
, logits(with_score ? num_hypotheses : 0)
{
}

Expand Down
7 changes: 6 additions & 1 deletion python/cpp/translation_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ namespace ctranslate2 {
"Score of each translation hypothesis (empty if :obj:`return_scores` was disabled).")
.def_readonly("attention", &TranslationResult::attention,
"Attention matrix of each translation hypothesis (empty if :obj:`return_attention` was disabled).")
.def_readonly("logits", &TranslationResult::logits,
"Logits for each decoding step")

.def("__repr__", [](const TranslationResult& result) {
return "TranslationResult(hypotheses=" + std::string(py::repr(py::cast(result.hypotheses)))
+ ", scores=" + std::string(py::repr(py::cast(result.scores)))
+ ", attention=" + std::string(py::repr(py::cast(result.attention)))
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
+ ")";
})

Expand All @@ -39,8 +42,10 @@ namespace ctranslate2 {
throw py::index_error();
py::dict hypothesis;
hypothesis["tokens"] = result.hypotheses[i];
if (result.has_scores())
if (result.has_scores()){
hypothesis["score"] = result.scores[i];
hypothesis["logits"] = result.logits[i];
};
if (result.has_attention())
hypothesis["attention"] = result.attention[i];
return hypothesis;
Expand Down
9 changes: 8 additions & 1 deletion python/cpp/translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,12 @@ namespace ctranslate2 {
size_t max_batch_size,
const std::string& batch_type_str,
size_t max_input_length,
dim_t offset,
bool asynchronous) {
const auto batch_type = str_to_batch_type(batch_type_str);
ScoringOptions options;
options.max_input_length = max_input_length;
options.offset = offset;

std::shared_lock lock(_mutex);
assert_model_is_ready();
Expand All @@ -252,6 +254,7 @@ namespace ctranslate2 {
size_t read_batch_size,
const std::string& batch_type_str,
size_t max_input_length,
dim_t offset,
bool with_tokens_score,
const TokenizeFn& source_tokenize_fn,
const TokenizeFn& target_tokenize_fn,
Expand All @@ -263,7 +266,7 @@ namespace ctranslate2 {
const auto batch_type = str_to_batch_type(batch_type_str);
ScoringOptions options;
options.max_input_length = max_input_length;

options.offset = offset;
std::shared_lock lock(_mutex);
assert_model_is_ready();

Expand Down Expand Up @@ -592,6 +595,7 @@ namespace ctranslate2 {
py::arg("max_batch_size")=0,
py::arg("batch_type")="examples",
py::arg("max_input_length")=1024,
py::arg("offset") = 0,
py::arg("asynchronous")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Expand All @@ -606,6 +610,7 @@ namespace ctranslate2 {
minimized.
batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens".
max_input_length: Truncate inputs after this many tokens (0 to disable).
offset: Ignore the first n tokens in target in score calculation.
asynchronous: Run the scoring asynchronously.
Returns:
Expand All @@ -621,6 +626,7 @@ namespace ctranslate2 {
py::arg("read_batch_size")=0,
py::arg("batch_type")="examples",
py::arg("max_input_length")=1024,
py::arg("offset")=0,
py::arg("with_tokens_score")=false,
py::arg("source_tokenize_fn")=nullptr,
py::arg("target_tokenize_fn")=nullptr,
Expand Down Expand Up @@ -649,6 +655,7 @@ namespace ctranslate2 {
batch_type: Whether :obj:`max_batch_size` and :obj:`read_batch_size` are the
number of "examples" or "tokens".
max_input_length: Truncate inputs after this many tokens (0 to disable).
offset: Ignore the first n tokens in target in score calculation.
with_tokens_score: Include the token-level scores in the output file.
source_tokenize_fn: Function to tokenize source lines.
target_tokenize_fn: Function to tokenize target lines.
Expand Down
9 changes: 5 additions & 4 deletions python/ctranslate2/converters/opennmt_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd
activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu")
num_heads = getattr(opt, "heads", 8)
num_kv = getattr(opt, "num_kv", 0)
if num_kv == num_heads:
if num_kv == num_heads or num_kv == 0:
num_kv = None
rotary_dim = 0 if with_rotary else None
rotary_interleave = getattr(opt, "rotary_interleave", True)
ffn_glu = activation_fn == "silu"
sliding_window = getattr(opt, "sliding_window", 0)

Expand All @@ -119,7 +120,7 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd
alibi=with_alibi,
rms_norm=opt.layer_norm == "rms",
rotary_dim=rotary_dim,
rotary_interleave=True,
rotary_interleave=rotary_interleave,
multi_query_attention=getattr(opt, "multiquery", False),
num_heads_kv=num_kv,
sliding_window=sliding_window,
Expand Down Expand Up @@ -329,7 +330,7 @@ def set_linear(spec, variables, scope):
spec.weight = _get_variable(variables, "%s.weight" % scope)
bias = variables.get("%s.bias" % scope)
if bias is not None:
spec.bias = bias.numpy()
spec.bias = bias


def set_embeddings(spec, variables, scope):
Expand All @@ -341,7 +342,7 @@ def set_position_encodings(spec, variables, scope):


def _get_variable(variables, name):
return variables[name].numpy()
return variables[name]


def main():
Expand Down
Loading

0 comments on commit 8bce69f

Please sign in to comment.