diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..bb29d8c --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.pt filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000..32f7bbd --- /dev/null +++ b/README.md @@ -0,0 +1,47 @@ +# AudioCLIP +## Extending [CLIP](https://github.com/openai/CLIP) to Image, Text and Audio + +This repository contains implementation of the models described in the paper [arXiv:2106.13043](https://arxiv.org/abs/2106.13043). +This work based on our previous works: +* [ESResNe(X)t-fbsp: Learning Robust Time-Frequency Transformation of Audio (2021)](https://github.com/AndreyGuzhov/ESResNeXt-fbsp). +* [ESResNet: Environmental Sound Classification Based on Visual Domain Models (2020)](https://github.com/AndreyGuzhov/ESResNet). + +### Abstract + +In the past, the rapidly evolving field of sound classification greatly benefited from the application of methods from other domains. +Today, we observe the trend to fuse domain-specific tasks and approaches together, which provides the community with new outstanding models. + +In this work, we present an extension of the CLIP model that handles audio in addition to text and images. +Our proposed model incorporates the ESResNeXt audio-model into the CLIP framework using the AudioSet dataset. +Such a combination enables the proposed model to perform bimodal and unimodal classification and querying, while keeping CLIP's ability to generalize to unseen datasets in a zero-shot inference fashion. + +AudioCLIP achieves new state-of-the-art results in the Environmental Sound Classification (ESC) task, out-performing other approaches by reaching accuracies of 90.07% on the UrbanSound8K and 97.15% on the ESC-50 datasets. +Further it sets new baselines in the zero-shot ESC-task on the same datasets (68.78% and 69.40%, respectively). + +Finally, we also assess the cross-modal querying performance of the proposed model as well as the influence of full and partial training on the results. +For the sake of reproducibility, our code is published. + +### How to Run the Model + +The required Python version is >= 3.7. + +#### AudioCLIP + +##### On the [ESC-50](https://github.com/karolpiczak/ESC-50) dataset + python main.py --config protocols/audioclip-esc50.json --Dataset.args.root /path/to/ESC50 + +##### On the [UrbanSound8K](https://urbansounddataset.weebly.com/) dataset + python main.py --config protocols/audioclip-us8k.json --Dataset.args.root /path/to/UrbanSound8K + +### Cite Us + +``` +@misc{guzhov2021audioclip, + title={AudioCLIP: Extending CLIP to Image, Text and Audio}, + author={Andrey Guzhov and Federico Raue and Jörn Hees and Andreas Dengel}, + year={2021}, + eprint={2106.13043}, + archivePrefix={arXiv}, + primaryClass={cs.SD} +} +``` diff --git a/assets/AudioCLIP-Full-Training.pt b/assets/AudioCLIP-Full-Training.pt new file mode 100644 index 0000000..3d85a17 --- /dev/null +++ b/assets/AudioCLIP-Full-Training.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2441d35b353352c8b1bbfb8f7c687f46314c3d2909e940eaf763b8c17f632c44 +size 537302068 diff --git a/assets/AudioCLIP-Partial-Training.pt b/assets/AudioCLIP-Partial-Training.pt new file mode 100644 index 0000000..fc67480 --- /dev/null +++ b/assets/AudioCLIP-Partial-Training.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1921885b6c3b5de0619f2ef4a9700ecbade758a398b34eb064230c30baef0c75 +size 537302068 diff --git a/assets/CLIP.pt b/assets/CLIP.pt new file mode 100644 index 0000000..f5823d4 --- /dev/null +++ b/assets/CLIP.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:124e6d59d54c0837c456c953b3147adf438a21570d5c5b01ad83f6ad3e78d15e +size 408415159 diff --git a/assets/ESRNXFBSP.pt b/assets/ESRNXFBSP.pt new file mode 100644 index 0000000..ec3bb0a --- /dev/null +++ b/assets/ESRNXFBSP.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:acfcc6b0a0f07025cc660ed6dcf9b5a05fbd7c8ebf65edaa0a81d88a92ff6834 +size 124795647 diff --git a/assets/README.md b/assets/README.md new file mode 100644 index 0000000..a84a522 --- /dev/null +++ b/assets/README.md @@ -0,0 +1 @@ +This folder contains snapshots of the pre-trained models. diff --git a/assets/bpe_simple_vocab_16e6.txt.gz b/assets/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000..36a1585 --- /dev/null +++ b/assets/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/ignite_trainer/README.md b/ignite_trainer/README.md new file mode 100644 index 0000000..c58920d --- /dev/null +++ b/ignite_trainer/README.md @@ -0,0 +1,3 @@ +# Training Wrapper + +Utility code to run training and evaluation of the model. diff --git a/ignite_trainer/__init__.py b/ignite_trainer/__init__.py new file mode 100644 index 0000000..f57a10e --- /dev/null +++ b/ignite_trainer/__init__.py @@ -0,0 +1,16 @@ +import os as _os +import sys as _sys + +from ignite_trainer.version import __version__ +from ._trainer import main, run +from ._utils import load_class +from ._interfaces import AbstractNet, AbstractTransform + +__all__ = [ + '__version__', + 'main', 'run', + 'load_class', + 'AbstractNet', 'AbstractTransform' +] + +_sys.path.extend([_os.getcwd()]) diff --git a/ignite_trainer/_interfaces.py b/ignite_trainer/_interfaces.py new file mode 100644 index 0000000..882d425 --- /dev/null +++ b/ignite_trainer/_interfaces.py @@ -0,0 +1,37 @@ +import abc +import torch + +from typing import Tuple +from typing import Union +from typing import Callable +from typing import Optional + + +TensorPair = Tuple[torch.Tensor, torch.Tensor] +TensorOrTwo = Union[torch.Tensor, TensorPair] + + +class AbstractNet(abc.ABC, torch.nn.Module): + + @abc.abstractmethod + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> TensorOrTwo: + pass + + @abc.abstractmethod + def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + pass + + @property + @abc.abstractmethod + def loss_fn_name(self) -> str: + pass + + +class AbstractTransform(abc.ABC, Callable[[torch.Tensor], torch.Tensor]): + + @abc.abstractmethod + def __call__(self, x: torch.Tensor) -> torch.Tensor: + pass + + def __repr__(self): + return self.__class__.__name__ + '()' diff --git a/ignite_trainer/_trainer.py b/ignite_trainer/_trainer.py new file mode 100644 index 0000000..b78cac0 --- /dev/null +++ b/ignite_trainer/_trainer.py @@ -0,0 +1,763 @@ +import io +import os +import glob +import json +import time +import tqdm +import signal +import argparse +import numpy as np + +import torch +import torch.utils.data + +import torchvision as tv + +import ignite.engine as ieng +import ignite.metrics as imet +import ignite.handlers as ihan + +from typing import Any +from typing import Dict +from typing import List +from typing import Type +from typing import Union +from typing import Optional + +from termcolor import colored + +from collections import defaultdict +from collections.abc import Iterable + +from ignite_trainer import _utils +from ignite_trainer import _visdom +from ignite_trainer import _interfaces + +VISDOM_HOST = 'localhost' +VISDOM_PORT = 8097 +VISDOM_ENV_PATH = os.path.join(os.path.expanduser('~'), 'logs') +BATCH_TRAIN = 128 +BATCH_TEST = 1024 +WORKERS_TRAIN = 0 +WORKERS_TEST = 0 +EPOCHS = 100 +LOG_INTERVAL = 50 +SAVED_MODELS_PATH = os.path.join(os.path.expanduser('~'), 'saved_models') + + +def run(experiment_name: str, + visdom_host: str, + visdom_port: int, + visdom_env_path: str, + model_class: str, + model_args: Dict[str, Any], + optimizer_class: str, + optimizer_args: Dict[str, Any], + dataset_class: str, + dataset_args: Dict[str, Any], + batch_train: int, + batch_test: int, + workers_train: int, + workers_test: int, + transforms: List[Dict[str, Union[str, Dict[str, Any]]]], + epochs: int, + log_interval: int, + saved_models_path: str, + performance_metrics: Optional = None, + scheduler_class: Optional[str] = None, + scheduler_args: Optional[Dict[str, Any]] = None, + model_suffix: Optional[str] = None, + setup_suffix: Optional[str] = None, + orig_stdout: Optional[io.TextIOBase] = None, + skip_train_val: bool = False): + + with _utils.tqdm_stdout(orig_stdout) as orig_stdout: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + num_gpus = torch.cuda.device_count() + + if num_gpus > 1: + experiment_name = f'{experiment_name}-x{num_gpus}' + + transforms_train = list() + transforms_test = list() + + for idx, transform in enumerate(transforms): + use_train = transform.get('train', True) + use_test = transform.get('test', True) + + transform = _utils.load_class(transform['class'])(**transform['args']) + + if use_train: + transforms_train.append(transform) + if use_test: + transforms_test.append(transform) + + transforms[idx]['train'] = use_train + transforms[idx]['test'] = use_test + + transforms_train = tv.transforms.Compose(transforms_train) + transforms_test = tv.transforms.Compose(transforms_test) + + Dataset: Type = _utils.load_class(dataset_class) + + train_loader, eval_loader = _utils.get_data_loaders( + Dataset, + dataset_args, + batch_train, + batch_test, + workers_train, + workers_test, + transforms_train, + transforms_test + ) + + Network: Type = _utils.load_class(model_class) + model: _interfaces.AbstractNet = Network(**model_args) + + if hasattr(train_loader.dataset, 'class_weights'): + model.register_buffer('class_weights', train_loader.dataset.class_weights.clone().exp(), persistent=False) + if hasattr(train_loader.dataset, 'label_to_class_idx'): + model.label_to_class_idx = {idx: lb for idx, lb in train_loader.dataset.label_to_class_idx.items()} + + model = torch.nn.DataParallel(model, device_ids=range(num_gpus)) + model = model.to(device) + + # disable all parameters + for p in model.parameters(): + p.requires_grad = False + + # enable only audio-related parameters + for p in model.module.audio.parameters(): + p.requires_grad = True + + # disable fbsp-parameters + for p in model.module.audio.fbsp.parameters(): + p.requires_grad = False + + # disable logit scaling + model.module.logit_scale_ai.requires_grad = False + model.module.logit_scale_at.requires_grad = False + + # add only enabled parameters to optimizer's list + param_groups = [ + {'params': [p for p in model.module.parameters() if p.requires_grad]} + ] + + # enable fbsp-parameters + for p in model.module.audio.fbsp.parameters(): + p.requires_grad = True + + # enable logit scaling + model.module.logit_scale_ai.requires_grad = True + model.module.logit_scale_at.requires_grad = True + + # add fbsp- and logit scaling parameters to a separate group without weight decay + param_groups.append({ + 'params': [ + p for p in model.module.audio.fbsp.parameters() + ] + [ + model.module.logit_scale_ai, + model.module.logit_scale_at + ], + 'weight_decay': 0.0 + }) + + Optimizer: Type = _utils.load_class(optimizer_class) + optimizer: torch.optim.Optimizer = Optimizer( + param_groups, + **{**optimizer_args, **{'lr': optimizer_args['lr'] * num_gpus}} + ) + + if scheduler_class is not None: + Scheduler: Type = _utils.load_class(scheduler_class) + + if scheduler_args is None: + scheduler_args = dict() + + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = Scheduler(optimizer, **scheduler_args) + else: + scheduler = None + + model_short_name = ''.join([c for c in Network.__name__ if c == c.upper()]) + model_name = '{}{}'.format( + model_short_name, + '-{}'.format(model_suffix) if model_suffix is not None else '' + ) + visdom_env_name = '{}_{}_{}{}'.format( + Dataset.__name__, + experiment_name, + model_name, + '-{}'.format(setup_suffix) if setup_suffix is not None else '' + ) + + vis, vis_pid = _visdom.get_visdom_instance(visdom_host, visdom_port, visdom_env_name, visdom_env_path) + + prog_bar_epochs = tqdm.tqdm(total=epochs, desc='Epochs', file=orig_stdout, dynamic_ncols=True, unit='epoch') + prog_bar_iters = tqdm.tqdm(desc='Batches', file=orig_stdout, dynamic_ncols=True) + + num_params_total = sum(p.numel() for p in model.parameters()) + num_params_train = sum(p.numel() for grp in optimizer.param_groups for p in grp['params']) + + params_total_label = '' + params_train_label = '' + if num_params_total > 1e6: + num_params_total /= 1e6 + params_total_label = 'M' + elif num_params_total > 1e3: + num_params_total /= 1e3 + params_total_label = 'k' + + if num_params_train > 1e6: + num_params_train /= 1e6 + params_train_label = 'M' + elif num_params_train > 1e3: + num_params_train /= 1e3 + params_train_label = 'k' + + tqdm.tqdm.write(f'\n{Network.__name__}\n') + tqdm.tqdm.write('Total number of parameters: {:.2f}{}'.format(num_params_total, params_total_label)) + tqdm.tqdm.write('Number of trainable parameters: {:.2f}{}'.format(num_params_train, params_train_label)) + + def training_step(engine: ieng.Engine, batch) -> torch.Tensor: + model.train() + model.epoch = engine.state.epoch + model.batch_idx = (engine.state.iteration - 1) % len(train_loader) + model.num_batches = len(train_loader) + + optimizer.zero_grad() + + audio, image, text = batch + if audio is not None: + audio = audio.to(device) + if image is not None: + image = image.to(device) + + batch_indices = torch.arange(audio.shape[0], dtype=torch.int64, device=device) + _, loss = model(audio, image, text, batch_indices) + + if loss.ndim > 0: + loss = loss.mean() + + loss.backward(retain_graph=False) + optimizer.step(None) + + return loss.item() + + def eval_step(_: ieng.Engine, batch) -> _interfaces.TensorPair: + model.eval() + + with torch.no_grad(): + audio, _, text = batch + + ((audio_features, _, _), _), _ = model( + audio=audio, + batch_indices=torch.arange(audio.shape[0], dtype=torch.int64, device=device) + ) + audio_features = audio_features.unsqueeze(1) + + ((_, _, text_features), _), _ = model( + text=[ + [eval_loader.dataset.class_idx_to_label[class_idx]] + for class_idx in sorted(eval_loader.dataset.class_idx_to_label.keys()) + ], + batch_indices=torch.arange( + len(eval_loader.dataset.class_idx_to_label), dtype=torch.int64, device=device + ) + ) + text_features = text_features.unsqueeze(1).transpose(0, 1) + + logit_scale_at = torch.clamp(model.module.logit_scale_at.exp(), min=1.0, max=100.0) + y_pred = (logit_scale_at * audio_features @ text_features.transpose(-1, -2)).squeeze(1) + + y = torch.zeros( + audio.shape[0], len(eval_loader.dataset.class_idx_to_label), dtype=torch.int8, device=device + ) + for item_idx, labels in enumerate(text): + class_ids = list(sorted([ + eval_loader.dataset.label_to_class_idx[lb] for lb in labels + ])) + y[item_idx][class_ids] = 1 + + if model.module.multilabel: + y_pred = torch.sigmoid(y_pred / logit_scale_at - 0.5) + else: + y_pred = torch.softmax(y_pred, dim=-1) + y = y.argmax(dim=-1) + + return y_pred, y + + trainer = ieng.Engine(training_step) + validator_train = ieng.Engine(eval_step) + validator_eval = ieng.Engine(eval_step) + + # placeholder for summary window + vis.text( + text='', + win=experiment_name, + env=visdom_env_name, + opts={'title': 'Summary', 'width': 940, 'height': 416}, + append=vis.win_exists(experiment_name, visdom_env_name) + ) + + default_metrics = { + "Loss": { + "window_name": None, + "x_label": "#Epochs", + "y_label": model.loss_fn_name if not isinstance(model, torch.nn.DataParallel) else model.module.loss_fn_name, + "width": 940, + "height": 416, + "lines": [ + { + "line_label": "SMA", + "object": imet.RunningAverage(output_transform=lambda x: x), + "test": False, + "update_rate": "iteration" + } + ] + } + } + + performance_metrics = {**default_metrics, **performance_metrics} + checkpoint_metrics = list() + + for scope_name, scope in performance_metrics.items(): + scope['window_name'] = scope.get('window_name', scope_name) or scope_name + + for line in scope['lines']: + if 'object' not in line: + line['object']: imet.Metric = _utils.load_class(line['class'])(**line['args']) + + line['metric_label'] = '{}: {}'.format(scope['window_name'], line['line_label']) + + line['update_rate'] = line.get('update_rate', 'epoch') + line_suffixes = list() + if line['update_rate'] == 'iteration': + line['object'].attach(trainer, line['metric_label']) + line['train'] = False + line['test'] = False + + line_suffixes.append(' Train.') + + if line.get('train', True): + line['object'].attach(validator_train, line['metric_label']) + line_suffixes.append(' Train.') + if line.get('test', True): + line['object'].attach(validator_eval, line['metric_label']) + line_suffixes.append(' Eval.') + + if line.get('is_checkpoint', False): + checkpoint_metrics.append(line['metric_label']) + + for line_suffix in line_suffixes: + _visdom.plot_line( + vis=vis, + window_name=scope['window_name'], + env=visdom_env_name, + line_label=line['line_label'] + line_suffix, + x_label=scope['x_label'], + y_label=scope['y_label'], + width=scope['width'], + height=scope['height'], + draw_marker=(line['update_rate'] == 'epoch') + ) + + if checkpoint_metrics: + score_name = 'performance' + + def get_score(engine: ieng.Engine) -> float: + current_mode = getattr(engine.state.dataloader.iterable.dataset, dataset_args['training']['key']) + val_mode = dataset_args['training']['no'] + + score = 0.0 + if current_mode == val_mode: + for metric_name in checkpoint_metrics: + try: + score += engine.state.metrics[metric_name] + except KeyError: + pass + + return score + + model_saver = ihan.ModelCheckpoint( + os.path.join(saved_models_path, visdom_env_name), + filename_prefix=visdom_env_name, + score_name=score_name, + score_function=get_score, + n_saved=3, + save_as_state_dict=True, + require_empty=False, + create_dir=True + ) + + validator_eval.add_event_handler(ieng.Events.EPOCH_COMPLETED, model_saver, {model_name: model}) + + if not skip_train_val: + @trainer.on(ieng.Events.STARTED) + def engine_started(engine: ieng.Engine): + log_validation(engine, False) + + @trainer.on(ieng.Events.EPOCH_STARTED) + def reset_progress_iterations(engine: ieng.Engine): + prog_bar_iters.clear() + prog_bar_iters.n = 0 + prog_bar_iters.last_print_n = 0 + prog_bar_iters.start_t = time.time() + prog_bar_iters.last_print_t = time.time() + prog_bar_iters.total = len(engine.state.dataloader) + + @trainer.on(ieng.Events.ITERATION_COMPLETED) + def log_training(engine: ieng.Engine): + prog_bar_iters.update(1) + + num_iter = (engine.state.iteration - 1) % len(train_loader) + 1 + + early_stop = np.isnan(engine.state.output) or np.isinf(engine.state.output) + + if num_iter % log_interval == 0 or num_iter == len(train_loader) or early_stop: + tqdm.tqdm.write( + 'Epoch[{}] Iteration[{}/{}] Loss: {:.4f}'.format( + engine.state.epoch, num_iter, len(train_loader), engine.state.output + ) + ) + + x_pos = engine.state.epoch + num_iter / len(train_loader) - 1 + for scope_name, scope in performance_metrics.items(): + for line in scope['lines']: + if line['update_rate'] == 'iteration': + line_label = '{} Train.'.format(line['line_label']) + line_value = engine.state.metrics[line['metric_label']] + + if engine.state.epoch >= 1: + _visdom.plot_line( + vis=vis, + window_name=scope['window_name'], + env=visdom_env_name, + line_label=line_label, + x_label=scope['x_label'], + y_label=scope['y_label'], + x=np.full(1, x_pos), + y=np.full(1, line_value) + ) + + if early_stop: + tqdm.tqdm.write(colored('Early stopping due to invalid loss value.', 'red')) + trainer.terminate() + + def log_validation(engine: ieng.Engine, + train: bool = True): + + if train: + run_type = 'Train.' + data_loader = train_loader + validator = validator_train + else: + run_type = 'Eval.' + data_loader = eval_loader + validator = validator_eval + + prog_bar_validation = tqdm.tqdm( + data_loader, + desc=f'Validation {run_type}', + file=orig_stdout, + dynamic_ncols=True, + leave=False + ) + validator.run(prog_bar_validation) + prog_bar_validation.clear() + prog_bar_validation.close() + + tqdm_info = [ + 'Epoch: {}'.format(engine.state.epoch) + ] + for scope_name, scope in performance_metrics.items(): + for line in scope['lines']: + if line['update_rate'] == 'epoch': + try: + line_label = '{} {}'.format(line['line_label'], run_type) + line_value = validator.state.metrics[line['metric_label']] + + _visdom.plot_line( + vis=vis, + window_name=scope['window_name'], + env=visdom_env_name, + line_label=line_label, + x_label=scope['x_label'], + y_label=scope['y_label'], + x=np.full(1, engine.state.epoch), + y=np.full(1, line_value), + draw_marker=True + ) + + tqdm_info.append('{}: {:.4f}'.format(line_label, line_value)) + except KeyError: + pass + + tqdm.tqdm.write('{} results - {}'.format(run_type, '; '.join(tqdm_info))) + + if not skip_train_val: + @trainer.on(ieng.Events.EPOCH_COMPLETED) + def log_validation_train(engine: ieng.Engine): + log_validation(engine, True) + + @trainer.on(ieng.Events.EPOCH_COMPLETED) + def log_validation_eval(engine: ieng.Engine): + log_validation(engine, False) + + if engine.state.epoch == 1: + summary = _utils.build_summary_str( + experiment_name=experiment_name, + model_short_name=model_name, + model_class=model_class, + model_args=model_args, + optimizer_class=optimizer_class, + optimizer_args=optimizer_args, + dataset_class=dataset_class, + dataset_args=dataset_args, + transforms=transforms, + epochs=epochs, + batch_train=batch_train, + log_interval=log_interval, + saved_models_path=saved_models_path, + scheduler_class=scheduler_class, + scheduler_args=scheduler_args + ) + _visdom.create_summary_window( + vis=vis, + visdom_env_name=visdom_env_name, + experiment_name=experiment_name, + summary=summary + ) + + vis.save([visdom_env_name]) + + prog_bar_epochs.update(1) + + if scheduler is not None: + scheduler.step(engine.state.epoch) + + trainer.run(train_loader, max_epochs=epochs) + + if vis_pid is not None: + tqdm.tqdm.write('Stopping visdom') + os.kill(vis_pid, signal.SIGTERM) + + del vis + del train_loader + del eval_loader + + prog_bar_iters.clear() + prog_bar_iters.close() + + prog_bar_epochs.clear() + prog_bar_epochs.close() + + tqdm.tqdm.write('\n') + + +def main(): + with _utils.tqdm_stdout() as orig_stdout: + parser = argparse.ArgumentParser() + + parser.add_argument('-c', '--config', type=str, required=True) + parser.add_argument('-H', '--visdom-host', type=str, required=False) + parser.add_argument('-P', '--visdom-port', type=int, required=False) + parser.add_argument('-E', '--visdom-env-path', type=str, required=False) + parser.add_argument('-b', '--batch-train', type=int, required=False) + parser.add_argument('-B', '--batch-test', type=int, required=False) + parser.add_argument('-w', '--workers-train', type=int, required=False) + parser.add_argument('-W', '--workers-test', type=int, required=False) + parser.add_argument('-e', '--epochs', type=int, required=False) + parser.add_argument('-L', '--log-interval', type=int, required=False) + parser.add_argument('-M', '--saved-models-path', type=str, required=False) + parser.add_argument('-R', '--random-seed', type=int, required=False) + parser.add_argument('-s', '--suffix', type=str, required=False) + parser.add_argument('-S', '--skip-train-val', action='store_true', default=False) + + args, unknown_args = parser.parse_known_args() + + if args.batch_test is None: + args.batch_test = args.batch_train + + if args.random_seed is not None: + args.suffix = '{}r-{}'.format( + '{}_'.format(args.suffix) if args.suffix is not None else '', + args.random_seed + ) + + np.random.seed(args.random_seed) + torch.random.manual_seed(args.random_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(args.random_seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + configs_found = list(sorted(glob.glob(os.path.expanduser(args.config)))) + prog_bar_exps = tqdm.tqdm( + configs_found, + desc='Experiments', + unit='setup', + file=orig_stdout, + dynamic_ncols=True + ) + + for config_path in prog_bar_exps: + config = json.load(open(config_path)) + + if unknown_args: + tqdm.tqdm.write('\nParsing additional arguments...') + + args_not_found = list() + for arg in unknown_args: + if arg.startswith('--'): + keys = arg.strip('-').split('.') + + section = config + found = True + for key in keys: + if key in section: + section = section[key] + else: + found = False + break + + if found: + override_parser = argparse.ArgumentParser() + + section_nargs = None + section_type = type(section) if section is not None else str + + if section_type is bool: + if section_type is bool: + def infer_bool(x: str) -> bool: + return x.lower() not in ('0', 'false', 'no') + + section_type = infer_bool + + if isinstance(section, Iterable) and section_type is not str: + section_nargs = '+' + section_type = {type(value) for value in section} + + if len(section_type) == 1: + section_type = section_type.pop() + else: + section_type = str + + override_parser.add_argument(arg, nargs=section_nargs, type=section_type) + overridden_args, _ = override_parser.parse_known_args(unknown_args) + overridden_args = vars(overridden_args) + + overridden_key = arg.strip('-') + overriding_value = overridden_args[overridden_key] + + section = config + old_value = None + for i, key in enumerate(keys, 1): + if i == len(keys): + old_value = section[key] + section[key] = overriding_value + else: + section = section[key] + + tqdm.tqdm.write( + colored(f'Overriding "{overridden_key}": {old_value} -> {overriding_value}', 'magenta') + ) + else: + args_not_found.append(arg) + + if args_not_found: + tqdm.tqdm.write( + colored( + '\nThere are unrecognized arguments to override: {}'.format( + ', '.join(args_not_found) + ), + 'red' + ) + ) + + config = defaultdict(None, config) + + experiment_name = config['Setup']['name'] + + visdom_host = _utils.arg_selector( + args.visdom_host, config['Visdom']['host'], VISDOM_HOST + ) + visdom_port = int(_utils.arg_selector( + args.visdom_port, config['Visdom']['port'], VISDOM_PORT + )) + visdom_env_path = _utils.arg_selector( + args.visdom_env_path, config['Visdom']['env_path'], VISDOM_ENV_PATH + ) + batch_train = int(_utils.arg_selector( + args.batch_train, config['Setup']['batch_train'], BATCH_TRAIN + )) + batch_test = int(_utils.arg_selector( + args.batch_test, config['Setup']['batch_test'], BATCH_TEST + )) + workers_train = _utils.arg_selector( + args.workers_train, config['Setup']['workers_train'], WORKERS_TRAIN + ) + workers_test = _utils.arg_selector( + args.workers_test, config['Setup']['workers_test'], WORKERS_TEST + ) + epochs = _utils.arg_selector( + args.epochs, config['Setup']['epochs'], EPOCHS + ) + log_interval = _utils.arg_selector( + args.log_interval, config['Setup']['log_interval'], LOG_INTERVAL + ) + saved_models_path = _utils.arg_selector( + args.saved_models_path, config['Setup']['saved_models_path'], SAVED_MODELS_PATH + ) + + model_class = config['Model']['class'] + model_args = config['Model']['args'] + + optimizer_class = config['Optimizer']['class'] + optimizer_args = config['Optimizer']['args'] + + if 'Scheduler' in config: + scheduler_class = config['Scheduler']['class'] + scheduler_args = config['Scheduler']['args'] + else: + scheduler_class = None + scheduler_args = None + + dataset_class = config['Dataset']['class'] + dataset_args = config['Dataset']['args'] + + transforms = config['Transforms'] + performance_metrics = config['Metrics'] + + tqdm.tqdm.write(f'\nStarting experiment "{experiment_name}"\n') + + run( + experiment_name=experiment_name, + visdom_host=visdom_host, + visdom_port=visdom_port, + visdom_env_path=visdom_env_path, + model_class=model_class, + model_args=model_args, + optimizer_class=optimizer_class, + optimizer_args=optimizer_args, + dataset_class=dataset_class, + dataset_args=dataset_args, + batch_train=batch_train, + batch_test=batch_test, + workers_train=workers_train, + workers_test=workers_test, + transforms=transforms, + epochs=epochs, + log_interval=log_interval, + saved_models_path=saved_models_path, + performance_metrics=performance_metrics, + scheduler_class=scheduler_class, + scheduler_args=scheduler_args, + model_suffix=config['Setup']['suffix'], + setup_suffix=args.suffix, + orig_stdout=orig_stdout, + skip_train_val=args.skip_train_val + ) + + prog_bar_exps.close() + + tqdm.tqdm.write('\n') diff --git a/ignite_trainer/_utils.py b/ignite_trainer/_utils.py new file mode 100644 index 0000000..25f7e11 --- /dev/null +++ b/ignite_trainer/_utils.py @@ -0,0 +1,221 @@ +import io +import sys +import json +import tqdm +import datetime +import importlib +import contextlib + +import numpy as np + +import torch +import torch.utils.data as td + +import torchvision as tv + +from PIL import Image + +from collections import OrderedDict + +from typing import Any +from typing import Dict +from typing import List +from typing import Type +from typing import Tuple +from typing import Union +from typing import Callable +from typing import Optional + + +@contextlib.contextmanager +def tqdm_stdout(orig_stdout: Optional[io.TextIOBase] = None): + + class DummyFile(object): + file = None + + def __init__(self, file): + self.file = file + + def write(self, x): + if len(x.rstrip()) > 0: + tqdm.tqdm.write(x, file=self.file) + + def flush(self): + return getattr(self.file, 'flush', lambda: None)() + + orig_out_err = sys.stdout, sys.stderr + + try: + if orig_stdout is None: + sys.stdout, sys.stderr = map(DummyFile, orig_out_err) + yield orig_out_err[0] + else: + yield orig_stdout + except Exception as exc: + raise exc + finally: + sys.stdout, sys.stderr = orig_out_err + + +def load_class(package_name: str, class_name: Optional[str] = None) -> Type: + if class_name is None: + package_name, class_name = package_name.rsplit('.', 1) + + importlib.invalidate_caches() + + package = importlib.import_module(package_name) + cls = getattr(package, class_name) + + return cls + + +def arg_selector(arg_cmd: Optional[Any], arg_conf: Optional[Any], arg_const: Any) -> Any: + if arg_cmd is not None: + return arg_cmd + else: + if arg_conf is not None: + return arg_conf + else: + return arg_const + + +def collate_fn(batch): + batch_audio, batch_image, batch_text = zip(*batch) + + keep_ids = [idx for idx, (_, _) in enumerate(zip(batch_audio, batch_image))] + + if not all(audio is None for audio in batch_audio): + batch_audio = [batch_audio[idx] for idx in keep_ids] + batch_audio = torch.stack(batch_audio) + else: + batch_audio = None + + if not all(image is None for image in batch_image): + batch_image = [batch_image[idx] for idx in keep_ids] + batch_image = torch.stack(batch_image) + else: + batch_image = None + + if not all(text is None for text in batch_text): + batch_text = [batch_text[idx] for idx in keep_ids] + else: + batch_text = None + + return batch_audio, batch_image, batch_text + + +def get_data_loaders(Dataset: Type, + dataset_args: Dict[str, Any], + batch_train: int = 64, + batch_test: int = 1024, + workers_train: int = 0, + workers_test: int = 0, + transforms_train: Optional[Callable[ + [Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor] + ]] = None, + transforms_test: Optional[Callable[ + [Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor] + ]] = None) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: + + dl_shuffle = dataset_args.pop('dl_shuffle', True) + + dataset_mode_train = {dataset_args['training']['key']: dataset_args['training']['yes']} + dataset_mode_test = {dataset_args['training']['key']: dataset_args['training']['no']} + + dataset_args_train = {**{k: v for k, v in dataset_args.items() if k != 'training'}, **dataset_mode_train} + dataset_args_test = {**{k: v for k, v in dataset_args.items() if k != 'training'}, **dataset_mode_test} + + ds_train = Dataset(**{ + **dataset_args_train, + **{'transform_audio': transforms_train}, + **{'transform_frames': tv.transforms.Compose([ + tv.transforms.ToTensor(), + tv.transforms.Resize(224, interpolation=Image.BICUBIC), + tv.transforms.CenterCrop(224), + tv.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + ])} + }) + train_loader = torch.utils.data.DataLoader( + ds_train, + batch_size=batch_train, + shuffle=dl_shuffle, + num_workers=workers_train, + pin_memory=True, + collate_fn=collate_fn, + drop_last=True + ) + ds_eval = Dataset(**{ + **dataset_args_test, + **{'transform_audio': transforms_test}, + **{'transform_frames': tv.transforms.Compose([ + tv.transforms.ToTensor(), + tv.transforms.Resize(224, interpolation=Image.BICUBIC), + tv.transforms.CenterCrop(224), + tv.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + ])} + }) + eval_loader = torch.utils.data.DataLoader( + ds_eval, + batch_size=batch_test, + num_workers=workers_test, + pin_memory=True, + collate_fn=collate_fn + ) + + return train_loader, eval_loader + + +def build_summary_str(experiment_name: str, + model_short_name: str, + model_class: str, + model_args: Dict[str, Any], + optimizer_class: str, + optimizer_args: Dict[str, Any], + dataset_class: str, + dataset_args: Dict[str, Any], + transforms: List[Dict[str, Union[str, Dict[str, Any]]]], + epochs: int, + batch_train: int, + log_interval: int, + saved_models_path: str, + scheduler_class: Optional[str] = None, + scheduler_args: Optional[Dict[str, Any]] = None) -> str: + + setup_title = '{}-{}'.format(experiment_name, model_short_name) + + summary_window_text = '

