Skip to content

Commit

Permalink
Add an option tail_paddings
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Dec 7, 2023
1 parent 191407a commit 947173e
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .github/scripts/test-offline-whisper.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ for name in ${names[@]}; do
--tokens=$repo/${name}-tokens.txt \
--whisper-encoder=$repo/${name}-encoder.onnx \
--whisper-decoder=$repo/${name}-decoder.onnx \
--whisper-tail-paddings=500 \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
Expand All @@ -58,6 +59,7 @@ for name in ${names[@]}; do
--tokens=$repo/${name}-tokens.txt \
--whisper-encoder=$repo/${name}-encoder.int8.onnx \
--whisper-decoder=$repo/${name}-decoder.int8.onnx \
--whisper-tail-paddings=500 \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
tail_padding_frames = 300;
}

if (config_.model_config.whisper.tail_paddings > 0) {
tail_padding_frames = config_.model_config.whisper.tail_paddings;
}

int32_t actual_frames =
std::min(num_frames + tail_padding_frames, max_num_frames);

Expand Down
11 changes: 10 additions & 1 deletion sherpa-onnx/csrc/offline-whisper-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
"Valid values: transcribe, translate. "
"Note that for non-multilingual models, it supports "
"only 'transcribe'");

po->Register(
"whisper-tail-paddings", &tail_paddings,
"Suggest value: 50 for English models. 300 for multilingual models. "
"Since we have removed the 30-second constraint, we need to add some "
"tail padding frames "
"so that whisper can detect the eot token. Leave it to -1 to use 50 for "
"English models and 300 for multilingual models.");
}

bool OfflineWhisperModelConfig::Validate() const {
Expand Down Expand Up @@ -63,7 +71,8 @@ std::string OfflineWhisperModelConfig::ToString() const {
os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\", ";
os << "language=\"" << language << "\", ";
os << "task=\"" << task << "\")";
os << "task=\"" << task << "\", ";
os << "tail_paddings=" << tail_paddings << ")";

return os.str();
}
Expand Down
18 changes: 16 additions & 2 deletions sherpa-onnx/csrc/offline-whisper-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,26 @@ struct OfflineWhisperModelConfig {
// Note: For non-multilingual models, it supports only "transcribe"
std::string task = "transcribe";

// Number of tail padding frames.
//
// Since we remove the 30-second constraint, we need to add some paddings
// at the end.
//
// Recommended values:
// - 50 for English models
// - 300 for multilingual models
int32_t tail_paddings = -1;

OfflineWhisperModelConfig() = default;
OfflineWhisperModelConfig(const std::string &encoder,
const std::string &decoder,
const std::string &language,
const std::string &task)
: encoder(encoder), decoder(decoder), language(language), task(task) {}
const std::string &task, int32_t tail_paddings)
: encoder(encoder),
decoder(decoder),
language(language),
task(task),
tail_paddings(tail_paddings) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
5 changes: 3 additions & 2 deletions sherpa-onnx/python/csrc/offline-whisper-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ void PybindOfflineWhisperModelConfig(py::module *m) {
using PyClass = OfflineWhisperModelConfig;
py::class_<PyClass>(*m, "OfflineWhisperModelConfig")
.def(py::init<const std::string &, const std::string &,
const std::string &, const std::string &>(),
const std::string &, const std::string &, int32_t>(),
py::arg("encoder"), py::arg("decoder"), py::arg("language"),
py::arg("task"))
py::arg("task"), py::arg("tail_paddings") = -1)
.def_readwrite("encoder", &PyClass::encoder)
.def_readwrite("decoder", &PyClass::decoder)
.def_readwrite("language", &PyClass::language)
.def_readwrite("task", &PyClass::task)
.def_readwrite("tail_paddings", &PyClass::tail_paddings)
.def("__str__", &PyClass::ToString);
}

Expand Down

0 comments on commit 947173e

Please sign in to comment.