Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support reading multi-channel wave files with 8/16/32-bit encoded samples #1258

Merged
merged 6 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions .github/scripts/test-offline-ctc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,28 @@ done


# test wav reader for non-standard wav files
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/naudio.wav
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/junk-padding.wav
waves=(
naudio.wav
junk-padding.wav
int8-1-channel-zh.wav
int8-2-channel-zh.wav
int8-4-channel-zh.wav
int16-1-channel-zh.wav
int16-2-channel-zh.wav
int32-1-channel-zh.wav
int32-2-channel-zh.wav
float32-1-channel-zh.wav
float32-2-channel-zh.wav
)
for w in ${waves[@]}; do
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/$w

time $EXE \
--tokens=$repo/tokens.txt \
--sense-voice-model=$repo/model.int8.onnx \
./naudio.wav \
./junk-padding.wav
time $EXE \
--tokens=$repo/tokens.txt \
--sense-voice-model=$repo/model.int8.onnx \
$w
rm -v $w
done

rm -rf $repo

Expand Down
17 changes: 8 additions & 9 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,35 +143,34 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*

- name: Test online punctuation
- name: Test offline CTC
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-online-punctuation
export EXE=sherpa-onnx-offline

.github/scripts/test-online-punctuation.sh
.github/scripts/test-offline-ctc.sh
du -h -d1 .

- name: Test offline transducer
- name: Test online punctuation
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
export EXE=sherpa-onnx-online-punctuation

.github/scripts/test-offline-transducer.sh
.github/scripts/test-online-punctuation.sh
du -h -d1 .


- name: Test offline CTC
- name: Test offline transducer
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline

.github/scripts/test-offline-ctc.sh
.github/scripts/test-offline-transducer.sh
du -h -d1 .

- name: Test online transducer
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/offline-tts-frontend.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_FRONTEND_H_
#include <cstdint>
#include <string>
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/macros.h"
Expand Down
127 changes: 110 additions & 17 deletions sherpa-onnx/csrc/wave-reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ struct WaveHeader {
};
static_assert(sizeof(WaveHeader) == 44);

/*
sox int16-1-channel-zh.wav -b 8 int8-1-channel-zh.wav

sox int16-1-channel-zh.wav -c 2 int16-2-channel-zh.wav

we use audacity to generate int32-1-channel-zh.wav and float32-1-channel-zh.wav
because sox uses WAVE_FORMAT_EXTENSIBLE, which is not easy to support
in sherpa-onnx.
*/

// Read a wave file of mono-channel.
// Return its samples normalized to the range [-1, 1).
std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
Expand Down Expand Up @@ -114,9 +124,18 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
is.read(reinterpret_cast<char *>(&header.audio_format),
sizeof(header.audio_format));

if (header.audio_format != 1) { // 1 for PCM
if (header.audio_format != 1 && header.audio_format != 3) {
// 1 for integer PCM
// 3 for floating point PCM
// see https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
// and https://github.com/microsoft/DirectXTK/wiki/Wave-Formats
SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n",
header.audio_format);

if (header.audio_format == static_cast<int16_t>(0xfffe)) {
SHERPA_ONNX_LOGE("We don't support WAVE_FORMAT_EXTENSIBLE files.");
}

*is_ok = false;
return {};
}
Expand All @@ -125,10 +144,9 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
sizeof(header.num_channels));

if (header.num_channels != 1) { // we support only single channel for now
SHERPA_ONNX_LOGE("Expected single channel. Given: %d\n",
header.num_channels);
*is_ok = false;
return {};
SHERPA_ONNX_LOGE(
"Warning: %d channels are found. We only use the first channel.\n",
header.num_channels);
}

is.read(reinterpret_cast<char *>(&header.sample_rate),
Expand Down Expand Up @@ -161,8 +179,9 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
return {};
}

