Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shallow Contextual Biasing for Whisper #1789

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/decoding.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ The prefix effectively changes the target context and the rest of the translatio

> Dieses Projekt ist auf das effiziente **Servieren** von Standard-Übersetzungsmodellen ausgerichtet, ist aber auch ein Ort für Experimente rund um Modellkompression und Inferenzbeschleunigung.

## Biased decoding
## Biased prefix decoding for translation

Instead of using {ref}`decoding:autocompletion` to force a translation to start with a `target_prefix` argument, we can "bias" a translation towards a prefix by setting `prefix_bias_beta` to a value in (0, 1). The higher `prefix_bias_beta` is, the stronger the bias. A translation can diverge from a prefix when `prefix_bias_beta` is low and the translator is confident in decoding tokens that are different from the prefix's tokens. See [section 4.2](https://arxiv.org/abs/1912.03393) for more details on the biasing algorithm.

Expand Down Expand Up @@ -113,6 +113,10 @@ Lowering the bias by setting `prefix_bias_beta=0.1` results in a divergence in t

> Dieses Projekt ist auf **die** effiziente Bedienung von Standard-Übersetzungsmodellen ausgerichtet, ist aber auch ein Ort für Experimente rund um Modellkompression und Inferenzbeschleunigung.

## Shallow biasing for contextual ASR