' + summary_window_text += ''.format(setup_title) + + summary_window_text += setup_title + + summary_window_text += '' + summary_window_text += '

' + summary_window_text += '
' + summary_window_text += '' + summary_window_text += '
' + + return summary_window_text diff --git a/ignite_trainer/_visdom.py b/ignite_trainer/_visdom.py new file mode 100644 index 0000000..9d2ffa6 --- /dev/null +++ b/ignite_trainer/_visdom.py @@ -0,0 +1,191 @@ +import os +import sys +import json +import time +import tqdm +import socket +import subprocess +import numpy as np + +import visdom + +from typing import Tuple +from typing import Optional + + +def calc_ytick_range(vis: visdom.Visdom, window_name: str, env: Optional[str] = None) -> Tuple[float, float]: + lower_bound, upper_bound = -1.0, 1.0 + + stats = vis.get_window_data(win=window_name, env=env) + + if stats: + stats = json.loads(stats) + + stats = [np.array(item['y']) for item in stats['content']['data']] + stats = [item[item != np.array([None])].astype(np.float16) for item in stats] + + if stats: + q25s = np.array([np.quantile(item, 0.25) for item in stats if len(item) > 0]) + q75s = np.array([np.quantile(item, 0.75) for item in stats if len(item) > 0]) + + if q25s.shape == q75s.shape and len(q25s) > 0: + iqrs = q75s - q25s + + lower_bounds = q25s - 1.5 * iqrs + upper_bounds = q75s + 1.5 * iqrs + + stats_sanitized = list() + idx = 0 + for item in stats: + if len(item) > 0: + item_sanitized = item[(item >= lower_bounds[idx]) & (item <= upper_bounds[idx])] + stats_sanitized.append(item_sanitized) + + idx += 1 + + stats_sanitized = np.array(stats_sanitized) + + q25_sanitized = np.array([np.quantile(item, 0.25) for item in stats_sanitized]) + q75_sanitized = np.array([np.quantile(item, 0.75) for item in stats_sanitized]) + + iqr_sanitized = np.sum(q75_sanitized - q25_sanitized) + lower_bound = np.min(q25_sanitized) - 1.5 * iqr_sanitized + upper_bound = np.max(q75_sanitized) + 1.5 * iqr_sanitized + + return lower_bound, upper_bound + + +def plot_line(vis: visdom.Visdom, + window_name: str, + env: Optional[str] = None, + line_label: Optional[str] = None, + x: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + x_label: Optional[str] = None, + y_label: Optional[str] = None, + width: int = 576, + height: int = 416, + draw_marker: bool = False) -> str: + + empty_call = not vis.win_exists(window_name) + + if empty_call and (x is not None or y is not None): + return window_name + + if x is None: + x = np.ones(1) + empty_call = empty_call & True + + if y is None: + y = np.full(1, np.nan) + empty_call = empty_call & True + + if x.shape != y.shape: + x = np.ones_like(y) + + opts = { + 'showlegend': True, + 'markers': draw_marker, + 'markersize': 5, + } + + if empty_call: + opts['title'] = window_name + opts['width'] = width + opts['height'] = height + + window_name = vis.line( + X=x, + Y=y, + win=window_name, + env=env, + update='append', + name=line_label, + opts=opts + ) + + xtickmin, xtickmax = 0.0, np.max(x) * 1.05 + ytickmin, ytickmax = calc_ytick_range(vis, window_name, env) + + opts = { + 'showlegend': True, + 'xtickmin': xtickmin, + 'xtickmax': xtickmax, + 'ytickmin': ytickmin, + 'ytickmax': ytickmax, + 'xlabel': x_label, + 'ylabel': y_label + } + + window_name = vis.update_window_opts(win=window_name, opts=opts, env=env) + + return window_name + + +def create_summary_window(vis: visdom.Visdom, + visdom_env_name: str, + experiment_name: str, + summary: str) -> str: + + return vis.text( + text=summary, + win=experiment_name, + env=visdom_env_name, + opts={'title': 'Summary', 'width': 576, 'height': 416}, + append=vis.win_exists(experiment_name, visdom_env_name) + ) + + +def connection_is_alive(host: str, port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + try: + sock.connect((host, port)) + sock.shutdown(socket.SHUT_RDWR) + + return True + except socket.error: + return False + + +def get_visdom_instance(host: str = 'localhost', + port: int = 8097, + env_name: str = 'main', + env_path: str = 'visdom_env') -> Tuple[visdom.Visdom, Optional[int]]: + + vis_pid = None + + if not connection_is_alive(host, port): + if any(host.strip('/').endswith(lh) for lh in ['127.0.0.1', 'localhost']): + os.makedirs(env_path, exist_ok=True) + + tqdm.tqdm.write('Starting visdom on port {}'.format(port), end='') + + vis_args = [ + sys.executable, + '-m', 'visdom.server', + '-port', str(port), + '-env_path', os.path.join(os.getcwd(), env_path) + ] + vis_proc = subprocess.Popen(vis_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + time.sleep(2.0) + + vis_pid = vis_proc.pid + tqdm.tqdm.write('PID -> {}'.format(vis_pid)) + + trials_left = 5 + while not connection_is_alive(host, port): + time.sleep(1.0) + + tqdm.tqdm.write('Trying to connect ({} left)...'.format(trials_left)) + + trials_left -= 1 + if trials_left < 1: + raise RuntimeError('Visdom server is not running. Please run "python -m visdom.server".') + + vis = visdom.Visdom( + server='http://{}'.format(host), + port=port, + env=env_name + ) + + return vis, vis_pid diff --git a/ignite_trainer/version.py b/ignite_trainer/version.py new file mode 100644 index 0000000..c73c6e0 --- /dev/null +++ b/ignite_trainer/version.py @@ -0,0 +1 @@ +__version__ = '0.2.5b5' diff --git a/main.py b/main.py new file mode 100644 index 0000000..821b1c2 --- /dev/null +++ b/main.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3.7 + +import ignite_trainer as it + + +def main(): + it.main() + + +if __name__ == '__main__': + main() diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..956e839 --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,3 @@ +from .clip import * +from .esresnet import * +from .audioclip import AudioCLIP diff --git a/model/audioclip.py b/model/audioclip.py new file mode 100644 index 0000000..9dd42d7 --- /dev/null +++ b/model/audioclip.py @@ -0,0 +1,255 @@ +import os + +import torch +import torch.nn.functional as F + +from model.clip import CLIP +from model.clip.clip import tokenize +from model.esresnet import ESResNeXtFBSP + +from typing import List +from typing import Tuple +from typing import Union +from typing import Optional + + +ClipFeatures = Tuple[ + Optional[torch.Tensor], # audio + Optional[torch.Tensor], # image + Optional[torch.Tensor] # audio +] + + +ClipLogits = Tuple[ + Optional[torch.Tensor], # audio x image + Optional[torch.Tensor], # audio x text + Optional[torch.Tensor] # image x text +] + + +ClipOutput = Tuple[ + Tuple[ClipFeatures, ClipLogits], + Optional[torch.Tensor] # loss +] + + +class AudioCLIP(CLIP): + + def __init__(self, + embed_dim: int = 1024, + # vision + image_resolution: int = 224, + vision_layers: Union[Tuple[int, int, int, int], int] = (3, 4, 6, 3), + vision_width: int = 64, + vision_patch_size: Optional[int] = None, + # text + context_length: int = 77, + vocab_size: int = 49408, + transformer_width: int = 512, + transformer_heads: int = 8, + transformer_layers: int = 12, + # audio + n_fft: int = 2048, + hop_length: Optional[int] = 561, + win_length: Optional[int] = 1654, + window: Optional[str] = 'blackmanharris', + normalized: bool = True, + onesided: bool = True, + spec_height: int = -1, + spec_width: int = -1, + apply_attention: bool = True, + multilabel: bool = True, + pretrained: Union[bool, str] = True): + + super(AudioCLIP, self).__init__( + embed_dim=embed_dim, + image_resolution=image_resolution, + vision_layers=vision_layers, + vision_width=vision_width, + vision_patch_size=vision_patch_size, + context_length=context_length, + vocab_size=vocab_size, + transformer_width=transformer_width, + transformer_heads=transformer_heads, + transformer_layers=transformer_layers + ) + + self.audio = ESResNeXtFBSP( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + normalized=normalized, + onesided=onesided, + spec_height=spec_height, + spec_width=spec_width, + num_classes=embed_dim, + apply_attention=apply_attention, + pretrained=False + ) + + self.multilabel = multilabel + self.pretrained = pretrained + + self.logit_scale_ai = torch.nn.Parameter(torch.log(torch.ones([]) * 100)) + self.logit_scale_at = torch.nn.Parameter(torch.log(torch.ones([]) * 100)) + + if isinstance(self.pretrained, str): + self.load_state_dict(torch.load(self.pretrained, map_location='cpu'), strict=False) + elif self.pretrained: + self.load_state_dict(torch.load( + os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'assets', 'CLIP.pt'), + map_location='cpu' + ), strict=False) + print('Image & Text weights loaded') + try: + self.audio.load_state_dict(torch.load( + os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'assets', 'ESRNXFBSP.pt'), + map_location='cpu' + ), strict=False) + except RuntimeError as ex: + print(ex) + print('Audio weights loaded') + + self.embed_dim = embed_dim + + @property + def device(self): + return self.visual.conv1.weight.device + + def encode_audio(self, audio: torch.Tensor) -> torch.Tensor: + return self.audio(audio.to(self.device)) + + def encode_text(self, + text: List[List[str]], + base_str: str = '{}', + batch_indices: Optional[torch.Tensor] = None) -> torch.Tensor: + + if batch_indices is not None: + text = [text[idx] for idx in batch_indices] + + text_joined = [', '.join(entities) for entities in text] + text_tokens = torch.cat([ + tokenize(base_str.format(entities)) for entities in text_joined + ]) + text_tokens = text_tokens.to(self.device) + + return super(AudioCLIP, self).encode_text(text_tokens) + + def forward(self, + audio: Optional[torch.Tensor] = None, + image: Optional[torch.Tensor] = None, + text: Optional[List[List[str]]] = None, + batch_indices: Optional[torch.Tensor] = None) -> ClipOutput: + + audio_features = None + image_features = None + text_features = None + sample_weights = None + + if audio is not None: + audio_features = self.encode_audio(audio) + audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True) + + if image is not None: + image_features = self.encode_image(image) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + + if text is not None: + if batch_indices is None: + batch_indices = torch.arange(len(text), dtype=self.dtype, device=self.device) + + text_features = self.encode_text(text, '{}', batch_indices) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + if hasattr(self, 'class_weights') and hasattr(self, 'label_to_class_idx'): + sample_weights = torch.stack([ + sum(self.class_weights[self.label_to_class_idx[label]] for label in entities) + for idx, entities in enumerate(text) if idx in batch_indices + ]) + + features: ClipFeatures = (audio_features, image_features, text_features) + + logit_scale_ai = torch.clamp(self.logit_scale_ai.exp(), min=1.0, max=100.0) + logit_scale_at = torch.clamp(self.logit_scale_at.exp(), min=1.0, max=100.0) + logit_scale_it = torch.clamp(self.logit_scale.exp(), min=1.0, max=100.0) + + logits_audio_image = None + logits_audio_text = None + logits_image_text = None + + if (audio_features is not None) and (image_features is not None): + logits_audio_image = logit_scale_ai * audio_features @ image_features.T + + if (audio_features is not None) and (text_features is not None): + logits_audio_text = logit_scale_at * audio_features @ text_features.T + + if (image_features is not None) and (text_features is not None): + logits_image_text = logit_scale_it * image_features @ text_features.T + + logits: ClipLogits = (logits_audio_image, logits_audio_text, logits_image_text) + + loss = self.loss_fn(logits, sample_weights) + if audio is not None and loss is not None: + loss = loss + self.audio.loss_ttf(self.device) + + return (features, logits), loss + + def loss_fn(self, logits: ClipLogits, sample_weights: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: + logits_audio_image, logits_audio_text, logits_image_text = logits + + if logits_audio_image is not None: + batch_size = logits_audio_image.shape[0] + elif logits_audio_text is not None: + batch_size = logits_audio_text.shape[0] + elif logits_image_text is not None: + batch_size = logits_image_text.shape[0] + else: + return None + + reference = torch.arange( + batch_size, + dtype=torch.int64, + device=self.device + ) + + loss = torch.tensor(0.0, dtype=self.dtype, device=self.device) + + num_modalities: int = 0 + scale = torch.tensor(1.0, dtype=self.dtype, device=self.device) + + if logits_audio_image is not None: + loss_ai = F.cross_entropy( + logits_audio_image, reference, weight=sample_weights + ) + F.cross_entropy( + logits_audio_image.transpose(-1, -2), reference, weight=sample_weights + ) + loss = loss + loss_ai + num_modalities += 1 + + if logits_audio_text is not None: + loss_at = F.cross_entropy( + logits_audio_text, reference, weight=sample_weights + ) + F.cross_entropy( + logits_audio_text.transpose(-1, -2), reference, weight=sample_weights + ) + loss = loss + loss_at + num_modalities += 1 + + if logits_image_text is not None: + loss_it = F.cross_entropy( + logits_image_text, reference, weight=sample_weights + ) + F.cross_entropy( + logits_image_text.transpose(-1, -2), reference, weight=sample_weights + ) + loss = loss + loss_it + num_modalities += 1 + + for idx in range(num_modalities): + scale = scale * (idx + 1) + + return loss / scale + + @property + def loss_fn_name(self) -> str: + return 'Cross Entropy' diff --git a/model/clip/__init__.py b/model/clip/__init__.py new file mode 100644 index 0000000..b235f83 --- /dev/null +++ b/model/clip/__init__.py @@ -0,0 +1,5 @@ +from .model import CLIP +from .model import convert_weights + + +__all__ = ['CLIP', 'convert_weights'] diff --git a/model/clip/clip.py b/model/clip/clip.py new file mode 100644 index 0000000..c4776ef --- /dev/null +++ b/model/clip/clip.py @@ -0,0 +1,193 @@ +# CREDITS: https://github.com/openai/CLIP + +import hashlib +import os +import urllib +import warnings +from typing import Union, List + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from utils.simple_tokenizer import SimpleTokenizer as _Tokenizer + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", +} + + +def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=Image.BICUBIC), + CenterCrop(n_px), + lambda image: image.convert("RGB"), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name]) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + graphs = [module.graph] if hasattr(module, "graph") else [] + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + graphs = [module.graph] if hasattr(module, "graph") else [] + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/model/clip/model.py b/model/clip/model.py new file mode 100644 index 0000000..356482d --- /dev/null +++ b/model/clip/model.py @@ -0,0 +1,433 @@ +# CREDITS: https://github.com/openai/CLIP + +from collections import OrderedDict +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisualTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([])) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/model/esresnet/__init__.py b/model/esresnet/__init__.py new file mode 100644 index 0000000..684cdc6 --- /dev/null +++ b/model/esresnet/__init__.py @@ -0,0 +1,8 @@ +from .base import ESResNet +from .base import ESResNeXt +from .fbsp import ESResNetFBSP +from .fbsp import ESResNeXtFBSP +from .attention import Attention2d + + +__all__ = ['ESResNet', 'ESResNeXt', 'ESResNetFBSP', 'ESResNeXtFBSP', 'Attention2d'] diff --git a/model/esresnet/attention.py b/model/esresnet/attention.py new file mode 100644 index 0000000..fb3c0c3 --- /dev/null +++ b/model/esresnet/attention.py @@ -0,0 +1,40 @@ +import torch +import torch.nn.functional as F + +from typing import Tuple + + +class Attention2d(torch.nn.Module): + + def __init__(self, + in_channels: int, + out_channels: int, + num_kernels: int, + kernel_size: Tuple[int, int], + padding_size: Tuple[int, int]): + + super(Attention2d, self).__init__() + + self.conv_depth = torch.nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels * num_kernels, + kernel_size=kernel_size, + padding=padding_size, + groups=in_channels + ) + self.conv_point = torch.nn.Conv2d( + in_channels=in_channels * num_kernels, + out_channels=out_channels, + kernel_size=(1, 1) + ) + self.bn = torch.nn.BatchNorm2d(num_features=out_channels) + self.activation = torch.nn.Sigmoid() + + def forward(self, x: torch.Tensor, size: torch.Size) -> torch.Tensor: + x = F.adaptive_max_pool2d(x, size) + x = self.conv_depth(x) + x = self.conv_point(x) + x = self.bn(x) + x = self.activation(x) + + return x diff --git a/model/esresnet/base.py b/model/esresnet/base.py new file mode 100644 index 0000000..a2a0acc --- /dev/null +++ b/model/esresnet/base.py @@ -0,0 +1,708 @@ +import termcolor + +import numpy as np +import scipy.signal as sps + +import torch +import torch.nn.functional as F + +import torchvision as tv + +import ignite_trainer as it + +from model.esresnet import attention +from utils.transforms import scale + +from typing import cast +from typing import List +from typing import Type +from typing import Tuple +from typing import Union +from typing import Optional + + +def conv3x3(in_planes: int, out_planes: int, stride=1, groups: int = 1, dilation: Union[int, Tuple[int, int]] = 1): + """ + CREDITS: https://github.com/pytorch/vision + 3x3 convolution with padding + """ + return torch.nn.Conv2d( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: Union[int, Tuple[int, int]] = 1): + """ + CREDITS: https://github.com/pytorch/vision + 1x1 convolution + """ + return torch.nn.Conv2d( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=1, + stride=stride, + bias=False + ) + + +class BasicBlock(torch.nn.Module): + + """ + CREDITS: https://github.com/pytorch/vision + """ + + expansion: int = 1 + + def __init__(self, + inplanes: int, + planes: int, + stride: Union[int, Tuple[int, int]] = 1, + downsample: Optional[torch.nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: Union[int, Tuple[int, int]] = 1, + norm_layer: Optional[Type[torch.nn.Module]] = None): + + super(BasicBlock, self).__init__() + + if norm_layer is None: + norm_layer = torch.nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = torch.nn.ReLU() + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(torch.nn.Module): + + """ + CREDITS: https://github.com/pytorch/vision + """ + + expansion: int = 4 + + def __init__(self, + inplanes: int, + planes: int, + stride: Union[int, Tuple[int, int]] = 1, + downsample: Optional[torch.nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: Union[int, Tuple[int, int]] = 1, + norm_layer: Optional[Type[torch.nn.Module]] = None): + + super(Bottleneck, self).__init__() + + if norm_layer is None: + norm_layer = torch.nn.BatchNorm2d + + width = int(planes * (base_width / 64.0)) * groups + + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = torch.nn.ReLU() + self.downsample = downsample + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNetWithAttention(it.AbstractNet): + + """ + CREDITS: https://github.com/pytorch/vision + """ + + def __init__(self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + apply_attention: bool = False, + num_channels: int = 3, + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: bool = None, + norm_layer: Optional[Type[torch.nn.Module]] = None): + + super(ResNetWithAttention, self).__init__() + + self.apply_attention = apply_attention + + if norm_layer is None: + norm_layer = torch.nn.BatchNorm2d + + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + + if len(replace_stride_with_dilation) != 3: + raise ValueError( + f'replace_stride_with_dilation should be None or a 3-element tuple, got {replace_stride_with_dilation}' + ) + + self.groups = groups + self.base_width = width_per_group + + self.conv1 = torch.nn.Conv2d(num_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = torch.nn.ReLU() + self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 64, layers[0]) + if self.apply_attention: + self.att1 = attention.Attention2d( + in_channels=64, + out_channels=64 * block.expansion, + num_kernels=1, + kernel_size=(3, 1), + padding_size=(1, 0) + ) + + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) + if self.apply_attention: + self.att2 = attention.Attention2d( + in_channels=64 * block.expansion, + out_channels=128 * block.expansion, + num_kernels=1, + kernel_size=(1, 5), + padding_size=(0, 2) + ) + + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) + if self.apply_attention: + self.att3 = attention.Attention2d( + in_channels=128 * block.expansion, + out_channels=256 * block.expansion, + num_kernels=1, + kernel_size=(3, 1), + padding_size=(1, 0) + ) + + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) + if self.apply_attention: + self.att4 = attention.Attention2d( + in_channels=256 * block.expansion, + out_channels=512 * block.expansion, + num_kernels=1, + kernel_size=(1, 5), + padding_size=(0, 2) + ) + + self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) + if self.apply_attention: + self.att5 = attention.Attention2d( + in_channels=512 * block.expansion, + out_channels=512 * block.expansion, + num_kernels=1, + kernel_size=(3, 5), + padding_size=(1, 2) + ) + + self.fc = torch.nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, torch.nn.Conv2d): + torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.GroupNorm)): + torch.nn.init.constant_(m.weight, 1) + torch.nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + torch.nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + torch.nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: Union[int, Tuple[int, int]] = 1, + dilate: bool = False) -> torch.nn.Module: + + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + + if dilate: + self.dilation *= stride + stride = 1 + + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = torch.nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion) + ) + + layers = list() + layers.append(block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer + )) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer + )) + + return torch.nn.Sequential(*layers) + + def _forward_pre_processing(self, x: torch.Tensor) -> torch.Tensor: + x = x.to(torch.get_default_dtype()) + + return x + + def _forward_pre_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + return x + + def _forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self._forward_pre_features(x) + + if self.apply_attention: + x_att = x.clone() + x = self.layer1(x) + x_att = self.att1(x_att, x.shape[-2:]) + x = x * x_att + + x_att = x.clone() + x = self.layer2(x) + x_att = self.att2(x_att, x.shape[-2:]) + x = x * x_att + + x_att = x.clone() + x = self.layer3(x) + x_att = self.att3(x_att, x.shape[-2:]) + x = x * x_att + + x_att = x.clone() + x = self.layer4(x) + x_att = self.att4(x_att, x.shape[-2:]) + x = x * x_att + else: + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + return x + + def _forward_reduction(self, x: torch.Tensor) -> torch.Tensor: + if self.apply_attention: + x_att = x.clone() + x = self.avgpool(x) + x_att = self.att5(x_att, x.shape[-2:]) + x = x * x_att + else: + x = self.avgpool(x) + + x = torch.flatten(x, 1) + + return x + + def _forward_classifier(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + + return x + + def forward(self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + x = self._forward_pre_processing(x) + x = self._forward_features(x) + x = self._forward_reduction(x) + y_pred = self._forward_classifier(x) + + loss = None + if y is not None: + loss = self.loss_fn(y_pred, y).mean() + + return y_pred if loss is None else (y_pred, loss) + + def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + if isinstance(y_pred, tuple): + y_pred, *_ = y_pred + + if y_pred.shape == y.shape: + loss_pred = F.binary_cross_entropy_with_logits( + y_pred, + y.to(dtype=y_pred.dtype, device=y_pred.device), + reduction='sum' + ) / y_pred.shape[0] + else: + loss_pred = F.cross_entropy(y_pred, y.to(y_pred.device)) + + return loss_pred + + @property + def loss_fn_name(self) -> str: + return 'Cross Entropy' + + +class _ESResNet(ResNetWithAttention): + + @staticmethod + def loading_function(*args, **kwargs) -> torch.nn.Module: + raise NotImplementedError + + def __init__(self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + apply_attention: bool = False, + n_fft: int = 256, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[str] = None, + normalized: bool = False, + onesided: bool = True, + spec_height: int = 224, + spec_width: int = 224, + num_classes: int = 1000, + pretrained: Union[bool, str] = False, + lock_pretrained: Optional[Union[bool, List[str]]] = None, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: bool = None, + norm_layer: Optional[Type[torch.nn.Module]] = None): + + super(_ESResNet, self).__init__( + block=block, + layers=layers, + apply_attention=apply_attention, + num_channels=3, + num_classes=num_classes, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + replace_stride_with_dilation=replace_stride_with_dilation, + norm_layer=norm_layer + ) + + self.num_classes = num_classes + + self.fc = torch.nn.Linear( + in_features=self.fc.in_features, + out_features=self.num_classes, + bias=self.fc.bias is not None + ) + + if hop_length is None: + hop_length = int(np.floor(n_fft / 4)) + + if win_length is None: + win_length = n_fft + + if window is None: + window = 'boxcar' + + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + + self.normalized = normalized + self.onesided = onesided + + self.spec_height = spec_height + self.spec_width = spec_width + + self.pretrained = pretrained + self._inject_members() + if pretrained: + err_msg = self.load_pretrained() + + unlocked_weights = list() + + for name, p in self.named_parameters(): + unlock = True + if isinstance(lock_pretrained, bool): + if lock_pretrained and name not in err_msg: + unlock = False + elif isinstance(lock_pretrained, list): + if name in lock_pretrained: + unlock = False + + p.requires_grad_(unlock) + if unlock: + unlocked_weights.append(name) + + print(f'Following weights are unlocked: {unlocked_weights}') + + window_buffer: torch.Tensor = torch.from_numpy( + sps.get_window(window=window, Nx=win_length, fftbins=True) + ).to(torch.get_default_dtype()) + self.register_buffer('window', window_buffer) + + self.log10_eps = 1e-18 + + if self.apply_attention and pretrained and not isinstance(pretrained, str): + self._reset_attention() + + def _inject_members(self): + pass + + def _reset_attention(self): + print(termcolor.colored('Resetting attention blocks', 'green')) + + self.att1.bn.weight.data.fill_(1.0) + self.att1.bn.bias.data.fill_(1.0) + + self.att2.bn.weight.data.fill_(1.0) + self.att2.bn.bias.data.fill_(1.0) + + self.att3.bn.weight.data.fill_(1.0) + self.att3.bn.bias.data.fill_(1.0) + + self.att4.bn.weight.data.fill_(1.0) + self.att4.bn.bias.data.fill_(1.0) + + self.att5.bn.weight.data.fill_(1.0) + self.att5.bn.bias.data.fill_(1.0) + + def load_pretrained(self) -> str: + if isinstance(self.pretrained, bool): + state_dict = self.loading_func(pretrained=True).state_dict() + else: + state_dict = torch.load(self.pretrained, map_location='cpu') + + err_msg = '' + try: + self.load_state_dict(state_dict=state_dict, strict=True) + except RuntimeError as ex: + err_msg += f'While loading some errors occurred.\n{ex}' + print(termcolor.colored(err_msg, 'red')) + + return err_msg + + def spectrogram(self, x: torch.Tensor) -> torch.Tensor: + spec = torch.stft( + x.view(-1, x.shape[-1]), + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + pad_mode='reflect', + normalized=self.normalized, + onesided=True + ) + + if not self.onesided: + spec = torch.cat((torch.flip(spec, dims=(-3,)), spec), dim=-3) + + return spec + + def split_spectrogram(self, spec: torch.Tensor, batch_size: int) -> torch.Tensor: + spec_height_per_band = spec.shape[-3] // self.conv1.in_channels + spec_height_single_band = self.conv1.in_channels * spec_height_per_band + spec = spec[:, :spec_height_single_band] + + spec = spec.reshape(batch_size, -1, spec.shape[-3] // self.conv1.in_channels, *spec.shape[-2:]) + + return spec + + def spectrogram_to_power(self, spec: torch.Tensor) -> torch.Tensor: + spec_height = spec.shape[-3] if self.spec_height < 1 else self.spec_height + spec_width = spec.shape[-2] if self.spec_width < 1 else self.spec_width + + pow_spec = spec[..., 0] ** 2 + spec[..., 1] ** 2 + + if spec_height != pow_spec.shape[-2] or spec_width != pow_spec.shape[-1]: + pow_spec = F.interpolate( + pow_spec, + size=(spec_height, spec_width), + mode='bilinear', + align_corners=True + ) + + return pow_spec + + def _forward_pre_processing(self, x: torch.Tensor) -> torch.Tensor: + x = super(_ESResNet, self)._forward_pre_processing(x) + x = scale(x, -32768.0, 32767, -1.0, 1.0) + + spec = self.spectrogram(x) + spec_split_ch = self.split_spectrogram(spec, x.shape[0]) + pow_spec_split_ch = self.spectrogram_to_power(spec_split_ch) + pow_spec_split_ch = torch.where( + cast(torch.Tensor, pow_spec_split_ch > 0.0), + pow_spec_split_ch, + torch.full_like(pow_spec_split_ch, self.log10_eps) + ) + pow_spec_split_ch = pow_spec_split_ch.reshape( + x.shape[0], -1, self.conv1.in_channels, *pow_spec_split_ch.shape[-2:] + ) + x_db = torch.log10(pow_spec_split_ch).mul(10.0) + + return x_db + + def _forward_features(self, x_db: torch.Tensor) -> List[torch.Tensor]: + outputs = list() + for ch_idx in range(x_db.shape[1]): + ch = x_db[:, ch_idx] + out = super(_ESResNet, self)._forward_features(ch) + outputs.append(out) + + return outputs + + def _forward_reduction(self, x: List[torch.Tensor]) -> torch.Tensor: + outputs = list() + for ch in x: + out = super(_ESResNet, self)._forward_reduction(ch) + outputs.append(out) + outputs = torch.stack(outputs, dim=-1).sum(dim=-1) + + return outputs + + +class ESResNet(_ESResNet): + + loading_func = staticmethod(tv.models.resnet50) + + def __init__(self, + n_fft: int = 256, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[str] = None, + normalized: bool = False, + onesided: bool = True, + spec_height: int = 224, + spec_width: int = 224, + num_classes: int = 1000, + apply_attention: bool = False, + pretrained: bool = False, + lock_pretrained: Optional[Union[bool, List[str]]] = None): + + super(ESResNet, self).__init__( + block=Bottleneck, + layers=[3, 4, 6, 3], + apply_attention=apply_attention, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + normalized=normalized, + onesided=onesided, + spec_height=spec_height, + spec_width=spec_width, + num_classes=num_classes, + pretrained=pretrained, + lock_pretrained=lock_pretrained + ) + + +class ESResNeXt(_ESResNet): + + loading_func = staticmethod(tv.models.resnext50_32x4d) + + def __init__(self, + n_fft: int = 256, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[str] = None, + normalized: bool = False, + onesided: bool = True, + spec_height: int = 224, + spec_width: int = 224, + num_classes: int = 1000, + apply_attention: bool = False, + pretrained: Union[bool, str] = False, + lock_pretrained: Optional[Union[bool, List[str]]] = None): + + super(ESResNeXt, self).__init__( + block=Bottleneck, + layers=[3, 4, 6, 3], + apply_attention=apply_attention, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + normalized=normalized, + onesided=onesided, + spec_height=spec_height, + spec_width=spec_width, + num_classes=num_classes, + pretrained=pretrained, + lock_pretrained=lock_pretrained, + groups=32, + width_per_group=4 + ) diff --git a/model/esresnet/fbsp.py b/model/esresnet/fbsp.py new file mode 100644 index 0000000..15d329d --- /dev/null +++ b/model/esresnet/fbsp.py @@ -0,0 +1,247 @@ +import numpy as np + +import torch +import torch.nn.functional as F + +import torchvision as tv + +from utils import transforms +from model.esresnet.base import _ESResNet +from model.esresnet.base import Bottleneck + +from typing import cast +from typing import List +from typing import Tuple +from typing import Union +from typing import Optional + + +class LinearFBSP(torch.nn.Module): + + def __init__(self, out_features: int, bias: bool = True, normalized: bool = False): + super(LinearFBSP, self).__init__() + + self.out_features = out_features + self.normalized = normalized + self.eps = 1e-8 + + default_dtype = torch.get_default_dtype() + + self.register_parameter('m', torch.nn.Parameter(torch.zeros(self.out_features, dtype=default_dtype))) + self.register_parameter('fb', torch.nn.Parameter(torch.ones(self.out_features, dtype=default_dtype))) + self.register_parameter('fc', torch.nn.Parameter(torch.arange(self.out_features, dtype=default_dtype))) + self.register_parameter( + 'bias', + torch.nn.Parameter( + torch.normal( + 0.0, 0.5, (self.out_features, 2), dtype=default_dtype + ) if bias else cast( + torch.nn.Parameter, None + ) + ) + ) + + self.m.register_hook(lambda grad: grad / (torch.norm(grad, p=float('inf')) + self.eps)) + self.fb.register_hook(lambda grad: grad / (torch.norm(grad, p=float('inf')) + self.eps)) + self.fc.register_hook(lambda grad: grad / (torch.norm(grad, p=float('inf')) + self.eps)) + + @staticmethod + def power(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + magnitudes = (x1[..., 0] ** 2 + x1[..., 1] ** 2) ** 0.5 + phases = x1[..., 1].atan2(x1[..., 0]) + + power_real = x2[..., 0] + power_imag = x2[..., 1] + + mag_out = ((magnitudes ** 2) ** (0.5 * power_real) * torch.exp(-power_imag * phases)) + + return mag_out.unsqueeze(-1) * torch.stack(( + (power_real * phases + 0.5 * power_imag * (magnitudes ** 2).log()).cos(), + (power_real * phases + 0.5 * power_imag * (magnitudes ** 2).log()).sin() + ), dim=-1) + + @staticmethod + def sinc(x: torch.Tensor) -> torch.Tensor: + return torch.where(cast(torch.Tensor, x == 0), torch.ones_like(x), torch.sin(x) / x) + + def _materialize_weights(self, x: torch.Tensor) -> Tuple[torch.Tensor, bool]: + x_is_complex = x.shape[-1] == 2 + in_features = x.shape[-1 - int(x_is_complex)] + + t = np.pi * torch.linspace(-1.0, 1.0, in_features, dtype=x.dtype, device=x.device).reshape(1, -1, 1) + self.eps + + m = self.m.reshape(-1, 1, 1) + fb = self.fb.reshape(-1, 1, 1) + fc = self.fc.reshape(-1, 1, 1) + + kernel = torch.cat((torch.cos(fc * t), -torch.sin(fc * t)), dim=-1) # complex + scale = fb.sqrt() # real + win = self.sinc(fb * t / (m + self.eps)) # real + win = self.power( + torch.cat((win, torch.zeros_like(win)), dim=-1), + torch.cat((m, torch.zeros_like(m)), dim=-1) + ) # complex + + weights = scale * torch.cat(( + win[..., :1] * kernel[..., :1] - win[..., 1:] * kernel[..., 1:], + win[..., :1] * kernel[..., 1:] + win[..., 1:] * kernel[..., :1] + ), dim=-1) + + if self.normalized: + weights = weights / (in_features ** 0.5) + + return weights, x_is_complex + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + weights, x_is_complex = self._materialize_weights(x) + + if x_is_complex: + x = torch.stack(( + F.linear(x[..., 0], weights[..., 0]) - F.linear(x[..., 1], weights[..., 1]), + F.linear(x[..., 0], weights[..., 1]) + F.linear(x[..., 1], weights[..., 0]) + ), dim=-1) + else: + x = torch.stack(( + F.linear(x, weights[..., 0]), + F.linear(x, weights[..., 1]) + ), dim=-1) + + if (self.bias is not None) and (self.bias.numel() == (self.out_features * 2)): + x = x + self.bias + + return x, weights + + def extra_repr(self) -> str: + return 'out_features={}, bias={}, normalized={}'.format( + self.out_features, + (self.bias is not None) and (self.bias.numel() == (self.out_features * 2)), + self.normalized + ) + + +ttf_weights = dict() + + +class _ESResNetFBSP(_ESResNet): + + def _inject_members(self): + self.add_module( + 'fbsp', + LinearFBSP( + out_features=int(round(self.n_fft / 2)) + 1 if self.onesided else self.n_fft, + normalized=self.normalized, + bias=False + ) + ) + + def spectrogram(self, x: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + frames = transforms.frame_signal( + signal=x.view(-1, x.shape[-1]), + frame_length=self.win_length, + hop_length=self.hop_length, + window=self.window + ) + + if self.n_fft > self.win_length: + pad_length = self.n_fft - self.win_length + pad_left = pad_length // 2 + pad_right = pad_length - pad_left + frames = F.pad(frames, [pad_left, pad_right]) + + spec, ttf_weights_ = self.fbsp(frames) + + spec = spec.transpose(-2, -3) + ttf_weights[x.device] = ttf_weights_ + + return spec + + def loss_ttf(self, device: torch.device) -> torch.Tensor: + ttf_norm = torch.norm(ttf_weights[device], p=2, dim=[-1, -2]) + loss_ttf_norm = F.mse_loss( + ttf_norm, + torch.full_like(ttf_norm, 1.0 if self.normalized else self.n_fft ** 0.5) + ) + + return loss_ttf_norm + + def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + loss_pred = super(_ESResNetFBSP, self).loss_fn(y_pred, y) + loss_ttf_norm = self.loss_ttf(y_pred.device) + loss = loss_pred + loss_ttf_norm + + return loss + + +class ESResNetFBSP(_ESResNetFBSP): + + loading_func = staticmethod(tv.models.resnet50) + + def __init__(self, + n_fft: int = 256, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[str] = None, + normalized: bool = False, + onesided: bool = True, + spec_height: int = 224, + spec_width: int = 224, + num_classes: int = 1000, + apply_attention: bool = False, + pretrained: bool = False, + lock_pretrained: Optional[Union[bool, List[str]]] = None): + + super(ESResNetFBSP, self).__init__( + block=Bottleneck, + layers=[3, 4, 6, 3], + apply_attention=apply_attention, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + normalized=normalized, + onesided=onesided, + spec_height=spec_height, + spec_width=spec_width, + num_classes=num_classes, + pretrained=pretrained, + lock_pretrained=lock_pretrained + ) + + +class ESResNeXtFBSP(_ESResNetFBSP): + + loading_func = staticmethod(tv.models.resnext50_32x4d) + + def __init__(self, + n_fft: int = 256, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[str] = None, + normalized: bool = False, + onesided: bool = True, + spec_height: int = 224, + spec_width: int = 224, + num_classes: int = 1000, + apply_attention: bool = False, + pretrained: Union[bool, str] = False, + lock_pretrained: Optional[Union[bool, List[str]]] = None): + + super(ESResNeXtFBSP, self).__init__( + block=Bottleneck, + layers=[3, 4, 6, 3], + apply_attention=apply_attention, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + normalized=normalized, + onesided=onesided, + spec_height=spec_height, + spec_width=spec_width, + num_classes=num_classes, + pretrained=pretrained, + lock_pretrained=lock_pretrained, + groups=32, + width_per_group=4 + ) diff --git a/protocols/README.md b/protocols/README.md new file mode 100644 index 0000000..48e72c8 --- /dev/null +++ b/protocols/README.md @@ -0,0 +1,3 @@ +# Protocols + +Here are the JSON-files that describe configurations of experiments. diff --git a/protocols/audioclip-esc50.json b/protocols/audioclip-esc50.json new file mode 100644 index 0000000..32e1ada --- /dev/null +++ b/protocols/audioclip-esc50.json @@ -0,0 +1,108 @@ +{ + "Visdom": { + "host": null, + "port": null, + "env_path": null + }, + "Setup": { + "name": "Multimodal-Audio", + "suffix": "CV1", + "batch_train": 64, + "batch_test": 64, + "workers_train": 4, + "workers_test": 4, + "epochs": 50, + "log_interval": 10, + "saved_models_path": "/path/to/saved/models" + }, + "Model": { + "class": "model.audioclip.AudioCLIP", + "args": { + "multilabel": false, + "pretrained": "/path/to/assets/trained_AudioCLIP.pt" + } + }, + "Optimizer": { + "class": "torch.optim.SGD", + "args": { + "lr": 5e-5, + "momentum": 0.9, + "nesterov": true, + "weight_decay": 5e-4 + } + }, + "Scheduler": { + "class": "torch.optim.lr_scheduler.ExponentialLR", + "args": { + "gamma": 0.96 + } + }, + "Dataset": { + "class": "utils.datasets.ESC50", + "args": { + "dl_shuffle": true, + "root": "/path/to/ESC50", + "sample_rate": 44100, + "fold": 1, + "training": {"key": "train", "yes": true, "no": false} + } + }, + "Transforms": [ + { + "class": "utils.transforms.ToTensor1D", + "args": {} + }, + { + "class": "utils.transforms.RandomFlip", + "args": {"p": 0.5}, + "test": false + }, + { + "class": "utils.transforms.RandomScale", + "args": {"max_scale": 1.50}, + "test": false + }, + { + "class": "utils.transforms.RandomPadding", + "args": {"out_len": 220500}, + "test": false + }, + { + "class": "utils.transforms.RandomCrop", + "args": {"out_len": 220500}, + "test": false + }, + { + "class": "utils.transforms.RandomNoise", + "args": {"snr_min_db": 10.0, "snr_max_db": 120.0, "p": 0.25}, + "test": false + }, + { + "class": "utils.transforms.RandomPadding", + "args": {"out_len": 220500, "train": false}, + "train": false + }, + { + "class": "utils.transforms.RandomCrop", + "args": {"out_len": 220500, "train": false}, + "train": false + } + ], + "Metrics": { + "Performance": { + "window_name": null, + "x_label": "#Epochs", + "y_label": "Accuracy", + "width": 1890, + "height": 416, + "lines": [ + { + "line_label": "Val. Acc.", + "class": "ignite.metrics.Accuracy", + "args": {}, + "is_checkpoint": true + } + ] + } + } +} diff --git a/protocols/audioclip-us8k.json b/protocols/audioclip-us8k.json new file mode 100644 index 0000000..4c1f649 --- /dev/null +++ b/protocols/audioclip-us8k.json @@ -0,0 +1,108 @@ +{ + "Visdom": { + "host": null, + "port": null, + "env_path": null + }, + "Setup": { + "name": "Multimodal-Audio", + "suffix": "CV01", + "batch_train": 64, + "batch_test": 64, + "workers_train": 4, + "workers_test": 4, + "epochs": 50, + "log_interval": 25, + "saved_models_path": "/path/to/saved/models" + }, + "Model": { + "class": "model.audioclip.AudioCLIP", + "args": { + "multilabel": false, + "pretrained": "/path/to/assets/trained_AudioCLIP.pt" + } + }, + "Optimizer": { + "class": "torch.optim.SGD", + "args": { + "lr": 1e-5, + "momentum": 0.9, + "nesterov": true, + "weight_decay": 5e-4 + } + }, + "Scheduler": { + "class": "torch.optim.lr_scheduler.ExponentialLR", + "args": { + "gamma": 0.96 + } + }, + "Dataset": { + "class": "utils.datasets.UrbanSound8K", + "args": { + "root": "/path/to/UrbanSound8K", + "sample_rate": 44100, + "fold": 1, + "mono": false, + "training": {"key": "train", "yes": true, "no": false} + } + }, + "Transforms": [ + { + "class": "utils.transforms.ToTensor1D", + "args": {} + }, + { + "class": "utils.transforms.RandomFlip", + "args": {"p": 0.5}, + "test": false + }, + { + "class": "utils.transforms.RandomScale", + "args": {"max_scale": 1.50}, + "test": false + }, + { + "class": "utils.transforms.RandomPadding", + "args": {"out_len": 176400}, + "test": false + }, + { + "class": "utils.transforms.RandomCrop", + "args": {"out_len": 176400}, + "test": false + }, + { + "class": "utils.transforms.RandomNoise", + "args": {"snr_min_db": 10.0, "snr_max_db": 120.0, "p": 0.25}, + "test": false + }, + { + "class": "utils.transforms.RandomPadding", + "args": {"out_len": 176400, "train": false}, + "train": false + }, + { + "class": "utils.transforms.RandomCrop", + "args": {"out_len": 176400, "train": false}, + "train": false + } + ], + "Metrics": { + "Performance": { + "window_name": null, + "x_label": "#Epochs", + "y_label": "Accuracy", + "width": 1890, + "height": 416, + "lines": [ + { + "line_label": "Val. Acc.", + "class": "ignite.metrics.Accuracy", + "args": {}, + "is_checkpoint": true + } + ] + } + } +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f78635a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +librosa==0.7.2 +numpy==1.18.1 +pandas==1.0.3 +pytorch-ignite==0.3.0 +scikit-learn==0.22.1 +scipy==1.4.1 +termcolor==1.1.0 +torch==1.7.1 +torchvision==0.8.2 +tqdm==4.43.0 +visdom==0.1.8.9 diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..c5be7f9 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,7 @@ +from . import datasets +from . import transforms + +__all__ = [ + 'datasets', + 'transforms' +] diff --git a/utils/datasets/__init__.py b/utils/datasets/__init__.py new file mode 100644 index 0000000..cba11e7 --- /dev/null +++ b/utils/datasets/__init__.py @@ -0,0 +1,4 @@ +from .esc50 import ESC50 +from .us8k import UrbanSound8K + +__all__ = ['ESC50', 'UrbanSound8K'] diff --git a/utils/datasets/esc50.py b/utils/datasets/esc50.py new file mode 100644 index 0000000..17b0111 --- /dev/null +++ b/utils/datasets/esc50.py @@ -0,0 +1,128 @@ +import os +import warnings +import multiprocessing as mp + +import tqdm +import librosa + +import numpy as np +import pandas as pd + +import torch.utils.data as td + +from typing import Any +from typing import Dict +from typing import List +from typing import Tuple +from typing import Union +from typing import Optional + + +class ESC50(td.Dataset): + + def __init__(self, + root: str, + sample_rate: int = 22050, + train: bool = True, + fold: Optional[int] = None, + transform_audio=None, + target_transform=None, + **_): + + super(ESC50, self).__init__() + + self.sample_rate = sample_rate + + meta = self.load_meta(os.path.join(root, 'meta', 'esc50.csv')) + + if fold is None: + fold = 5 + + self.folds_to_load = set(meta['fold']) + + if fold not in self.folds_to_load: + raise ValueError(f'fold {fold} does not exist') + + self.train = train + self.transform = transform_audio + + if self.train: + self.folds_to_load -= {fold} + else: + self.folds_to_load -= self.folds_to_load - {fold} + + self.data: Dict[Union[str, int], Dict[str, Any]] = dict() + self.load_data(meta, os.path.join(root, 'audio')) + self.indices = list(self.data.keys()) + + self.class_idx_to_label = dict() + for row in self.data.values(): + idx = row['target'] + label = row['category'] + self.class_idx_to_label[idx] = label + self.label_to_class_idx = {lb: idx for idx, lb in self.class_idx_to_label.items()} + + self.target_transform = target_transform + + @staticmethod + def load_meta(path_to_csv: str) -> pd.DataFrame: + meta = pd.read_csv(path_to_csv) + + return meta + + @staticmethod + def _load_worker(idx: int, filename: str, sample_rate: Optional[int] = None) -> Tuple[int, int, np.ndarray]: + wav, sample_rate = librosa.load(filename, sr=sample_rate, mono=True) + + if wav.ndim == 1: + wav = wav[:, np.newaxis] + + wav = wav.T * 32768.0 + + return idx, sample_rate, wav.astype(np.float32) + + def load_data(self, meta: pd.DataFrame, base_path: str): + items_to_load = dict() + + for idx, row in meta.iterrows(): + if row['fold'] in self.folds_to_load: + items_to_load[idx] = os.path.join(base_path, row['filename']), self.sample_rate + + items_to_load = [(idx, path, sample_rate) for idx, (path, sample_rate) in items_to_load.items()] + + num_processes = os.cpu_count() + warnings.filterwarnings('ignore') + with mp.Pool(processes=num_processes) as pool: + tqdm.tqdm.write(f'Loading {self.__class__.__name__} (train={self.train})') + for idx, sample_rate, wav in pool.starmap( + func=self._load_worker, + iterable=items_to_load, + chunksize=int(np.ceil(len(items_to_load) / num_processes)) or 1 + ): + row = meta.loc[idx] + + self.data[idx] = { + 'audio': wav, + 'sample_rate': sample_rate, + 'target': row['target'], + 'category': row['category'].replace('_', ' '), + 'fold': row['fold'], + 'esc10': row['esc10'] + } + + def __getitem__(self, index: int) -> Tuple[np.ndarray, Optional[np.ndarray], List[str]]: + if not (0 <= index < len(self)): + raise IndexError + + audio: np.ndarray = self.data[self.indices[index]]['audio'] + target: str = self.data[self.indices[index]]['category'] + + if self.transform is not None: + audio = self.transform(audio) + if self.target_transform is not None: + target = self.target_transform(target) + + return audio, None, [target] + + def __len__(self) -> int: + return len(self.indices) diff --git a/utils/datasets/us8k.py b/utils/datasets/us8k.py new file mode 100644 index 0000000..ff45324 --- /dev/null +++ b/utils/datasets/us8k.py @@ -0,0 +1,167 @@ +import os +import warnings +import multiprocessing as mp + +import tqdm +import librosa +import soundfile as sf + +import numpy as np +import pandas as pd + +import torch.utils.data as td + +import sklearn.model_selection as skms + +import utils.transforms as transforms + +from typing import Any +from typing import Dict +from typing import List +from typing import Tuple +from typing import Optional + + +class UrbanSound8K(td.Dataset): + + def __init__(self, + root: str, + sample_rate: int = 22050, + train: bool = True, + fold: Optional[int] = None, + mono: bool = False, + transform_audio=None, + target_transform=None, + **_): + + super(UrbanSound8K, self).__init__() + + self.root = root + self.sample_rate = sample_rate + self.train = train + self.random_split_seed = None + + if fold is None: + fold = 1 + + if not (1 <= fold <= 10): + raise ValueError(f'Expected fold in range [1, 10], got {fold}') + + self.fold = fold + self.folds_to_load = set(range(1, 11)) + + if self.fold not in self.folds_to_load: + raise ValueError(f'fold {fold} does not exist') + + if self.train: + # if in training mode, keep all but test fold + self.folds_to_load -= {self.fold} + else: + # if in evaluation mode, keep the test samples only + self.folds_to_load -= self.folds_to_load - {self.fold} + + self.mono = mono + + self.transform = transform_audio + self.target_transform = target_transform + + self.data: Dict[str, Dict[str, Any]] = dict() + self.indices = dict() + self.load_data() + + self.class_idx_to_label = dict() + for row in self.data.values(): + idx = row['target'] + label = row['category'] + self.class_idx_to_label[idx] = label + self.label_to_class_idx = {lb: idx for idx, lb in self.class_idx_to_label.items()} + + @staticmethod + def _load_worker(fn: str, path_to_file: str, sample_rate: int, mono: bool = False) -> Tuple[str, int, np.ndarray]: + wav, sample_rate_ = sf.read( + path_to_file, + dtype='float32', + always_2d=True + ) + + wav = librosa.resample(wav.T, sample_rate_, sample_rate) + + if wav.shape[0] == 1 and not mono: + wav = np.concatenate((wav, wav), axis=0) + + wav = wav[:, :sample_rate * 4] + wav = transforms.scale(wav, wav.min(), wav.max(), -32768.0, 32767.0) + + return fn, sample_rate, wav.astype(np.float32) + + def load_data(self): + # read metadata + meta = pd.read_csv( + os.path.join(self.root, 'metadata', 'UrbanSound8K.csv'), + sep=',', + index_col='slice_file_name' + ) + + for row_idx, (fn, row) in enumerate(meta.iterrows()): + path = os.path.join(self.root, 'audio', 'fold{}'.format(row['fold']), fn) + self.data[fn] = path, self.sample_rate, self.mono + + # by default, the official split from the metadata is used + files_to_load = list() + # if the random seed is not None, the random split is used + if self.random_split_seed is not None: + # given an integer random seed + skf = skms.StratifiedKFold(n_splits=10, shuffle=True, random_state=self.random_split_seed) + + # split the US8K samples into 10 folds + for fold_idx, (train_ids, test_ids) in enumerate(skf.split( + np.zeros(len(meta)), meta['classID'].values.astype(int) + ), 1): + # if this is the fold we want to load, add the corresponding files to the list + if fold_idx == self.fold: + ids = train_ids if self.train else test_ids + filenames = meta.iloc[ids].index + files_to_load.extend(filenames) + break + else: + # if the random seed is None, use the official split + for fn, row in meta.iterrows(): + if int(row['fold']) in self.folds_to_load: + files_to_load.append(fn) + + self.data = {fn: vals for fn, vals in self.data.items() if fn in files_to_load} + self.indices = {idx: fn for idx, fn in enumerate(self.data)} + + num_processes = os.cpu_count() + warnings.filterwarnings('ignore') + with mp.Pool(processes=num_processes) as pool: + tqdm.tqdm.write(f'Loading {self.__class__.__name__} (train={self.train})') + for fn, sample_rate, wav in pool.starmap( + func=self._load_worker, + iterable=[(fn, path, sr, mono) for fn, (path, sr, mono) in self.data.items()], + chunksize=int(np.ceil(len(meta) / num_processes)) or 1 + ): + self.data[fn] = { + 'audio': wav, + 'sample_rate': sample_rate, + 'target': meta.loc[fn, 'classID'], + 'category': meta.loc[fn, 'class'].replace('_', ' ').strip(' '), + 'background': bool(meta.loc[fn, 'salience'] - 1) + } + + def __getitem__(self, index: int) -> Tuple[np.ndarray, Optional[np.ndarray], List[str]]: + if not (0 <= index < len(self)): + raise IndexError + + audio: np.ndarray = self.data[self.indices[index]]['audio'] + target: str = self.data[self.indices[index]]['category'] + + if self.transform is not None: + audio = self.transform(audio) + if self.target_transform is not None: + target = self.target_transform(target) + + return audio, None, [target] + + def __len__(self) -> int: + return len(self.data) diff --git a/utils/simple_tokenizer.py b/utils/simple_tokenizer.py new file mode 100644 index 0000000..68946a1 --- /dev/null +++ b/utils/simple_tokenizer.py @@ -0,0 +1,134 @@ +# CREDITS: https://github.com/openai/CLIP + +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'assets', 'bpe_simple_vocab_16e6.txt.gz') + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/utils/transforms.py b/utils/transforms.py new file mode 100644 index 0000000..082175e --- /dev/null +++ b/utils/transforms.py @@ -0,0 +1,199 @@ +import math + +import numpy as np + +import torch +import torchvision as tv + +import ignite_trainer as it + + +def scale(old_value, old_min, old_max, new_min, new_max): + old_range = (old_max - old_min) + new_range = (new_max - new_min) + new_value = (((old_value - old_min) * new_range) / old_range) + new_min + + return new_value + + +def frame_signal(signal: torch.Tensor, + frame_length: int, + hop_length: int, + window: torch.Tensor = None) -> torch.Tensor: + + if window is None: + window = torch.ones(frame_length, dtype=signal.dtype, device=signal.device) + + if window.shape[0] != frame_length: + raise ValueError('Wrong `window` length: expected {}, got {}'.format(window.shape[0], frame_length)) + + signal_length = signal.shape[-1] + + if signal_length <= frame_length: + num_frames = 1 + else: + num_frames = 1 + int(math.ceil((1.0 * signal_length - frame_length) / hop_length)) + + pad_len = int((num_frames - 1) * hop_length + frame_length) + if pad_len > signal_length: + zeros = torch.zeros(pad_len - signal_length, device=signal.device, dtype=signal.dtype) + + while zeros.dim() < signal.dim(): + zeros.unsqueeze_(0) + + pad_signal = torch.cat((zeros.expand(*signal.shape[:-1], -1)[..., :zeros.shape[-1] // 2], signal), dim=-1) + pad_signal = torch.cat((pad_signal, zeros.expand(*signal.shape[:-1], -1)[..., zeros.shape[-1] // 2:]), dim=-1) + else: + pad_signal = signal + + indices = torch.arange(0, frame_length, device=signal.device).repeat(num_frames, 1) + indices += torch.arange( + 0, + num_frames * hop_length, + hop_length, + device=signal.device + ).repeat(frame_length, 1).t_() + indices = indices.long() + + frames = pad_signal[..., indices] + frames = frames * window + + return frames + + +class ToTensor1D(tv.transforms.ToTensor): + + def __call__(self, tensor: np.ndarray): + tensor_2d = super(ToTensor1D, self).__call__(tensor[..., np.newaxis]) + + return tensor_2d.squeeze_(0) + + +class RandomFlip(it.AbstractTransform): + + def __init__(self, p: float = 0.5): + super(RandomFlip, self).__init__() + + self.p = p + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() > 2: + flip_mask = torch.rand(x.shape[0], device=x.device) <= self.p + x[flip_mask] = x[flip_mask].flip(-1) + else: + if torch.rand(1) <= self.p: + x = x.flip(0) + + return x + + +class RandomScale(it.AbstractTransform): + + def __init__(self, max_scale: float = 1.25): + super(RandomScale, self).__init__() + + self.max_scale = max_scale + + @staticmethod + def random_scale(max_scale: float, signal: torch.Tensor) -> torch.Tensor: + scaling = np.power(max_scale, np.random.uniform(-1, 1)) + output_size = int(signal.shape[-1] * scaling) + ref = torch.arange(output_size, device=signal.device, dtype=signal.dtype).div_(scaling) + + ref1 = ref.clone().type(torch.int64) + ref2 = torch.min(ref1 + 1, torch.full_like(ref1, signal.shape[-1] - 1, dtype=torch.int64)) + r = ref - ref1.type(ref.type()) + scaled_signal = signal[..., ref1] * (1 - r) + signal[..., ref2] * r + + return scaled_signal + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self.random_scale(self.max_scale, x) + + +class RandomCrop(it.AbstractTransform): + + def __init__(self, out_len: int = 44100, train: bool = True): + super(RandomCrop, self).__init__() + + self.out_len = out_len + self.train = train + + def random_crop(self, signal: torch.Tensor) -> torch.Tensor: + if self.train: + left = np.random.randint(0, signal.shape[-1] - self.out_len) + else: + left = int(round(0.5 * (signal.shape[-1] - self.out_len))) + + orig_std = signal.float().std() * 0.5 + output = signal[..., left:left + self.out_len] + + out_std = output.float().std() + if out_std < orig_std: + output = signal[..., :self.out_len] + + new_out_std = output.float().std() + if orig_std > new_out_std > out_std: + output = signal[..., -self.out_len:] + + return output + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self.random_crop(x) if x.shape[-1] > self.out_len else x + + +class RandomPadding(it.AbstractTransform): + + def __init__(self, out_len: int = 88200, train: bool = True): + super(RandomPadding, self).__init__() + + self.out_len = out_len + self.train = train + + def random_pad(self, signal: torch.Tensor) -> torch.Tensor: + if self.train: + left = np.random.randint(0, self.out_len - signal.shape[-1]) + else: + left = int(round(0.5 * (self.out_len - signal.shape[-1]))) + + right = self.out_len - (left + signal.shape[-1]) + + pad_value_left = signal[..., 0].float().mean().to(signal.dtype) + pad_value_right = signal[..., -1].float().mean().to(signal.dtype) + output = torch.cat(( + torch.zeros(signal.shape[:-1] + (left,), dtype=signal.dtype, device=signal.device).fill_(pad_value_left), + signal, + torch.zeros(signal.shape[:-1] + (right,), dtype=signal.dtype, device=signal.device).fill_(pad_value_right) + ), dim=-1) + + return output + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self.random_pad(x) if x.shape[-1] < self.out_len else x + + +class RandomNoise(it.AbstractTransform): + + def __init__(self, snr_min_db: float = -10.0, snr_max_db: float = 100.0, p: float = 0.5): + super(RandomNoise, self).__init__() + + self.p = p + self.snr_min_db = snr_min_db + self.snr_max_db = snr_max_db + + def random_noise(self, signal: torch.Tensor) -> torch.Tensor: + target_snr = np.random.rand() * (self.snr_max_db - self.snr_min_db + 1.0) + self.snr_min_db + + signal_watts = torch.mean(signal ** 2, dim=(-1, -2)) + signal_db = 10 * torch.log10(signal_watts) + + noise_db = signal_db - target_snr + noise_watts = 10 ** (noise_db / 10) + noise = torch.normal(0.0, noise_watts.item() ** 0.5, signal.shape) + + output = signal + noise + + return output + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self.random_noise(x) if np.random.rand() <= self.p else x