if (header.bits_per_sample != 16) { // we support only 16 bits per sample
SHERPA_ONNX_LOGE("Expected bits_per_sample 16. Given: %d\n",
if (header.bits_per_sample != 8 && header.bits_per_sample != 16 &&
header.bits_per_sample != 32) {
SHERPA_ONNX_LOGE("Expected bits_per_sample 8, 16 or 32. Given: %d\n",
header.bits_per_sample);
*is_ok = false;
return {};
Expand Down Expand Up @@ -199,21 +218,95 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,

*sampling_rate = header.sample_rate;

// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains two bytes, so it is divided by 2 here
std::vector<int16_t> samples(header.subchunk2_size / 2);
std::vector<float> ans;

is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
if (!is) {
if (header.bits_per_sample == 16 && header.audio_format == 1) {
// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains two bytes, so it is divided by 2 here
std::vector<int16_t> samples(header.subchunk2_size / 2);
SHERPA_ONNX_LOGE("%d samples, bytes: %d", (int)samples.size(),
header.subchunk2_size);

is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
if (!is) {
SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size);
*is_ok = false;
return {};
}

ans.resize(samples.size() / header.num_channels);

// samples are interleaved
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
ans[i] = samples[i * header.num_channels] / 32768.;
}
} else if (header.bits_per_sample == 8 && header.audio_format == 1) {
// number of samples == number of bytes for 8-bit encoded samples
//
// For 8-bit encoded samples, they are unsigned!
std::vector<uint8_t> samples(header.subchunk2_size);

is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
if (!is) {
SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size);
*is_ok = false;
return {};
}

ans.resize(samples.size() / header.num_channels);
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
// Note(fangjun): We want to normalize each sample into the range [-1, 1]
// Since each original sample is in the range [0, 256], dividing
// them by 128 converts them to the range [0, 2];
// so after subtracting 1, we get the range [-1, 1]
//
ans[i] = samples[i * header.num_channels] / 128. - 1;
}
} else if (header.bits_per_sample == 32 && header.audio_format == 1) {
// 32 here is for int32
//
// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains 4 bytes, so it is divided by 4 here
std::vector<int32_t> samples(header.subchunk2_size / 4);

is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
if (!is) {
SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size);
*is_ok = false;
return {};
}

ans.resize(samples.size() / header.num_channels);
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
ans[i] = static_cast<float>(samples[i * header.num_channels]) / (1 << 31);
}
} else if (header.bits_per_sample == 32 && header.audio_format == 3) {
// 32 here is for float32
//
// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains 4 bytes, so it is divided by 4 here
std::vector<float> samples(header.subchunk2_size / 4);

is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
if (!is) {
SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size);
*is_ok = false;
return {};
}

ans.resize(samples.size() / header.num_channels);
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
ans[i] = samples[i * header.num_channels];
}
} else {
SHERPA_ONNX_LOGE(
"Unsupported %d bits per sample and audio format: %d. Supported values "
"are: 8, 16, 32.",
header.bits_per_sample, header.audio_format);
*is_ok = false;
return {};
}

std::vector<float> ans(samples.size());
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
ans[i] = samples[i] / 32768.;
}

*is_ok = true;
return ans;
}
Expand Down
17 changes: 8 additions & 9 deletions sherpa-onnx/jni/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,9 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromFile(JNIEnv *env,
return (jlong)model;
}


SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_setConfig(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jobject _config) {
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_setConfig(
JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) {
auto config = sherpa_onnx::GetOfflineConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

Expand Down Expand Up @@ -350,9 +346,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env,
// [3]: lang, jstring
// [4]: emotion, jstring
// [5]: event, jstring
env->SetObjectArrayElement(obj_arr, 3, env->NewStringUTF(result.lang.c_str()));
env->SetObjectArrayElement(obj_arr, 4, env->NewStringUTF(result.emotion.c_str()));
env->SetObjectArrayElement(obj_arr, 5, env->NewStringUTF(result.event.c_str()));
env->SetObjectArrayElement(obj_arr, 3,
env->NewStringUTF(result.lang.c_str()));
env->SetObjectArrayElement(obj_arr, 4,
env->NewStringUTF(result.emotion.c_str()));
env->SetObjectArrayElement(obj_arr, 5,
env->NewStringUTF(result.event.c_str()));

return obj_arr;
}
Loading