Skip to content

Commit

Permalink
Migrated to end-to-end models trained using Icefall
Browse files Browse the repository at this point in the history
  • Loading branch information
alumae committed Mar 30, 2023
1 parent 6292194 commit d35dc2d
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 136 deletions.
79 changes: 58 additions & 21 deletions asr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from vosk import Model, KaldiRecognizer, SetLogLevel
import sys
import os
import sys
Expand All @@ -8,12 +7,30 @@
from queue import Queue
import threading

def text2result(text):
words = text.split()
return {"result": [{"word": word} for word in words]}

def add_pseudo_timestamps(result, start_sample, end_sample):
#print(result)
if len(result["result"]) == 0:
return result
num_chars = sum([len(w["word"]) for w in result["result"]])
num_samples_per_char = (end_sample - start_sample) / num_chars
pos = start_sample
for w in result["result"]:
w["start"] = pos / 16000
pos += len(w["word"]) * num_samples_per_char
#print(result)
return result



class TurnDecoder():
def __init__(self, model, chunk_generator):
self.model = model
self.rec = KaldiRecognizer(model, 16000)
def __init__(self, recognizer, chunk_generator):
self.recognizer = recognizer
self.chunk_generator = chunk_generator
self.send_chunk_length = 16000 # how big are the chunks that we send to Kaldi
self.send_chunk_length = 16000 // 10 # how big are the chunks that we send to Kaldi
self.result_queue = Queue(10)
thread = threading.Thread(target=self.run)
thread.daemon = True
Expand All @@ -29,30 +46,50 @@ def decode_results(self):

def run(self):
buffer = torch.tensor([])


stream = self.recognizer.create_stream()
last_result = ""
segment_id = 0
current_start_sample = 0
num_samples_consumed = 0
for chunk in self.chunk_generator:
buffer = torch.cat([buffer, chunk])

if len(buffer) >= self.send_chunk_length:
bytes = (buffer * torch.iinfo(torch.int16).max).short().numpy().tobytes()
if self.rec.AcceptWaveform(bytes):
res = self.rec.Result()
jres = json.loads(res)
jres["final"] = True
self.result_queue.put(jres)
else:
res = self.rec.PartialResult()
jres = json.loads(res)
jres["final"] = False
stream.accept_waveform(16000, buffer.numpy())
num_samples_consumed += len(buffer)
while self.recognizer.is_ready(stream):
self.recognizer.decode_stream(stream)

is_endpoint = self.recognizer.is_endpoint(stream)
result = self.recognizer.get_result(stream)
jres = text2result(result)
jres = add_pseudo_timestamps(jres, current_start_sample, num_samples_consumed)
if result and (last_result != result) or is_endpoint:
last_result = result
jres["final"] = is_endpoint
self.result_queue.put(jres)

if is_endpoint:
if result:
segment_id += 1
current_start_sample = num_samples_consumed
self.recognizer.reset(stream)

buffer = torch.tensor([])


if len(buffer) > 0:
bytes = (buffer * torch.iinfo(torch.int16).max).short().numpy().tobytes()
self.rec.AcceptWaveform(bytes)

res = self.rec.FinalResult()
jres = json.loads(res)
stream.accept_waveform(16000, buffer.numpy())
num_samples_consumed += len(buffer)
stream.input_finished()
while self.recognizer.is_ready(stream):
self.recognizer.decode_stream(stream)

text = self.recognizer.get_result(stream)
self.recognizer.reset(stream)
jres = text2result(text)
jres = add_pseudo_timestamps(jres, current_start_sample, num_samples_consumed)
jres["final"] = True
self.result_queue.put(jres)
self.result_queue.put(None)
Expand Down
96 changes: 12 additions & 84 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,98 +1,26 @@
FROM debian:10.4
FROM python:3.9-slim-buster

ARG KALDI_MKL
RUN pip install torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu

RUN sed -i "s/deb.debian.org/ftp.ee.debian.org/g" /etc/apt/sources.list