Setting `sequence_bias` with tuples of `(sequence, biasing_multiplier)` for Whisper models to boost or diminute the hypotheses hitting words in the biasing list during beam search. See [Ssection 3.3](https://aclanthology.org/2024.lrec-main.328.pdf) for the general concept. See [HuggingFace implementation](https://huggingface.co/docs/transformers/en/internal/generation_utils#transformers.SequenceBiasLogitsProcessor) of an additive version.

## Alternatives at a position

Combining `target_prefix` with the `return_alternatives` flag returns alternative sequences just after the prefix:
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/decoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ namespace ctranslate2 {
std::vector<size_t> disable_ids;
std::vector<size_t> disable_ids_begin;
std::vector<std::vector<size_t>> disable_sequences;
std::vector<std::pair<std::vector<size_t>, float>> sequence_bias;
std::vector<std::shared_ptr<LogitsProcessor>> logits_processors;
std::function<bool(DecodingStepResult)> callback = nullptr;
};
Expand Down
63 changes: 63 additions & 0 deletions include/ctranslate2/decoding_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,46 @@ namespace ctranslate2 {
std::vector<int32_t> _flat_indices;
};

// Helper class to disable tokens in the model output.
class BiasTokens {
public:
BiasTokens(StorageView& logits);

void add(dim_t batch_id, dim_t token_id, float bias_value) {
const auto flat_index = batch_id * _vocabulary_size + token_id;

if (_logits_data) {
// On CPU, directly assign the value
_logits_data[flat_index] = _logits_data[flat_index] * bias_value;
} else {
// On GPU, prepare a list of unique indices and values to disable
const auto it = std::lower_bound(_flat_indices.begin(), _flat_indices.end(), flat_index,
[](const auto& a, const auto& b) { return a.first < b; });

if (it == _flat_indices.end() || it->first != flat_index) {
_flat_indices.emplace(it, flat_index, bias_value);
} else {
it->second *= bias_value;
}
}
}

// Disable a token for all batches.
void add(dim_t token_id, float bias_value) {
for (dim_t batch_id = 0; batch_id < _batch_size; ++batch_id)
add(batch_id, token_id, bias_value);
}

void apply();

private:
StorageView& _logits;
float* _logits_data;
const dim_t _batch_size;
const dim_t _vocabulary_size;
std::vector<std::pair<int32_t, float>> _flat_indices;
};

// Base class for processing the output logits.
class LogitsProcessor {
public:
Expand All @@ -82,6 +122,7 @@ namespace ctranslate2 {
virtual void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) = 0;
Expand Down Expand Up @@ -109,6 +150,7 @@ namespace ctranslate2 {
void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) override;
Expand All @@ -124,6 +166,7 @@ namespace ctranslate2 {
void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) override;
Expand All @@ -139,6 +182,7 @@ namespace ctranslate2 {
void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) override;
Expand All @@ -148,13 +192,31 @@ namespace ctranslate2 {
std::vector<std::vector<size_t>> _sequences;
};

// Disable the generation of some sequences of tokens.
class BiasSequences : public LogitsProcessor {
public:
BiasSequences(std::vector<std::pair<std::vector<size_t>, float>> sequences);
void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) override;

private:
std::vector<std::pair<size_t, float>> _ids;
std::vector<std::pair<std::vector<size_t>, float>> _sequences;
};

// Disable the generation of some tokens.
class SuppressTokens : public LogitsProcessor {
public:
SuppressTokens(std::vector<size_t> ids);
void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) override;
Expand All @@ -170,6 +232,7 @@ namespace ctranslate2 {
void apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
BiasTokens& bias_tokens,
const StorageView& sequences,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) override;
Expand Down
3 changes: 3 additions & 0 deletions include/ctranslate2/models/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ namespace ctranslate2 {
// List of token IDs to suppress.
// -1 will suppress a default set of symbols as defined in the model config.json file.
std::vector<int> suppress_tokens = {-1};

// List of sequences and a bias factor to contextualize decoding.
std::vector<std::pair<std::vector<size_t>, float>> sequence_bias = {};
};

struct WhisperGenerationResult {
Expand Down
2 changes: 2 additions & 0 deletions include/ctranslate2/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ namespace ctranslate2 {
static void strided_fill(T* x, T a, dim_t inc_x, dim_t size);
template <typename T>
static void indexed_fill(T* x, T a, const int32_t* indices, dim_t num_indices);
template <typename T>
static void indexed_pointwise_multiply(T* x, const T* values, const int32_t* indices, dim_t num_indices);

template <typename T>
static void copy(const T* x, T* y, dim_t size);
Expand Down
8 changes: 8 additions & 0 deletions python/cpp/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ namespace ctranslate2 {
size_t max_initial_timestamp_index,
bool suppress_blank,
const std::optional<std::vector<int>>& suppress_tokens,
const std::optional<std::vector<std::pair<std::vector<size_t>, float>>>& sequence_bias,
size_t sampling_topk,
float sampling_temperature) {
std::vector<std::future<models::WhisperGenerationResult>> futures;
Expand All @@ -69,6 +70,10 @@ namespace ctranslate2 {
options.suppress_tokens = suppress_tokens.value();
else
options.suppress_tokens.clear();
if (sequence_bias)
options.sequence_bias = sequence_bias.value();
else
options.sequence_bias.clear();
std::shared_lock lock(_mutex);
assert_model_is_ready();

Expand Down Expand Up @@ -254,6 +259,7 @@ namespace ctranslate2 {
py::arg("max_initial_timestamp_index")=50,
py::arg("suppress_blank")=true,
py::arg("suppress_tokens")=std::vector<int>{-1},
py::arg("sequence_bias")=std::vector<std::pair<int, float>>{},
py::arg("sampling_topk")=1,
py::arg("sampling_temperature")=1,
py::call_guard<py::gil_scoped_release>(),
Expand Down Expand Up @@ -286,6 +292,8 @@ namespace ctranslate2 {
suppress_blank: Suppress blank outputs at the beginning of the sampling.
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
of symbols as defined in the model ``config.json`` file.
sequence_bias: List of pairs of sequences and a biasing factor to boost or surpass
certain sequences.
sampling_topk: Randomly sample predictions from the top K candidates.
sampling_temperature: Sampling temperature to generate more random samples.

Expand Down
141 changes: 141 additions & 0 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,146 @@ def _get_features(audio):
transcription = processor.decode(token_ids)
assert transcription == expected_transcription


@test_utils.only_on_linux
@test_utils.on_available_devices
@pytest.mark.parametrize(
"model_name,prompts,expected_transcriptions,expected_no_speech_probs",
[
(
"openai/whisper-tiny",
[
[
"<|startoftranscript|>",
"<|en|>",
"<|transcribe|>",
"<|notimestamps|>",
],
[
"<|startoftranscript|>",
"<|en|>",
"<|transcribe|>",
"<|notimestamps|>",
"ĠAnd",
"Ġthus",
"Ġmy",
],
],
[
" Mr. Quiltre is the apostle of the middle classes and we are glad"
" to welcome his gospel.",
" And thus my fellow Americans ask not what your country can do for you,"
" ask what you can do for your country.",
],
[
pytest.approx(0.0022832120303064585, abs=1e-4),
pytest.approx(0.06885894387960434, abs=1e-3),
],
),
(
"openai/whisper-tiny",
[
["<|startoftranscript|>", "<|en|>", "<|transcribe|>"],
["<|startoftranscript|>", "<|en|>", "<|transcribe|>"],
],
[
" Mr. Quiltre is the apostle of the middle classes and we are glad"
" to welcome his gospel.",
" And so, my fellow Americans, ask not what your country can do for you,"
" ask what you can do for your country.",
],
[
pytest.approx(0.0022832120303064585, abs=1e-4),
pytest.approx(0.06885894387960434, abs=1e-3),
],
)
],
)
def test_transformers_contextually_biased_whisper(
self,
tmp_dir,
device,
model_name,
prompts,
expected_transcriptions,
expected_no_speech_probs,
):
import transformers

converter = ctranslate2.converters.TransformersConverter(model_name)
output_dir = str(tmp_dir.join("ctranslate2_model"))
output_dir = converter.convert(output_dir)
print(os.path.join(
os.path.dirname(os.path.realpath(__file__)), "..", "..", "tests", "data"
))
audio_paths = [
os.path.join(test_utils.get_data_dir(), "audio", "mr_quilter.npy"),
os.path.join(test_utils.get_data_dir(), "audio", "jfk.npy"),
]
audio = list(map(np.load, audio_paths))

processor = transformers.WhisperProcessor.from_pretrained(model_name)

def _get_features(audio):
# Pad after computing the log-Mel spectrogram to match the openai/whisper behavior.
inputs = processor(audio, padding=False, sampling_rate=16000)
features = inputs.input_features[0]
features = np.pad(features, [(0, 0), (0, 3000 - features.shape[-1])])
return features

features = np.stack(list(map(_get_features, audio)))
features = ctranslate2.StorageView.from_array(features)

model = ctranslate2.models.Whisper(output_dir, device=device)

assert model.is_multilingual == (not model_name.endswith(".en"))

if model.is_multilingual:
for result in model.detect_language(features):
best_lang, best_prob = result[0]
assert best_lang == "<|en|>"
assert best_prob > 0.9
else:
with pytest.raises(RuntimeError, match="multilingual"):
model.detect_language(features)

#bias the first two generated words into ("Mr. Quiltre")
results = model.generate(
features,
prompts,
beam_size=2,
num_hypotheses=2,
return_no_speech_prob=True,
sequence_bias=[([2221, 13, 2326, 2352, 265], 1.3), ([2221, 13, 2326, 2352], 1.3), ([2221, 13, 2326], 1.3)],
)

timestamp_begin = (
processor.tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
)

for prompt, result, expected_transcription, expected_no_speech_prob in zip(
prompts, results, expected_transcriptions, expected_no_speech_probs
):
assert len(result.sequences_ids) == 2
assert result.no_speech_prob == expected_no_speech_prob

for tokens in result.sequences_ids:
if "<|notimestamps|>" in prompt:
assert all(token < timestamp_begin for token in tokens)
else:
assert tokens[0] >= timestamp_begin
assert tokens[-1] >= timestamp_begin
assert tokens[-1] > tokens[0]

token_ids = list(
filter(lambda token: token < timestamp_begin, result.sequences_ids[0])
)

transcription = processor.decode(token_ids)
print(transcription)
print(expected_transcription)
assert transcription == expected_transcription

@test_utils.only_on_linux
@test_utils.on_available_devices
@pytest.mark.parametrize(
Expand Down Expand Up @@ -1025,6 +1165,7 @@ def test_transformers_wav2vec2(
assert transcription == expected_transcription[0]



class TestWav2Vec2Bert:
@classmethod
def teardown_class(cls):
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_data_dir():
)

# Verify that downloaded files are present.
translit_model = os.path.join(data_dir, "models", "transliteration-aren-all")
translit_model = os.path.join(data_dir, "models", "v1", "aren-transliteration")
if not os.path.isdir(translit_model):
pytest.skip("Data files are not available")

Expand Down
10 changes: 10 additions & 0 deletions src/cpu/primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ namespace ctranslate2 {
x[indices[i]] = a;
}

template<>
template <typename T>
void primitives<Device::CPU>::indexed_pointwise_multiply(T* x, const T* values, const int32_t* indices, dim_t num_indices) {
for (dim_t i = 0; i < num_indices; ++i) {
x[indices[i]] = x[indices[i]] * values[i];
}
}

template<>
template <typename T>
void primitives<Device::CPU>::copy(const T* x, T* y, dim_t size) {
Expand Down Expand Up @@ -1153,6 +1161,8 @@ namespace ctranslate2 {
template void \
primitives<Device::CPU>::indexed_fill(T*, T, const int32_t*, dim_t); \
template void \
primitives<Device::CPU>::indexed_pointwise_multiply(T* x, const T*, const int32_t*, dim_t); \
template void \
primitives<Device::CPU>::copy(const T* x, T* y, dim_t size); \
template T \
primitives<Device::CPU>::sum(const T* array, dim_t size); \
Expand Down
Loading
Loading