diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..f8f2779 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,127 @@ +# Created by .ignore support plugin (hsz.mobi) +### VirtualEnv template +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +.Python +[Bb]in +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +[Ss]cripts +pyvenv.cfg +.venv +pip-selfcheck.json +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +### Example user template template +### Example user template + +# IntelliJ project files +.idea +*.iml +out +gen diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9717b09 --- /dev/null +++ b/.gitignore @@ -0,0 +1,127 @@ +# Created by .ignore support plugin (hsz.mobi) +experiments/ +tests/test_wavs/ + +**/.DS_Store + +.idea/ +warp-ctc/ +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensionsg +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +### VirtualEnv template +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +.Python +[Bb]in +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +[Ss]cripts +pyvenv.cfg +.venv +pip-selfcheck.json + diff --git a/.travis.yaml b/.travis.yaml new file mode 100644 index 0000000..c1c8b03 --- /dev/null +++ b/.travis.yaml @@ -0,0 +1,11 @@ +language: python + +python: + - "3.6" + +install: + - bash scripts/install_dependencies.sh + - pip install -e . + +script: + - pytest diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..3b9c9ab --- /dev/null +++ b/Dockerfile @@ -0,0 +1,38 @@ +FROM pytorch/pytorch:1.1.0-cuda10.0-cudnn7.5-devel +ARG CUDA=false + + +WORKDIR /workspace/ +COPY . . +# install basics +RUN apt-get update -y +RUN apt-get install -y git curl ca-certificates bzip2 cmake tree htop bmon iotop sox libsox-dev libsox-fmt-all vim wget + +ENV LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH + +# install python deps +RUN pip install -r requirements.txt + +RUN rm -rf warp-ctc +RUN git clone https://github.com/SeanNaren/warp-ctc.git +RUN if [ "$CUDA" = false ] ; then sed -i 's/option(WITH_OMP \"compile warp-ctc with openmp.\" ON)/option(WITH_OMP \"compile warp-ctc with openmp.\" ${CUDA_FOUND})/' warp-ctc/CMakeLists.txt ; else export CUDA_HOME="/usr/local/cuda" ; fi +RUN cd warp-ctc; mkdir build; cd build; cmake ..; make +RUN cd warp-ctc/pytorch_binding && python setup.py install +RUN rm -rf warp-ctc + +RUN pip install -r post_requirements.txt + + +#TODO: Do we need those two below? +# install ctcdecode +#RUN git clone --recursive https://github.com/parlance/ctcdecode.git +#RUN cd ctcdecode; pip install . + +# install deepspeech.pytorch +ADD . /workspace/deepspeech.pytorch +RUN cd deepspeech.pytorch; pip install -r requirements.txt + +# launch jupiter +RUN pip install jupyter +RUN mkdir data; mkdir notebooks; +#CMD jupyter-notebook --ip="*" --no-browser --allow-root \ No newline at end of file diff --git a/loader.py b/loader.py deleted file mode 100644 index 00b06d1..0000000 --- a/loader.py +++ /dev/null @@ -1,248 +0,0 @@ -# ---------------------------------------------------------------------------- -# Based on SeanNaren's deepspeech.pytorch: -# https://github.com/SeanNaren/deepspeech.pytorch -# ---------------------------------------------------------------------------- - -import math -import warnings -from typing import Tuple - -import librosa -import numpy as np -import torch -import torchaudio -from scipy import signal -from torch.utils.data import Dataset, DataLoader, Sampler -from torch.distributed import get_rank -from torch.distributed import get_world_size - -windows = {"bartlett": torch.bartlett_window, - "blackman": torch.blackman_window, - "hamming": torch.hamming_window, - "hann": torch.hann_window} - -windows_legacy = {'hamming': signal.hamming, - 'hann': signal.hann, - 'blackman': signal.blackman, - 'bartlett': signal.bartlett} - - -class DataProcessor(object): - def __init__(self, audio_conf, labels="abc", normalize=False, augment=False, legacy=True): - """ - Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by - a comma. Each new line is a different sample. Example below: - - /path/to/audio.wav,/path/to/audio.txt - ... - - :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds - :param labels: String containing all the possible characters to map to - :param normalize: Apply standard mean and deviation normalization to audio tensor - :param augment(default False): Apply random tempo and gain perturbations - """ - self.labels_map = dict([(labels[i], i) for i in range(len(labels))]) - self.window_stride = audio_conf["window_stride"] - self.window_size = audio_conf["window_size"] - self.sample_rate = audio_conf["sample_rate"] - self.window = windows_legacy.get(audio_conf["window"], windows_legacy["hamming"]) if legacy else windows.get(audio_conf["window"], windows["hamming"]) - self.normalize = normalize - self.augment = augment - self.legacy = legacy - self.transform = torchaudio.transforms.Spectrogram(n_fft=int(self.sample_rate * self.window_size), - hop=int(self.sample_rate * self.window_stride), - window=self.window, normalize=self.normalize) - - @staticmethod - def retrieve_file(audio_path, legacy=True): - sound, sample_rate = torchaudio.load(audio_path) - if legacy: - sound = sound.numpy().T - if len(sound.shape) > 1: - if sound.shape[1] == 1: - sound = sound.squeeze() - else: - sound = sound.mean(axis=1) - return sound, sample_rate - - @staticmethod - def augment_audio(sound, tempo_range: Tuple = (0.85, 1.15), gain_range: Tuple = (-6, 8)): - """ - Changes tempo and gain of the wave - """ - warnings.warn("Augmentation is not implemented") # TODO: Implement - return sound - - def parse_audio(self, audio_path): - sound, sample_rate = self.retrieve_file(audio_path, self.legacy) - if sample_rate != self.sample_rate: - raise ValueError(f"The stated sample rate {self.sample_rate} and the factual rate {sample_rate} differ!") - - if self.augment: - sound = self.augment_audio(sound) - - if self.legacy: - n_fft = int(self.sample_rate * self.window_size) - win_length = n_fft - hop_length = int(self.sample_rate * self.window_stride) - # STFT - D = librosa.stft(sound, n_fft=n_fft, hop_length=hop_length, - win_length=win_length, window=self.window) - spectrogram, phase = librosa.magphase(D) - # S = log(S+1) - - spectrogram = torch.FloatTensor(np.log1p(spectrogram)) - else: - # TODO: Why these are different from librosa.stft? - sound = sound.cuda() - spectrogram = self.transform(sound)[-1, :, :].transpose(0, 1) - - # spectrogram = torch.stft(torch.from_numpy(sound.numpy().T.squeeze()), - # n_fft=int(self.sample_rate * self.window_size), - # hop_length=int(self.sample_rate * self.window_stride), - # win_length=int(self.sample_rate * self.window_size), - # window=torch.hamming_window(int(self.sample_rate * self.window_size)))[:, :, -1] - - if self.normalize: - mean = spectrogram.mean() - std = spectrogram.std() - spectrogram.add_(-mean) - spectrogram.div_(std) - - return spectrogram - - def parse_transcript(self, transcript_path): - with open(transcript_path, 'r', encoding='utf8') as transcript_file: - transcript = transcript_file.read().replace('\n', '') - # TODO: Is it fast enough? - transcript = list(filter(None, [self.labels_map.get(x) for x in list(transcript)])) - return transcript - - -class AudioDataset(Dataset): - def __init__(self, audio_conf, manifest_filepath, labels, normalize=False, augment=False, legacy=True): - """ - Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by - a comma. Each new line is a different sample. Example below: - - /path/to/audio.wav,/path/to/audio.txt - ... - - :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds - :param manifest_filepath: Path to manifest csv as describe above - :param labels: String containing all the possible characters to map to - :param normalize: Apply standard mean and deviation normalization to audio tensor - :param augment(default False): Apply random tempo and gain perturbations - """ - super(AudioDataset, self).__init__() - with open(manifest_filepath) as f: - ids = f.readlines() - ids = [x.strip().split(',') for x in ids] - self.ids = ids - self.size = len(ids) - self.processor = DataProcessor(audio_conf, labels, normalize, augment, legacy) - - def __getitem__(self, index): - sample = self.ids[index] - audio_path, transcript_path = sample[0], sample[1] - - spectrogram = self.processor.parse_audio(audio_path) - transcript = self.processor.parse_transcript(transcript_path) - - return spectrogram, transcript - - def __len__(self): - return self.size - - -# TODO: Optimise -def _collate_fn(batch): - batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True) - longest_sample = batch[0][0] - freq_size, max_seqlength = longest_sample.size() - minibatch_size = len(batch) - inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength) - input_percentages = torch.FloatTensor(minibatch_size) - target_sizes = np.zeros(minibatch_size, dtype=np.int32) - - # TODO: Numpy broadcasting magic - targets = [] - - for x in range(minibatch_size): - inputs[x][0].narrow(1, 0, batch[x][0].size(1)).copy_(batch[x][0]) - input_percentages[x] = batch[x][0].size(1) / float(max_seqlength) - target_sizes[x] = len(batch[x][1]) - targets.extend(batch[x][1]) - - return inputs, torch.IntTensor(targets), input_percentages, torch.from_numpy(target_sizes) - - -class AudioDataLoader(DataLoader): - def __init__(self, *args, **kwargs): - """ - Creates a data loader for AudioDatasets. - """ - super(AudioDataLoader, self).__init__(*args, **kwargs) - self.collate_fn = _collate_fn - - -class BucketingSampler(Sampler): - def __init__(self, data_source, batch_size=1): - """ - Samples batches assuming they are in order of size to batch similarly sized samples together. - """ - super(BucketingSampler, self).__init__(data_source) - self.data_source = data_source - ids = list(range(0, len(data_source))) - # TODO: Optimise - self.bins = [ids[i:i + batch_size] for i in range(0, len(ids), batch_size)] - - def __iter__(self): - for ids in self.bins: - np.random.shuffle(ids) - yield ids - - def __len__(self): - return len(self.bins) - - def shuffle(self, epoch): - np.random.shuffle(self.bins) - - -# TODO: Optimise -class DistributedBucketingSampler(Sampler): - def __init__(self, data_source, batch_size=1, num_replicas=None, rank=None): - """ - Samples batches assuming they are in order of size to batch similarly sized samples together. - """ - super(DistributedBucketingSampler, self).__init__(data_source) - if num_replicas is None: - num_replicas = get_world_size() - if rank is None: - rank = get_rank() - self.data_source = data_source - self.ids = list(range(0, len(data_source))) - self.batch_size = batch_size - self.bins = [self.ids[i:i + batch_size] for i in range(0, len(self.ids), batch_size)] - self.num_replicas = num_replicas - self.rank = rank - self.num_samples = int(math.ceil(len(self.bins) * 1.0 / self.num_replicas)) - self.total_size = self.num_samples * self.num_replicas - - def __iter__(self): - offset = self.rank - # add extra samples to make it evenly divisible - bins = self.bins + self.bins[:(self.total_size - len(self.bins))] - assert len(bins) == self.total_size - samples = bins[offset::self.num_replicas] # Get every Nth bin, starting from rank - return iter(samples) - - def __len__(self): - return self.num_samples - - def shuffle(self, epoch): - # deterministically shuffle based on epoch - g = torch.Generator() - g.manual_seed(epoch) - bin_ids = list(torch.randperm(len(self.bins), generator=g)) - self.bins = [self.bins[i] for i in bin_ids] diff --git a/post_requirements.txt b/post_requirements.txt new file mode 100644 index 0000000..f750f4b --- /dev/null +++ b/post_requirements.txt @@ -0,0 +1 @@ +-e git://github.com/NVIDIA/apex.git#egg=apex diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8ae1526 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,26 @@ +audioread==2.1.7 +cycler==0.10.0 +decorator==4.4.0 +joblib==0.13.2 +kiwisolver==1.1.0 +librosa==0.6.3 +llvmlite==0.28.0 +matplotlib==3.1.0 +numba==0.43.1 +numpy==1.16.3 +Pillow==6.0.0 +pyparsing==2.4.0 +python-dateutil==2.8.0 +resampy==0.2.1 +scikit-learn==0.21.2 +scipy==1.3.0 +six==1.12.0 +torch==1.1.0 +torchvision==0.3.0 +tqdm==4.32.1 +pyyaml==5.1 +wget==3.2 +pytest==4.6.3 +click==7.0 +deprecation==2.0.6 +dataclasses==0.6 diff --git a/scripts/install_dependencies.sh b/scripts/install_dependencies.sh new file mode 100755 index 0000000..2be3b54 --- /dev/null +++ b/scripts/install_dependencies.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash + + +# Running without arguments -> installing into virtual env located in ./venv +# -a= takes precedence before the virtual env and installs to conda env +# -e=/path/to/venv installs in different venv then ./venv +# -c=true installs with cuda support (default false) + +set -e + +# define arguments +for i in "$@" +do +case ${i} in + -c=*|--cuda=*) + CUDA="${i#*=}" + shift # past argument=value + ;; + -a=*|--anaconda=*) + ANACONDA="${i#*=}" + shift # past argument=value + ;; + -e=*|--venv=*) + VENV="${i#*=}" + shift # past argument=value + ;; + *) + # unknown option + ;; +esac +done + +VENV=${VENV:-./venv} + +if [ -z ${ANACONDA+x} ] ; then + conda activate ${ANACONDA} +elif [ -z ${VENV+x} ] ; then + source ${VENV}/bin/activate +fi + +#TODO: Infer this automatically +CUDA=${CUDA:-false} + +pip install -r requirements.txt + +git clone https://github.com/SeanNaren/warp-ctc.git +if [ "$CUDA" = false ] ; then + # This works for mac, for other OSes remove '' after -i + sed -i '' 's/option(WITH_OMP \"compile warp-ctc with openmp.\" ON)/option(WITH_OMP \"compile warp-ctc with openmp.\" ${CUDA_FOUND})/' warp-ctc/CMakeLists.txt +else + export CUDA_HOME="/usr/local/cuda" +fi +cd warp-ctc; mkdir build; cd build; cmake ..; make +cd ../pytorch_binding && MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install +cd ../.. +rm -rf warp-ctc + +git clone git@github.com:pytorch/audio.git +cd audio; MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install +cd .. +rm -rf audio + +pip install -r post_requirements.txt + +if [ -f ./src/pip-delete-this-directory.txt ]; then + rm -rf ./src/ +fi \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..4a1ac9b --- /dev/null +++ b/setup.py @@ -0,0 +1,10 @@ +from setuptools import setup, find_packages + +setup( + name="sonosco", + description="Framework for deep automatic speech recognition systems.", + author="Roboy", + packages=["sonosco"], + include_package_data=True, + dependency_links=[] +) diff --git a/STT_srv.py b/sonosco/STT_srv.py similarity index 100% rename from STT_srv.py rename to sonosco/STT_srv.py diff --git a/decoders/__init__.py b/sonosco/__init__.py similarity index 100% rename from decoders/__init__.py rename to sonosco/__init__.py diff --git a/models/__init__.py b/sonosco/common/__init__.py similarity index 100% rename from models/__init__.py rename to sonosco/common/__init__.py diff --git a/sonosco/common/audio_tools.py b/sonosco/common/audio_tools.py new file mode 100644 index 0000000..fa53021 --- /dev/null +++ b/sonosco/common/audio_tools.py @@ -0,0 +1,35 @@ +import subprocess +import numpy as np +import librosa + +from .noise_makers import NoiseMaker, GaussianNoiseMaker + + +def get_duration(file_path): + return float(subprocess.check_output([f'soxi -D "{file_path.strip()}"'], shell=True)) + + +def transcode_recording(source, destination, sample_rate): + subprocess.call([f"sox {source} -r {sample_rate} -b 16 -c 1 {destination}"], shell=True) + + +def transcode_recordings_an4(raw_path, wav_path, sample_rate): + subprocess.call([f'sox -t raw -r {sample_rate} -b 16 -e signed-integer -B -c 1 \"{raw_path}\" \"{wav_path}\"'], shell=True) + + +def transcode_recordings_ted3(source, destination, start_time, end_time, sample_rate): + subprocess.call([f"sox {source} -r {sample_rate} -b 16 -c 1 {destination} trim {start_time} ={end_time}"],shell=True) + + +def shift(audio, n_samples=1600): + return np.roll(audio, n_samples) + + +def stretch(audio, rate=1): + stretched_audio = librosa.effects.time_stretch(audio, rate) + return stretched_audio + + +def pitch_shift(audio, sample_rate=16000, n_steps=3.0): + stretched_audio = librosa.effects.pitch_shift(audio, sr=sample_rate, n_steps=n_steps) + return stretched_audio diff --git a/sonosco/common/class_utils.py b/sonosco/common/class_utils.py new file mode 100644 index 0000000..e4af560 --- /dev/null +++ b/sonosco/common/class_utils.py @@ -0,0 +1,36 @@ +import inspect +from typing import Set + + +def get_constructor_args(cls) -> Set[str]: + """ + E.g. + + class Bar(): + def __init__(self, arg1, arg2): + + get_constructor_args(Bar) + # returns ['arg1', 'arg2'] + Args: + cls (object): + + Returns: set containing names of constructor arguments + + """ + return set(inspect.getfullargspec(cls.__init__).args[1:]) + + +def get_class_by_name(name: str) -> type: + """ + Returns type object of class specified by name + Args: + name: full name of the class (with packages) + + Returns: class object + + """ + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod diff --git a/sonosco/common/click_extensions.py b/sonosco/common/click_extensions.py new file mode 100644 index 0000000..3554572 --- /dev/null +++ b/sonosco/common/click_extensions.py @@ -0,0 +1,16 @@ +import click +import ast +import logging + + +LOGGER = logging.getLogger(__name__) + + +class PythonLiteralOption(click.Option): + + def type_cast_value(self, ctx, value): + try: + return ast.literal_eval(value) + except Exception as e: + LOGGER.error(f"Malformed click input for PythonLiteralOption {e}", exc_info=True) + raise click.BadParameter(value) diff --git a/sonosco/common/constants.py b/sonosco/common/constants.py new file mode 100644 index 0000000..3d06bc6 --- /dev/null +++ b/sonosco/common/constants.py @@ -0,0 +1 @@ +SONOSCO = "sonosco" \ No newline at end of file diff --git a/sonosco/common/noise_makers.py b/sonosco/common/noise_makers.py new file mode 100644 index 0000000..a280e94 --- /dev/null +++ b/sonosco/common/noise_makers.py @@ -0,0 +1,24 @@ +import numpy as np + +from abc import ABC, abstractmethod + + +class NoiseMaker(ABC): + + @abstractmethod + def __call__(self, audio): + """Adds noise to the audio signal.""" + pass + + def add_noise(self, audio): + return self(audio) + + +class GaussianNoiseMaker(NoiseMaker): + + def __init__(self, std=0.002): + self.std = std + + def __call__(self, audio): + noise = np.random.randn(len(audio)) + return audio + self.std * noise diff --git a/sonosco/common/path_utils.py b/sonosco/common/path_utils.py new file mode 100644 index 0000000..2199b8f --- /dev/null +++ b/sonosco/common/path_utils.py @@ -0,0 +1,19 @@ +import os +import wget +import yaml +import codecs + + +def try_create_directory(path: str): + if not os.path.exists(path): + os.makedirs(path) + + +def try_download(destination: str, url: str): + if not os.path.exists(destination): + wget.download(url, destination) + + +def parse_yaml(file_path: str): + with codecs.open(file_path, "r", "utf-8") as file: + return yaml.load(file, Loader=yaml.FullLoader) diff --git a/sonosco/common/utils.py b/sonosco/common/utils.py new file mode 100644 index 0000000..d5b8d25 --- /dev/null +++ b/sonosco/common/utils.py @@ -0,0 +1,99 @@ +import logging +import numpy as np +import os +import subprocess +import os.path as path + +from shutil import copyfile +from typing import Tuple + + +def setup_logging(logger: logging.Logger, filename=None, verbosity=False): + logger.setLevel(logging.DEBUG) + + if filename is not None: + add_log_file(filename, logger) + + c_handler = logging.StreamHandler() + c_handler.setLevel(logging.DEBUG) if verbosity else c_handler.setLevel(logging.INFO) + c_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + c_handler.setFormatter(c_format) + logger.addHandler(c_handler) + + +def add_log_file(filename: str, logger: logging.Logger): + log_directory = os.path.dirname(filename) + if not os.path.exists(log_directory): + os.makedirs(log_directory) + filename = os.path.join(log_directory, f"{filename}.log") + f_handler = logging.FileHandler(filename=filename, mode="w") + f_handler.setLevel(logging.DEBUG) + f_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + f_handler.setFormatter(f_format) + logger.addHandler(f_handler) + + +def random_float(low: float, high: float): + return np.random.random() * (high - low) + low + + +def copy_code(source_dir, dest_dir, exclude_dirs: Tuple[str] = tuple(), exclude_files: Tuple[str] = tuple()): + """ + Copies code from source_dir to dest_dir. Excludes specified folders and files by substring-matching. + Parameters: + source_dir (string): location of the code to copy + dest_dir (string): location where the code should be copied to + exclude_dirs (list of strings): folders containing strings specified in this list will be ignored + exclude_files (list of strings): files containing strings specified in this list will be ignored + """ + source_basename = path.basename(source_dir) + for root, dirs, files in os.walk(source_dir, topdown=True): + + # skip ignored dirs + if any(ex_subdir in root for ex_subdir in exclude_dirs): + continue + + # construct destination dir + cropped_root = root[2:] if (root[:2] == './') else root + subdir_basename = path.basename(cropped_root) + + # do not treat the root as a subdir + if subdir_basename == source_basename: + subdir_basename = "" + dest_subdir = os.path.join(dest_dir, subdir_basename) + + # create destination folder + if not os.path.exists(dest_subdir): + os.makedirs(dest_subdir) + + # copy files + for filename in filter(lambda x: not any(substr in x for substr in exclude_files), files): + source_file_path = os.path.join(root, filename) + dest_file_path = os.path.join(dest_subdir, filename) + copyfile(source_file_path, dest_file_path) + + +def retrieve_git_hash(): + """ + Retrieves and returns the current gith hash if execution location is a git repo. + """ + try: + git_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip() + return git_hash + except subprocess.CalledProcessError as e: + print(e.output) + return False + + +def save_run_params_in_file(folder_path, run_config): + """ + Receives a run_config class, retrieves all member variables and saves them + in a config file for logging purposes. + Parameters: + folder_path - output folder + filename - output filename + run_config - shallow class with parameter members + """ + with open(path.join(folder_path, "run_params.conf"), 'w') as run_param_file: + for attr, value in sorted(run_config.__dict__.items()): + run_param_file.write(f"{attr}: {value}\n") diff --git a/sonosco/config/__init__.py b/sonosco/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sonosco/config/global_settings.py b/sonosco/config/global_settings.py new file mode 100644 index 0000000..2f2522a --- /dev/null +++ b/sonosco/config/global_settings.py @@ -0,0 +1 @@ +CUDA_ENABLED = False \ No newline at end of file diff --git a/config/infer.yaml b/sonosco/config/infer.yaml similarity index 100% rename from config/infer.yaml rename to sonosco/config/infer.yaml diff --git a/config/test.yaml b/sonosco/config/test.yaml similarity index 83% rename from config/test.yaml rename to sonosco/config/test.yaml index 2589e15..7609ef1 100644 --- a/config/test.yaml +++ b/sonosco/config/test.yaml @@ -4,6 +4,6 @@ test: batch_size: 32 # Batch size for testing num_workers: 4 # Number of workers used in loading verbose: True # Print out decoded output and error of each sample - save_output: Trur # Saves output of model from test + save_output: True # Saves output of model from test output_path: "" # Where to save raw acoustic output diff --git a/config/train.yaml b/sonosco/config/train.yaml similarity index 92% rename from config/train.yaml rename to sonosco/config/train.yaml index 25be72c..3e8e364 100644 --- a/config/train.yaml +++ b/sonosco/config/train.yaml @@ -3,7 +3,7 @@ train: val_manifest: 'examples/manifests/val_manifest.csv' labels_path: 'examples/labels.json' # Contains all characters for transcription log_dir: 'logs' # Location for log files - def_dir: 'examples/checkpoints/', # Default location to save/load models + def_dir: 'examples/checkpoints/' # Default location to save/load models load_from: 'asr_final.pth' # File name containing a checkpoint to continue/finetune @@ -16,9 +16,10 @@ train: hidden_size: 800 # Hidden size of RNNs hidden_layers: 5 # Number of RNN layers rnn_type: 'gru' # Type of the RNN unit: gru|lstm are supported + labels: 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' # labels used by the model max_epochs: 70 # Number of training epochs - learning_rate: 3e-4 # Initial learning rate + learning_rate: 3.0e-4 # Initial learning rate momentum: 0.9 # Momentum max_norm: 800 # Norm cutoff to prevent explosion of gradients learning_anneal: 1.1n # Annealing applied to learning rate every epoch diff --git a/sonosco/config/train_librispeech.yaml b/sonosco/config/train_librispeech.yaml new file mode 100644 index 0000000..64d6294 --- /dev/null +++ b/sonosco/config/train_librispeech.yaml @@ -0,0 +1,49 @@ +train: + train_manifest: '/Users/yuriy/temp/data/libri_speech/libri_test_clean_manifest.csv' + val_manifest: '/Users/yuriy/temp/data/libri_speech/libri_test_clean_manifest.csv' + log_dir: 'logs' # Location for log files + def_dir: 'examples/checkpoints/' # Default location to save/load models + + load_from: 'asr_final.pth' # File name containing a checkpoint to continue/finetune + + sample_rate: 16000 # Sample rate + window_size: 0.02 # Window size for spectrogram in seconds + window_stride: 0.01 # Window stride for spectrogram in seconds + window: 'hamming' # Window type for spectrogram generation + + batch_size: 32 # Batch size for training + hidden_size: 800 # Hidden size of RNNs + hidden_layers: 5 # Number of RNN layers + rnn_type: 'gru' # Type of the RNN unit: gru|lstm are supported + labels: 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' # labels used by the model + + max_epochs: 70 # Number of training epochs + learning_rate: 3.0e-4 # Initial learning rate + momentum: 0.9 # Momentum + max_norm: 800 # Norm cutoff to prevent explosion of gradients + learning_anneal: 1.1n # Annealing applied to learning rate every epoch + sortaGrad: True # Turn on ordering of dataset on sequence length for the first epoch + + checkpoint: True # Enables checkpoint saving of model + checkpoint_per_epoch: 1 # Save checkpoint per x epochs + silent: False # Turn on progress tracking per iteration + verbose: False # Turn on verbose progress tracking + continue: False # Continue training with a pre-trained model + finetune: False # Finetune a pre-trained model + + num_data_workers: 8 # Number of workers used in data-loading + augment: False # Use random tempo and gain perturbations + shuffle: True # Turn on shuffling and sample from dataset based on sequence length (smallest to largest) + + seed: 123456 # Seed to generators + cuda: True # Use cuda to train model + half_precision: True # Uses half precision to train a model + apex: True # Uses mixed precision to train a model + static_loss_scaling: False # Static loss scale for mixed precision + dynamic_loss_scaling: True # Use dynamic loss scaling for mixed precision + + dist_url: 'tcp://127.0.0.1:1550' # URL used to set up distributed training + dist_backend: 'nccl' # Distributed backend + world_size: 1 # Number of distributed processes + rank: 0 # The rank of the current process + gpu_rank: 0 # If using distributed parallel for multi_gpu, sets the GPU for the process \ No newline at end of file diff --git a/sonosco/datasets/__init__.py b/sonosco/datasets/__init__.py new file mode 100644 index 0000000..b74d5b6 --- /dev/null +++ b/sonosco/datasets/__init__.py @@ -0,0 +1,3 @@ +from .dataset import AudioDataProcessor, AudioDataset +from .samplers import BucketingSampler +from .loader import AudioDataLoader, create_data_loaders diff --git a/sonosco/datasets/dataset.py b/sonosco/datasets/dataset.py new file mode 100644 index 0000000..8345f30 --- /dev/null +++ b/sonosco/datasets/dataset.py @@ -0,0 +1,53 @@ +# ---------------------------------------------------------------------------- +# Based on SeanNaren's deepspeech.pytorch: +# https://github.com/SeanNaren/deepspeech.pytorch +# ---------------------------------------------------------------------------- + +import logging + +from torch.utils.data import Dataset +from .processor import AudioDataProcessor + + +LOGGER = logging.getLogger(__name__) + + +class AudioDataset(Dataset): + + def __init__(self, processor: AudioDataProcessor, manifest_filepath): + """ + Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by + a comma. Each new line is a different sample. Example below: + /path/to/audio.wav,/path/to/audio.txt + ... + :param processor: Data processor object + :param manifest_filepath: Path to manifest csv as describe above + """ + super().__init__() + with open(manifest_filepath) as f: + ids = f.readlines() + ids = [x.strip().split(',') for x in ids] + self.ids = ids + self.size = len(ids) + self.processor = processor + + def get_raw(self, index): + sample = self.ids[index] + audio_path, transcript_path = sample[0], sample[1] + + sound = self.processor.parse_audio(audio_path, raw=True) + transcript = self.processor.parse_transcript(transcript_path) + + return sound, transcript + + def __getitem__(self, index): + sample = self.ids[index] + audio_path, transcript_path = sample[0], sample[1] + + spectrogram = self.processor.parse_audio(audio_path) + transcript = self.processor.parse_transcript(transcript_path) + + return spectrogram, transcript + + def __len__(self): + return self.size diff --git a/sonosco/datasets/download_datasets/__init__.py b/sonosco/datasets/download_datasets/__init__.py new file mode 100644 index 0000000..2ef1650 --- /dev/null +++ b/sonosco/datasets/download_datasets/__init__.py @@ -0,0 +1,3 @@ +def download_all_datasets(path: str): + """Downloads all datasets which are missing.""" + pass diff --git a/sonosco/datasets/download_datasets/an4.py b/sonosco/datasets/download_datasets/an4.py new file mode 100644 index 0000000..374cb5b --- /dev/null +++ b/sonosco/datasets/download_datasets/an4.py @@ -0,0 +1,110 @@ +import os +import click +import io +import shutil +import tarfile +import logging +import sonosco.common.audio_tools as audio_tools +import sonosco.common.path_utils as path_utils + +from sonosco.datasets.download_datasets.data_utils import create_manifest +from sonosco.common.utils import setup_logging +from sonosco.common.constants import * + +LOGGER = logging.getLogger(__name__) + +AN4_URL = 'http://www.speech.cs.cmu.edu/databases/an4/an4_raw.bigendian.tar.gz' + + +def try_download_an4(target_dir, sample_rate, min_duration, max_duration): + path_to_data = os.path.join(os.path.expanduser("~"), target_dir) + if not os.path.exists(path_to_data): + os.makedirs(path_to_data) + target_unpacked_dir = os.path.join(path_to_data, "an4_unpacked") + path_utils.try_create_directory(target_unpacked_dir) + + extracted_dir = os.path.join(path_to_data, "An4") + if os.path.exists(extracted_dir): + shutil.rmtree(extracted_dir) + LOGGER.info("Start downloading...") + file_name = AN4_URL.split("/")[-1] + + target_filename = os.path.join(target_unpacked_dir, file_name) + path_utils.try_download(target_filename, AN4_URL) + LOGGER.info("Download complete") + LOGGER.info("Unpacking...") + tar = tarfile.open(target_filename) + tar.extractall(extracted_dir) + tar.close() + assert os.path.exists(extracted_dir), f"Archive {file_name} was not properly uncompressed" + LOGGER.info("Converting files to wav and extracting transcripts...") + + create_wav_and_transcripts(path_to_data, 'train', sample_rate, extracted_dir, 'an4_clstk') + create_wav_and_transcripts(path_to_data, 'test', sample_rate, extracted_dir, 'an4test_clstk') + + create_manifest(path_to_data, os.path.join(path_to_data,'an4_train_manifest.csv'), min_duration, max_duration) + create_manifest(path_to_data, os.path.join(path_to_data,'an4_val_manifest.csv'), min_duration, max_duration) + + +def create_wav_and_transcripts(path, data_tag, sample_rate, extracted_dir, wav_subfolder_name): + tag_path = os.path.join(path,data_tag) + transcript_path_new = os.path.join(tag_path, 'txt') + wav_path_new = os.path.join(tag_path, 'wav') + + path_utils.try_create_directory(transcript_path_new) + path_utils.try_create_directory(wav_path_new) + + wav_path_ext = os.path.join(extracted_dir, 'an4/wav') + file_ids = os.path.join(extracted_dir, f'an4/etc/an4_{data_tag}.fileids') + transcripts_ext = os.path.join(extracted_dir, f'an4/etc/an4_{data_tag}.transcription') + path = os.path.join(wav_path_ext, wav_subfolder_name) + convert_audio_to_wav(path, sample_rate) + format_files(file_ids, transcript_path_new, wav_path_new, transcripts_ext, wav_path_ext) + + +def convert_audio_to_wav(train_path, sample_rate): + with os.popen('find %s -type f -name "*.raw"' % train_path) as pipe: + for line in pipe: + raw_path = line.strip() + new_path = line.replace('.raw', '.wav').strip() + audio_tools.transcode_recordings_an4(raw_path=raw_path, wav_path= new_path, sample_rate=sample_rate) + + +def format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_path): + with open(file_ids, 'r') as f: + with open(transcripts, 'r') as t: + paths = f.readlines() + transcripts = t.readlines() + for x in range(len(paths)): + path = os.path.join(wav_path, paths[x].strip()) + '.wav' + filename = path.split('/')[-1] + extracted_transcript = _process_transcript(transcripts, x) + current_path = os.path.abspath(path) + new_path = os.path.join(new_wav_path ,filename) + text_path = os.path.join(new_transcript_path,filename.replace('.wav', '.txt')) + with io.FileIO(text_path, "w") as file: + file.write(extracted_transcript.encode('utf-8')) + os.rename(current_path, new_path) + + +def _process_transcript(transcripts, x): + extracted_transcript = transcripts[x].split('(')[0].strip("").split('<')[0].strip().upper() + return extracted_transcript + + +@click.command() +@click.option("--target-dir", default="temp/data/an4", type=str, help="Directory to store the dataset.") +@click.option("--sample-rate", default=16000, type=int, help="Sample rate.") +@click.option("--min-duration", default=1, type=int, + help="Prunes training samples shorter than the min duration (given in seconds).") +@click.option("--max-duration", default=15, type=int, + help="Prunes training samples longer than the max duration (given in seconds).") +def main(**kwargs): + """Processes and downloads an4 dataset.""" + try_download_an4(**kwargs) + + +if __name__ == '__main__': + LOGGER = logging.getLogger(SONOSCO) + setup_logging(LOGGER) + main() diff --git a/sonosco/datasets/download_datasets/common_voice.py b/sonosco/datasets/download_datasets/common_voice.py new file mode 100644 index 0000000..41bc102 --- /dev/null +++ b/sonosco/datasets/download_datasets/common_voice.py @@ -0,0 +1,107 @@ +import os +import click +import logging +import tarfile +import shutil +import csv +import sonosco.common.audio_tools as audio_tools +import sonosco.common.path_utils as path_utils + +from multiprocessing.pool import ThreadPool +from sonosco.datasets.download_datasets.data_utils import create_manifest +from sonosco.common.utils import setup_logging +from sonosco.common.constants import * + + +LOGGER = logging.getLogger(__name__) + +COMMON_VOICE_URL = "https://common-voice-data-download.s3.amazonaws.com/cv_corpus_v1.tar.gz" + + +def try_download_common_voice(target_dir, sample_rate, files_to_use, min_duration, max_duration): + path_to_data = os.path.join(os.path.expanduser("~"), target_dir) + path_utils.try_create_directory(path_to_data) + + target_unpacked_dir = os.path.join(path_to_data, "common_unpacked") + path_utils.try_create_directory(target_unpacked_dir) + + extracted_dir = os.path.join(path_to_data, "CommonVoice") + if os.path.exists(extracted_dir): + shutil.rmtree(extracted_dir) + LOGGER.info("Start downloading...") + file_name = COMMON_VOICE_URL.split("/")[-1] + target_filename = os.path.join(target_unpacked_dir, file_name) + path_utils.try_download(target_filename, COMMON_VOICE_URL) + + LOGGER.info("Download complete") + LOGGER.info("Unpacking...") + tar = tarfile.open(target_filename) + tar.extractall(extracted_dir) + tar.close() + os.remove(target_unpacked_dir) + assert os.path.exists(extracted_dir), f"Archive {file_name} was not properly uncompressed" + LOGGER.info("Converting files to wav and extracting transcripts...") + for csv_file in files_to_use.split(','): + convert_to_wav(os.path.join(extracted_dir, 'cv_corpus_v1/', csv_file), + os.path.join(target_dir, os.path.splitext(csv_file)[0]), + sample_rate) + LOGGER.info(f"Finished {COMMON_VOICE_URL}") + shutil.rmtree(extracted_dir) + + LOGGER.info('Creating manifests...') + for csv_file in files_to_use.split(','): + create_manifest(os.path.join(path_to_data, os.path.splitext(csv_file)[0]), + os.path.splitext(csv_file)[0] + '_manifest.csv', + min_duration, + max_duration) + + +def convert_to_wav(csv_file, target_dir, sample_rate): + """ Read *.csv file description, convert mp3 to wav, process text. + Save results to target_dir. + Args: + csv_file: str, path to *.csv file with data description, usually start from 'cv-' + target_dir: str, path to dir to save results; wav/ and txt/ dirs will be created + """ + wav_dir = os.path.join(target_dir, 'wav/') + txt_dir = os.path.join(target_dir, 'txt/') + path_utils.try_create_directory(wav_dir) + path_utils.try_create_directory(txt_dir) + path_to_data = os.path.dirname(csv_file) + + def process(x): + file_path, text = x + file_name = os.path.splitext(os.path.basename(file_path))[0] + text = text.strip().upper() + with open(os.path.join(txt_dir, file_name + '.txt'), 'w') as f: + f.write(text) + audio_tools.transcode_recording(source=os.path.join(path_to_data, file_path), + destination=os.path.join(wav_dir, file_name + '.wav'), + sample_rate=sample_rate) + + LOGGER.info('Converting mp3 to wav for {}.'.format(csv_file)) + with open(csv_file) as csvfile: + reader = csv.DictReader(csvfile) + data = [(row['filename'], row['text']) for row in reader] + with ThreadPool(10) as pool: + pool.map(process, data) + + +@click.command() +@click.option("--target-dir", default="temp/data/common_voice", type=str, help="Directory to store the dataset.") +@click.option("--sample-rate", default=16000, type=int, help="Sample rate.") +@click.option("--files-to-use", multiple=True, + default=["cv-valid-dev.csv","cv-valid-test.csv","cv-valid-train.csv"]) +@click.option("--min-duration", default=1, type=int, + help="Prunes training samples shorter than the min duration (given in seconds).") +@click.option("--max-duration", default=15, type=int, + help="Prunes training samples longer than the max duration (given in seconds).") +def main(**kwargs): + global LOGGER + LOGGER = logging.getLogger(SONOSCO) + setup_logging(LOGGER) + try_download_common_voice(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/sonosco/datasets/download_datasets/data_utils.py b/sonosco/datasets/download_datasets/data_utils.py new file mode 100644 index 0000000..167eb92 --- /dev/null +++ b/sonosco/datasets/download_datasets/data_utils.py @@ -0,0 +1,44 @@ +import fnmatch +import io +import os +import logging +import torch.distributed as dist +import sonosco.common.audio_tools as audio_tools + +from tqdm import tqdm + +LOGGER = logging.getLogger(__name__) + + +def create_manifest(data_path, output_path, min_duration=None, max_duration=None): + LOGGER.info(f"Creating a manifest for path: {data_path}") + file_paths = [os.path.join(dirpath, f) + for dirpath, dirnames, files in os.walk(data_path) + for f in fnmatch.filter(files, '*.wav')] + LOGGER.info(f"Found {len(file_paths)} .wav files") + file_paths = order_and_prune_files(file_paths, min_duration, max_duration) + with io.FileIO(output_path, "w") as file: + for wav_path in tqdm(file_paths, total=len(file_paths)): + transcript_path = wav_path.replace('/wav/', '/txt/').replace('.wav', '.txt') + sample = f"{os.path.abspath(wav_path)},{os.path.abspath(transcript_path)}\n" + file.write(sample.encode('utf-8')) + + +def order_and_prune_files(file_paths, min_duration, max_duration): + LOGGER.info("Sorting manifests...") + path_and_duration = [(path, audio_tools.get_duration(path)) for path in file_paths] + + if min_duration and max_duration: + LOGGER.info(f"Pruning manifests between {min_duration} and {max_duration} seconds") + path_and_duration = [(path, duration) for path, duration in path_and_duration + if min_duration <= duration <= max_duration] + + path_and_duration.sort(key=lambda e: e[1]) + return [x[0] for x in path_and_duration] + + +def reduce_tensor(tensor, world_size): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.reduce_op.SUM) + rt /= world_size + return rt diff --git a/sonosco/datasets/download_datasets/librispeech.py b/sonosco/datasets/download_datasets/librispeech.py new file mode 100644 index 0000000..246f0e2 --- /dev/null +++ b/sonosco/datasets/download_datasets/librispeech.py @@ -0,0 +1,131 @@ +import os +import click +import tarfile +import shutil +import logging +import sonosco.common.audio_tools as audio_tools +import sonosco.common.path_utils as path_utils + +from sonosco.datasets.download_datasets.data_utils import create_manifest +from sonosco.common.utils import setup_logging +from sonosco.common.constants import * +from tqdm import tqdm + + +LOGGER = logging.getLogger(__name__) + + +LIBRI_SPEECH_URLS = { + #"train": ["http://www.openslr.org/resources/12/train-clean-100.tar.gz", + # "http://www.openslr.org/resources/12/train-clean-360.tar.gz", + # "http://www.openslr.org/resources/12/train-other-500.tar.gz"], + + #"val": ["http://www.openslr.org/resources/12/dev-clean.tar.gz", + # "http://www.openslr.org/resources/12/dev-other.tar.gz"], + + "test_clean": ["http://www.openslr.org/resources/12/test-clean.tar.gz"]#, + #"test_other": ["http://www.openslr.org/resources/12/test-other.tar.gz"] +} + + +def try_download_librispeech(target_dir, sample_rate, files_to_use, min_duration, max_duration): + path_to_data = os.path.join(os.path.expanduser("~"), target_dir) + if not os.path.exists(path_to_data): + os.makedirs(path_to_data) + + for split_type, lst_libri_urls in LIBRI_SPEECH_URLS.items(): + split_dir = os.path.join(path_to_data, split_type) + path_utils.try_create_directory(split_dir) + split_wav_dir = os.path.join(split_dir, "wav") + path_utils.try_create_directory(split_wav_dir) + split_txt_dir = os.path.join(split_dir, "txt") + path_utils.try_create_directory(split_txt_dir) + extracted_dir = os.path.join(split_dir, "LibriSpeech") + + if os.path.exists(extracted_dir): + shutil.rmtree(extracted_dir) + + for url in lst_libri_urls: + # check if we want to dl this file + dl_flag = False + for f in files_to_use: + if url.find(f) != -1: + dl_flag = True + if not dl_flag: + LOGGER.info(f"Skipping url: {url}") + continue + + filename = url.split("/")[-1] + target_filename = os.path.join(split_dir, filename) + LOGGER.info(f"Downloading from {url}") + path_utils.try_download(target_filename, url) + LOGGER.info("Download complete") + LOGGER.info(f"Unpacking {filename}...") + tar = tarfile.open(target_filename) + tar.extractall(split_dir) + tar.close() + os.remove(target_filename) + assert os.path.exists(extracted_dir), f"Archive {filename} was not properly uncompressed" + + LOGGER.info("Converting flac files to wav and extracting transcripts...") + for root, subdirs, files in tqdm(os.walk(extracted_dir)): + for f in files: + if f.find(".flac") != -1: + _process_file(wav_dir=split_wav_dir, txt_dir=split_txt_dir, + base_filename=f, root_dir=root, sample_rate=sample_rate) + + LOGGER.info(f"Finished {url}") + shutil.rmtree(extracted_dir) + + manifest_path = os.path.join(path_to_data, f"libri_{split_type}_manifest.csv") + if os.path.exists(manifest_path): + continue + + if split_type == 'train': # Prune to min/max duration + create_manifest(split_dir, manifest_path, min_duration, max_duration) + else: + create_manifest(split_dir, manifest_path) + + +def _preprocess_transcript(phrase): + return phrase.strip().upper() + + +def _process_file(wav_dir, txt_dir, base_filename, root_dir, sample_rate): + full_recording_path = os.path.join(root_dir, base_filename) + assert os.path.exists(full_recording_path) and os.path.exists(root_dir) + wav_recording_path = os.path.join(wav_dir, base_filename.replace(".flac", ".wav")) + audio_tools.transcode_recording(full_recording_path, wav_recording_path, sample_rate) + # process transcript + txt_transcript_path = os.path.join(txt_dir, base_filename.replace(".flac", ".txt")) + transcript_file = os.path.join(root_dir, "-".join(base_filename.split('-')[:-1]) + ".trans.txt") + assert os.path.exists(transcript_file), f"Transcript file {transcript_file} does not exist" + transcriptions = open(transcript_file).read().strip().split("\n") + transcriptions = {t.split()[0].split("-")[-1]: " ".join(t.split()[1:]) for t in transcriptions} + with open(txt_transcript_path, "w") as f: + key = base_filename.replace(".flac", "").split("-")[-1] + assert key in transcriptions, f"{key} is not in the transcriptions" + f.write(_preprocess_transcript(transcriptions[key])) + f.flush() + + +@click.command() +@click.option("--target-dir", default="temp/data/libri_speech", type=str, help="Directory to store the dataset.") +@click.option("--sample-rate", default=16000, type=int, help="Sample rate.") +@click.option("--files-to-use", multiple=True, + default=["train-clean-100.tar.gz", "train-clean-360.tar.gz", "train-other-500.tar.gz", + "dev-clean.tar.gz", "dev-other.tar.gz", "test-clean.tar.gz", "test-other.tar.gz"], + type=str, help="List of file names to download.") +@click.option("--min-duration", default=1, type=int, + help="Prunes training samples shorter than the min duration (given in seconds).") +@click.option("--max-duration", default=15, type=int, + help="Prunes training samples longer than the max duration (given in seconds).") +def main(**kwargs): + """Processes and downloads LibriSpeech dataset.""" + try_download_librispeech(**kwargs) + + +if __name__ == "__main__": + LOGGER = logging.getLogger(SONOSCO) + setup_logging(LOGGER) + main() diff --git a/sonosco/datasets/download_datasets/merge_manifests.py b/sonosco/datasets/download_datasets/merge_manifests.py new file mode 100644 index 0000000..6218d52 --- /dev/null +++ b/sonosco/datasets/download_datasets/merge_manifests.py @@ -0,0 +1,30 @@ +import argparse +import io +import os + +from tqdm import tqdm +from .data_utils import order_and_prune_files + + +parser = argparse.ArgumentParser(description='Merges all manifest CSV files in specified folder.') +parser.add_argument('--merge-dir', default='manifests/', help='Path to all manifest files you want to merge') +parser.add_argument('--min-duration', default=1, type=int, + help='Prunes any samples shorter than the min duration (given in seconds, default 1)') +parser.add_argument('--max-duration', default=15, type=int, + help='Prunes any samples longer than the max duration (given in seconds, default 15)') +parser.add_argument('--output-path', default='merged_manifest.csv', help='Output path to merged manifest') + +args = parser.parse_args() + +file_paths = [] +for file in os.listdir(args.merge_dir): + if file.endswith(".csv"): + with open(os.path.join(args.merge_dir, file), 'r') as fh: + file_paths += fh.readlines() +file_paths = [file_path.split(',')[0] for file_path in file_paths] +file_paths = order_and_prune_files(file_paths, args.min_duration, args.max_duration) +with io.FileIO(args.output_path, "w") as file: + for wav_path in tqdm(file_paths, total=len(file_paths)): + transcript_path = wav_path.replace('/wav/', '/txt/').replace('.wav', '.txt') + sample = os.path.abspath(wav_path) + ',' + os.path.abspath(transcript_path) + '\n' + file.write(sample.encode('utf-8')) diff --git a/sonosco/datasets/download_datasets/ted3.py b/sonosco/datasets/download_datasets/ted3.py new file mode 100644 index 0000000..bb42c32 --- /dev/null +++ b/sonosco/datasets/download_datasets/ted3.py @@ -0,0 +1,124 @@ +import os +import click +import logging +import argparse +import subprocess +import unicodedata +import tarfile +import io +import shutil +import sonosco.common.audio_tools as audio_tools +import sonosco.common.path_utils as path_utils +from sonosco.datasets.download_datasets.data_utils import create_manifest +from sonosco.common.utils import setup_logging +from sonosco.common.constants import * +from tqdm import tqdm + +LOGGER = logging.getLogger(__name__) + +TED_LIUM_V2_DL_URL = "http://www.openslr.org/resources/51/TEDLIUM_release-3.tgz" + +def try_download_ted3(target_dir, sample_rate, min_duration, max_duration): + path_to_data = os.path.join(os.path.expanduser("~"), target_dir) + path_utils.try_create_directory(path_to_data) + + target_unpacked_dir = os.path.join(path_to_data, "ted3_unpacked") + path_utils.try_create_directory(target_unpacked_dir) + + extracted_dir = os.path.join(path_to_data, "Ted3") + if os.path.exists(extracted_dir): + shutil.rmtree(extracted_dir) + LOGGER.info("Start downloading...") + file_name = TED_LIUM_V2_DL_URL.split("/")[-1] + target_filename = os.path.join(target_unpacked_dir, file_name) + path_utils.try_download(target_filename, TED_LIUM_V2_DL_URL) + + LOGGER.info("Download complete") + LOGGER.info("Unpacking...") + tar = tarfile.open(target_filename) + tar.extractall(extracted_dir) + tar.close() + os.remove(target_unpacked_dir) + assert os.path.exists(extracted_dir), f"Archive {file_name} was not properly uncompressed" + LOGGER.info("Converting files to wav and extracting transcripts...") + prepare_dir(path_to_data, sample_rate) + create_manifest(path_to_data, os.path.join(path_to_data,'ted3_train_manifest.csv'), min_duration, max_duration) + + +def get_utterances_from_stm(stm_file): + """ + Return list of entries containing phrase and its start/end timings + :param stm_file: + :return: + """ + res = [] + with io.open(stm_file, "r", encoding='utf-8') as f: + for stm_line in f: + tokens = stm_line.split() + start_time = float(tokens[3]) + end_time = float(tokens[4]) + filename = tokens[0] + transcript = unicodedata.normalize("NFKD", + " ".join(t for t in tokens[6:]).strip()). \ + encode("utf-8", "ignore").decode("utf-8", "ignore") + if transcript != "ignore_time_segment_in_scoring": + res.append({ + "start_time": start_time, "end_time": end_time, + "filename": filename, "transcript": transcript + }) + return res + + +def _preprocess_transcript(phrase): + return phrase.strip().upper() + + +def filter_short_utterances(utterance_info, min_len_sec=1.0): + return utterance_info["end_time"] - utterance_info["start_time"] > min_len_sec + + +def prepare_dir(ted_dir, sample_rate): + # directories to store converted wav files and their transcriptions + wav_dir = os.path.join(ted_dir, "wav") + path_utils.try_create_directory(wav_dir) + txt_dir = os.path.join(ted_dir, "txt") + path_utils.try_create_directory(txt_dir) + counter = 0 + entries = os.listdir(os.path.join(ted_dir, "sph")) + for sph_file in tqdm(entries, total=len(entries)): + speaker_name = sph_file.split('.sph')[0] + + sph_file_full = os.path.join(ted_dir, "sph", sph_file) + stm_file_full = os.path.join(ted_dir, "stm", "{}.stm".format(speaker_name)) + + assert os.path.exists(sph_file_full) and os.path.exists(stm_file_full) + all_utterances = get_utterances_from_stm(stm_file_full) + + all_utterances = filter(filter_short_utterances, all_utterances) + for utterance_id, utterance in enumerate(all_utterances): + target_wav_file = os.path.join(wav_dir, "{}_{}.wav".format(utterance["filename"], str(utterance_id))) + target_txt_file = os.path.join(txt_dir, "{}_{}.txt".format(utterance["filename"], str(utterance_id))) + audio_tools.transcode_recordings_ted3(sph_file_full, target_wav_file, utterance["start_time"], utterance["end_time"], + sample_rate=sample_rate) + with io.FileIO(target_txt_file, "w") as f: + f.write(_preprocess_transcript(utterance["transcript"]).encode('utf-8')) + counter += 1 + +@click.command() +@click.option("--target-dir", default="temp/data/ted3", type=str, help="Directory to store the dataset.") +@click.option("--sample-rate", default=16000, type=int, help="Sample rate.") + +@click.option("--min-duration", default=1, type=int, + help="Prunes training samples shorter than the min duration (given in seconds).") +@click.option("--max-duration", default=15, type=int, + help="Prunes training samples longer than the max duration (given in seconds).") + + +def main(**kwargs): + global LOGGER + logger = logging.getLogger(SONOSCO) + setup_logging(logger) + try_download_ted3(**kwargs) + +if __name__ == "__main__": + main() diff --git a/sonosco/datasets/download_datasets/voxforge.py b/sonosco/datasets/download_datasets/voxforge.py new file mode 100644 index 0000000..393fd6e --- /dev/null +++ b/sonosco/datasets/download_datasets/voxforge.py @@ -0,0 +1,113 @@ +import os +import click +import logging +from six.moves import urllib +import argparse +import re +import tempfile +import shutil +import subprocess +import tarfile +import io +import sonosco.common.audio_tools as audio_tools +import sonosco.common.path_utils as path_utils +from sonosco.datasets.download_datasets.data_utils import create_manifest +from sonosco.common.utils import setup_logging +from sonosco.common.constants import * +from tqdm import tqdm + +LOGGER = logging.getLogger(__name__) + +VOXFORGE_URL_16kHz = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit/' + +def try_download_voxforge(target_dir, sample_rate, min_duration, max_duration): + path_to_data = os.path.join(os.path.expanduser("~"), target_dir) + path_utils.try_create_directory(path_to_data) + + LOGGER.info("Start downloading...") + request = urllib.request.Request(VOXFORGE_URL_16kHz) + response = urllib.request.urlopen(request) + content = response.read() + all_files = re.findall("href\=\"(.*\.tgz)\"", content.decode("utf-8")) + for f in tqdm(all_files, total=len(all_files)): + prepare_sample(f.replace(".tgz", ""), VOXFORGE_URL_16kHz + f, path_to_data, sample_rate) + create_manifest(path_to_data, os.path.join(path_to_data,'voxforge_train_manifest.csv'), min_duration, max_duration) + +def _get_recordings_dir(sample_dir, recording_name): + wav_dir = os.path.join(sample_dir, recording_name, "wav") + if os.path.exists(wav_dir): + return "wav", wav_dir + flac_dir = os.path.join(sample_dir, recording_name, "flac") + if os.path.exists(flac_dir): + return "flac", flac_dir + raise Exception("wav or flac directory was not found for recording name: {}".format(recording_name)) + + +def prepare_sample(recording_name, url, target_folder, sample_rate): + """ + Downloads and extracts a sample from VoxForge and puts the wav and txt files into :target_folder. + """ + wav_dir = os.path.join(target_folder, "wav") + path_utils.try_create_directory(wav_dir) + txt_dir = os.path.join(target_folder, "txt") + path_utils.try_create_directory(txt_dir) + # check if sample is processed + filename_set = set(['_'.join(wav_file.split('_')[:-1]) for wav_file in os.listdir(wav_dir)]) + if recording_name in filename_set: + return + + request = urllib.request.Request(url) + response = urllib.request.urlopen(request) + content = response.read() + response.close() + with tempfile.NamedTemporaryFile(suffix=".tgz", mode='wb') as target_tgz: + target_tgz.write(content) + target_tgz.flush() + dirpath = tempfile.mkdtemp() + + tar = tarfile.open(target_tgz.name) + tar.extractall(dirpath) + tar.close() + + recordings_type, recordings_dir = _get_recordings_dir(dirpath, recording_name) + tgz_prompt_file = os.path.join(dirpath, recording_name, "etc", "PROMPTS") + + if os.path.exists(recordings_dir) and os.path.exists(tgz_prompt_file): + transcriptions = open(tgz_prompt_file).read().strip().split("\n") + transcriptions = {t.split()[0]: " ".join(t.split()[1:]) for t in transcriptions} + for wav_file in os.listdir(recordings_dir): + recording_id = wav_file.split('.{}'.format(recordings_type))[0] + transcription_key = recording_name + "/mfc/" + recording_id + if transcription_key not in transcriptions: + continue + utterance = transcriptions[transcription_key] + + target_wav_file = os.path.join(wav_dir, "{}_{}.wav".format(recording_name, recording_id)) + target_txt_file = os.path.join(txt_dir, "{}_{}.txt".format(recording_name, recording_id)) + with io.FileIO(target_txt_file, "w") as file: + file.write(utterance.encode('utf-8')) + original_wav_file = os.path.join(recordings_dir, wav_file) + audio_tools.transcode_recording(original_wav_file, target_wav_file, sample_rate) + + shutil.rmtree(dirpath) + +@click.command() +@click.option("--target-dir", default="temp/data/voxforge", type=str, help="Directory to store the dataset.") +@click.option("--sample-rate", default=16000, type=int, help="Sample rate.") + +@click.option("--min-duration", default=1, type=int, + help="Prunes training samples shorter than the min duration (given in seconds).") +@click.option("--max-duration", default=15, type=int, + help="Prunes training samples longer than the max duration (given in seconds).") + + + +def main(**kwargs): + global LOGGER + logger = logging.getLogger(SONOSCO) + setup_logging(logger) + try_download_voxforge(**kwargs) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/sonosco/datasets/loader.py b/sonosco/datasets/loader.py new file mode 100644 index 0000000..e23c2fd --- /dev/null +++ b/sonosco/datasets/loader.py @@ -0,0 +1,61 @@ +import logging +import torch +import torch.nn + +from torch.utils.data import DataLoader +from .dataset import AudioDataProcessor, AudioDataset +from .samplers import BucketingSampler + + +LOGGER = logging.getLogger(__name__) + + +class AudioDataLoader(DataLoader): + + def __init__(self, *args, **kwargs): + ''' + Creates a data loader for AudioDatasets. + ''' + super(AudioDataLoader, self).__init__(*args, **kwargs) + self.collate_fn = self._collate_fn + + def _collate_fn(self, batch): + # sort the batch in decreasing order of sequence length + batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True) + + # pad the tensors to have equal lengths, therefore transpose the tensors in + # the batch. The tensors have shape: freq_size x sequence_length + # and need to be of shape: sequence_length x freq_length, as sequence length differs + # but not the freq_length + inputs = torch.nn.utils.rnn.pad_sequence(list(map(lambda x: x[0].transpose(0,1), batch)), batch_first=True) + + # inputs need to be transposed back from shape batch_size x sequence_length x freq_length + # to batch_size x freq_length x sequence_length. Additionally, unsqueeze tensor + inputs = inputs.transpose(1, 2).unsqueeze(1) + input_lengths = torch.IntTensor(list(map(lambda x: x[0].size(1), batch))) # create tensor of input lengths + + targets_arr = list(zip(*batch))[1] # extract targets array from batch ( batch is array of tuples) + target_lengths = torch.IntTensor(list(map(lambda x: len(x), targets_arr))) # create tensor of target lengths + targets = torch.cat(list(map(lambda x: torch.IntTensor(x), targets_arr))) # create tensor of targets + + return inputs, targets, input_lengths, target_lengths + + +def create_data_loaders(**kwargs): + processor = AudioDataProcessor(**kwargs) + + # create train loader + train_dataset = AudioDataset(processor, manifest_filepath=kwargs["train_manifest"]) + LOGGER.info(f"Training dataset containing {len(train_dataset)} samples is created") + sampler = BucketingSampler(train_dataset, batch_size=kwargs["batch_size"]) + train_loader = AudioDataLoader(dataset=train_dataset, num_workers=kwargs["num_data_workers"], batch_sampler=sampler) + LOGGER.info("Training data loader created.") + + # create validation loader + val_dataset = AudioDataset(processor, manifest_filepath=kwargs["val_manifest"]) + LOGGER.info(f"Validation dataset containing {len(val_dataset)} samples is created") + sampler = BucketingSampler(val_dataset, batch_size=kwargs["batch_size"]) + val_loader = AudioDataLoader(dataset=val_dataset, num_workers=kwargs["num_data_workers"], batch_sampler=sampler) + LOGGER.info("Validation data loader created.") + + return train_loader, val_loader diff --git a/sonosco/datasets/processor.py b/sonosco/datasets/processor.py new file mode 100644 index 0000000..d8e91e2 --- /dev/null +++ b/sonosco/datasets/processor.py @@ -0,0 +1,93 @@ +import logging +import torch +import librosa +import numpy as np +import sonosco.common.audio_tools as audio_tools +import sonosco.common.utils as utils +import sonosco.common.noise_makers as noise_makers + + +LOGGER = logging.getLogger(__name__) +MIN_STRETCH = 0.7 +MAX_STRETCH = 1.3 +MIN_PITCH = 0.7 +MAX_PITCH = 1.5 +MAX_SHIFT = 4000 + + +class AudioDataProcessor: + + def __init__(self, window_stride, window_size, sample_rate, labels="abc", normalize=False, augment=False, **kwargs): + """ + Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by + a comma. Each new line is a different sample. Example below: + /path/to/audio.wav,/path/to/audio.txt + ... + :param window_stride: number of seconds to skip between each window + :param window_size: number of seconds to use for a window of spectrogram + :param sample_rate: sample rate of the recordings + :param labels: string containing all the possible characters to map to + :param normalize: apply standard mean and deviation normalization to audio tensor + :param augment(default False): apply random tempo and gain perturbations + """ + self.window_stride = window_stride + self.window_size = window_size + self.sample_rate = sample_rate + self.labels_map = dict([(labels[i], i) for i in range(len(labels))]) + self.normalize = normalize + self.augment = augment + + @property + def window_stride_samples(self): + return int(self.sample_rate * self.window_stride) + + @property + def window_size_samples(self): + return int(self.sample_rate * self.window_stride) + + def retrieve_file(self, audio_path): + sound, sample_rate = librosa.load(audio_path, sr=self.sample_rate) + return sound, sample_rate + + def augment_audio(self, sound, stretch=True, shift=False, pitch=True, noise=True): + augmented = audio_tools.stretch(sound, utils.random_float(MIN_STRETCH, MAX_STRETCH)) if stretch else sound + augmented = audio_tools.shift(augmented, np.random.randint(MAX_SHIFT)) if shift else augmented + augmented = audio_tools.pitch_shift(augmented, self.sample_rate, + n_steps=utils.random_float(MIN_PITCH, MAX_PITCH)) if pitch else augmented + + if noise: + noise_maker = noise_makers.GaussianNoiseMaker() + augmented = noise_maker.add_noise(augmented) if noise else augmented + + return augmented + + def parse_audio(self, audio_path, raw=False): + sound, sample_rate = self.retrieve_file(audio_path) + + if sample_rate != self.sample_rate: + raise ValueError(f"The stated sample rate {self.sample_rate} and the factual rate {sample_rate} differ!") + + if self.augment: + sound = self.augment_audio(sound) + + if raw: + return sound + + # TODO: comment why take the last element? + complex_spectrogram = librosa.stft(sound, + n_fft=self.window_size_samples, + hop_length=self.window_stride_samples, + win_length=self.window_size_samples) + spectrogram, phase = librosa.magphase(complex_spectrogram) + # S = log(S+1) + spectrogram = torch.from_numpy(np.log1p(spectrogram)) + + return spectrogram + + def parse_transcript(self, transcript_path): + with open(transcript_path, 'r', encoding='utf8') as transcript_file: + transcript = transcript_file.read().replace('\n', '') + # TODO: Is it fast enough? + transcript = list(filter(None, [self.labels_map.get(x) for x in list(transcript)])) + LOGGER.debug(f"transcript_path: {transcript_path} transcript: {transcript}") + return transcript diff --git a/sonosco/datasets/samplers.py b/sonosco/datasets/samplers.py new file mode 100644 index 0000000..416b754 --- /dev/null +++ b/sonosco/datasets/samplers.py @@ -0,0 +1,69 @@ +import math + +import numpy as np +import torch +from torch.utils.data import Sampler +from torch.distributed.deprecated import get_rank +from torch.distributed.deprecated import get_world_size + + +class BucketingSampler(Sampler): + def __init__(self, data_source, batch_size=1): + """ + Samples batches assuming they are in order of size to batch similarly sized samples together. + """ + super(BucketingSampler, self).__init__(data_source) + self.data_source = data_source + ids = list(range(0, len(data_source))) + # TODO: Optimise + self.bins = [ids[i:i + batch_size] for i in range(0, len(ids), batch_size)] + + def __iter__(self): + for ids in self.bins: + np.random.shuffle(ids) + yield ids + + def __len__(self): + return len(self.bins) + + def shuffle(self, epoch): + np.random.shuffle(self.bins) + + +# TODO: Optimise +class DistributedBucketingSampler(Sampler): + def __init__(self, data_source, batch_size=1, num_replicas=None, rank=None): + """ + Samples batches assuming they are in order of size to batch similarly sized samples together. + """ + super(DistributedBucketingSampler, self).__init__(data_source) + if num_replicas is None: + num_replicas = get_world_size() + if rank is None: + rank = get_rank() + self.data_source = data_source + self.ids = list(range(0, len(data_source))) + self.batch_size = batch_size + self.bins = [self.ids[i:i + batch_size] for i in range(0, len(self.ids), batch_size)] + self.num_replicas = num_replicas + self.rank = rank + self.num_samples = int(math.ceil(len(self.bins) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + offset = self.rank + # add extra samples to make it evenly divisible + bins = self.bins + self.bins[:(self.total_size - len(self.bins))] + assert len(bins) == self.total_size + samples = bins[offset::self.num_replicas] # Get every Nth bin, starting from rank + return iter(samples) + + def __len__(self): + return self.num_samples + + def shuffle(self, epoch): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(epoch) + bin_ids = list(torch.randperm(len(self.bins), generator=g)) + self.bins = [self.bins[i] for i in bin_ids] diff --git a/sonosco/decoders/__init__.py b/sonosco/decoders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/decoders/beam_decoder.py b/sonosco/decoders/beam_decoder.py similarity index 100% rename from decoders/beam_decoder.py rename to sonosco/decoders/beam_decoder.py diff --git a/decoders/decoder.py b/sonosco/decoders/decoder.py similarity index 100% rename from decoders/decoder.py rename to sonosco/decoders/decoder.py diff --git a/decoders/greedy_decoder.py b/sonosco/decoders/greedy_decoder.py similarity index 100% rename from decoders/greedy_decoder.py rename to sonosco/decoders/greedy_decoder.py diff --git a/infer.py b/sonosco/infer.py similarity index 100% rename from infer.py rename to sonosco/infer.py diff --git a/sonosco/models/__init__.py b/sonosco/models/__init__.py new file mode 100644 index 0000000..eaab54d --- /dev/null +++ b/sonosco/models/__init__.py @@ -0,0 +1 @@ +from .deepspeech2 import DeepSpeech2 \ No newline at end of file diff --git a/models/deepspeech2.py b/sonosco/models/deepspeech2.py similarity index 69% rename from models/deepspeech2.py rename to sonosco/models/deepspeech2.py index dcfa100..0fff74f 100644 --- a/models/deepspeech2.py +++ b/sonosco/models/deepspeech2.py @@ -2,106 +2,15 @@ # Based on SeanNaren's deepspeech.pytorch: # https://github.com/SeanNaren/deepspeech.pytorch # ---------------------------------------------------------------------------- - import math -from collections import OrderedDict - import torch +import logging import torch.nn as nn -import torch.nn.functional as F - - -supported_rnns = { - 'lstm': nn.LSTM, - 'rnn': nn.RNN, - 'gru': nn.GRU -} -supported_rnns_inv = dict((v, k) for k, v in supported_rnns.items()) - - -class SequenceWise(nn.Module): - def __init__(self, module): - """ - Collapses input of dim T*N*H to (T*N)*H, and applies to a module. - Allows handling of variable sequence lengths and minibatch sizes. - :param module: Module to apply input to. - """ - super(SequenceWise, self).__init__() - self.module = module - - def forward(self, x): - t, n = x.size(0), x.size(1) - x = x.view(t * n, -1) - x = self.module(x) - x = x.view(t, n, -1) - return x - - def __repr__(self): - tmpstr = self.__class__.__name__ + ' (\n' - tmpstr += self.module.__repr__() - tmpstr += ')' - return tmpstr - -class MaskConv(nn.Module): - def __init__(self, seq_module): - """ - Adds padding to the output of the module based on the given lengths. This is to ensure that the - results of the model do not change when batch sizes change during inference. - Input needs to be in the shape of (BxCxDxT) - :param seq_module: The sequential module containing the conv stack. - """ - super(MaskConv, self).__init__() - self.seq_module = seq_module - - def forward(self, x, lengths): - """ - :param x: The input of size BxCxDxT - :param lengths: The actual length of each sequence in the batch - :return: Masked output from the module - """ - for module in self.seq_module: - x = module(x) - mask = torch.ByteTensor(x.size()).fill_(0) - if x.is_cuda: - mask = mask.cuda() - for i, length in enumerate(lengths): - length = length.item() - if (mask[i].size(2) - length) > 0: - mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1) - x = x.masked_fill(mask, 0) - return x, lengths - - -class InferenceBatchSoftmax(nn.Module): - def forward(self, input_): - if not self.training: - return F.softmax(input_, dim=-1) - else: - return input_ - - -class BatchRNN(nn.Module): - def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, batch_norm=True): - super(BatchRNN, self).__init__() - self.input_size = input_size - self.hidden_size = hidden_size - self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None - self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size, - bidirectional=True, bias=True) - - def flatten_parameters(self): - self.rnn.flatten_parameters() +from collections import OrderedDict +from .modules import MaskConv, BatchRNN, SequenceWise, InferenceBatchSoftmax, supported_rnns, supported_rnns_inv - def forward(self, x, output_lengths): - if self.batch_norm is not None: - x = self.batch_norm(x) - x = nn.utils.rnn.pack_padded_sequence(x, output_lengths) - x, h = self.rnn(x) - x, _ = nn.utils.rnn.pad_packed_sequence(x) - if self.bidirectional: - x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) by sum - return x +LOGGER = logging.getLogger(__name__) class DeepSpeech2(nn.Module): @@ -134,7 +43,8 @@ def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hid_size=768, nb_layers=5 nn.Hardtanh(0, 20, inplace=True) )) # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1 - rnn_in_size = int(math.floor((sample_rate * window_size) / 2) + 1) + rnn_in_size = int(math.floor((sample_rate * window_size) / 4) + 1) + LOGGER.debug(f"Initial calculated feature size: {rnn_in_size}") rnn_in_size = int(math.floor(rnn_in_size + 2 * 20 - 41) / 2 + 1) rnn_in_size = int(math.floor(rnn_in_size + 2 * 10 - 21) / 2 + 1) rnn_in_size *= 32 @@ -158,6 +68,7 @@ def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hid_size=768, nb_layers=5 def forward(self, x, lengths): # if x.is_cuda and self.mixed_precision: # x = x.half() + LOGGER.debug(f"Actual initial size: {x.size()}") lengths = lengths.cpu().int() output_lengths = self.get_seq_lens(lengths) x, _ = self.conv(x, output_lengths) diff --git a/sonosco/models/deepspeech2_sonosco.py b/sonosco/models/deepspeech2_sonosco.py new file mode 100644 index 0000000..c95e7eb --- /dev/null +++ b/sonosco/models/deepspeech2_sonosco.py @@ -0,0 +1,115 @@ +import math + +from torch import nn +from .modules import MaskConv, BatchRNN, SequenceWise, InferenceBatchSoftmax, supported_rnns, supported_rnns_inv +from .serialization import serializable +from collections import OrderedDict +from dataclasses import field + + +@serializable +class DeepSpeech2(nn.Module): + rnn_type: nn.RNNBase = nn.LSTM + labels: str = "abc" + rnn_hid_size: int = 768 + nb_layers: int = 5 + audio_conf: dict = field(default_factory={}) + bidirectional: bool = True + version: str = '0.0.1' + + def __post__init__(self): + sample_rate = self.audio_conf.get("sample_rate", 16000) + window_size = self.audio_conf.get("window_size", 0.02) + num_classes = len(self.labels) + self.conv = MaskConv(nn.Sequential( + nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)), + nn.BatchNorm2d(32), + nn.Hardtanh(0, 20, inplace=True), + nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)), + nn.BatchNorm2d(32), + nn.Hardtanh(0, 20, inplace=True) + )) + # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1 + rnn_in_size = int(math.floor((sample_rate * window_size) / 2) + 1) + rnn_in_size = int(math.floor(rnn_in_size + 2 * 20 - 41) / 2 + 1) + rnn_in_size = int(math.floor(rnn_in_size + 2 * 10 - 21) / 2 + 1) + rnn_in_size *= 32 + + rnns = [('0', BatchRNN(input_size=rnn_in_size, hidden_size=self.rnn_hid_size, rnn_type=self.rnn_type, batch_norm=False))] + rnns.extend([(f"{x + 1}", BatchRNN(input_size=self.rnn_hid_size, hidden_size=self.rnn_hid_size, rnn_type=self.rnn_type)) + for x in range(self.nb_layers - 1)]) + self.rnns = nn.Sequential(OrderedDict(rnns)) + + fully_connected = nn.Sequential( + nn.BatchNorm1d(self.rnn_hid_size), + nn.Linear(self.rnn_hid_size, num_classes, bias=False) + ) + + self.fc = nn.Sequential( + SequenceWise(fully_connected), + ) + + self.inference_softmax = InferenceBatchSoftmax() + + def forward(self, x, lengths): + # if x.is_cuda and self.mixed_precision: + # x = x.half() + lengths = lengths.cpu().int() + output_lengths = self.get_seq_lens(lengths) + x, _ = self.conv(x, output_lengths) + + sizes = x.size() + x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # Collapse feature dimension + x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH + + for rnn in self.rnns: + x = rnn(x, output_lengths) + + if not self.bidirectional: # no need for lookahead layer in bidirectional + x = self.lookahead(x) + + x = self.fc(x) + x = x.transpose(0, 1) + # identity in training mode, softmax in eval mode + x = self.inference_softmax(x) + return x, output_lengths + + def get_seq_lens(self, input_length): + """ + Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable + containing the size sequences that will be output by the network. + :param input_length: 1D Tensor + :return: 1D Tensor scaled by model + """ + seq_len = input_length + for m in self.conv.modules(): + if type(m) == nn.modules.conv.Conv2d: + seq_len = ((seq_len + 2 * m.padding[1] - m.dilation[1] * (m.kernel_size[1] - 1) - 1) / m.stride[1] + 1) + return seq_len.int() + + @staticmethod + def get_param_size(model): + params = 0 + for p in model.parameters(): + tmp = 1 + for x in p.size(): + tmp *= x + params += tmp + return params + + def __repr__(self): + rep = f"DeepSpeech2 version: {self.version}\n" + \ + "=======================================" + \ + "Recurrent Neural Network Properties\n" + \ + f" RNN Type: \t{self.rnn_type.__name__.lower()}\n" + \ + f" RNN Layers:\t{self.hidden_layers}\n" + \ + f" RNN Size: \t{self.hidden_size}\n" + \ + f" Classes: \t{len(self.labels)}\n" + \ + "---------------------------------------\n" + \ + "Model Features\n" + \ + f" Labels: \t{self.labels}\n" + \ + f" Sample Rate: \t{self.audio_conf.get('sample_rate', 'n/a')}\n" + \ + f" Window Type: \t{self.audio_conf.get('window', 'n/a')}\n" + \ + f" Window Size: \t{self.audio_conf.get('window_size', 'n/a')}\n" + \ + f" Window Stride:\t{self.audio_conf.get('window_stride', 'n/a')}" + return rep diff --git a/sonosco/models/loader.py b/sonosco/models/loader.py new file mode 100644 index 0000000..66e0b11 --- /dev/null +++ b/sonosco/models/loader.py @@ -0,0 +1,89 @@ +import logging +import torch +import deprecation +import torch.nn as nn + +from sonosco.common.class_utils import get_constructor_args, get_class_by_name + +LOGGER = logging.getLogger(__name__) + + +class Loader: + + @deprecation.deprecated( + details="This type of loading may cause problems when path of model class changes. " + "Pleas use only when saved with save_model_simple method") + def load_model_simple(self, path: str): + """ + + Args: + path: + + Returns: + + """ + return torch.load(path) + + def load_model_from_path(self, cls_path: str, path: str, deserialize_method_name: str = 'deserialize') -> nn.Module: + """ + Loads the model from pickle file. + + If deserialize_method_name exists the deserialized content of pickle file in path is passed to the + deserialize_method_name method. In this case, + the responsibility of creating cls object stays at the caller side. + + Args: + cls_path (str): name of the class of the model + path (str): path to pickle-serialized model or model parameters + deserialize_method_name (str): name of the function that this method should call in order to deserialize the + model. Must accept single argument of type dict. + + + Returns (nn.Module): Loaded model + + """ + return self.load_model(get_class_by_name(cls_path), path, deserialize_method_name) + + def load_model(self, cls: type, path: str, deserialize_method_name: str = 'deserialize') -> nn.Module: + """ + Loads the model from pickle file. + + If deserialize_method_name exists the deserialized content of pickle file in path is passed to the + deserialize_method_name method. In this case, + the responsibility of creating cls object stays at the caller side. + + Args: + cls (type): class object of the model + path (str): path to pickle-serialized model or model parameters + deserialize_method_name (str): name of the function that this method should call in order to deserialize the + model. Must accept single argument of type dict. + + + Returns (nn.Module): Loaded model + + """ + package = torch.load(path, map_location=lambda storage, loc: storage) + if hasattr(cls, deserialize_method_name) and callable(getattr(cls, deserialize_method_name)): + return getattr(cls, deserialize_method_name)(package) + constructor_args = get_constructor_args(cls) + stored_keys = set(package.keys()) + stored_keys.remove('state_dict') + + args_to_apply = constructor_args & stored_keys + # If the lengths are not equal it means that there is some inconsistency between save and load + if len(args_to_apply) != len(constructor_args): + not_in_constructor = stored_keys - constructor_args + if not_in_constructor: + LOGGER.warning( + f"Following fields were deserialized " + f"but could not be found in constructor of provided class {not_in_constructor}") + not_in_package = constructor_args - stored_keys + if not_in_package: + LOGGER.warning( + f"Following fields exist in class constructor " + f"but could not be found in serialized package {not_in_package}") + + filtered_package = {key: package[key] for key in stored_keys} + model = cls(**filtered_package) + model.load_state_dict(package['state_dict']) + return model diff --git a/sonosco/models/modules.py b/sonosco/models/modules.py new file mode 100644 index 0000000..015dc5a --- /dev/null +++ b/sonosco/models/modules.py @@ -0,0 +1,101 @@ +import torch +import logging +import torch.nn as nn +import torch.nn.functional as functional + + +LOGGER = logging.getLogger(__name__) + +supported_rnns = { + 'lstm': nn.LSTM, + 'rnn': nn.RNN, + 'gru': nn.GRU +} + +supported_rnns_inv = dict((v, k) for k, v in supported_rnns.items()) + + +class SequenceWise(nn.Module): + def __init__(self, module): + """ + Collapses input of dim T*N*H to (T*N)*H, and applies to a module. + Allows handling of variable sequence lengths and minibatch sizes. + :param module: Module to apply input to. + """ + super(SequenceWise, self).__init__() + self.module = module + + def forward(self, x): + t, n = x.size(0), x.size(1) + x = x.view(t * n, -1) + x = self.module(x) + x = x.view(t, n, -1) + return x + + def __repr__(self): + tmpstr = self.__class__.__name__ + ' (\n' + tmpstr += self.module.__repr__() + tmpstr += ')' + return tmpstr + + +class MaskConv(nn.Module): + def __init__(self, seq_module): + """ + Adds padding to the output of the module based on the given lengths. This is to ensure that the + results of the model do not change when batch sizes change during inference. + Input needs to be in the shape of (BxCxDxT) + :param seq_module: The sequential module containing the conv stack. + """ + super(MaskConv, self).__init__() + self.seq_module = seq_module + + def forward(self, x, lengths): + """ + :param x: The input of size BxCxDxT + :param lengths: The actual length of each sequence in the batch + :return: Masked output from the module + """ + for module in self.seq_module: + x = module(x) + mask = torch.ByteTensor(x.size()).fill_(0) + if x.is_cuda: + mask = mask.cuda() + for i, length in enumerate(lengths): + length = length.item() + if (mask[i].size(2) - length) > 0: + mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1) + x = x.masked_fill(mask, 0) + return x, lengths + + +class InferenceBatchSoftmax(nn.Module): + def forward(self, input_): + if not self.training: + return functional.softmax(input_, dim=-1) + else: + return input_ + + +class BatchRNN(nn.Module): + def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, batch_norm=True, bidirectional=False): + super(BatchRNN, self).__init__() + self.bidirectional = bidirectional + self.input_size = input_size + self.hidden_size = hidden_size + self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None + self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size, + bidirectional=bidirectional, bias=True) + + def flatten_parameters(self): + self.rnn.flatten_parameters() + + def forward(self, x, output_lengths): + if self.batch_norm is not None: + x = self.batch_norm(x) + x = nn.utils.rnn.pack_padded_sequence(x, output_lengths) + x, h = self.rnn(x) + x, _ = nn.utils.rnn.pad_packed_sequence(x) + if self.bidirectional: + x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) by sum + return x diff --git a/sonosco/models/saver.py b/sonosco/models/saver.py new file mode 100644 index 0000000..4198a2c --- /dev/null +++ b/sonosco/models/saver.py @@ -0,0 +1,56 @@ +import logging +import torch +import deprecation +import torch.nn as nn + +from .serialization import is_serializable + +LOGGER = logging.getLogger(__name__) + + +class Saver: + + def __init__(self) -> None: + super().__init__() + + @deprecation.deprecated( + details="This type of saving may cause problems when path of model class changes. Pleas use save_model instead") + def save_model_simple(self, model: nn.Module, path: str) -> None: + """ + Simply saves the model using pickle protocol. + Args: + model (nn.Module): model to save + path (str) : path where to save the model + + Returns: + + """ + torch.save(model, path) + + def save_model(self, model: nn.Module, path: str) -> None: + """ + Saves the model using pickle protocol. + + If the infer_structure is True this method infers all the meta parameters of the model and save them together + with learnable parameters. + + If the infer_structure is False and method specified by serialize_method_name exists, the return value of the + serialize_method_name method is saved. + + If neither of above only learnable parameters a.k.a. state_dict are saved. + + Args: + model (nn.Module): model to save + path (str) : path where to save the model + infer_structure (bool): indicator whether to infer the model structure + serialize_method_name (str): name of the function that this method should call in order to serialize the + model. Must return dict. + + Returns: + + """ + if is_serializable(model): + entity_to_save = model.__serialize__() + torch.save(entity_to_save, path) + else: + raise TypeError("Only @serializable class can be serialized") diff --git a/sonosco/models/serialization.py b/sonosco/models/serialization.py new file mode 100644 index 0000000..4f824af --- /dev/null +++ b/sonosco/models/serialization.py @@ -0,0 +1,86 @@ +from dataclasses import _process_class, _create_fn, _set_new_attribute, fields, is_dataclass +__primitives = {int, float, str, bool} +__iterables = [list, set, tuple] + + +def serializable(_cls=None): + """ + + Returns the same class as was passed in, with init and serialize methods. + + + Args: + _cls: + + Returns: + + """ + + def wrap(cls): + cls = _process_class(cls, init=True, repr=False, eq=False, order=False, unsafe_hash=False, frozen=False) + _set_new_attribute(cls, '__serialize__', __add_serialize(cls)) + return cls + + # See if we're being called as @dataclass or @dataclass(). + if _cls is None: + # We're called with parens. + return wrap + + # We're called as @dataclass without parens. + return wrap(_cls) + + +def is_serializable(obj): + return hasattr(obj, '__serialize__') + + +def __add_serialize(cls): + fields_to_serialize = fields(cls) + sonosco_self = ['__sonosco_self__' if 'self' in fields_to_serialize else 'self'] + serialize_body = __create_serialize_body(cls, fields_to_serialize) + return _create_fn('__serialize__', [sonosco_self], [f'return {serialize_body}'], return_type=dict) + + +def __create_serialize_body(cls, fields_to_serialize): + body_lines = ["{"] + for field in fields_to_serialize: + if __is_primitive(field) or __is_iterable_of_primitives(field): + body_lines.append(__create_dict_entry(field.name, f"self.{field.name}")) + elif is_dataclass(field.type): + body_lines.append(__create_dict_entry(field.name, f"self.{field.name}.__serialize__()")) + elif __is_nn_class(field.type): + body_lines.append(f"'{field.name}': {{") + __extract_from_nn(cls, body_lines) + body_lines.append("}") + else: + __throw_unsupported_data_type() + body_lines.append(__create_dict_entry("state_dict", "self.state_dict()")) + body_lines.append("}") + return body_lines + + +def __extract_from_nn(cls, body_lines): + constants = list(filter(lambda el: not el.startswith('_'), cls.__constants__)) + for constant in constants: + body_lines.append(__create_dict_entry(constant, f"self.{constant}")) + + +def __is_iterable_of_primitives(field): + return field.__origin__ in __iterables and field.__args__[0] in __primitives + + +def __throw_unsupported_data_type(): + raise TypeError("Unsupported data type. Only primitives, lists of primitives, torch.nn.Module" + "@serializable and @dataclass objects can be serialized") + + +def __create_dict_entry(key, value): + return f"'{key}': {value}," + + +def __is_primitive(obj): + return obj.type in __primitives + + +def __is_nn_class(cls): + return hasattr(cls, '__constants__') diff --git a/modelwrapper.py b/sonosco/modelwrapper.py similarity index 98% rename from modelwrapper.py rename to sonosco/modelwrapper.py index 3a0ec81..b313603 100644 --- a/modelwrapper.py +++ b/sonosco/modelwrapper.py @@ -14,8 +14,13 @@ import torch.distributed as dist import torch.utils.data.distributed -from apex.fp16_utils import FP16_Optimizer -from apex.parallel import DistributedDataParallel + +try: + from apex.fp16_utils import FP16_Optimizer + from apex.parallel import DistributedDataParallel +except Exception as e: + print(f"Apex import failed: {e}") + from tqdm import tqdm from warpctc_pytorch import CTCLoss @@ -320,7 +325,7 @@ def train(self, seed, cuda, mixed_precision, world_size, gpu_rank, rank, save_fo train_sampler.shuffle(epoch) def validate(self): - + pass def test(self): torch.set_grad_enabled(False) @@ -386,8 +391,7 @@ def test(self): if save_output: np.save(output_path, output_data) - -def infer(self, sound): + def infer(self, sound): pass @staticmethod @@ -401,7 +405,6 @@ def get_default_path(def_path: str) -> str: default = latest_subdir + "/final.pth" return default - def print_training_info(self, epoch, loss, cer, wer): print(f"\nTraining Information\n " + \ f"- Epoch:\t{epoch}\n " + \ diff --git a/sonosco/run_training.py b/sonosco/run_training.py new file mode 100644 index 0000000..76e0a4e --- /dev/null +++ b/sonosco/run_training.py @@ -0,0 +1,46 @@ +import logging +import click +import torch.nn.functional as torch_functional + +from sonosco.common.constants import SONOSCO +from sonosco.common.utils import setup_logging +from sonosco.common.path_utils import parse_yaml +from sonosco.training import Experiment, ModelTrainer +from sonosco.datasets import create_data_loaders +from sonosco.models import DeepSpeech2 + +LOGGER = logging.getLogger(SONOSCO) + + +@click.command() +@click.option("-e", "--experiment_name", default="default", type=click.STRING, help="Experiment name.") +@click.option("-c", "--config_path", default="config/train.yaml", type=click.STRING, + help="Path to train configurations.") +def main(experiment_name, config_path): + Experiment.create(experiment_name) + config = parse_yaml(config_path)["train"] + + train_loader, val_loader = create_data_loaders(**config) + + def custom_loss(batch, model): + batch_x, batch_y, input_lengths, target_lengths = batch + model_output, output_lengths = model(batch_x, input_lengths) + loss = torch_functional.ctc_loss(model_output.transpose(0, 1), batch_y, output_lengths, target_lengths) + return loss, model_output + + # TODO: change to load different models dynamically + model = DeepSpeech2(labels=config["labels"]) + + trainer = ModelTrainer(model, loss=custom_loss, epochs=config["max_epochs"], + train_data_loader=train_loader, val_data_loader=val_loader, + lr=config["learning_rate"], custom_model_eval=True) + + try: + trainer.start_training() + except KeyboardInterrupt: + trainer.stop_training() + + +if __name__ == '__main__': + setup_logging(LOGGER) + main() diff --git a/test.py b/sonosco/test.py similarity index 100% rename from test.py rename to sonosco/test.py diff --git a/sonosco/training/__init__.py b/sonosco/training/__init__.py new file mode 100644 index 0000000..05caba2 --- /dev/null +++ b/sonosco/training/__init__.py @@ -0,0 +1,2 @@ +from .experiment import Experiment +from .trainer import ModelTrainer diff --git a/sonosco/training/abstract_callback.py b/sonosco/training/abstract_callback.py new file mode 100644 index 0000000..8d920b6 --- /dev/null +++ b/sonosco/training/abstract_callback.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + + +class AbstractCallback(ABC): + """ + Interface that defines how callbacks must be specified. + """ + + @abstractmethod + def __call__(self, epoch, step, performance_measures, context): + """ + Called after every batch by the ModelTrainer. + Parameters: + epoch (int): current epoch number + step (int): current batch number + performance_measures (dict): losses and metrics based on a running average + context (ModelTrainer): reference to the calling ModelTrainer, allows to access members + """ + pass + + def close(self): + """ + Handle cleanup work if necessary. Will be called at the end of the last epoch. + """ + pass diff --git a/sonosco/training/early_stopping.py b/sonosco/training/early_stopping.py new file mode 100644 index 0000000..8eefa6b --- /dev/null +++ b/sonosco/training/early_stopping.py @@ -0,0 +1,45 @@ +import logging +import sys + +from .abstract_callback import AbstractCallback + + +LOGGER = logging.getLogger(__name__) + + +class EarlyStopping(AbstractCallback): + """ + Early Stopping to terminate training early if the monitored metric did not improve + over a number of epochs. + Args: + monitor (string): name of the relevant loss or metric (usually 'val_loss') + min_delta (float): minimum change in monitored metric to qualify as an improvement + patience (int): number of epochs to wait for an improvement before terminating the training + """ + + def __init__(self, monitor='val_loss', min_delta=0, patience=5): + self.monitor = monitor + self.min_delta = min_delta + self.patience = patience + self.last_best = sys.float_info.max + self.counter = 0 + self.stopped_epoch = 0 + + def __call__(self, epoch, step, performance_measures, context): + + if step != len(context.train_data_loader) - 1: # only continue at end of epoch + return + + if self.monitor not in performance_measures: + return + + current_loss = performance_measures[self.monitor] + if (self.last_best - current_loss) >= self.min_delta: + self.last_best = current_loss + self.counter = 0 + else: + self.counter += 1 + + if self.counter >= self.patience: + context._stop_training = True # make ModelTrainer stop + LOGGER.info(f"Early stopping after epoch {epoch}") diff --git a/sonosco/training/experiment.py b/sonosco/training/experiment.py new file mode 100644 index 0000000..07fd1cf --- /dev/null +++ b/sonosco/training/experiment.py @@ -0,0 +1,108 @@ +import os +import os.path as path +import datetime +import logging +import sonosco.common.path_utils as path_utils +import sonosco.common.utils as utils + +from time import time + + +LOGGER = logging.getLogger(__name__) + + +class Experiment: + """ + Generates a folder where all experiments will be stored an then a named experiment with current + timestamp and provided name. Automatically starts logging the console output and creates a copy + of the currently executed code in the experiment folder. The experiment's subfolder paths are provided + to the outside as member variables. It also allows adding of more subfolders conveniently. + Args: + experiment_name (string): name of the exerpiment to be created + experiments_path (string): location where all experiments will be stored, default is './experiments' + Example: + >>> experiment = Experiment('mnist_classification') + >>> print(experiment.plots) # path to experiment plots + """ + + def __init__(self, + experiment_name, + experiments_path=None, + sub_directories=("plots", "logs", "code"), + exclude_dirs=('__pycache__', '.git', 'experiments'), + exclude_files=('.pyc',)): + + self.experiments_path = self._set_experiments_dir(experiments_path) + self.name = self._set_experiment_name(experiment_name) + self.path = path.join(self.experiments_path, self.name) # path to current experiment + self.logs = path.join(self.experiments_path, "logs") + + self.code = path.join(self.experiments_path, "code") + self._sub_directories = sub_directories + + self._exclude_dirs = exclude_dirs + self._exclude_files = exclude_files + + self._init_directories() + self._copy_sourcecode() + self._set_logging() + + @staticmethod + def _set_experiments_dir(experiments_path): + if experiments_path is not None: + return experiments_path + + local_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + local_path = local_path if local_path != '' else './' + return path.join(local_path, "experiments") + + @staticmethod + def _set_experiment_name(experiment_name): + date_time = datetime.datetime.fromtimestamp(time()).strftime('%Y-%m-%d_%H:%M:%S') + return f"{date_time}_{experiment_name}" + + def _set_logging(self): + utils.add_log_file(self.logs, LOGGER) + + def _init_directories(self): + """ Create all basic directories. """ + path_utils.try_create_directory(self.experiments_path) + path_utils.try_create_directory(path.join(self.experiments_path, self.name)) + for sub_dir_name in self._sub_directories: + self.add_directory(sub_dir_name) + + def _add_member(self, key, value): + """ Add a member variable named 'key' with value 'value' to the experiment instance. """ + self.__dict__[key] = value + + def _copy_sourcecode(self): + """ Copy code from execution directory in experiment code directory. """ + sources_path = os.path.dirname(os.path.dirname(__file__)) + sources_path = sources_path if sources_path != '' else './' + utils.copy_code(sources_path, self.code, + exclude_dirs=self._exclude_dirs, + exclude_files=self._exclude_files) + + def add_directory(self, dir_name): + """ + Add a sub-directory to the experiment. The directory will be automatically + created and provided to the outside as a member variable. + """ + # store in sub-dir list + if dir_name not in self._sub_directories: + self._sub_directories.append(dir_name) + # add as member + dir_path = path.join(self.experiments_path, self.name, dir_name) + self._add_member(dir_name, dir_path) + # create directory + path_utils.try_create_directory(dir_path) + + @staticmethod + def add_file(folder_path, filename, content): + """ Adds a file with provided content to folder. Convenience function. """ + with open(path.join(folder_path, filename), 'w') as text_file: + text_file.write(content) + + @staticmethod + def create(name: str): + return Experiment(name) diff --git a/sonosco/training/gradient_collector.py b/sonosco/training/gradient_collector.py new file mode 100644 index 0000000..386b874 --- /dev/null +++ b/sonosco/training/gradient_collector.py @@ -0,0 +1,50 @@ +import logging +import numpy as np +import torch + +from .abstract_callback import AbstractCallback + + +LOGGER = logging.getLogger(__name__) + + +class LayerwiseGradientNorm(AbstractCallback): + """ Collects the layer-wise gradient norms for each epoch. """ + + def __init__(self): + self.layer_grads = dict() + self._batch_layer_grads = dict() + + def __call__(self, epoch, step, performance_measures, context): + """ + Store gradient norms for each batch and compute means after the + epoch's last batch. + """ + self._store_batch_layer_grads(context.model) + + if step == (len(context.train_data_loader) - 1): # end of epoch + self._store_layer_grads() + self._batch_layer_grads = dict() + + def _store_batch_layer_grads(self, model): + """ Store gradient norm of each layer for current batch. """ + for name, param in model.named_parameters(): + + if not param.requires_grad or param.grad is None: + continue + + if not name in self._batch_layer_grads: + self._batch_layer_grads[name] = [] + + grad_norm = torch.sqrt(torch.sum(param.grad**2)).item() + self._batch_layer_grads[name].append(grad_norm) + + def _store_layer_grads(self): + """ Compute mean of all batch steps in epoch. """ + for name, grads in self._batch_layer_grads.items(): + + if name not in self.layer_grads: + self.layer_grads[name] = [] + + layer_epoch_grad = np.mean(grads) + self.layer_grads[name].append(layer_epoch_grad) diff --git a/sonosco/training/history_recorder.py b/sonosco/training/history_recorder.py new file mode 100644 index 0000000..5737a7b --- /dev/null +++ b/sonosco/training/history_recorder.py @@ -0,0 +1,26 @@ +import logging +import torch + +from collections import defaultdict +from .abstract_callback import AbstractCallback + + +LOGGER = logging.getLogger(__name__) + + +class HistoryRecorder(AbstractCallback): + """ Records all losses and metrics during training. """ + + def __init__(self, epoch_steps): + self.history = defaultdict(list) + self._epoch_steps = epoch_steps + + def __call__(self, epoch, step, performance_measures, context): + + if step % self._epoch_steps == 0: # only record at end of epoch + return + + for key, value in performance_measures.items(): + if type(value) == torch.Tensor: + value = value.item() + self.history[key].append(value) diff --git a/sonosco/training/learning_rates.py b/sonosco/training/learning_rates.py new file mode 100644 index 0000000..e977514 --- /dev/null +++ b/sonosco/training/learning_rates.py @@ -0,0 +1,131 @@ +import logging +import sys + +from .abstract_callback import AbstractCallback + + +LOGGER = logging.getLogger(__name__) + + +class StepwiseLearningRateReduction(AbstractCallback): + """ + Reduces the learning rate of the optimizer every N epochs. + Args: + epoch_steps (int): number of epochs after which learning rate is reduced + reduction_factor (float): multiplicative factor for learning rate reduction + min_lr (float): lower bound for learning rate + """ + + def __init__(self, epoch_steps, reduction_factor, min_lr=None): + self._epoch_steps = epoch_steps + self._reduction_factor = reduction_factor + self._min_lr = min_lr + + def __call__(self, epoch, step, performance_measures, context): + # execute at the beginning of every Nth epoch + if epoch > 0 and step == 0 and epoch % self._epoch_steps == 0: + + # reduce lr for each param group (necessary for e.g. Adam) + for param_group in context.optimizer.param_groups: + new_lr = param_group['lr'] * self._reduction_factor + + if self._min_lr is not None and new_lr < self._min_lr: + continue + + param_group['lr'] = new_lr + LOGGER.info("Epoch {}: Reducing learning rate to {}".format(epoch, new_lr)) + + +class ScheduledLearningRateReduction(AbstractCallback): + """ + Reduces the learning rate of the optimizer for every scheduled epoch. + Args: + epoch_schedule (list of int): defines at which epoch the learning rate will be reduced + reduction_factor (float): multiplicative factor for learning rate reduction + min_lr (float): lower bound for learning rate + """ + + def __init__(self, epoch_schedule, reduction_factor, min_lr=None): + self._epoch_schedule = sorted(epoch_schedule) + self._reduction_factor = reduction_factor + self._min_lr = min_lr + + def __call__(self, epoch, step, performance_measures, context): + + if not self._epoch_schedule: # stop if schedule is empty + return + + next_epoch_step = self._epoch_schedule[0] + if epoch >= next_epoch_step and step == 0: + + # reduce lr for each param group (necessary for e.g. Adam) + for param_group in context.optimizer.param_groups: + new_lr = param_group['lr'] * self._reduction_factor + + if self._min_lr is not None and new_lr < self._min_lr: + continue + + param_group['lr'] = new_lr + LOGGER.info("Epoch {}: Reducing learning rate to {}".format(epoch, new_lr)) + + self._epoch_schedule.pop(0) + + +class ReduceLROnPlateau(AbstractCallback): + """ + Reduce the learning rate if the train or validation loss plateaus. + Args: + monitor (string): name of the relevant loss or metric (usually 'val_loss') + factor (float): factor by which the lr is decreased at each step + patience (int): number of epochs to wait on plateau for loss improvement before reducing lr + min_delta (float): minimum improvement necessary to reset patience + cooldown (int): number of epochs to cooldown after a lr reduction + min_lr (float): minimum value the learning rate can decrease to + verbose (bool): print to console + """ + + def __init__(self, monitor='val_loss', factor=0.1, patience=10, min_delta=0, cooldown=0, min_lr=0, verbose=False): + self.monitor = monitor + if factor >= 1.0 or factor < 0: + raise ValueError('ReduceLROnPlateau does only support a factor in [0,1[.') + self.factor = factor + self.min_lr = min_lr + self.min_delta = min_delta + self.patience = patience + self.verbose = verbose + self.cooldown = cooldown + self.cooldown_counter = 0 + self.wait = 0 + self.best_loss = sys.float_info.max + + def __call__(self, epoch, step, performance_measures, context): + + if self.monitor not in performance_measures: + return + + if step != len(context.train_data_loader)-1: # only continue at end of epoch + return + + if self.cooldown_counter > 0: # in cooldown phase + self.cooldown_counter -= 1 + self.wait = 0 + + current_loss = performance_measures[self.monitor] + if (self.best_loss - current_loss) >= self.min_delta: # loss improved, save and reset wait counter + self.best_loss = current_loss + self.wait = 0 + + elif self.cooldown_counter <= 0: # no improvement and not in cooldown + + if self.wait >= self.patience: # waited long enough, reduce lr + for param_group in context.optimizer.param_groups: + old_lr = param_group['lr'] + new_lr = old_lr * self.factor + if new_lr >= self.min_lr: # only decrease if there is still enough buffer space + if self.verbose: + LOGGER.info("Epoch {}: Reducing learning rate from {} to {}".format(epoch, old_lr, new_lr)) #TODO print per param group? + param_group['lr'] = new_lr + self.cooldown_counter = self.cooldown # new cooldown phase after lr reduction + self.wait = 0 + else: + self.wait += 1 diff --git a/sonosco/training/model_checkpoint.py b/sonosco/training/model_checkpoint.py new file mode 100644 index 0000000..78f8f4a --- /dev/null +++ b/sonosco/training/model_checkpoint.py @@ -0,0 +1,43 @@ +import logging +import sys +import os.path as path +import torch + +from .abstract_callback import AbstractCallback + + +LOGGER = logging.getLogger(__name__) + + +class ModelCheckpoint(AbstractCallback): + """ + Saves the model and optimizer state at the point with lowest validation error throughout training. + Args: + output_path (string): path to directory where the checkpoint will be saved to + model_name (string): name of the checkpoint file + """ + + def __init__(self, output_path, model_name='model_checkpoint.pt'): + self.output_path = path.join(output_path, model_name) + self.best_val_score = sys.float_info.max + + def __call__(self, epoch, step, performance_measures, context): + + if 'val_loss' not in performance_measures: + return + + if performance_measures['val_loss'] < self.best_val_score: + self.best_val_score = performance_measures['val_loss'] + self._save_checkpoint(context.model, context.optimizer, epoch) + + def _save_checkpoint(self, model, optimizer, epoch): + LOGGER.info("Saving model at checkpoint.") + model.eval() + model_state_dict = model.state_dict() + optimizer_state_dict = optimizer.state_dict() + torch.save({'arch': model.__class__.__name__, + 'epoch': epoch, + 'model_state_dict': model_state_dict, + 'optimizer_state_dict': optimizer_state_dict + }, self.output_path) + model.train() diff --git a/sonosco/training/trainer.py b/sonosco/training/trainer.py new file mode 100644 index 0000000..7c9bd83 --- /dev/null +++ b/sonosco/training/trainer.py @@ -0,0 +1,250 @@ +import logging +import torch +import torch.optim.optimizer +import torch.nn.utils.clip_grad as grads + +from collections import defaultdict +from typing import Callable, Union, Tuple, List, Any +from torch.utils.data import DataLoader +from .abstract_callback import AbstractCallback + + +LOGGER = logging.getLogger(__name__) + + +class ModelTrainer: + """ + This class handles the training of a pytorch model. It provides convenience + functionality to add metrics and callbacks and is inspired by the keras API. + Args: + model (nn.Module): model to be trained + optimizer (optim.Optimizer): optimizer used for training, e.g. torch.optim.Adam + loss (function): loss function that either accepts (model_output, label) or (input, label, model) if custom_model_eval is true + epochs (int): epochs to train + train_data_loader (utils.data.DataLoader): training data + val_data_loader (utils.data.DataLoader, optional): validation data + custom_model_eval (boolean, optional): enables training mode where the model is evaluated in the loss function + gpu (int, optional): if not set training runs on cpu, otherwise an int is expected that determines the training gpu + clip_grads (float, optional): if set training gradients will be clipped at specified norm + """ + + def __init__(self, + model: torch.nn.Module, + loss: Union[Callable[[Any, Any], Any], + Callable[[torch.Tensor, torch.Tensor, torch.nn.Module], float]], + epochs: int, + train_data_loader: DataLoader, + val_data_loader: DataLoader = None, + optimizer=torch.optim.Adam, + lr: float = 1e-4, + custom_model_eval: bool = False, + gpu: int = None, + clip_grads: float = None, + metrics: List[Callable[[torch.Tensor, Any], Union[float, torch.Tensor]]] = None, + callbacks: List[AbstractCallback] = None): + + self.model = model + self.train_data_loader = train_data_loader + self.val_data_loader = val_data_loader + self.optimizer = optimizer(self.model.parameters(), lr=lr) + self.loss = loss + self._epochs = epochs + self._metrics = metrics if metrics is not None else list() + self._callbacks = callbacks if callbacks is not None else list() + self._gpu = gpu + self._custom_model_eval = custom_model_eval + self._clip_grads = clip_grads + self._stop_training = False # used stop training externally + + def set_metrics(self, metrics): + """ + Set metric functions that receive y_pred and y_true. Metrics are expected to return + a basic numeric type like float or int. + """ + self._metrics = metrics + + def add_metric(self, metric): + self._metrics.append(metric) + + def set_callbacks(self, callbacks): + """ + Set callbacks that are callable functionals and receive epoch, step, loss, context. + Context is a pointer to the ModelTrainer instance. Callbacks are called after each + processed batch. + """ + self._callbacks = callbacks + + def add_callback(self, callback): + self._callbacks.append(callback) + + def start_training(self): + self.model.train() # train mode + for epoch in range(1, self._epochs + 1): + self._epoch_step(epoch) + + if self._stop_training: + break + + self._close_callbacks() + + def _epoch_step(self, epoch): + """ Execute one training epoch. """ + running_batch_loss = 0 + running_metrics = defaultdict(float) + + for step, (batch_x, batch_y, input_lengths, target_lengths) in enumerate(self.train_data_loader): + batch = (batch_x, batch_y, input_lengths, target_lengths) + batch = self._recursive_to_cuda(batch) # move to GPU + + # compute training batch + loss, model_output, grad_norm = self._train_on_batch(batch) + running_batch_loss += loss.item() + + # compute metrics + self._compute_running_metrics(model_output, batch, running_metrics) + running_metrics['gradient_norm'] += grad_norm # add grad norm to metrics + + # evaluate validation set at end of epoch + if self.val_data_loader and step == (len(self.train_data_loader) - 1): + self._compute_validation_error(running_metrics) + + # print current loss and metrics and provide it to callbacks + performance_measures = self._construct_performance_dict(step, running_batch_loss, running_metrics) + self._print_step_info(epoch, step, performance_measures) + self._apply_callbacks(epoch, step, performance_measures) + + def stop_training(self): + self._stop_training = True + + def _comp_gradients(self): + """ Compute the gradient norm for all model parameters. """ + grad_sum = 0 + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + grad_sum += torch.sum(param.grad ** 2) + grad_norm = torch.sqrt(grad_sum).item() + return grad_norm + + def _train_on_batch(self, batch): + """ Compute loss depending on settings, compute gradients and apply optimization step. """ + # evaluate loss + batch_x, batch_y, input_lengths, target_lengths = batch + if self._custom_model_eval: + loss, model_output = self.loss(batch, self.model) + else: + model_output = self.model(batch_x, input_lengths) + loss = self.loss(model_output, batch_y) + + self.optimizer.zero_grad() # reset gradients + loss.backward() # backpropagation + + # gradient clipping + if self._clip_grads is not None: + grads.clip_grad_norm(self.model.parameters(), self._clip_grads) + + grad_norm = self._comp_gradients() # compute average gradient norm + + self.optimizer.step() # apply optimization step + return loss, model_output, grad_norm + + def _compute_validation_error(self, running_metrics): + """ Evaluate the model's validation error. """ + running_val_loss = 0 + + self.model.eval() + for batch in self.val_data_loader: + batch = self._recursive_to_cuda(batch) + + # evaluate loss + batch_x, batch_y = batch + if self._custom_model_eval: # e.g. used for sequences and other complex model evaluations + val_loss, model_output = self.loss(batch, self.model) + else: + model_output = self.model(batch_x) + val_loss = self.loss(model_output, batch_y) + + # compute running validation loss and metrics. add 'val_' prefix to all measures. + running_val_loss += val_loss.item() + self._compute_running_metrics(model_output, batch, running_metrics, prefix='val_') + self.model.train() + + # add loss to metrics and normalize all validation measures + running_metrics['val_loss'] = running_val_loss + for key, value in running_metrics.items(): + if 'val_' not in key: + continue + running_metrics[key] = value / len(self.val_data_loader) + + def _compute_running_metrics(self, + y_pred: torch.Tensor, + batch: Tuple[torch.Tensor, torch.Tensor], + running_metrics: dict, + prefix: str = ''): + """ + Computes all metrics based on predictions and batches and adds them to the metrics + dictionary. Allows to prepend a prefix to the metric names in the dictionary. + """ + for metric in self._metrics: + if self._custom_model_eval: + metric_result = metric(y_pred, batch) + else: + batch_y = batch[1] + metric_result = metric(y_pred, batch_y) + + # convert to float if metric returned tensor + if type(metric_result) == torch.Tensor: + metric_result = metric_result.item() + + running_metrics[prefix + metric.__name__] += metric_result + + def _construct_performance_dict(self, train_step, running_batch_loss, running_metrics): + """ + Constructs a combined dictionary of losses and metrics for callbacks based on + the current running averages. + """ + performance_dict = defaultdict() + for key, value in running_metrics.items(): + if 'val_' not in key: + performance_dict[key] = value / (train_step + 1.) + else: + performance_dict[key] = value # validation metrics, already normalized + + performance_dict['loss'] = running_batch_loss / (train_step + 1.) + return performance_dict + + def _apply_callbacks(self, epoch, step, performance_measures): + """ Call all registered callbacks with current batch information. """ + for callback in self._callbacks: + callback(epoch, step, performance_measures, self) + + def _close_callbacks(self): + """ Signal callbacks training is finished. """ + for callback in self._callbacks: + callback.close() + + def _print_step_info(self, epoch, step, performance_measures): + """ Print running averages for loss and metrics during training. """ + output_message = "epoch {} batch {}/{}".format(epoch, step, len(self.train_data_loader) - 1) + delim = " " + for metric_name in sorted(list(performance_measures.keys())): + if metric_name == 'gradient_norm': + continue + output_message += delim + "{}: {:.6f}".format(metric_name, performance_measures[metric_name]) + LOGGER.info(output_message) + + def _recursive_to_cuda(self, tensors): + """ + Recursively iterates nested lists in depth-first order and transfers all tensors + to specified cuda device. + Parameters: + tensors (list or Tensor): list of tensors or tensor tuples, can be nested + """ + if self._gpu is None: # keep on cpu + return tensors + + if type(tensors) != list: # not only for torch.Tensor + return tensors.to(device=self._gpu) + + for i in range(len(tensors)): + tensors[i] = self._recursive_to_cuda(tensors[i]) + return tensors diff --git a/utils.py b/sonosco/utils.py similarity index 89% rename from utils.py rename to sonosco/utils.py index 5e219d1..72d8353 100644 --- a/utils.py +++ b/sonosco/utils.py @@ -1,5 +1,11 @@ import torch -from apex.fp16_utils import BN_convert_float + +try: + from apex.fp16_utils import BN_convert_float +except Exception as e: + print(f"Apex import failed: {e}") + + import torch.distributed as dist from models.deepspeech2 import DeepSpeech2 @@ -43,7 +49,7 @@ def check_loss(loss, loss_value): def load_model(device, model_path, is_cuda): - model = DeepSpeech.load_model(model_path) + model = DeepSpeech2.load_model(model_path) model.eval() model = model.to(device) if is_cuda and model.mixed_precision: diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..2acd1a9 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,79 @@ +import logging +import os +import pytest +import numpy as np +import librosa + +from sonosco.common.constants import SONOSCO +from sonosco.common.utils import setup_logging +from sonosco.datasets.dataset import AudioDataset, AudioDataProcessor +from sonosco.datasets.samplers import BucketingSampler +from sonosco.datasets.loader import DataLoader +from sonosco.datasets.download_datasets.librispeech import try_download_librispeech + + +LIBRI_SPEECH_DIR = "temp/test_data/libri_speech" +TEST_WAVS_DIR = "test_wavs" +SAMPLE_RATE = 16000 + + +@pytest.fixture +def logger(): + logger = logging.getLogger(SONOSCO) + setup_logging(logger) + return logger + + +def test_librispeech_download(logger): + # prepare + if os.path.exists(LIBRI_SPEECH_DIR): + os.removedirs(LIBRI_SPEECH_DIR) + + # get manifest file + manifest_directory = os.path.join(os.path.expanduser("~"), LIBRI_SPEECH_DIR) + test_manifest = os.path.join(manifest_directory, "libri_test_clean_manifest.csv") + + if not os.path.exists(test_manifest): + logger.info("Starting to download dataset") + try_download_librispeech(LIBRI_SPEECH_DIR, 16000, ["test-clean.tar.gz", "test-other.tar.gz"], 1, 15) + + assert os.path.exists(test_manifest) + + +def test_librispeech_clean(logger): + # create data processor + audio_conf = dict(sample_rate=SAMPLE_RATE, window_size=.02, window_stride=.01, + labels='ABCDEFGHIJKLMNOPQRSTUVWXYZ', normalize=True, augment=False) + processor = AudioDataProcessor(**audio_conf) + + # get manifest file + manifest_directory = os.path.join(os.path.expanduser("~"), LIBRI_SPEECH_DIR) + test_manifest = os.path.join(manifest_directory, "libri_test_clean_manifest.csv") + + if not os.path.exists(test_manifest): + try_download_librispeech(LIBRI_SPEECH_DIR, SAMPLE_RATE, ["test-clean.tar.gz", "test-other.tar.gz"], 1, 15) + + assert os.path.exists(test_manifest) + + # create audio dataset + test_dataset = AudioDataset(processor, manifest_filepath=test_manifest) + logger.info("Dataset is created") + + if os.path.exists(TEST_WAVS_DIR): + os.removedirs(TEST_WAVS_DIR) + + os.makedirs(TEST_WAVS_DIR) + + n_samples = len(test_dataset) + + ids = np.random.randint(n_samples, size=min(10, n_samples)) + + for index in ids: + sound, transcription = test_dataset.get_raw(index) + librosa.output.write_wav(os.path.join(TEST_WAVS_DIR, f"audio_{index}.wav"), sound, SAMPLE_RATE) + + # batch_size = 16 + # sampler = BucketingSampler(test_dataset, batch_size=batch_size) + # dataloader = DataLoader(dataset=test_dataset, num_workers=4, batch_sampler=sampler) + # test_dataset[0] + diff --git a/train.py b/train.py deleted file mode 100644 index a763a47..0000000 --- a/train.py +++ /dev/null @@ -1,18 +0,0 @@ -import argparse -from typing import Dict - -import yaml - -from modelwrapper import ModelWrapper - -parser = argparse.ArgumentParser(description='ASR training') -parser.add_argument('--config', metavar='DIR', - help='Path to train config file', default='config/train.yaml') - -if __name__ == '__main__': - args = parser.parse_args() - with open(args.config, 'r') as file: - config = yaml.load(file) - config_dict: Dict = config["train"] - model = ModelWrapper(**config_dict) - model.train()