diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000..bb29d8c
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,2 @@
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..32f7bbd
--- /dev/null
+++ b/README.md
@@ -0,0 +1,47 @@
+# AudioCLIP
+## Extending [CLIP](https://github.com/openai/CLIP) to Image, Text and Audio
+
+This repository contains implementation of the models described in the paper [arXiv:2106.13043](https://arxiv.org/abs/2106.13043).
+This work based on our previous works:
+* [ESResNe(X)t-fbsp: Learning Robust Time-Frequency Transformation of Audio (2021)](https://github.com/AndreyGuzhov/ESResNeXt-fbsp).
+* [ESResNet: Environmental Sound Classification Based on Visual Domain Models (2020)](https://github.com/AndreyGuzhov/ESResNet).
+
+### Abstract
+
+In the past, the rapidly evolving field of sound classification greatly benefited from the application of methods from other domains.
+Today, we observe the trend to fuse domain-specific tasks and approaches together, which provides the community with new outstanding models.
+
+In this work, we present an extension of the CLIP model that handles audio in addition to text and images.
+Our proposed model incorporates the ESResNeXt audio-model into the CLIP framework using the AudioSet dataset.
+Such a combination enables the proposed model to perform bimodal and unimodal classification and querying, while keeping CLIP's ability to generalize to unseen datasets in a zero-shot inference fashion.
+
+AudioCLIP achieves new state-of-the-art results in the Environmental Sound Classification (ESC) task, out-performing other approaches by reaching accuracies of 90.07% on the UrbanSound8K and 97.15% on the ESC-50 datasets.
+Further it sets new baselines in the zero-shot ESC-task on the same datasets (68.78% and 69.40%, respectively).
+
+Finally, we also assess the cross-modal querying performance of the proposed model as well as the influence of full and partial training on the results.
+For the sake of reproducibility, our code is published.
+
+### How to Run the Model
+
+The required Python version is >= 3.7.
+
+#### AudioCLIP
+
+##### On the [ESC-50](https://github.com/karolpiczak/ESC-50) dataset
+ python main.py --config protocols/audioclip-esc50.json --Dataset.args.root /path/to/ESC50
+
+##### On the [UrbanSound8K](https://urbansounddataset.weebly.com/) dataset
+ python main.py --config protocols/audioclip-us8k.json --Dataset.args.root /path/to/UrbanSound8K
+
+### Cite Us
+
+```
+@misc{guzhov2021audioclip,
+ title={AudioCLIP: Extending CLIP to Image, Text and Audio},
+ author={Andrey Guzhov and Federico Raue and Jörn Hees and Andreas Dengel},
+ year={2021},
+ eprint={2106.13043},
+ archivePrefix={arXiv},
+ primaryClass={cs.SD}
+}
+```
diff --git a/assets/AudioCLIP-Full-Training.pt b/assets/AudioCLIP-Full-Training.pt
new file mode 100644
index 0000000..3d85a17
--- /dev/null
+++ b/assets/AudioCLIP-Full-Training.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2441d35b353352c8b1bbfb8f7c687f46314c3d2909e940eaf763b8c17f632c44
+size 537302068
diff --git a/assets/AudioCLIP-Partial-Training.pt b/assets/AudioCLIP-Partial-Training.pt
new file mode 100644
index 0000000..fc67480
--- /dev/null
+++ b/assets/AudioCLIP-Partial-Training.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1921885b6c3b5de0619f2ef4a9700ecbade758a398b34eb064230c30baef0c75
+size 537302068
diff --git a/assets/CLIP.pt b/assets/CLIP.pt
new file mode 100644
index 0000000..f5823d4
--- /dev/null
+++ b/assets/CLIP.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:124e6d59d54c0837c456c953b3147adf438a21570d5c5b01ad83f6ad3e78d15e
+size 408415159
diff --git a/assets/ESRNXFBSP.pt b/assets/ESRNXFBSP.pt
new file mode 100644
index 0000000..ec3bb0a
--- /dev/null
+++ b/assets/ESRNXFBSP.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:acfcc6b0a0f07025cc660ed6dcf9b5a05fbd7c8ebf65edaa0a81d88a92ff6834
+size 124795647
diff --git a/assets/README.md b/assets/README.md
new file mode 100644
index 0000000..a84a522
--- /dev/null
+++ b/assets/README.md
@@ -0,0 +1 @@
+This folder contains snapshots of the pre-trained models.
diff --git a/assets/bpe_simple_vocab_16e6.txt.gz b/assets/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000..36a1585
--- /dev/null
+++ b/assets/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/ignite_trainer/README.md b/ignite_trainer/README.md
new file mode 100644
index 0000000..c58920d
--- /dev/null
+++ b/ignite_trainer/README.md
@@ -0,0 +1,3 @@
+# Training Wrapper
+
+Utility code to run training and evaluation of the model.
diff --git a/ignite_trainer/__init__.py b/ignite_trainer/__init__.py
new file mode 100644
index 0000000..f57a10e
--- /dev/null
+++ b/ignite_trainer/__init__.py
@@ -0,0 +1,16 @@
+import os as _os
+import sys as _sys
+
+from ignite_trainer.version import __version__
+from ._trainer import main, run
+from ._utils import load_class
+from ._interfaces import AbstractNet, AbstractTransform
+
+__all__ = [
+ '__version__',
+ 'main', 'run',
+ 'load_class',
+ 'AbstractNet', 'AbstractTransform'
+]
+
+_sys.path.extend([_os.getcwd()])
diff --git a/ignite_trainer/_interfaces.py b/ignite_trainer/_interfaces.py
new file mode 100644
index 0000000..882d425
--- /dev/null
+++ b/ignite_trainer/_interfaces.py
@@ -0,0 +1,37 @@
+import abc
+import torch
+
+from typing import Tuple
+from typing import Union
+from typing import Callable
+from typing import Optional
+
+
+TensorPair = Tuple[torch.Tensor, torch.Tensor]
+TensorOrTwo = Union[torch.Tensor, TensorPair]
+
+
+class AbstractNet(abc.ABC, torch.nn.Module):
+
+ @abc.abstractmethod
+ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> TensorOrTwo:
+ pass
+
+ @abc.abstractmethod
+ def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def loss_fn_name(self) -> str:
+ pass
+
+
+class AbstractTransform(abc.ABC, Callable[[torch.Tensor], torch.Tensor]):
+
+ @abc.abstractmethod
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ pass
+
+ def __repr__(self):
+ return self.__class__.__name__ + '()'
diff --git a/ignite_trainer/_trainer.py b/ignite_trainer/_trainer.py
new file mode 100644
index 0000000..b78cac0
--- /dev/null
+++ b/ignite_trainer/_trainer.py
@@ -0,0 +1,763 @@
+import io
+import os
+import glob
+import json
+import time
+import tqdm
+import signal
+import argparse
+import numpy as np
+
+import torch
+import torch.utils.data
+
+import torchvision as tv
+
+import ignite.engine as ieng
+import ignite.metrics as imet
+import ignite.handlers as ihan
+
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Type
+from typing import Union
+from typing import Optional
+
+from termcolor import colored
+
+from collections import defaultdict
+from collections.abc import Iterable
+
+from ignite_trainer import _utils
+from ignite_trainer import _visdom
+from ignite_trainer import _interfaces
+
+VISDOM_HOST = 'localhost'
+VISDOM_PORT = 8097
+VISDOM_ENV_PATH = os.path.join(os.path.expanduser('~'), 'logs')
+BATCH_TRAIN = 128
+BATCH_TEST = 1024
+WORKERS_TRAIN = 0
+WORKERS_TEST = 0
+EPOCHS = 100
+LOG_INTERVAL = 50
+SAVED_MODELS_PATH = os.path.join(os.path.expanduser('~'), 'saved_models')
+
+
+def run(experiment_name: str,
+ visdom_host: str,
+ visdom_port: int,
+ visdom_env_path: str,
+ model_class: str,
+ model_args: Dict[str, Any],
+ optimizer_class: str,
+ optimizer_args: Dict[str, Any],
+ dataset_class: str,
+ dataset_args: Dict[str, Any],
+ batch_train: int,
+ batch_test: int,
+ workers_train: int,
+ workers_test: int,
+ transforms: List[Dict[str, Union[str, Dict[str, Any]]]],
+ epochs: int,
+ log_interval: int,
+ saved_models_path: str,
+ performance_metrics: Optional = None,
+ scheduler_class: Optional[str] = None,
+ scheduler_args: Optional[Dict[str, Any]] = None,
+ model_suffix: Optional[str] = None,
+ setup_suffix: Optional[str] = None,
+ orig_stdout: Optional[io.TextIOBase] = None,
+ skip_train_val: bool = False):
+
+ with _utils.tqdm_stdout(orig_stdout) as orig_stdout:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ num_gpus = torch.cuda.device_count()
+
+ if num_gpus > 1:
+ experiment_name = f'{experiment_name}-x{num_gpus}'
+
+ transforms_train = list()
+ transforms_test = list()
+
+ for idx, transform in enumerate(transforms):
+ use_train = transform.get('train', True)
+ use_test = transform.get('test', True)
+
+ transform = _utils.load_class(transform['class'])(**transform['args'])
+
+ if use_train:
+ transforms_train.append(transform)
+ if use_test:
+ transforms_test.append(transform)
+
+ transforms[idx]['train'] = use_train
+ transforms[idx]['test'] = use_test
+
+ transforms_train = tv.transforms.Compose(transforms_train)
+ transforms_test = tv.transforms.Compose(transforms_test)
+
+ Dataset: Type = _utils.load_class(dataset_class)
+
+ train_loader, eval_loader = _utils.get_data_loaders(
+ Dataset,
+ dataset_args,
+ batch_train,
+ batch_test,
+ workers_train,
+ workers_test,
+ transforms_train,
+ transforms_test
+ )
+
+ Network: Type = _utils.load_class(model_class)
+ model: _interfaces.AbstractNet = Network(**model_args)
+
+ if hasattr(train_loader.dataset, 'class_weights'):
+ model.register_buffer('class_weights', train_loader.dataset.class_weights.clone().exp(), persistent=False)
+ if hasattr(train_loader.dataset, 'label_to_class_idx'):
+ model.label_to_class_idx = {idx: lb for idx, lb in train_loader.dataset.label_to_class_idx.items()}
+
+ model = torch.nn.DataParallel(model, device_ids=range(num_gpus))
+ model = model.to(device)
+
+ # disable all parameters
+ for p in model.parameters():
+ p.requires_grad = False
+
+ # enable only audio-related parameters
+ for p in model.module.audio.parameters():
+ p.requires_grad = True
+
+ # disable fbsp-parameters
+ for p in model.module.audio.fbsp.parameters():
+ p.requires_grad = False
+
+ # disable logit scaling
+ model.module.logit_scale_ai.requires_grad = False
+ model.module.logit_scale_at.requires_grad = False
+
+ # add only enabled parameters to optimizer's list
+ param_groups = [
+ {'params': [p for p in model.module.parameters() if p.requires_grad]}
+ ]
+
+ # enable fbsp-parameters
+ for p in model.module.audio.fbsp.parameters():
+ p.requires_grad = True
+
+ # enable logit scaling
+ model.module.logit_scale_ai.requires_grad = True
+ model.module.logit_scale_at.requires_grad = True
+
+ # add fbsp- and logit scaling parameters to a separate group without weight decay
+ param_groups.append({
+ 'params': [
+ p for p in model.module.audio.fbsp.parameters()
+ ] + [
+ model.module.logit_scale_ai,
+ model.module.logit_scale_at
+ ],
+ 'weight_decay': 0.0
+ })
+
+ Optimizer: Type = _utils.load_class(optimizer_class)
+ optimizer: torch.optim.Optimizer = Optimizer(
+ param_groups,
+ **{**optimizer_args, **{'lr': optimizer_args['lr'] * num_gpus}}
+ )
+
+ if scheduler_class is not None:
+ Scheduler: Type = _utils.load_class(scheduler_class)
+
+ if scheduler_args is None:
+ scheduler_args = dict()
+
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = Scheduler(optimizer, **scheduler_args)
+ else:
+ scheduler = None
+
+ model_short_name = ''.join([c for c in Network.__name__ if c == c.upper()])
+ model_name = '{}{}'.format(
+ model_short_name,
+ '-{}'.format(model_suffix) if model_suffix is not None else ''
+ )
+ visdom_env_name = '{}_{}_{}{}'.format(
+ Dataset.__name__,
+ experiment_name,
+ model_name,
+ '-{}'.format(setup_suffix) if setup_suffix is not None else ''
+ )
+
+ vis, vis_pid = _visdom.get_visdom_instance(visdom_host, visdom_port, visdom_env_name, visdom_env_path)
+
+ prog_bar_epochs = tqdm.tqdm(total=epochs, desc='Epochs', file=orig_stdout, dynamic_ncols=True, unit='epoch')
+ prog_bar_iters = tqdm.tqdm(desc='Batches', file=orig_stdout, dynamic_ncols=True)
+
+ num_params_total = sum(p.numel() for p in model.parameters())
+ num_params_train = sum(p.numel() for grp in optimizer.param_groups for p in grp['params'])
+
+ params_total_label = ''
+ params_train_label = ''
+ if num_params_total > 1e6:
+ num_params_total /= 1e6
+ params_total_label = 'M'
+ elif num_params_total > 1e3:
+ num_params_total /= 1e3
+ params_total_label = 'k'
+
+ if num_params_train > 1e6:
+ num_params_train /= 1e6
+ params_train_label = 'M'
+ elif num_params_train > 1e3:
+ num_params_train /= 1e3
+ params_train_label = 'k'
+
+ tqdm.tqdm.write(f'\n{Network.__name__}\n')
+ tqdm.tqdm.write('Total number of parameters: {:.2f}{}'.format(num_params_total, params_total_label))
+ tqdm.tqdm.write('Number of trainable parameters: {:.2f}{}'.format(num_params_train, params_train_label))
+
+ def training_step(engine: ieng.Engine, batch) -> torch.Tensor:
+ model.train()
+ model.epoch = engine.state.epoch
+ model.batch_idx = (engine.state.iteration - 1) % len(train_loader)
+ model.num_batches = len(train_loader)
+
+ optimizer.zero_grad()
+
+ audio, image, text = batch
+ if audio is not None:
+ audio = audio.to(device)
+ if image is not None:
+ image = image.to(device)
+
+ batch_indices = torch.arange(audio.shape[0], dtype=torch.int64, device=device)
+ _, loss = model(audio, image, text, batch_indices)
+
+ if loss.ndim > 0:
+ loss = loss.mean()
+
+ loss.backward(retain_graph=False)
+ optimizer.step(None)
+
+ return loss.item()
+
+ def eval_step(_: ieng.Engine, batch) -> _interfaces.TensorPair:
+ model.eval()
+
+ with torch.no_grad():
+ audio, _, text = batch
+
+ ((audio_features, _, _), _), _ = model(
+ audio=audio,
+ batch_indices=torch.arange(audio.shape[0], dtype=torch.int64, device=device)
+ )
+ audio_features = audio_features.unsqueeze(1)
+
+ ((_, _, text_features), _), _ = model(
+ text=[
+ [eval_loader.dataset.class_idx_to_label[class_idx]]
+ for class_idx in sorted(eval_loader.dataset.class_idx_to_label.keys())
+ ],
+ batch_indices=torch.arange(
+ len(eval_loader.dataset.class_idx_to_label), dtype=torch.int64, device=device
+ )
+ )
+ text_features = text_features.unsqueeze(1).transpose(0, 1)
+
+ logit_scale_at = torch.clamp(model.module.logit_scale_at.exp(), min=1.0, max=100.0)
+ y_pred = (logit_scale_at * audio_features @ text_features.transpose(-1, -2)).squeeze(1)
+
+ y = torch.zeros(
+ audio.shape[0], len(eval_loader.dataset.class_idx_to_label), dtype=torch.int8, device=device
+ )
+ for item_idx, labels in enumerate(text):
+ class_ids = list(sorted([
+ eval_loader.dataset.label_to_class_idx[lb] for lb in labels
+ ]))
+ y[item_idx][class_ids] = 1
+
+ if model.module.multilabel:
+ y_pred = torch.sigmoid(y_pred / logit_scale_at - 0.5)
+ else:
+ y_pred = torch.softmax(y_pred, dim=-1)
+ y = y.argmax(dim=-1)
+
+ return y_pred, y
+
+ trainer = ieng.Engine(training_step)
+ validator_train = ieng.Engine(eval_step)
+ validator_eval = ieng.Engine(eval_step)
+
+ # placeholder for summary window
+ vis.text(
+ text='',
+ win=experiment_name,
+ env=visdom_env_name,
+ opts={'title': 'Summary', 'width': 940, 'height': 416},
+ append=vis.win_exists(experiment_name, visdom_env_name)
+ )
+
+ default_metrics = {
+ "Loss": {
+ "window_name": None,
+ "x_label": "#Epochs",
+ "y_label": model.loss_fn_name if not isinstance(model, torch.nn.DataParallel) else model.module.loss_fn_name,
+ "width": 940,
+ "height": 416,
+ "lines": [
+ {
+ "line_label": "SMA",
+ "object": imet.RunningAverage(output_transform=lambda x: x),
+ "test": False,
+ "update_rate": "iteration"
+ }
+ ]
+ }
+ }
+
+ performance_metrics = {**default_metrics, **performance_metrics}
+ checkpoint_metrics = list()
+
+ for scope_name, scope in performance_metrics.items():
+ scope['window_name'] = scope.get('window_name', scope_name) or scope_name
+
+ for line in scope['lines']:
+ if 'object' not in line:
+ line['object']: imet.Metric = _utils.load_class(line['class'])(**line['args'])
+
+ line['metric_label'] = '{}: {}'.format(scope['window_name'], line['line_label'])
+
+ line['update_rate'] = line.get('update_rate', 'epoch')
+ line_suffixes = list()
+ if line['update_rate'] == 'iteration':
+ line['object'].attach(trainer, line['metric_label'])
+ line['train'] = False
+ line['test'] = False
+
+ line_suffixes.append(' Train.')
+
+ if line.get('train', True):
+ line['object'].attach(validator_train, line['metric_label'])
+ line_suffixes.append(' Train.')
+ if line.get('test', True):
+ line['object'].attach(validator_eval, line['metric_label'])
+ line_suffixes.append(' Eval.')
+
+ if line.get('is_checkpoint', False):
+ checkpoint_metrics.append(line['metric_label'])
+
+ for line_suffix in line_suffixes:
+ _visdom.plot_line(
+ vis=vis,
+ window_name=scope['window_name'],
+ env=visdom_env_name,
+ line_label=line['line_label'] + line_suffix,
+ x_label=scope['x_label'],
+ y_label=scope['y_label'],
+ width=scope['width'],
+ height=scope['height'],
+ draw_marker=(line['update_rate'] == 'epoch')
+ )
+
+ if checkpoint_metrics:
+ score_name = 'performance'
+
+ def get_score(engine: ieng.Engine) -> float:
+ current_mode = getattr(engine.state.dataloader.iterable.dataset, dataset_args['training']['key'])
+ val_mode = dataset_args['training']['no']
+
+ score = 0.0
+ if current_mode == val_mode:
+ for metric_name in checkpoint_metrics:
+ try:
+ score += engine.state.metrics[metric_name]
+ except KeyError:
+ pass
+
+ return score
+
+ model_saver = ihan.ModelCheckpoint(
+ os.path.join(saved_models_path, visdom_env_name),
+ filename_prefix=visdom_env_name,
+ score_name=score_name,
+ score_function=get_score,
+ n_saved=3,
+ save_as_state_dict=True,
+ require_empty=False,
+ create_dir=True
+ )
+
+ validator_eval.add_event_handler(ieng.Events.EPOCH_COMPLETED, model_saver, {model_name: model})
+
+ if not skip_train_val:
+ @trainer.on(ieng.Events.STARTED)
+ def engine_started(engine: ieng.Engine):
+ log_validation(engine, False)
+
+ @trainer.on(ieng.Events.EPOCH_STARTED)
+ def reset_progress_iterations(engine: ieng.Engine):
+ prog_bar_iters.clear()
+ prog_bar_iters.n = 0
+ prog_bar_iters.last_print_n = 0
+ prog_bar_iters.start_t = time.time()
+ prog_bar_iters.last_print_t = time.time()
+ prog_bar_iters.total = len(engine.state.dataloader)
+
+ @trainer.on(ieng.Events.ITERATION_COMPLETED)
+ def log_training(engine: ieng.Engine):
+ prog_bar_iters.update(1)
+
+ num_iter = (engine.state.iteration - 1) % len(train_loader) + 1
+
+ early_stop = np.isnan(engine.state.output) or np.isinf(engine.state.output)
+
+ if num_iter % log_interval == 0 or num_iter == len(train_loader) or early_stop:
+ tqdm.tqdm.write(
+ 'Epoch[{}] Iteration[{}/{}] Loss: {:.4f}'.format(
+ engine.state.epoch, num_iter, len(train_loader), engine.state.output
+ )
+ )
+
+ x_pos = engine.state.epoch + num_iter / len(train_loader) - 1
+ for scope_name, scope in performance_metrics.items():
+ for line in scope['lines']:
+ if line['update_rate'] == 'iteration':
+ line_label = '{} Train.'.format(line['line_label'])
+ line_value = engine.state.metrics[line['metric_label']]
+
+ if engine.state.epoch >= 1:
+ _visdom.plot_line(
+ vis=vis,
+ window_name=scope['window_name'],
+ env=visdom_env_name,
+ line_label=line_label,
+ x_label=scope['x_label'],
+ y_label=scope['y_label'],
+ x=np.full(1, x_pos),
+ y=np.full(1, line_value)
+ )
+
+ if early_stop:
+ tqdm.tqdm.write(colored('Early stopping due to invalid loss value.', 'red'))
+ trainer.terminate()
+
+ def log_validation(engine: ieng.Engine,
+ train: bool = True):
+
+ if train:
+ run_type = 'Train.'
+ data_loader = train_loader
+ validator = validator_train
+ else:
+ run_type = 'Eval.'
+ data_loader = eval_loader
+ validator = validator_eval
+
+ prog_bar_validation = tqdm.tqdm(
+ data_loader,
+ desc=f'Validation {run_type}',
+ file=orig_stdout,
+ dynamic_ncols=True,
+ leave=False
+ )
+ validator.run(prog_bar_validation)
+ prog_bar_validation.clear()
+ prog_bar_validation.close()
+
+ tqdm_info = [
+ 'Epoch: {}'.format(engine.state.epoch)
+ ]
+ for scope_name, scope in performance_metrics.items():
+ for line in scope['lines']:
+ if line['update_rate'] == 'epoch':
+ try:
+ line_label = '{} {}'.format(line['line_label'], run_type)
+ line_value = validator.state.metrics[line['metric_label']]
+
+ _visdom.plot_line(
+ vis=vis,
+ window_name=scope['window_name'],
+ env=visdom_env_name,
+ line_label=line_label,
+ x_label=scope['x_label'],
+ y_label=scope['y_label'],
+ x=np.full(1, engine.state.epoch),
+ y=np.full(1, line_value),
+ draw_marker=True
+ )
+
+ tqdm_info.append('{}: {:.4f}'.format(line_label, line_value))
+ except KeyError:
+ pass
+
+ tqdm.tqdm.write('{} results - {}'.format(run_type, '; '.join(tqdm_info)))
+
+ if not skip_train_val:
+ @trainer.on(ieng.Events.EPOCH_COMPLETED)
+ def log_validation_train(engine: ieng.Engine):
+ log_validation(engine, True)
+
+ @trainer.on(ieng.Events.EPOCH_COMPLETED)
+ def log_validation_eval(engine: ieng.Engine):
+ log_validation(engine, False)
+
+ if engine.state.epoch == 1:
+ summary = _utils.build_summary_str(
+ experiment_name=experiment_name,
+ model_short_name=model_name,
+ model_class=model_class,
+ model_args=model_args,
+ optimizer_class=optimizer_class,
+ optimizer_args=optimizer_args,
+ dataset_class=dataset_class,
+ dataset_args=dataset_args,
+ transforms=transforms,
+ epochs=epochs,
+ batch_train=batch_train,
+ log_interval=log_interval,
+ saved_models_path=saved_models_path,
+ scheduler_class=scheduler_class,
+ scheduler_args=scheduler_args
+ )
+ _visdom.create_summary_window(
+ vis=vis,
+ visdom_env_name=visdom_env_name,
+ experiment_name=experiment_name,
+ summary=summary
+ )
+
+ vis.save([visdom_env_name])
+
+ prog_bar_epochs.update(1)
+
+ if scheduler is not None:
+ scheduler.step(engine.state.epoch)
+
+ trainer.run(train_loader, max_epochs=epochs)
+
+ if vis_pid is not None:
+ tqdm.tqdm.write('Stopping visdom')
+ os.kill(vis_pid, signal.SIGTERM)
+
+ del vis
+ del train_loader
+ del eval_loader
+
+ prog_bar_iters.clear()
+ prog_bar_iters.close()
+
+ prog_bar_epochs.clear()
+ prog_bar_epochs.close()
+
+ tqdm.tqdm.write('\n')
+
+
+def main():
+ with _utils.tqdm_stdout() as orig_stdout:
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('-c', '--config', type=str, required=True)
+ parser.add_argument('-H', '--visdom-host', type=str, required=False)
+ parser.add_argument('-P', '--visdom-port', type=int, required=False)
+ parser.add_argument('-E', '--visdom-env-path', type=str, required=False)
+ parser.add_argument('-b', '--batch-train', type=int, required=False)
+ parser.add_argument('-B', '--batch-test', type=int, required=False)
+ parser.add_argument('-w', '--workers-train', type=int, required=False)
+ parser.add_argument('-W', '--workers-test', type=int, required=False)
+ parser.add_argument('-e', '--epochs', type=int, required=False)
+ parser.add_argument('-L', '--log-interval', type=int, required=False)
+ parser.add_argument('-M', '--saved-models-path', type=str, required=False)
+ parser.add_argument('-R', '--random-seed', type=int, required=False)
+ parser.add_argument('-s', '--suffix', type=str, required=False)
+ parser.add_argument('-S', '--skip-train-val', action='store_true', default=False)
+
+ args, unknown_args = parser.parse_known_args()
+
+ if args.batch_test is None:
+ args.batch_test = args.batch_train
+
+ if args.random_seed is not None:
+ args.suffix = '{}r-{}'.format(
+ '{}_'.format(args.suffix) if args.suffix is not None else '',
+ args.random_seed
+ )
+
+ np.random.seed(args.random_seed)
+ torch.random.manual_seed(args.random_seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(args.random_seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ configs_found = list(sorted(glob.glob(os.path.expanduser(args.config))))
+ prog_bar_exps = tqdm.tqdm(
+ configs_found,
+ desc='Experiments',
+ unit='setup',
+ file=orig_stdout,
+ dynamic_ncols=True
+ )
+
+ for config_path in prog_bar_exps:
+ config = json.load(open(config_path))
+
+ if unknown_args:
+ tqdm.tqdm.write('\nParsing additional arguments...')
+
+ args_not_found = list()
+ for arg in unknown_args:
+ if arg.startswith('--'):
+ keys = arg.strip('-').split('.')
+
+ section = config
+ found = True
+ for key in keys:
+ if key in section:
+ section = section[key]
+ else:
+ found = False
+ break
+
+ if found:
+ override_parser = argparse.ArgumentParser()
+
+ section_nargs = None
+ section_type = type(section) if section is not None else str
+
+ if section_type is bool:
+ if section_type is bool:
+ def infer_bool(x: str) -> bool:
+ return x.lower() not in ('0', 'false', 'no')
+
+ section_type = infer_bool
+
+ if isinstance(section, Iterable) and section_type is not str:
+ section_nargs = '+'
+ section_type = {type(value) for value in section}
+
+ if len(section_type) == 1:
+ section_type = section_type.pop()
+ else:
+ section_type = str
+
+ override_parser.add_argument(arg, nargs=section_nargs, type=section_type)
+ overridden_args, _ = override_parser.parse_known_args(unknown_args)
+ overridden_args = vars(overridden_args)
+
+ overridden_key = arg.strip('-')
+ overriding_value = overridden_args[overridden_key]
+
+ section = config
+ old_value = None
+ for i, key in enumerate(keys, 1):
+ if i == len(keys):
+ old_value = section[key]
+ section[key] = overriding_value
+ else:
+ section = section[key]
+
+ tqdm.tqdm.write(
+ colored(f'Overriding "{overridden_key}": {old_value} -> {overriding_value}', 'magenta')
+ )
+ else:
+ args_not_found.append(arg)
+
+ if args_not_found:
+ tqdm.tqdm.write(
+ colored(
+ '\nThere are unrecognized arguments to override: {}'.format(
+ ', '.join(args_not_found)
+ ),
+ 'red'
+ )
+ )
+
+ config = defaultdict(None, config)
+
+ experiment_name = config['Setup']['name']
+
+ visdom_host = _utils.arg_selector(
+ args.visdom_host, config['Visdom']['host'], VISDOM_HOST
+ )
+ visdom_port = int(_utils.arg_selector(
+ args.visdom_port, config['Visdom']['port'], VISDOM_PORT
+ ))
+ visdom_env_path = _utils.arg_selector(
+ args.visdom_env_path, config['Visdom']['env_path'], VISDOM_ENV_PATH
+ )
+ batch_train = int(_utils.arg_selector(
+ args.batch_train, config['Setup']['batch_train'], BATCH_TRAIN
+ ))
+ batch_test = int(_utils.arg_selector(
+ args.batch_test, config['Setup']['batch_test'], BATCH_TEST
+ ))
+ workers_train = _utils.arg_selector(
+ args.workers_train, config['Setup']['workers_train'], WORKERS_TRAIN
+ )
+ workers_test = _utils.arg_selector(
+ args.workers_test, config['Setup']['workers_test'], WORKERS_TEST
+ )
+ epochs = _utils.arg_selector(
+ args.epochs, config['Setup']['epochs'], EPOCHS
+ )
+ log_interval = _utils.arg_selector(
+ args.log_interval, config['Setup']['log_interval'], LOG_INTERVAL
+ )
+ saved_models_path = _utils.arg_selector(
+ args.saved_models_path, config['Setup']['saved_models_path'], SAVED_MODELS_PATH
+ )
+
+ model_class = config['Model']['class']
+ model_args = config['Model']['args']
+
+ optimizer_class = config['Optimizer']['class']
+ optimizer_args = config['Optimizer']['args']
+
+ if 'Scheduler' in config:
+ scheduler_class = config['Scheduler']['class']
+ scheduler_args = config['Scheduler']['args']
+ else:
+ scheduler_class = None
+ scheduler_args = None
+
+ dataset_class = config['Dataset']['class']
+ dataset_args = config['Dataset']['args']
+
+ transforms = config['Transforms']
+ performance_metrics = config['Metrics']
+
+ tqdm.tqdm.write(f'\nStarting experiment "{experiment_name}"\n')
+
+ run(
+ experiment_name=experiment_name,
+ visdom_host=visdom_host,
+ visdom_port=visdom_port,
+ visdom_env_path=visdom_env_path,
+ model_class=model_class,
+ model_args=model_args,
+ optimizer_class=optimizer_class,
+ optimizer_args=optimizer_args,
+ dataset_class=dataset_class,
+ dataset_args=dataset_args,
+ batch_train=batch_train,
+ batch_test=batch_test,
+ workers_train=workers_train,
+ workers_test=workers_test,
+ transforms=transforms,
+ epochs=epochs,
+ log_interval=log_interval,
+ saved_models_path=saved_models_path,
+ performance_metrics=performance_metrics,
+ scheduler_class=scheduler_class,
+ scheduler_args=scheduler_args,
+ model_suffix=config['Setup']['suffix'],
+ setup_suffix=args.suffix,
+ orig_stdout=orig_stdout,
+ skip_train_val=args.skip_train_val
+ )
+
+ prog_bar_exps.close()
+
+ tqdm.tqdm.write('\n')
diff --git a/ignite_trainer/_utils.py b/ignite_trainer/_utils.py
new file mode 100644
index 0000000..25f7e11
--- /dev/null
+++ b/ignite_trainer/_utils.py
@@ -0,0 +1,221 @@
+import io
+import sys
+import json
+import tqdm
+import datetime
+import importlib
+import contextlib
+
+import numpy as np
+
+import torch
+import torch.utils.data as td
+
+import torchvision as tv
+
+from PIL import Image
+
+from collections import OrderedDict
+
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Type
+from typing import Tuple
+from typing import Union
+from typing import Callable
+from typing import Optional
+
+
+@contextlib.contextmanager
+def tqdm_stdout(orig_stdout: Optional[io.TextIOBase] = None):
+
+ class DummyFile(object):
+ file = None
+
+ def __init__(self, file):
+ self.file = file
+
+ def write(self, x):
+ if len(x.rstrip()) > 0:
+ tqdm.tqdm.write(x, file=self.file)
+
+ def flush(self):
+ return getattr(self.file, 'flush', lambda: None)()
+
+ orig_out_err = sys.stdout, sys.stderr
+
+ try:
+ if orig_stdout is None:
+ sys.stdout, sys.stderr = map(DummyFile, orig_out_err)
+ yield orig_out_err[0]
+ else:
+ yield orig_stdout
+ except Exception as exc:
+ raise exc
+ finally:
+ sys.stdout, sys.stderr = orig_out_err
+
+
+def load_class(package_name: str, class_name: Optional[str] = None) -> Type:
+ if class_name is None:
+ package_name, class_name = package_name.rsplit('.', 1)
+
+ importlib.invalidate_caches()
+
+ package = importlib.import_module(package_name)
+ cls = getattr(package, class_name)
+
+ return cls
+
+
+def arg_selector(arg_cmd: Optional[Any], arg_conf: Optional[Any], arg_const: Any) -> Any:
+ if arg_cmd is not None:
+ return arg_cmd
+ else:
+ if arg_conf is not None:
+ return arg_conf
+ else:
+ return arg_const
+
+
+def collate_fn(batch):
+ batch_audio, batch_image, batch_text = zip(*batch)
+
+ keep_ids = [idx for idx, (_, _) in enumerate(zip(batch_audio, batch_image))]
+
+ if not all(audio is None for audio in batch_audio):
+ batch_audio = [batch_audio[idx] for idx in keep_ids]
+ batch_audio = torch.stack(batch_audio)
+ else:
+ batch_audio = None
+
+ if not all(image is None for image in batch_image):
+ batch_image = [batch_image[idx] for idx in keep_ids]
+ batch_image = torch.stack(batch_image)
+ else:
+ batch_image = None
+
+ if not all(text is None for text in batch_text):
+ batch_text = [batch_text[idx] for idx in keep_ids]
+ else:
+ batch_text = None
+
+ return batch_audio, batch_image, batch_text
+
+
+def get_data_loaders(Dataset: Type,
+ dataset_args: Dict[str, Any],
+ batch_train: int = 64,
+ batch_test: int = 1024,
+ workers_train: int = 0,
+ workers_test: int = 0,
+ transforms_train: Optional[Callable[
+ [Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor]
+ ]] = None,
+ transforms_test: Optional[Callable[
+ [Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor]
+ ]] = None) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
+
+ dl_shuffle = dataset_args.pop('dl_shuffle', True)
+
+ dataset_mode_train = {dataset_args['training']['key']: dataset_args['training']['yes']}
+ dataset_mode_test = {dataset_args['training']['key']: dataset_args['training']['no']}
+
+ dataset_args_train = {**{k: v for k, v in dataset_args.items() if k != 'training'}, **dataset_mode_train}
+ dataset_args_test = {**{k: v for k, v in dataset_args.items() if k != 'training'}, **dataset_mode_test}
+
+ ds_train = Dataset(**{
+ **dataset_args_train,
+ **{'transform_audio': transforms_train},
+ **{'transform_frames': tv.transforms.Compose([
+ tv.transforms.ToTensor(),
+ tv.transforms.Resize(224, interpolation=Image.BICUBIC),
+ tv.transforms.CenterCrop(224),
+ tv.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+ ])}
+ })
+ train_loader = torch.utils.data.DataLoader(
+ ds_train,
+ batch_size=batch_train,
+ shuffle=dl_shuffle,
+ num_workers=workers_train,
+ pin_memory=True,
+ collate_fn=collate_fn,
+ drop_last=True
+ )
+ ds_eval = Dataset(**{
+ **dataset_args_test,
+ **{'transform_audio': transforms_test},
+ **{'transform_frames': tv.transforms.Compose([
+ tv.transforms.ToTensor(),
+ tv.transforms.Resize(224, interpolation=Image.BICUBIC),
+ tv.transforms.CenterCrop(224),
+ tv.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+ ])}
+ })
+ eval_loader = torch.utils.data.DataLoader(
+ ds_eval,
+ batch_size=batch_test,
+ num_workers=workers_test,
+ pin_memory=True,
+ collate_fn=collate_fn
+ )
+
+ return train_loader, eval_loader
+
+
+def build_summary_str(experiment_name: str,
+ model_short_name: str,
+ model_class: str,
+ model_args: Dict[str, Any],
+ optimizer_class: str,
+ optimizer_args: Dict[str, Any],
+ dataset_class: str,
+ dataset_args: Dict[str, Any],
+ transforms: List[Dict[str, Union[str, Dict[str, Any]]]],
+ epochs: int,
+ batch_train: int,
+ log_interval: int,
+ saved_models_path: str,
+ scheduler_class: Optional[str] = None,
+ scheduler_args: Optional[Dict[str, Any]] = None) -> str:
+
+ setup_title = '{}-{}'.format(experiment_name, model_short_name)
+
+ summary_window_text = '
'
+ summary_window_text += ''
+ summary_window_text += '
'.format(setup_title)
+
+ summary = OrderedDict({
+ 'Date started': datetime.datetime.now().strftime('%Y-%m-%d @ %H:%M:%S'),
+ 'Model': OrderedDict({model_class: model_args}),
+ 'Setup': OrderedDict({
+ 'epochs': epochs,
+ 'batch': batch_train,
+ 'log_interval': log_interval,
+ 'saved_models_path': saved_models_path
+ }),
+ 'Optimizer': OrderedDict({optimizer_class: optimizer_args}),
+ 'Dataset': OrderedDict({dataset_class: dataset_args}),
+ 'Transforms': OrderedDict({
+ 'Training': OrderedDict({tr['class']: tr['args'] for tr in transforms if tr['train']}),
+ 'Validation': OrderedDict({tr['class']: tr['args'] for tr in transforms if tr['test']})
+ })
+ })
+ if scheduler_class is not None:
+ summary['Scheduler'] = {scheduler_class: scheduler_args}
+ summary_window_text += '{}'.format(
+ json.dumps(summary, indent=2)
+ )
+
+ summary_window_text += '
'
+ summary_window_text += '
'
+
+ return summary_window_text
diff --git a/ignite_trainer/_visdom.py b/ignite_trainer/_visdom.py
new file mode 100644
index 0000000..9d2ffa6
--- /dev/null
+++ b/ignite_trainer/_visdom.py
@@ -0,0 +1,191 @@
+import os
+import sys
+import json
+import time
+import tqdm
+import socket
+import subprocess
+import numpy as np
+
+import visdom
+
+from typing import Tuple
+from typing import Optional
+
+
+def calc_ytick_range(vis: visdom.Visdom, window_name: str, env: Optional[str] = None) -> Tuple[float, float]:
+ lower_bound, upper_bound = -1.0, 1.0
+
+ stats = vis.get_window_data(win=window_name, env=env)
+
+ if stats:
+ stats = json.loads(stats)
+
+ stats = [np.array(item['y']) for item in stats['content']['data']]
+ stats = [item[item != np.array([None])].astype(np.float16) for item in stats]
+
+ if stats:
+ q25s = np.array([np.quantile(item, 0.25) for item in stats if len(item) > 0])
+ q75s = np.array([np.quantile(item, 0.75) for item in stats if len(item) > 0])
+
+ if q25s.shape == q75s.shape and len(q25s) > 0:
+ iqrs = q75s - q25s
+
+ lower_bounds = q25s - 1.5 * iqrs
+ upper_bounds = q75s + 1.5 * iqrs
+
+ stats_sanitized = list()
+ idx = 0
+ for item in stats:
+ if len(item) > 0:
+ item_sanitized = item[(item >= lower_bounds[idx]) & (item <= upper_bounds[idx])]
+ stats_sanitized.append(item_sanitized)
+
+ idx += 1
+
+ stats_sanitized = np.array(stats_sanitized)
+
+ q25_sanitized = np.array([np.quantile(item, 0.25) for item in stats_sanitized])
+ q75_sanitized = np.array([np.quantile(item, 0.75) for item in stats_sanitized])
+
+ iqr_sanitized = np.sum(q75_sanitized - q25_sanitized)
+ lower_bound = np.min(q25_sanitized) - 1.5 * iqr_sanitized
+ upper_bound = np.max(q75_sanitized) + 1.5 * iqr_sanitized
+
+ return lower_bound, upper_bound
+
+
+def plot_line(vis: visdom.Visdom,
+ window_name: str,
+ env: Optional[str] = None,
+ line_label: Optional[str] = None,
+ x: Optional[np.ndarray] = None,
+ y: Optional[np.ndarray] = None,
+ x_label: Optional[str] = None,
+ y_label: Optional[str] = None,
+ width: int = 576,
+ height: int = 416,
+ draw_marker: bool = False) -> str:
+
+ empty_call = not vis.win_exists(window_name)
+
+ if empty_call and (x is not None or y is not None):
+ return window_name
+
+ if x is None:
+ x = np.ones(1)
+ empty_call = empty_call & True
+
+ if y is None:
+ y = np.full(1, np.nan)
+ empty_call = empty_call & True
+
+ if x.shape != y.shape:
+ x = np.ones_like(y)
+
+ opts = {
+ 'showlegend': True,
+ 'markers': draw_marker,
+ 'markersize': 5,
+ }
+
+ if empty_call:
+ opts['title'] = window_name
+ opts['width'] = width
+ opts['height'] = height
+
+ window_name = vis.line(
+ X=x,
+ Y=y,
+ win=window_name,
+ env=env,
+ update='append',
+ name=line_label,
+ opts=opts
+ )
+
+ xtickmin, xtickmax = 0.0, np.max(x) * 1.05
+ ytickmin, ytickmax = calc_ytick_range(vis, window_name, env)
+
+ opts = {
+ 'showlegend': True,
+ 'xtickmin': xtickmin,
+ 'xtickmax': xtickmax,
+ 'ytickmin': ytickmin,
+ 'ytickmax': ytickmax,
+ 'xlabel': x_label,
+ 'ylabel': y_label
+ }
+
+ window_name = vis.update_window_opts(win=window_name, opts=opts, env=env)
+
+ return window_name
+
+
+def create_summary_window(vis: visdom.Visdom,
+ visdom_env_name: str,
+ experiment_name: str,
+ summary: str) -> str:
+
+ return vis.text(
+ text=summary,
+ win=experiment_name,
+ env=visdom_env_name,
+ opts={'title': 'Summary', 'width': 576, 'height': 416},
+ append=vis.win_exists(experiment_name, visdom_env_name)
+ )
+
+
+def connection_is_alive(host: str, port: int) -> bool:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ try:
+ sock.connect((host, port))
+ sock.shutdown(socket.SHUT_RDWR)
+
+ return True
+ except socket.error:
+ return False
+
+
+def get_visdom_instance(host: str = 'localhost',
+ port: int = 8097,
+ env_name: str = 'main',
+ env_path: str = 'visdom_env') -> Tuple[visdom.Visdom, Optional[int]]:
+
+ vis_pid = None
+
+ if not connection_is_alive(host, port):
+ if any(host.strip('/').endswith(lh) for lh in ['127.0.0.1', 'localhost']):
+ os.makedirs(env_path, exist_ok=True)
+
+ tqdm.tqdm.write('Starting visdom on port {}'.format(port), end='')
+
+ vis_args = [
+ sys.executable,
+ '-m', 'visdom.server',
+ '-port', str(port),
+ '-env_path', os.path.join(os.getcwd(), env_path)
+ ]
+ vis_proc = subprocess.Popen(vis_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ time.sleep(2.0)
+
+ vis_pid = vis_proc.pid
+ tqdm.tqdm.write('PID -> {}'.format(vis_pid))
+
+ trials_left = 5
+ while not connection_is_alive(host, port):
+ time.sleep(1.0)
+
+ tqdm.tqdm.write('Trying to connect ({} left)...'.format(trials_left))
+
+ trials_left -= 1
+ if trials_left < 1:
+ raise RuntimeError('Visdom server is not running. Please run "python -m visdom.server".')
+
+ vis = visdom.Visdom(
+ server='http://{}'.format(host),
+ port=port,
+ env=env_name
+ )
+
+ return vis, vis_pid
diff --git a/ignite_trainer/version.py b/ignite_trainer/version.py
new file mode 100644
index 0000000..c73c6e0
--- /dev/null
+++ b/ignite_trainer/version.py
@@ -0,0 +1 @@
+__version__ = '0.2.5b5'
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..821b1c2
--- /dev/null
+++ b/main.py
@@ -0,0 +1,11 @@
+#!/usr/bin/env python3.7
+
+import ignite_trainer as it
+
+
+def main():
+ it.main()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000..956e839
--- /dev/null
+++ b/model/__init__.py
@@ -0,0 +1,3 @@
+from .clip import *
+from .esresnet import *
+from .audioclip import AudioCLIP
diff --git a/model/audioclip.py b/model/audioclip.py
new file mode 100644
index 0000000..9dd42d7
--- /dev/null
+++ b/model/audioclip.py
@@ -0,0 +1,255 @@
+import os
+
+import torch
+import torch.nn.functional as F
+
+from model.clip import CLIP
+from model.clip.clip import tokenize
+from model.esresnet import ESResNeXtFBSP
+
+from typing import List
+from typing import Tuple
+from typing import Union
+from typing import Optional
+
+
+ClipFeatures = Tuple[
+ Optional[torch.Tensor], # audio
+ Optional[torch.Tensor], # image
+ Optional[torch.Tensor] # audio
+]
+
+
+ClipLogits = Tuple[
+ Optional[torch.Tensor], # audio x image
+ Optional[torch.Tensor], # audio x text
+ Optional[torch.Tensor] # image x text
+]
+
+
+ClipOutput = Tuple[
+ Tuple[ClipFeatures, ClipLogits],
+ Optional[torch.Tensor] # loss
+]
+
+
+class AudioCLIP(CLIP):
+
+ def __init__(self,
+ embed_dim: int = 1024,
+ # vision
+ image_resolution: int = 224,
+ vision_layers: Union[Tuple[int, int, int, int], int] = (3, 4, 6, 3),
+ vision_width: int = 64,
+ vision_patch_size: Optional[int] = None,
+ # text
+ context_length: int = 77,
+ vocab_size: int = 49408,
+ transformer_width: int = 512,
+ transformer_heads: int = 8,
+ transformer_layers: int = 12,
+ # audio
+ n_fft: int = 2048,
+ hop_length: Optional[int] = 561,
+ win_length: Optional[int] = 1654,
+ window: Optional[str] = 'blackmanharris',
+ normalized: bool = True,
+ onesided: bool = True,
+ spec_height: int = -1,
+ spec_width: int = -1,
+ apply_attention: bool = True,
+ multilabel: bool = True,
+ pretrained: Union[bool, str] = True):
+
+ super(AudioCLIP, self).__init__(
+ embed_dim=embed_dim,
+ image_resolution=image_resolution,
+ vision_layers=vision_layers,
+ vision_width=vision_width,
+ vision_patch_size=vision_patch_size,
+ context_length=context_length,
+ vocab_size=vocab_size,
+ transformer_width=transformer_width,
+ transformer_heads=transformer_heads,
+ transformer_layers=transformer_layers
+ )
+
+ self.audio = ESResNeXtFBSP(
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=window,
+ normalized=normalized,
+ onesided=onesided,
+ spec_height=spec_height,
+ spec_width=spec_width,
+ num_classes=embed_dim,
+ apply_attention=apply_attention,
+ pretrained=False
+ )
+
+ self.multilabel = multilabel
+ self.pretrained = pretrained
+
+ self.logit_scale_ai = torch.nn.Parameter(torch.log(torch.ones([]) * 100))
+ self.logit_scale_at = torch.nn.Parameter(torch.log(torch.ones([]) * 100))
+
+ if isinstance(self.pretrained, str):
+ self.load_state_dict(torch.load(self.pretrained, map_location='cpu'), strict=False)
+ elif self.pretrained:
+ self.load_state_dict(torch.load(
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'assets', 'CLIP.pt'),
+ map_location='cpu'
+ ), strict=False)
+ print('Image & Text weights loaded')
+ try:
+ self.audio.load_state_dict(torch.load(
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'assets', 'ESRNXFBSP.pt'),
+ map_location='cpu'
+ ), strict=False)
+ except RuntimeError as ex:
+ print(ex)
+ print('Audio weights loaded')
+
+ self.embed_dim = embed_dim
+
+ @property
+ def device(self):
+ return self.visual.conv1.weight.device
+
+ def encode_audio(self, audio: torch.Tensor) -> torch.Tensor:
+ return self.audio(audio.to(self.device))
+
+ def encode_text(self,
+ text: List[List[str]],
+ base_str: str = '{}',
+ batch_indices: Optional[torch.Tensor] = None) -> torch.Tensor:
+
+ if batch_indices is not None:
+ text = [text[idx] for idx in batch_indices]
+
+ text_joined = [', '.join(entities) for entities in text]
+ text_tokens = torch.cat([
+ tokenize(base_str.format(entities)) for entities in text_joined
+ ])
+ text_tokens = text_tokens.to(self.device)
+
+ return super(AudioCLIP, self).encode_text(text_tokens)
+
+ def forward(self,
+ audio: Optional[torch.Tensor] = None,
+ image: Optional[torch.Tensor] = None,
+ text: Optional[List[List[str]]] = None,
+ batch_indices: Optional[torch.Tensor] = None) -> ClipOutput:
+
+ audio_features = None
+ image_features = None
+ text_features = None
+ sample_weights = None
+
+ if audio is not None:
+ audio_features = self.encode_audio(audio)
+ audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True)
+
+ if image is not None:
+ image_features = self.encode_image(image)
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+
+ if text is not None:
+ if batch_indices is None:
+ batch_indices = torch.arange(len(text), dtype=self.dtype, device=self.device)
+
+ text_features = self.encode_text(text, '{}', batch_indices)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+
+ if hasattr(self, 'class_weights') and hasattr(self, 'label_to_class_idx'):
+ sample_weights = torch.stack([
+ sum(self.class_weights[self.label_to_class_idx[label]] for label in entities)
+ for idx, entities in enumerate(text) if idx in batch_indices
+ ])
+
+ features: ClipFeatures = (audio_features, image_features, text_features)
+
+ logit_scale_ai = torch.clamp(self.logit_scale_ai.exp(), min=1.0, max=100.0)
+ logit_scale_at = torch.clamp(self.logit_scale_at.exp(), min=1.0, max=100.0)
+ logit_scale_it = torch.clamp(self.logit_scale.exp(), min=1.0, max=100.0)
+
+ logits_audio_image = None
+ logits_audio_text = None
+ logits_image_text = None
+
+ if (audio_features is not None) and (image_features is not None):
+ logits_audio_image = logit_scale_ai * audio_features @ image_features.T
+
+ if (audio_features is not None) and (text_features is not None):
+ logits_audio_text = logit_scale_at * audio_features @ text_features.T
+
+ if (image_features is not None) and (text_features is not None):
+ logits_image_text = logit_scale_it * image_features @ text_features.T
+
+ logits: ClipLogits = (logits_audio_image, logits_audio_text, logits_image_text)
+
+ loss = self.loss_fn(logits, sample_weights)
+ if audio is not None and loss is not None:
+ loss = loss + self.audio.loss_ttf(self.device)
+
+ return (features, logits), loss
+
+ def loss_fn(self, logits: ClipLogits, sample_weights: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
+ logits_audio_image, logits_audio_text, logits_image_text = logits
+
+ if logits_audio_image is not None:
+ batch_size = logits_audio_image.shape[0]
+ elif logits_audio_text is not None:
+ batch_size = logits_audio_text.shape[0]
+ elif logits_image_text is not None:
+ batch_size = logits_image_text.shape[0]
+ else:
+ return None
+
+ reference = torch.arange(
+ batch_size,
+ dtype=torch.int64,
+ device=self.device
+ )
+
+ loss = torch.tensor(0.0, dtype=self.dtype, device=self.device)
+
+ num_modalities: int = 0
+ scale = torch.tensor(1.0, dtype=self.dtype, device=self.device)
+
+ if logits_audio_image is not None:
+ loss_ai = F.cross_entropy(
+ logits_audio_image, reference, weight=sample_weights
+ ) + F.cross_entropy(
+ logits_audio_image.transpose(-1, -2), reference, weight=sample_weights
+ )
+ loss = loss + loss_ai
+ num_modalities += 1
+
+ if logits_audio_text is not None:
+ loss_at = F.cross_entropy(
+ logits_audio_text, reference, weight=sample_weights
+ ) + F.cross_entropy(
+ logits_audio_text.transpose(-1, -2), reference, weight=sample_weights
+ )
+ loss = loss + loss_at
+ num_modalities += 1
+
+ if logits_image_text is not None:
+ loss_it = F.cross_entropy(
+ logits_image_text, reference, weight=sample_weights
+ ) + F.cross_entropy(
+ logits_image_text.transpose(-1, -2), reference, weight=sample_weights
+ )
+ loss = loss + loss_it
+ num_modalities += 1
+
+ for idx in range(num_modalities):
+ scale = scale * (idx + 1)
+
+ return loss / scale
+
+ @property
+ def loss_fn_name(self) -> str:
+ return 'Cross Entropy'
diff --git a/model/clip/__init__.py b/model/clip/__init__.py
new file mode 100644
index 0000000..b235f83
--- /dev/null
+++ b/model/clip/__init__.py
@@ -0,0 +1,5 @@
+from .model import CLIP
+from .model import convert_weights
+
+
+__all__ = ['CLIP', 'convert_weights']
diff --git a/model/clip/clip.py b/model/clip/clip.py
new file mode 100644
index 0000000..c4776ef
--- /dev/null
+++ b/model/clip/clip.py
@@ -0,0 +1,193 @@
+# CREDITS: https://github.com/openai/CLIP
+
+import hashlib
+import os
+import urllib
+import warnings
+from typing import Union, List
+
+import torch
+from PIL import Image
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+from tqdm import tqdm
+
+from .model import build_model
+from utils.simple_tokenizer import SimpleTokenizer as _Tokenizer
+
+__all__ = ["available_models", "load", "tokenize"]
+_tokenizer = _Tokenizer()
+
+_MODELS = {
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+}
+
+
+def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split("/")[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+def _transform(n_px):
+ return Compose([
+ Resize(n_px, interpolation=Image.BICUBIC),
+ CenterCrop(n_px),
+ lambda image: image.convert("RGB"),
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+
+def available_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list(_MODELS.keys())
+
+
+def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
+ """Load a CLIP model
+
+ Parameters
+ ----------
+ name : str
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+
+ device : Union[str, torch.device]
+ The device to put the loaded model
+
+ jit : bool
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
+
+ Returns
+ -------
+ model : torch.nn.Module
+ The CLIP model
+
+ preprocess : Callable[[PIL.Image], torch.Tensor]
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+ """
+ if name in _MODELS:
+ model_path = _download(_MODELS[name])
+ elif os.path.isfile(name):
+ model_path = name
+ else:
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
+
+ try:
+ # loading JIT archive
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ state_dict = None
+ except RuntimeError:
+ # loading saved state dict
+ if jit:
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
+ jit = False
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ if not jit:
+ model = build_model(state_dict or model.state_dict()).to(device)
+ if str(device) == "cpu":
+ model.float()
+ return model, _transform(model.visual.input_resolution)
+
+ # patch the device names
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
+
+ def patch_device(module):
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if str(device) == "cpu":
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+
+ model.float()
+
+ return model, _transform(model.input_resolution.item())
+
+
+def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
+ """
+ Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ----------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
diff --git a/model/clip/model.py b/model/clip/model.py
new file mode 100644
index 0000000..356482d
--- /dev/null
+++ b/model/clip/model.py
@@ -0,0 +1,433 @@
+# CREDITS: https://github.com/openai/CLIP
+
+from collections import OrderedDict
+from typing import Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu(self.bn1(self.conv1(x)))
+ out = self.relu(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x, key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+
+ return x[0]
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.input_resolution = input_resolution
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.avgpool = nn.AvgPool2d(2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ def stem(x):
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
+ x = self.relu(bn(conv(x)))
+ x = self.avgpool(x)
+ return x
+
+ x = x.type(self.conv1.weight.dtype)
+ x = stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model))
+ ]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+
+ def forward(self, x: torch.Tensor):
+ return self.resblocks(x)
+
+
+class VisualTransformer(nn.Module):
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
+ self.ln_pre = LayerNorm(width)
+
+ self.transformer = Transformer(width, layers, heads)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ def forward(self, x: torch.Tensor):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+
+class CLIP(nn.Module):
+ def __init__(self,
+ embed_dim: int,
+ # vision
+ image_resolution: int,
+ vision_layers: Union[Tuple[int, int, int, int], int],
+ vision_width: int,
+ vision_patch_size: int,
+ # text
+ context_length: int,
+ vocab_size: int,
+ transformer_width: int,
+ transformer_heads: int,
+ transformer_layers: int
+ ):
+ super().__init__()
+
+ self.context_length = context_length
+
+ if isinstance(vision_layers, (tuple, list)):
+ vision_heads = vision_width * 32 // 64
+ self.visual = ModifiedResNet(
+ layers=vision_layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ input_resolution=image_resolution,
+ width=vision_width
+ )
+ else:
+ vision_heads = vision_width // 64
+ self.visual = VisualTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_width,
+ layers=vision_layers,
+ heads=vision_heads,
+ output_dim=embed_dim
+ )
+
+ self.transformer = Transformer(
+ width=transformer_width,
+ layers=transformer_layers,
+ heads=transformer_heads,
+ attn_mask=self.build_attention_mask()
+ )
+
+ self.vocab_size = vocab_size
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
+ self.ln_final = LayerNorm(transformer_width)
+
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
+ self.logit_scale = nn.Parameter(torch.ones([]))
+
+ self.initialize_parameters()
+
+ def initialize_parameters(self):
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+
+ if isinstance(self.visual, ModifiedResNet):
+ if self.visual.attnpool is not None:
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ attn_std = self.transformer.width ** -0.5
+ fc_std = (2 * self.transformer.width) ** -0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ @property
+ def dtype(self):
+ return self.visual.conv1.weight.dtype
+
+ def encode_image(self, image):
+ return self.visual(image.type(self.dtype))
+
+ def encode_text(self, text):
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+ return x
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image)
+ text_features = self.encode_text(text)
+
+ # normalized features
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_image = logit_scale * image_features @ text_features.t()
+ logits_per_text = logit_scale * text_features @ image_features.t()
+
+ # shape = [global_batch_size, global_batch_size]
+ return logits_per_image, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+ if isinstance(l, nn.MultiheadAttention):
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict):
+ vit = "visual.proj" in state_dict
+
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_resolution = vision_patch_size * grid_size
+ else:
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+ vision_layers = tuple(counts)
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+ vision_patch_size = None
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ image_resolution = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+
+ model = CLIP(
+ embed_dim,
+ image_resolution, vision_layers, vision_width, vision_patch_size,
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ if key in state_dict:
+ del state_dict[key]
+
+ convert_weights(model)
+ model.load_state_dict(state_dict)
+ return model.eval()
diff --git a/model/esresnet/__init__.py b/model/esresnet/__init__.py
new file mode 100644
index 0000000..684cdc6
--- /dev/null
+++ b/model/esresnet/__init__.py
@@ -0,0 +1,8 @@
+from .base import ESResNet
+from .base import ESResNeXt
+from .fbsp import ESResNetFBSP
+from .fbsp import ESResNeXtFBSP
+from .attention import Attention2d
+
+
+__all__ = ['ESResNet', 'ESResNeXt', 'ESResNetFBSP', 'ESResNeXtFBSP', 'Attention2d']
diff --git a/model/esresnet/attention.py b/model/esresnet/attention.py
new file mode 100644
index 0000000..fb3c0c3
--- /dev/null
+++ b/model/esresnet/attention.py
@@ -0,0 +1,40 @@
+import torch
+import torch.nn.functional as F
+
+from typing import Tuple
+
+
+class Attention2d(torch.nn.Module):
+
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ num_kernels: int,
+ kernel_size: Tuple[int, int],
+ padding_size: Tuple[int, int]):
+
+ super(Attention2d, self).__init__()
+
+ self.conv_depth = torch.nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=in_channels * num_kernels,
+ kernel_size=kernel_size,
+ padding=padding_size,
+ groups=in_channels
+ )
+ self.conv_point = torch.nn.Conv2d(
+ in_channels=in_channels * num_kernels,
+ out_channels=out_channels,
+ kernel_size=(1, 1)
+ )
+ self.bn = torch.nn.BatchNorm2d(num_features=out_channels)
+ self.activation = torch.nn.Sigmoid()
+
+ def forward(self, x: torch.Tensor, size: torch.Size) -> torch.Tensor:
+ x = F.adaptive_max_pool2d(x, size)
+ x = self.conv_depth(x)
+ x = self.conv_point(x)
+ x = self.bn(x)
+ x = self.activation(x)
+
+ return x
diff --git a/model/esresnet/base.py b/model/esresnet/base.py
new file mode 100644
index 0000000..a2a0acc
--- /dev/null
+++ b/model/esresnet/base.py
@@ -0,0 +1,708 @@
+import termcolor
+
+import numpy as np
+import scipy.signal as sps
+
+import torch
+import torch.nn.functional as F
+
+import torchvision as tv
+
+import ignite_trainer as it
+
+from model.esresnet import attention
+from utils.transforms import scale
+
+from typing import cast
+from typing import List
+from typing import Type
+from typing import Tuple
+from typing import Union
+from typing import Optional
+
+
+def conv3x3(in_planes: int, out_planes: int, stride=1, groups: int = 1, dilation: Union[int, Tuple[int, int]] = 1):
+ """
+ CREDITS: https://github.com/pytorch/vision
+ 3x3 convolution with padding
+ """
+ return torch.nn.Conv2d(
+ in_channels=in_planes,
+ out_channels=out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation
+ )
+
+
+def conv1x1(in_planes: int, out_planes: int, stride: Union[int, Tuple[int, int]] = 1):
+ """
+ CREDITS: https://github.com/pytorch/vision
+ 1x1 convolution
+ """
+ return torch.nn.Conv2d(
+ in_channels=in_planes,
+ out_channels=out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False
+ )
+
+
+class BasicBlock(torch.nn.Module):
+
+ """
+ CREDITS: https://github.com/pytorch/vision
+ """
+
+ expansion: int = 1
+
+ def __init__(self,
+ inplanes: int,
+ planes: int,
+ stride: Union[int, Tuple[int, int]] = 1,
+ downsample: Optional[torch.nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ norm_layer: Optional[Type[torch.nn.Module]] = None):
+
+ super(BasicBlock, self).__init__()
+
+ if norm_layer is None:
+ norm_layer = torch.nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = torch.nn.ReLU()
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(torch.nn.Module):
+
+ """
+ CREDITS: https://github.com/pytorch/vision
+ """
+
+ expansion: int = 4
+
+ def __init__(self,
+ inplanes: int,
+ planes: int,
+ stride: Union[int, Tuple[int, int]] = 1,
+ downsample: Optional[torch.nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ norm_layer: Optional[Type[torch.nn.Module]] = None):
+
+ super(Bottleneck, self).__init__()
+
+ if norm_layer is None:
+ norm_layer = torch.nn.BatchNorm2d
+
+ width = int(planes * (base_width / 64.0)) * groups
+
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = torch.nn.ReLU()
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNetWithAttention(it.AbstractNet):
+
+ """
+ CREDITS: https://github.com/pytorch/vision
+ """
+
+ def __init__(self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ apply_attention: bool = False,
+ num_channels: int = 3,
+ num_classes: int = 1000,
+ zero_init_residual: bool = False,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: bool = None,
+ norm_layer: Optional[Type[torch.nn.Module]] = None):
+
+ super(ResNetWithAttention, self).__init__()
+
+ self.apply_attention = apply_attention
+
+ if norm_layer is None:
+ norm_layer = torch.nn.BatchNorm2d
+
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError(
+ f'replace_stride_with_dilation should be None or a 3-element tuple, got {replace_stride_with_dilation}'
+ )
+
+ self.groups = groups
+ self.base_width = width_per_group
+
+ self.conv1 = torch.nn.Conv2d(num_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = torch.nn.ReLU()
+ self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ if self.apply_attention:
+ self.att1 = attention.Attention2d(
+ in_channels=64,
+ out_channels=64 * block.expansion,
+ num_kernels=1,
+ kernel_size=(3, 1),
+ padding_size=(1, 0)
+ )
+
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
+ if self.apply_attention:
+ self.att2 = attention.Attention2d(
+ in_channels=64 * block.expansion,
+ out_channels=128 * block.expansion,
+ num_kernels=1,
+ kernel_size=(1, 5),
+ padding_size=(0, 2)
+ )
+
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
+ if self.apply_attention:
+ self.att3 = attention.Attention2d(
+ in_channels=128 * block.expansion,
+ out_channels=256 * block.expansion,
+ num_kernels=1,
+ kernel_size=(3, 1),
+ padding_size=(1, 0)
+ )
+
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
+ if self.apply_attention:
+ self.att4 = attention.Attention2d(
+ in_channels=256 * block.expansion,
+ out_channels=512 * block.expansion,
+ num_kernels=1,
+ kernel_size=(1, 5),
+ padding_size=(0, 2)
+ )
+
+ self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
+ if self.apply_attention:
+ self.att5 = attention.Attention2d(
+ in_channels=512 * block.expansion,
+ out_channels=512 * block.expansion,
+ num_kernels=1,
+ kernel_size=(3, 5),
+ padding_size=(1, 2)
+ )
+
+ self.fc = torch.nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, torch.nn.Conv2d):
+ torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.GroupNorm)):
+ torch.nn.init.constant_(m.weight, 1)
+ torch.nn.init.constant_(m.bias, 0)
+
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ torch.nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ torch.nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ planes: int,
+ blocks: int,
+ stride: Union[int, Tuple[int, int]] = 1,
+ dilate: bool = False) -> torch.nn.Module:
+
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = torch.nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion)
+ )
+
+ layers = list()
+ layers.append(block(
+ self.inplanes,
+ planes,
+ stride,
+ downsample,
+ self.groups,
+ self.base_width,
+ previous_dilation,
+ norm_layer
+ ))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(
+ self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation,
+ norm_layer=norm_layer
+ ))
+
+ return torch.nn.Sequential(*layers)
+
+ def _forward_pre_processing(self, x: torch.Tensor) -> torch.Tensor:
+ x = x.to(torch.get_default_dtype())
+
+ return x
+
+ def _forward_pre_features(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ return x
+
+ def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
+ x = self._forward_pre_features(x)
+
+ if self.apply_attention:
+ x_att = x.clone()
+ x = self.layer1(x)
+ x_att = self.att1(x_att, x.shape[-2:])
+ x = x * x_att
+
+ x_att = x.clone()
+ x = self.layer2(x)
+ x_att = self.att2(x_att, x.shape[-2:])
+ x = x * x_att
+
+ x_att = x.clone()
+ x = self.layer3(x)
+ x_att = self.att3(x_att, x.shape[-2:])
+ x = x * x_att
+
+ x_att = x.clone()
+ x = self.layer4(x)
+ x_att = self.att4(x_att, x.shape[-2:])
+ x = x * x_att
+ else:
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ return x
+
+ def _forward_reduction(self, x: torch.Tensor) -> torch.Tensor:
+ if self.apply_attention:
+ x_att = x.clone()
+ x = self.avgpool(x)
+ x_att = self.att5(x_att, x.shape[-2:])
+ x = x * x_att
+ else:
+ x = self.avgpool(x)
+
+ x = torch.flatten(x, 1)
+
+ return x
+
+ def _forward_classifier(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.fc(x)
+
+ return x
+
+ def forward(self,
+ x: torch.Tensor,
+ y: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+
+ x = self._forward_pre_processing(x)
+ x = self._forward_features(x)
+ x = self._forward_reduction(x)
+ y_pred = self._forward_classifier(x)
+
+ loss = None
+ if y is not None:
+ loss = self.loss_fn(y_pred, y).mean()
+
+ return y_pred if loss is None else (y_pred, loss)
+
+ def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ if isinstance(y_pred, tuple):
+ y_pred, *_ = y_pred
+
+ if y_pred.shape == y.shape:
+ loss_pred = F.binary_cross_entropy_with_logits(
+ y_pred,
+ y.to(dtype=y_pred.dtype, device=y_pred.device),
+ reduction='sum'
+ ) / y_pred.shape[0]
+ else:
+ loss_pred = F.cross_entropy(y_pred, y.to(y_pred.device))
+
+ return loss_pred
+
+ @property
+ def loss_fn_name(self) -> str:
+ return 'Cross Entropy'
+
+
+class _ESResNet(ResNetWithAttention):
+
+ @staticmethod
+ def loading_function(*args, **kwargs) -> torch.nn.Module:
+ raise NotImplementedError
+
+ def __init__(self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ apply_attention: bool = False,
+ n_fft: int = 256,
+ hop_length: Optional[int] = None,
+ win_length: Optional[int] = None,
+ window: Optional[str] = None,
+ normalized: bool = False,
+ onesided: bool = True,
+ spec_height: int = 224,
+ spec_width: int = 224,
+ num_classes: int = 1000,
+ pretrained: Union[bool, str] = False,
+ lock_pretrained: Optional[Union[bool, List[str]]] = None,
+ zero_init_residual: bool = False,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: bool = None,
+ norm_layer: Optional[Type[torch.nn.Module]] = None):
+
+ super(_ESResNet, self).__init__(
+ block=block,
+ layers=layers,
+ apply_attention=apply_attention,
+ num_channels=3,
+ num_classes=num_classes,
+ zero_init_residual=zero_init_residual,
+ groups=groups,
+ width_per_group=width_per_group,
+ replace_stride_with_dilation=replace_stride_with_dilation,
+ norm_layer=norm_layer
+ )
+
+ self.num_classes = num_classes
+
+ self.fc = torch.nn.Linear(
+ in_features=self.fc.in_features,
+ out_features=self.num_classes,
+ bias=self.fc.bias is not None
+ )
+
+ if hop_length is None:
+ hop_length = int(np.floor(n_fft / 4))
+
+ if win_length is None:
+ win_length = n_fft
+
+ if window is None:
+ window = 'boxcar'
+
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+
+ self.normalized = normalized
+ self.onesided = onesided
+
+ self.spec_height = spec_height
+ self.spec_width = spec_width
+
+ self.pretrained = pretrained
+ self._inject_members()
+ if pretrained:
+ err_msg = self.load_pretrained()
+
+ unlocked_weights = list()
+
+ for name, p in self.named_parameters():
+ unlock = True
+ if isinstance(lock_pretrained, bool):
+ if lock_pretrained and name not in err_msg:
+ unlock = False
+ elif isinstance(lock_pretrained, list):
+ if name in lock_pretrained:
+ unlock = False
+
+ p.requires_grad_(unlock)
+ if unlock:
+ unlocked_weights.append(name)
+
+ print(f'Following weights are unlocked: {unlocked_weights}')
+
+ window_buffer: torch.Tensor = torch.from_numpy(
+ sps.get_window(window=window, Nx=win_length, fftbins=True)
+ ).to(torch.get_default_dtype())
+ self.register_buffer('window', window_buffer)
+
+ self.log10_eps = 1e-18
+
+ if self.apply_attention and pretrained and not isinstance(pretrained, str):
+ self._reset_attention()
+
+ def _inject_members(self):
+ pass
+
+ def _reset_attention(self):
+ print(termcolor.colored('Resetting attention blocks', 'green'))
+
+ self.att1.bn.weight.data.fill_(1.0)
+ self.att1.bn.bias.data.fill_(1.0)
+
+ self.att2.bn.weight.data.fill_(1.0)
+ self.att2.bn.bias.data.fill_(1.0)
+
+ self.att3.bn.weight.data.fill_(1.0)
+ self.att3.bn.bias.data.fill_(1.0)
+
+ self.att4.bn.weight.data.fill_(1.0)
+ self.att4.bn.bias.data.fill_(1.0)
+
+ self.att5.bn.weight.data.fill_(1.0)
+ self.att5.bn.bias.data.fill_(1.0)
+
+ def load_pretrained(self) -> str:
+ if isinstance(self.pretrained, bool):
+ state_dict = self.loading_func(pretrained=True).state_dict()
+ else:
+ state_dict = torch.load(self.pretrained, map_location='cpu')
+
+ err_msg = ''
+ try:
+ self.load_state_dict(state_dict=state_dict, strict=True)
+ except RuntimeError as ex:
+ err_msg += f'While loading some errors occurred.\n{ex}'
+ print(termcolor.colored(err_msg, 'red'))
+
+ return err_msg
+
+ def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
+ spec = torch.stft(
+ x.view(-1, x.shape[-1]),
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=self.window,
+ pad_mode='reflect',
+ normalized=self.normalized,
+ onesided=True
+ )
+
+ if not self.onesided:
+ spec = torch.cat((torch.flip(spec, dims=(-3,)), spec), dim=-3)
+
+ return spec
+
+ def split_spectrogram(self, spec: torch.Tensor, batch_size: int) -> torch.Tensor:
+ spec_height_per_band = spec.shape[-3] // self.conv1.in_channels
+ spec_height_single_band = self.conv1.in_channels * spec_height_per_band
+ spec = spec[:, :spec_height_single_band]
+
+ spec = spec.reshape(batch_size, -1, spec.shape[-3] // self.conv1.in_channels, *spec.shape[-2:])
+
+ return spec
+
+ def spectrogram_to_power(self, spec: torch.Tensor) -> torch.Tensor:
+ spec_height = spec.shape[-3] if self.spec_height < 1 else self.spec_height
+ spec_width = spec.shape[-2] if self.spec_width < 1 else self.spec_width
+
+ pow_spec = spec[..., 0] ** 2 + spec[..., 1] ** 2
+
+ if spec_height != pow_spec.shape[-2] or spec_width != pow_spec.shape[-1]:
+ pow_spec = F.interpolate(
+ pow_spec,
+ size=(spec_height, spec_width),
+ mode='bilinear',
+ align_corners=True
+ )
+
+ return pow_spec
+
+ def _forward_pre_processing(self, x: torch.Tensor) -> torch.Tensor:
+ x = super(_ESResNet, self)._forward_pre_processing(x)
+ x = scale(x, -32768.0, 32767, -1.0, 1.0)
+
+ spec = self.spectrogram(x)
+ spec_split_ch = self.split_spectrogram(spec, x.shape[0])
+ pow_spec_split_ch = self.spectrogram_to_power(spec_split_ch)
+ pow_spec_split_ch = torch.where(
+ cast(torch.Tensor, pow_spec_split_ch > 0.0),
+ pow_spec_split_ch,
+ torch.full_like(pow_spec_split_ch, self.log10_eps)
+ )
+ pow_spec_split_ch = pow_spec_split_ch.reshape(
+ x.shape[0], -1, self.conv1.in_channels, *pow_spec_split_ch.shape[-2:]
+ )
+ x_db = torch.log10(pow_spec_split_ch).mul(10.0)
+
+ return x_db
+
+ def _forward_features(self, x_db: torch.Tensor) -> List[torch.Tensor]:
+ outputs = list()
+ for ch_idx in range(x_db.shape[1]):
+ ch = x_db[:, ch_idx]
+ out = super(_ESResNet, self)._forward_features(ch)
+ outputs.append(out)
+
+ return outputs
+
+ def _forward_reduction(self, x: List[torch.Tensor]) -> torch.Tensor:
+ outputs = list()
+ for ch in x:
+ out = super(_ESResNet, self)._forward_reduction(ch)
+ outputs.append(out)
+ outputs = torch.stack(outputs, dim=-1).sum(dim=-1)
+
+ return outputs
+
+
+class ESResNet(_ESResNet):
+
+ loading_func = staticmethod(tv.models.resnet50)
+
+ def __init__(self,
+ n_fft: int = 256,
+ hop_length: Optional[int] = None,
+ win_length: Optional[int] = None,
+ window: Optional[str] = None,
+ normalized: bool = False,
+ onesided: bool = True,
+ spec_height: int = 224,
+ spec_width: int = 224,
+ num_classes: int = 1000,
+ apply_attention: bool = False,
+ pretrained: bool = False,
+ lock_pretrained: Optional[Union[bool, List[str]]] = None):
+
+ super(ESResNet, self).__init__(
+ block=Bottleneck,
+ layers=[3, 4, 6, 3],
+ apply_attention=apply_attention,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=window,
+ normalized=normalized,
+ onesided=onesided,
+ spec_height=spec_height,
+ spec_width=spec_width,
+ num_classes=num_classes,
+ pretrained=pretrained,
+ lock_pretrained=lock_pretrained
+ )
+
+
+class ESResNeXt(_ESResNet):
+
+ loading_func = staticmethod(tv.models.resnext50_32x4d)
+
+ def __init__(self,
+ n_fft: int = 256,
+ hop_length: Optional[int] = None,
+ win_length: Optional[int] = None,
+ window: Optional[str] = None,
+ normalized: bool = False,
+ onesided: bool = True,
+ spec_height: int = 224,
+ spec_width: int = 224,
+ num_classes: int = 1000,
+ apply_attention: bool = False,
+ pretrained: Union[bool, str] = False,
+ lock_pretrained: Optional[Union[bool, List[str]]] = None):
+
+ super(ESResNeXt, self).__init__(
+ block=Bottleneck,
+ layers=[3, 4, 6, 3],
+ apply_attention=apply_attention,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=window,
+ normalized=normalized,
+ onesided=onesided,
+ spec_height=spec_height,
+ spec_width=spec_width,
+ num_classes=num_classes,
+ pretrained=pretrained,
+ lock_pretrained=lock_pretrained,
+ groups=32,
+ width_per_group=4
+ )
diff --git a/model/esresnet/fbsp.py b/model/esresnet/fbsp.py
new file mode 100644
index 0000000..15d329d
--- /dev/null
+++ b/model/esresnet/fbsp.py
@@ -0,0 +1,247 @@
+import numpy as np
+
+import torch
+import torch.nn.functional as F
+
+import torchvision as tv
+
+from utils import transforms
+from model.esresnet.base import _ESResNet
+from model.esresnet.base import Bottleneck
+
+from typing import cast
+from typing import List
+from typing import Tuple
+from typing import Union
+from typing import Optional
+
+
+class LinearFBSP(torch.nn.Module):
+
+ def __init__(self, out_features: int, bias: bool = True, normalized: bool = False):
+ super(LinearFBSP, self).__init__()
+
+ self.out_features = out_features
+ self.normalized = normalized
+ self.eps = 1e-8
+
+ default_dtype = torch.get_default_dtype()
+
+ self.register_parameter('m', torch.nn.Parameter(torch.zeros(self.out_features, dtype=default_dtype)))
+ self.register_parameter('fb', torch.nn.Parameter(torch.ones(self.out_features, dtype=default_dtype)))
+ self.register_parameter('fc', torch.nn.Parameter(torch.arange(self.out_features, dtype=default_dtype)))
+ self.register_parameter(
+ 'bias',
+ torch.nn.Parameter(
+ torch.normal(
+ 0.0, 0.5, (self.out_features, 2), dtype=default_dtype
+ ) if bias else cast(
+ torch.nn.Parameter, None
+ )
+ )
+ )
+
+ self.m.register_hook(lambda grad: grad / (torch.norm(grad, p=float('inf')) + self.eps))
+ self.fb.register_hook(lambda grad: grad / (torch.norm(grad, p=float('inf')) + self.eps))
+ self.fc.register_hook(lambda grad: grad / (torch.norm(grad, p=float('inf')) + self.eps))
+
+ @staticmethod
+ def power(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
+ magnitudes = (x1[..., 0] ** 2 + x1[..., 1] ** 2) ** 0.5
+ phases = x1[..., 1].atan2(x1[..., 0])
+
+ power_real = x2[..., 0]
+ power_imag = x2[..., 1]
+
+ mag_out = ((magnitudes ** 2) ** (0.5 * power_real) * torch.exp(-power_imag * phases))
+
+ return mag_out.unsqueeze(-1) * torch.stack((
+ (power_real * phases + 0.5 * power_imag * (magnitudes ** 2).log()).cos(),
+ (power_real * phases + 0.5 * power_imag * (magnitudes ** 2).log()).sin()
+ ), dim=-1)
+
+ @staticmethod
+ def sinc(x: torch.Tensor) -> torch.Tensor:
+ return torch.where(cast(torch.Tensor, x == 0), torch.ones_like(x), torch.sin(x) / x)
+
+ def _materialize_weights(self, x: torch.Tensor) -> Tuple[torch.Tensor, bool]:
+ x_is_complex = x.shape[-1] == 2
+ in_features = x.shape[-1 - int(x_is_complex)]
+
+ t = np.pi * torch.linspace(-1.0, 1.0, in_features, dtype=x.dtype, device=x.device).reshape(1, -1, 1) + self.eps
+
+ m = self.m.reshape(-1, 1, 1)
+ fb = self.fb.reshape(-1, 1, 1)
+ fc = self.fc.reshape(-1, 1, 1)
+
+ kernel = torch.cat((torch.cos(fc * t), -torch.sin(fc * t)), dim=-1) # complex
+ scale = fb.sqrt() # real
+ win = self.sinc(fb * t / (m + self.eps)) # real
+ win = self.power(
+ torch.cat((win, torch.zeros_like(win)), dim=-1),
+ torch.cat((m, torch.zeros_like(m)), dim=-1)
+ ) # complex
+
+ weights = scale * torch.cat((
+ win[..., :1] * kernel[..., :1] - win[..., 1:] * kernel[..., 1:],
+ win[..., :1] * kernel[..., 1:] + win[..., 1:] * kernel[..., :1]
+ ), dim=-1)
+
+ if self.normalized:
+ weights = weights / (in_features ** 0.5)
+
+ return weights, x_is_complex
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ weights, x_is_complex = self._materialize_weights(x)
+
+ if x_is_complex:
+ x = torch.stack((
+ F.linear(x[..., 0], weights[..., 0]) - F.linear(x[..., 1], weights[..., 1]),
+ F.linear(x[..., 0], weights[..., 1]) + F.linear(x[..., 1], weights[..., 0])
+ ), dim=-1)
+ else:
+ x = torch.stack((
+ F.linear(x, weights[..., 0]),
+ F.linear(x, weights[..., 1])
+ ), dim=-1)
+
+ if (self.bias is not None) and (self.bias.numel() == (self.out_features * 2)):
+ x = x + self.bias
+
+ return x, weights
+
+ def extra_repr(self) -> str:
+ return 'out_features={}, bias={}, normalized={}'.format(
+ self.out_features,
+ (self.bias is not None) and (self.bias.numel() == (self.out_features * 2)),
+ self.normalized
+ )
+
+
+ttf_weights = dict()
+
+
+class _ESResNetFBSP(_ESResNet):
+
+ def _inject_members(self):
+ self.add_module(
+ 'fbsp',
+ LinearFBSP(
+ out_features=int(round(self.n_fft / 2)) + 1 if self.onesided else self.n_fft,
+ normalized=self.normalized,
+ bias=False
+ )
+ )
+
+ def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
+ with torch.no_grad():
+ frames = transforms.frame_signal(
+ signal=x.view(-1, x.shape[-1]),
+ frame_length=self.win_length,
+ hop_length=self.hop_length,
+ window=self.window
+ )
+
+ if self.n_fft > self.win_length:
+ pad_length = self.n_fft - self.win_length
+ pad_left = pad_length // 2
+ pad_right = pad_length - pad_left
+ frames = F.pad(frames, [pad_left, pad_right])
+
+ spec, ttf_weights_ = self.fbsp(frames)
+
+ spec = spec.transpose(-2, -3)
+ ttf_weights[x.device] = ttf_weights_
+
+ return spec
+
+ def loss_ttf(self, device: torch.device) -> torch.Tensor:
+ ttf_norm = torch.norm(ttf_weights[device], p=2, dim=[-1, -2])
+ loss_ttf_norm = F.mse_loss(
+ ttf_norm,
+ torch.full_like(ttf_norm, 1.0 if self.normalized else self.n_fft ** 0.5)
+ )
+
+ return loss_ttf_norm
+
+ def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ loss_pred = super(_ESResNetFBSP, self).loss_fn(y_pred, y)
+ loss_ttf_norm = self.loss_ttf(y_pred.device)
+ loss = loss_pred + loss_ttf_norm
+
+ return loss
+
+
+class ESResNetFBSP(_ESResNetFBSP):
+
+ loading_func = staticmethod(tv.models.resnet50)
+
+ def __init__(self,
+ n_fft: int = 256,
+ hop_length: Optional[int] = None,
+ win_length: Optional[int] = None,
+ window: Optional[str] = None,
+ normalized: bool = False,
+ onesided: bool = True,
+ spec_height: int = 224,
+ spec_width: int = 224,
+ num_classes: int = 1000,
+ apply_attention: bool = False,
+ pretrained: bool = False,
+ lock_pretrained: Optional[Union[bool, List[str]]] = None):
+
+ super(ESResNetFBSP, self).__init__(
+ block=Bottleneck,
+ layers=[3, 4, 6, 3],
+ apply_attention=apply_attention,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=window,
+ normalized=normalized,
+ onesided=onesided,
+ spec_height=spec_height,
+ spec_width=spec_width,
+ num_classes=num_classes,
+ pretrained=pretrained,
+ lock_pretrained=lock_pretrained
+ )
+
+
+class ESResNeXtFBSP(_ESResNetFBSP):
+
+ loading_func = staticmethod(tv.models.resnext50_32x4d)
+
+ def __init__(self,
+ n_fft: int = 256,
+ hop_length: Optional[int] = None,
+ win_length: Optional[int] = None,
+ window: Optional[str] = None,
+ normalized: bool = False,
+ onesided: bool = True,
+ spec_height: int = 224,
+ spec_width: int = 224,
+ num_classes: int = 1000,
+ apply_attention: bool = False,
+ pretrained: Union[bool, str] = False,
+ lock_pretrained: Optional[Union[bool, List[str]]] = None):
+
+ super(ESResNeXtFBSP, self).__init__(
+ block=Bottleneck,
+ layers=[3, 4, 6, 3],
+ apply_attention=apply_attention,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=window,
+ normalized=normalized,
+ onesided=onesided,
+ spec_height=spec_height,
+ spec_width=spec_width,
+ num_classes=num_classes,
+ pretrained=pretrained,
+ lock_pretrained=lock_pretrained,
+ groups=32,
+ width_per_group=4
+ )
diff --git a/protocols/README.md b/protocols/README.md
new file mode 100644
index 0000000..48e72c8
--- /dev/null
+++ b/protocols/README.md
@@ -0,0 +1,3 @@
+# Protocols
+
+Here are the JSON-files that describe configurations of experiments.
diff --git a/protocols/audioclip-esc50.json b/protocols/audioclip-esc50.json
new file mode 100644
index 0000000..32e1ada
--- /dev/null
+++ b/protocols/audioclip-esc50.json
@@ -0,0 +1,108 @@
+{
+ "Visdom": {
+ "host": null,
+ "port": null,
+ "env_path": null
+ },
+ "Setup": {
+ "name": "Multimodal-Audio",
+ "suffix": "CV1",
+ "batch_train": 64,
+ "batch_test": 64,
+ "workers_train": 4,
+ "workers_test": 4,
+ "epochs": 50,
+ "log_interval": 10,
+ "saved_models_path": "/path/to/saved/models"
+ },
+ "Model": {
+ "class": "model.audioclip.AudioCLIP",
+ "args": {
+ "multilabel": false,
+ "pretrained": "/path/to/assets/trained_AudioCLIP.pt"
+ }
+ },
+ "Optimizer": {
+ "class": "torch.optim.SGD",
+ "args": {
+ "lr": 5e-5,
+ "momentum": 0.9,
+ "nesterov": true,
+ "weight_decay": 5e-4
+ }
+ },
+ "Scheduler": {
+ "class": "torch.optim.lr_scheduler.ExponentialLR",
+ "args": {
+ "gamma": 0.96
+ }
+ },
+ "Dataset": {
+ "class": "utils.datasets.ESC50",
+ "args": {
+ "dl_shuffle": true,
+ "root": "/path/to/ESC50",
+ "sample_rate": 44100,
+ "fold": 1,
+ "training": {"key": "train", "yes": true, "no": false}
+ }
+ },
+ "Transforms": [
+ {
+ "class": "utils.transforms.ToTensor1D",
+ "args": {}
+ },
+ {
+ "class": "utils.transforms.RandomFlip",
+ "args": {"p": 0.5},
+ "test": false
+ },
+ {
+ "class": "utils.transforms.RandomScale",
+ "args": {"max_scale": 1.50},
+ "test": false
+ },
+ {
+ "class": "utils.transforms.RandomPadding",
+ "args": {"out_len": 220500},
+ "test": false
+ },
+ {
+ "class": "utils.transforms.RandomCrop",
+ "args": {"out_len": 220500},
+ "test": false
+ },
+ {
+ "class": "utils.transforms.RandomNoise",
+ "args": {"snr_min_db": 10.0, "snr_max_db": 120.0, "p": 0.25},
+ "test": false
+ },
+ {
+ "class": "utils.transforms.RandomPadding",
+ "args": {"out_len": 220500, "train": false},
+ "train": false
+ },
+ {
+ "class": "utils.transforms.RandomCrop",
+ "args": {"out_len": 220500, "train": false},
+ "train": false
+ }
+ ],
+ "Metrics": {
+ "Performance": {
+ "window_name": null,
+ "x_label": "#Epochs",
+ "y_label": "Accuracy",
+ "width": 1890,
+ "height": 416,
+ "lines": [
+ {
+ "line_label": "Val. Acc.",
+ "class": "ignite.metrics.Accuracy",
+ "args": {},
+ "is_checkpoint": true
+ }
+ ]
+ }
+ }
+}
diff --git a/protocols/audioclip-us8k.json b/protocols/audioclip-us8k.json
new file mode 100644
index 0000000..4c1f649
--- /dev/null
+++ b/protocols/audioclip-us8k.json
@@ -0,0 +1,108 @@
+{
+ "Visdom": {
+ "host": null,
+ "port": null,
+ "env_path": null
+ },
+ "Setup": {
+ "name": "Multimodal-Audio",
+ "suffix": "CV01",
+ "batch_train": 64,
+ "batch_test": 64,
+ "workers_train": 4,
+ "workers_test": 4,
+ "epochs": 50,
+ "log_interval": 25,
+ "saved_models_path": "/path/to/saved/models"
+ },
+ "Model": {
+ "class": "model.audioclip.AudioCLIP",
+ "args": {
+ "multilabel": false,
+ "pretrained": "/path/to/assets/trained_AudioCLIP.pt"
+ }
+ },
+ "Optimizer": {
+ "class": "torch.optim.SGD",
+ "args": {
+ "lr": 1e-5,
+ "momentum": 0.9,
+ "nesterov": true,
+ "weight_decay": 5e-4
+ }
+ },
+ "Scheduler": {
+ "class": "torch.optim.lr_scheduler.ExponentialLR",
+ "args": {
+ "gamma": 0.96
+ }
+ },
+ "Dataset": {
+ "class": "utils.datasets.UrbanSound8K",
+ "args": {
+ "root": "/path/to/UrbanSound8K",
+ "sample_rate": 44100,
+ "fold": 1,
+ "mono": false,
+ "training": {"key": "train", "yes": true, "no": false}
+ }
+ },
+ "Transforms": [
+ {
+ "class": "utils.transforms.ToTensor1D",
+ "args": {}
+ },
+ {
+ "class": "utils.transforms.RandomFlip",
+ "args": {"p": 0.5},
+ "test": false
+ },
+ {
+ "class": "utils.transforms.RandomScale",
+ "args": {"max_scale": 1.50},
+ "test": false
+ },
+ {
+ "class": "utils.transforms.RandomPadding",
+ "args": {"out_len": 176400},
+ "test": false
+ },
+ {
+ "class": "utils.transforms.RandomCrop",
+ "args": {"out_len": 176400},
+ "test": false
+ },
+ {
+ "class": "utils.transforms.RandomNoise",
+ "args": {"snr_min_db": 10.0, "snr_max_db": 120.0, "p": 0.25},
+ "test": false
+ },
+ {
+ "class": "utils.transforms.RandomPadding",
+ "args": {"out_len": 176400, "train": false},
+ "train": false
+ },
+ {
+ "class": "utils.transforms.RandomCrop",
+ "args": {"out_len": 176400, "train": false},
+ "train": false
+ }
+ ],
+ "Metrics": {
+ "Performance": {
+ "window_name": null,
+ "x_label": "#Epochs",
+ "y_label": "Accuracy",
+ "width": 1890,
+ "height": 416,
+ "lines": [
+ {
+ "line_label": "Val. Acc.",
+ "class": "ignite.metrics.Accuracy",
+ "args": {},
+ "is_checkpoint": true
+ }
+ ]
+ }
+ }
+}
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..f78635a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,11 @@
+librosa==0.7.2
+numpy==1.18.1
+pandas==1.0.3
+pytorch-ignite==0.3.0
+scikit-learn==0.22.1
+scipy==1.4.1
+termcolor==1.1.0
+torch==1.7.1
+torchvision==0.8.2
+tqdm==4.43.0
+visdom==0.1.8.9
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..c5be7f9
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,7 @@
+from . import datasets
+from . import transforms
+
+__all__ = [
+ 'datasets',
+ 'transforms'
+]
diff --git a/utils/datasets/__init__.py b/utils/datasets/__init__.py
new file mode 100644
index 0000000..cba11e7
--- /dev/null
+++ b/utils/datasets/__init__.py
@@ -0,0 +1,4 @@
+from .esc50 import ESC50
+from .us8k import UrbanSound8K
+
+__all__ = ['ESC50', 'UrbanSound8K']
diff --git a/utils/datasets/esc50.py b/utils/datasets/esc50.py
new file mode 100644
index 0000000..17b0111
--- /dev/null
+++ b/utils/datasets/esc50.py
@@ -0,0 +1,128 @@
+import os
+import warnings
+import multiprocessing as mp
+
+import tqdm
+import librosa
+
+import numpy as np
+import pandas as pd
+
+import torch.utils.data as td
+
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Tuple
+from typing import Union
+from typing import Optional
+
+
+class ESC50(td.Dataset):
+
+ def __init__(self,
+ root: str,
+ sample_rate: int = 22050,
+ train: bool = True,
+ fold: Optional[int] = None,
+ transform_audio=None,
+ target_transform=None,
+ **_):
+
+ super(ESC50, self).__init__()
+
+ self.sample_rate = sample_rate
+
+ meta = self.load_meta(os.path.join(root, 'meta', 'esc50.csv'))
+
+ if fold is None:
+ fold = 5
+
+ self.folds_to_load = set(meta['fold'])
+
+ if fold not in self.folds_to_load:
+ raise ValueError(f'fold {fold} does not exist')
+
+ self.train = train
+ self.transform = transform_audio
+
+ if self.train:
+ self.folds_to_load -= {fold}
+ else:
+ self.folds_to_load -= self.folds_to_load - {fold}
+
+ self.data: Dict[Union[str, int], Dict[str, Any]] = dict()
+ self.load_data(meta, os.path.join(root, 'audio'))
+ self.indices = list(self.data.keys())
+
+ self.class_idx_to_label = dict()
+ for row in self.data.values():
+ idx = row['target']
+ label = row['category']
+ self.class_idx_to_label[idx] = label
+ self.label_to_class_idx = {lb: idx for idx, lb in self.class_idx_to_label.items()}
+
+ self.target_transform = target_transform
+
+ @staticmethod
+ def load_meta(path_to_csv: str) -> pd.DataFrame:
+ meta = pd.read_csv(path_to_csv)
+
+ return meta
+
+ @staticmethod
+ def _load_worker(idx: int, filename: str, sample_rate: Optional[int] = None) -> Tuple[int, int, np.ndarray]:
+ wav, sample_rate = librosa.load(filename, sr=sample_rate, mono=True)
+
+ if wav.ndim == 1:
+ wav = wav[:, np.newaxis]
+
+ wav = wav.T * 32768.0
+
+ return idx, sample_rate, wav.astype(np.float32)
+
+ def load_data(self, meta: pd.DataFrame, base_path: str):
+ items_to_load = dict()
+
+ for idx, row in meta.iterrows():
+ if row['fold'] in self.folds_to_load:
+ items_to_load[idx] = os.path.join(base_path, row['filename']), self.sample_rate
+
+ items_to_load = [(idx, path, sample_rate) for idx, (path, sample_rate) in items_to_load.items()]
+
+ num_processes = os.cpu_count()
+ warnings.filterwarnings('ignore')
+ with mp.Pool(processes=num_processes) as pool:
+ tqdm.tqdm.write(f'Loading {self.__class__.__name__} (train={self.train})')
+ for idx, sample_rate, wav in pool.starmap(
+ func=self._load_worker,
+ iterable=items_to_load,
+ chunksize=int(np.ceil(len(items_to_load) / num_processes)) or 1
+ ):
+ row = meta.loc[idx]
+
+ self.data[idx] = {
+ 'audio': wav,
+ 'sample_rate': sample_rate,
+ 'target': row['target'],
+ 'category': row['category'].replace('_', ' '),
+ 'fold': row['fold'],
+ 'esc10': row['esc10']
+ }
+
+ def __getitem__(self, index: int) -> Tuple[np.ndarray, Optional[np.ndarray], List[str]]:
+ if not (0 <= index < len(self)):
+ raise IndexError
+
+ audio: np.ndarray = self.data[self.indices[index]]['audio']
+ target: str = self.data[self.indices[index]]['category']
+
+ if self.transform is not None:
+ audio = self.transform(audio)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return audio, None, [target]
+
+ def __len__(self) -> int:
+ return len(self.indices)
diff --git a/utils/datasets/us8k.py b/utils/datasets/us8k.py
new file mode 100644
index 0000000..ff45324
--- /dev/null
+++ b/utils/datasets/us8k.py
@@ -0,0 +1,167 @@
+import os
+import warnings
+import multiprocessing as mp
+
+import tqdm
+import librosa
+import soundfile as sf
+
+import numpy as np
+import pandas as pd
+
+import torch.utils.data as td
+
+import sklearn.model_selection as skms
+
+import utils.transforms as transforms
+
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Tuple
+from typing import Optional
+
+
+class UrbanSound8K(td.Dataset):
+
+ def __init__(self,
+ root: str,
+ sample_rate: int = 22050,
+ train: bool = True,
+ fold: Optional[int] = None,
+ mono: bool = False,
+ transform_audio=None,
+ target_transform=None,
+ **_):
+
+ super(UrbanSound8K, self).__init__()
+
+ self.root = root
+ self.sample_rate = sample_rate
+ self.train = train
+ self.random_split_seed = None
+
+ if fold is None:
+ fold = 1
+
+ if not (1 <= fold <= 10):
+ raise ValueError(f'Expected fold in range [1, 10], got {fold}')
+
+ self.fold = fold
+ self.folds_to_load = set(range(1, 11))
+
+ if self.fold not in self.folds_to_load:
+ raise ValueError(f'fold {fold} does not exist')
+
+ if self.train:
+ # if in training mode, keep all but test fold
+ self.folds_to_load -= {self.fold}
+ else:
+ # if in evaluation mode, keep the test samples only
+ self.folds_to_load -= self.folds_to_load - {self.fold}
+
+ self.mono = mono
+
+ self.transform = transform_audio
+ self.target_transform = target_transform
+
+ self.data: Dict[str, Dict[str, Any]] = dict()
+ self.indices = dict()
+ self.load_data()
+
+ self.class_idx_to_label = dict()
+ for row in self.data.values():
+ idx = row['target']
+ label = row['category']
+ self.class_idx_to_label[idx] = label
+ self.label_to_class_idx = {lb: idx for idx, lb in self.class_idx_to_label.items()}
+
+ @staticmethod
+ def _load_worker(fn: str, path_to_file: str, sample_rate: int, mono: bool = False) -> Tuple[str, int, np.ndarray]:
+ wav, sample_rate_ = sf.read(
+ path_to_file,
+ dtype='float32',
+ always_2d=True
+ )
+
+ wav = librosa.resample(wav.T, sample_rate_, sample_rate)
+
+ if wav.shape[0] == 1 and not mono:
+ wav = np.concatenate((wav, wav), axis=0)
+
+ wav = wav[:, :sample_rate * 4]
+ wav = transforms.scale(wav, wav.min(), wav.max(), -32768.0, 32767.0)
+
+ return fn, sample_rate, wav.astype(np.float32)
+
+ def load_data(self):
+ # read metadata
+ meta = pd.read_csv(
+ os.path.join(self.root, 'metadata', 'UrbanSound8K.csv'),
+ sep=',',
+ index_col='slice_file_name'
+ )
+
+ for row_idx, (fn, row) in enumerate(meta.iterrows()):
+ path = os.path.join(self.root, 'audio', 'fold{}'.format(row['fold']), fn)
+ self.data[fn] = path, self.sample_rate, self.mono
+
+ # by default, the official split from the metadata is used
+ files_to_load = list()
+ # if the random seed is not None, the random split is used
+ if self.random_split_seed is not None:
+ # given an integer random seed
+ skf = skms.StratifiedKFold(n_splits=10, shuffle=True, random_state=self.random_split_seed)
+
+ # split the US8K samples into 10 folds
+ for fold_idx, (train_ids, test_ids) in enumerate(skf.split(
+ np.zeros(len(meta)), meta['classID'].values.astype(int)
+ ), 1):
+ # if this is the fold we want to load, add the corresponding files to the list
+ if fold_idx == self.fold:
+ ids = train_ids if self.train else test_ids
+ filenames = meta.iloc[ids].index
+ files_to_load.extend(filenames)
+ break
+ else:
+ # if the random seed is None, use the official split
+ for fn, row in meta.iterrows():
+ if int(row['fold']) in self.folds_to_load:
+ files_to_load.append(fn)
+
+ self.data = {fn: vals for fn, vals in self.data.items() if fn in files_to_load}
+ self.indices = {idx: fn for idx, fn in enumerate(self.data)}
+
+ num_processes = os.cpu_count()
+ warnings.filterwarnings('ignore')
+ with mp.Pool(processes=num_processes) as pool:
+ tqdm.tqdm.write(f'Loading {self.__class__.__name__} (train={self.train})')
+ for fn, sample_rate, wav in pool.starmap(
+ func=self._load_worker,
+ iterable=[(fn, path, sr, mono) for fn, (path, sr, mono) in self.data.items()],
+ chunksize=int(np.ceil(len(meta) / num_processes)) or 1
+ ):
+ self.data[fn] = {
+ 'audio': wav,
+ 'sample_rate': sample_rate,
+ 'target': meta.loc[fn, 'classID'],
+ 'category': meta.loc[fn, 'class'].replace('_', ' ').strip(' '),
+ 'background': bool(meta.loc[fn, 'salience'] - 1)
+ }
+
+ def __getitem__(self, index: int) -> Tuple[np.ndarray, Optional[np.ndarray], List[str]]:
+ if not (0 <= index < len(self)):
+ raise IndexError
+
+ audio: np.ndarray = self.data[self.indices[index]]['audio']
+ target: str = self.data[self.indices[index]]['category']
+
+ if self.transform is not None:
+ audio = self.transform(audio)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return audio, None, [target]
+
+ def __len__(self) -> int:
+ return len(self.data)
diff --git a/utils/simple_tokenizer.py b/utils/simple_tokenizer.py
new file mode 100644
index 0000000..68946a1
--- /dev/null
+++ b/utils/simple_tokenizer.py
@@ -0,0 +1,134 @@
+# CREDITS: https://github.com/openai/CLIP
+
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'assets', 'bpe_simple_vocab_16e6.txt.gz')
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
diff --git a/utils/transforms.py b/utils/transforms.py
new file mode 100644
index 0000000..082175e
--- /dev/null
+++ b/utils/transforms.py
@@ -0,0 +1,199 @@
+import math
+
+import numpy as np
+
+import torch
+import torchvision as tv
+
+import ignite_trainer as it
+
+
+def scale(old_value, old_min, old_max, new_min, new_max):
+ old_range = (old_max - old_min)
+ new_range = (new_max - new_min)
+ new_value = (((old_value - old_min) * new_range) / old_range) + new_min
+
+ return new_value
+
+
+def frame_signal(signal: torch.Tensor,
+ frame_length: int,
+ hop_length: int,
+ window: torch.Tensor = None) -> torch.Tensor:
+
+ if window is None:
+ window = torch.ones(frame_length, dtype=signal.dtype, device=signal.device)
+
+ if window.shape[0] != frame_length:
+ raise ValueError('Wrong `window` length: expected {}, got {}'.format(window.shape[0], frame_length))
+
+ signal_length = signal.shape[-1]
+
+ if signal_length <= frame_length:
+ num_frames = 1
+ else:
+ num_frames = 1 + int(math.ceil((1.0 * signal_length - frame_length) / hop_length))
+
+ pad_len = int((num_frames - 1) * hop_length + frame_length)
+ if pad_len > signal_length:
+ zeros = torch.zeros(pad_len - signal_length, device=signal.device, dtype=signal.dtype)
+
+ while zeros.dim() < signal.dim():
+ zeros.unsqueeze_(0)
+
+ pad_signal = torch.cat((zeros.expand(*signal.shape[:-1], -1)[..., :zeros.shape[-1] // 2], signal), dim=-1)
+ pad_signal = torch.cat((pad_signal, zeros.expand(*signal.shape[:-1], -1)[..., zeros.shape[-1] // 2:]), dim=-1)
+ else:
+ pad_signal = signal
+
+ indices = torch.arange(0, frame_length, device=signal.device).repeat(num_frames, 1)
+ indices += torch.arange(
+ 0,
+ num_frames * hop_length,
+ hop_length,
+ device=signal.device
+ ).repeat(frame_length, 1).t_()
+ indices = indices.long()
+
+ frames = pad_signal[..., indices]
+ frames = frames * window
+
+ return frames
+
+
+class ToTensor1D(tv.transforms.ToTensor):
+
+ def __call__(self, tensor: np.ndarray):
+ tensor_2d = super(ToTensor1D, self).__call__(tensor[..., np.newaxis])
+
+ return tensor_2d.squeeze_(0)
+
+
+class RandomFlip(it.AbstractTransform):
+
+ def __init__(self, p: float = 0.5):
+ super(RandomFlip, self).__init__()
+
+ self.p = p
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ if x.dim() > 2:
+ flip_mask = torch.rand(x.shape[0], device=x.device) <= self.p
+ x[flip_mask] = x[flip_mask].flip(-1)
+ else:
+ if torch.rand(1) <= self.p:
+ x = x.flip(0)
+
+ return x
+
+
+class RandomScale(it.AbstractTransform):
+
+ def __init__(self, max_scale: float = 1.25):
+ super(RandomScale, self).__init__()
+
+ self.max_scale = max_scale
+
+ @staticmethod
+ def random_scale(max_scale: float, signal: torch.Tensor) -> torch.Tensor:
+ scaling = np.power(max_scale, np.random.uniform(-1, 1))
+ output_size = int(signal.shape[-1] * scaling)
+ ref = torch.arange(output_size, device=signal.device, dtype=signal.dtype).div_(scaling)
+
+ ref1 = ref.clone().type(torch.int64)
+ ref2 = torch.min(ref1 + 1, torch.full_like(ref1, signal.shape[-1] - 1, dtype=torch.int64))
+ r = ref - ref1.type(ref.type())
+ scaled_signal = signal[..., ref1] * (1 - r) + signal[..., ref2] * r
+
+ return scaled_signal
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ return self.random_scale(self.max_scale, x)
+
+
+class RandomCrop(it.AbstractTransform):
+
+ def __init__(self, out_len: int = 44100, train: bool = True):
+ super(RandomCrop, self).__init__()
+
+ self.out_len = out_len
+ self.train = train
+
+ def random_crop(self, signal: torch.Tensor) -> torch.Tensor:
+ if self.train:
+ left = np.random.randint(0, signal.shape[-1] - self.out_len)
+ else:
+ left = int(round(0.5 * (signal.shape[-1] - self.out_len)))
+
+ orig_std = signal.float().std() * 0.5
+ output = signal[..., left:left + self.out_len]
+
+ out_std = output.float().std()
+ if out_std < orig_std:
+ output = signal[..., :self.out_len]
+
+ new_out_std = output.float().std()
+ if orig_std > new_out_std > out_std:
+ output = signal[..., -self.out_len:]
+
+ return output
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ return self.random_crop(x) if x.shape[-1] > self.out_len else x
+
+
+class RandomPadding(it.AbstractTransform):
+
+ def __init__(self, out_len: int = 88200, train: bool = True):
+ super(RandomPadding, self).__init__()
+
+ self.out_len = out_len
+ self.train = train
+
+ def random_pad(self, signal: torch.Tensor) -> torch.Tensor:
+ if self.train:
+ left = np.random.randint(0, self.out_len - signal.shape[-1])
+ else:
+ left = int(round(0.5 * (self.out_len - signal.shape[-1])))
+
+ right = self.out_len - (left + signal.shape[-1])
+
+ pad_value_left = signal[..., 0].float().mean().to(signal.dtype)
+ pad_value_right = signal[..., -1].float().mean().to(signal.dtype)
+ output = torch.cat((
+ torch.zeros(signal.shape[:-1] + (left,), dtype=signal.dtype, device=signal.device).fill_(pad_value_left),
+ signal,
+ torch.zeros(signal.shape[:-1] + (right,), dtype=signal.dtype, device=signal.device).fill_(pad_value_right)
+ ), dim=-1)
+
+ return output
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ return self.random_pad(x) if x.shape[-1] < self.out_len else x
+
+
+class RandomNoise(it.AbstractTransform):
+
+ def __init__(self, snr_min_db: float = -10.0, snr_max_db: float = 100.0, p: float = 0.5):
+ super(RandomNoise, self).__init__()
+
+ self.p = p
+ self.snr_min_db = snr_min_db
+ self.snr_max_db = snr_max_db
+
+ def random_noise(self, signal: torch.Tensor) -> torch.Tensor:
+ target_snr = np.random.rand() * (self.snr_max_db - self.snr_min_db + 1.0) + self.snr_min_db
+
+ signal_watts = torch.mean(signal ** 2, dim=(-1, -2))
+ signal_db = 10 * torch.log10(signal_watts)
+
+ noise_db = signal_db - target_snr
+ noise_watts = 10 ** (noise_db / 10)
+ noise = torch.normal(0.0, noise_watts.item() ** 0.5, signal.shape)
+
+ output = signal + noise
+
+ return output
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ return self.random_noise(x) if np.random.rand() <= self.p else x