RUN apt-get update && \
apt-get install -y --no-install-recommends \
wget \
bzip2 \
unzip \
xz-utils \
g++ \
make \
cmake \
git \
python3 \
python3-dev \
python3-websockets \
python3-setuptools \
python3-pip \
python3-wheel \
python3-cffi \
zlib1g-dev \
automake \
autoconf \
libtool \
pkg-config \
ca-certificates \
libsndfile1 \
ffmpeg \
&& rm -rf /var/lib/apt/lists/*

RUN \
git clone -b vosk --single-branch https://github.com/alphacep/kaldi /opt/kaldi \
&& cd /opt/kaldi/tools \
&& git clone -b v0.3.13 --single-branch https://github.com/xianyi/OpenBLAS \
&& git clone -b v3.2.1 --single-branch https://github.com/alphacep/clapack \
&& make -C OpenBLAS ONLY_CBLAS=1 DYNAMIC_ARCH=1 TARGET=NEHALEM USE_LOCKING=1 USE_THREAD=0 all \
&& make -C OpenBLAS PREFIX=$(pwd)/OpenBLAS/install install \
&& mkdir -p clapack/BUILD && cd clapack/BUILD && cmake .. && make -j 10 && find . -name "*.a" | xargs cp -t ../../OpenBLAS/install/lib \
&& cd /opt/kaldi/tools \
&& git clone --single-branch https://github.com/alphacep/openfst openfst \
&& cd openfst \
&& autoreconf -i \
&& CFLAGS="-g -O3" ./configure --prefix=/opt/kaldi/tools/openfst --enable-static --enable-shared --enable-far --enable-ngram-fsts --enable-lookahead-fsts --with-pic --disable-bin \
&& make -j 10 && make install \
&& cd /opt/kaldi/src \
&& ./configure --mathlib=OPENBLAS_CLAPACK --shared --use-cuda=no \
&& sed -i 's:-msse -msse2:-msse -msse2:g' kaldi.mk \
&& sed -i 's: -O1 : -O3 :g' kaldi.mk \
&& make -j $(nproc) online2 lm rnnlm \
&& find /opt/kaldi -name "*.o" -exec rm {} \;

ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PATH="/root/miniconda3/bin:${PATH}"

RUN wget \
https://repo.anaconda.com/miniconda/Miniconda3-py39_22.11.1-1-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-py39_22.11.1-1-Linux-x86_64.sh -b \
&& rm -f Miniconda3-py39_22.11.1-1-Linux-x86_64.sh
RUN conda --version

RUN \
git clone -b intermediate_full_results https://github.com/alumae/vosk-api /opt/vosk-api \
&& cd /opt/vosk-api/src \
&& KALDI_ROOT=/opt/kaldi OPENFST_ROOT=/opt/kaldi/tools/openfst/ OPENBLAS_ROOT=/opt/kaldi/tools/OpenBLAS/install make -j $(nproc) \
&& cd /opt/vosk-api/python \
&& python3 ./setup.py install


RUN conda install -c conda-forge pynini=2.1.3

RUN conda install pytorch=1.10.0 torchvision torchaudio=0.10.0 cpuonly -c pytorch

RUN pip install pytorch-lightning==1.2.5 'ray[default]' torchmetrics==0.2.0 \
RUN pip3 install pytorch-lightning==1.2.5 'ray[default]' torchmetrics==0.2.0 \
tokenizers pytorch-nlp py-term matplotlib scipy \
librosa==0.8.0 lxml audiomentations pytest event-scheduler
librosa==0.8.0 lxml audiomentations pytest event-scheduler \
onnx sherpa-onnx

COPY ./models /opt/models
RUN apt-get update && apt-get install -y --no-install-recommends git ffmpeg

RUN echo '2022-01-31_16:24' >/dev/null
COPY models /opt/models

RUN git clone https://github.com/alumae/streaming-punctuator /opt/streaming-punctuator
RUN echo '2022-01-31_16:24' >/dev/null

RUN git clone https://github.com/alumae/online_speaker_change_detector.git /opt/online-speaker-change-detector

RUN git clone https://github.com/alumae/et-g2p-fst.git /opt/et-g2p-fst

RUN echo '2022-04-11_15:54' >/dev/null \
&& git clone https://github.com/alumae/kiirkirjutaja.git /opt/kiirkirjutaja \
RUN mkdir /opt/kiirkirjutaja \
&& cd /opt/kiirkirjutaja && ln -s ../models

COPY *.py /opt/kiirkirjutaja/

ENV PYTHONPATH="/opt/streaming-punctuator:/opt/online-speaker-change-detector:/opt/et-g2p-fst"
ENV PYTHONPATH="/opt/online-speaker-change-detector"

WORKDIR /opt/kiirkirjutaja

Expand Down
67 changes: 45 additions & 22 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import re
import ray
import torch
import sherpa_onnx

# Needed for loading the speaker change detection model
from pytorch_lightning.utilities import argparse_utils
Expand All @@ -19,11 +20,11 @@
from asr import TurnDecoder
from lid import LanguageFilter
from online_scd.model import SCDModel
import vosk
from unk_decoder import UnkDecoder
from compound import CompoundReconstructor
from words2numbers import Words2Numbers
from punctuate import Punctuate
#import vosk
#from unk_decoder import UnkDecoder
#from compound import CompoundReconstructor
#from words2numbers import Words2Numbers
#from punctuate import Punctuate
from confidence import confidence_filter
from presenters import *
import utils
Expand All @@ -34,26 +35,33 @@

ray.init(num_cpus=4)

RemotePunctuate = ray.remote(Punctuate)
RemoteWords2Numbers = ray.remote(Words2Numbers)
#RemotePunctuate = ray.remote(Punctuate)
#RemoteWords2Numbers = ray.remote(Words2Numbers)

unk_decoder = UnkDecoder()
compound_reconstructor = CompoundReconstructor()
remote_words2numbers = RemoteWords2Numbers.remote()
remote_punctuate = RemotePunctuate.remote("models/punctuator/checkpoints/best.ckpt", "models/punctuator/tokenizer.json")
#unk_decoder = UnkDecoder()
#compound_reconstructor = CompoundReconstructor()
#remote_words2numbers = RemoteWords2Numbers.remote()
#remote_punctuate = RemotePunctuate.remote("models/punctuator/checkpoints/best.ckpt", "models/punctuator/tokenizer.json")


def process_result(result):
result = unk_decoder.post_process(result)
#result = unk_decoder.post_process(result)
text = ""
if "result" in result:
text = " ".join([wi["word"] for wi in result["result"]])

text = compound_reconstructor.post_process(text)
text = ray.get(remote_words2numbers.post_process.remote(text))
text = ray.get(remote_punctuate.post_process.remote(text))
result = utils.reconstruct_full_result(result, text)
result = confidence_filter(result)
result_words = []
for word in result["result"]:
if word["word"] in ",.!?" and len(result_words) > 0:
result_words[-1]["word"] += word["word"]
else:
result_words.append(word)
result["result"] = result_words
#text = " ".join([wi["word"] for wi in result["result"]])

#text = compound_reconstructor.post_process(text)
#text = ray.get(remote_words2numbers.post_process.remote(text))
#text = ray.get(remote_punctuate.post_process.remote(text))
#result = utils.reconstruct_full_result(result, text)
#result = confidence_filter(result)
return result
else:
return result
Expand All @@ -73,7 +81,22 @@ def main(args):
#presenter = TerminalPresenter()

scd_model = SCDModel.load_from_checkpoint("models/online-speaker-change-detector/checkpoints/epoch=102.ckpt")
vosk_model = vosk.Model("models/asr_model")
sherpa_model = sherpa_onnx.OnlineRecognizer(
tokens="models/sherpa/tokens.txt",
encoder="models/sherpa/encoder.onnx",
decoder="models/sherpa/decoder.onnx",
joiner="models/sherpa/joiner.onnx",
num_threads=4,
sample_rate=16000,
feature_dim=80,
enable_endpoint_detection=True,
rule1_min_trailing_silence=2.4,
rule2_min_trailing_silence=1.2,
rule3_min_utterance_length=300,
decoding_method="modified_beam_search",
max_feature_vectors=1000, # 10 seconds
)


speech_segment_generator = SpeechSegmentGenerator(args.input_file)
language_filter = LanguageFilter()
Expand All @@ -90,11 +113,11 @@ def main_loop():
presenter.new_turn()
turn_start_time = (speech_segment.start_sample + turn.start_sample) / 16000

turn_decoder = TurnDecoder(vosk_model, language_filter.filter(turn.chunks()))
turn_decoder = TurnDecoder(sherpa_model, language_filter.filter(turn.chunks()))
for res in turn_decoder.decode_results():
if "result" in res:
processed_res = process_result(res)

#processed_res = res
if res["final"]:
presenter.final_result(processed_res["result"])
else:
Expand Down
Loading

0 comments on commit d35dc2d

Please sign in to comment.