From 4ede95aaf1d7d90eb2e7eb91058a30cabed12367 Mon Sep 17 00:00:00 2001 From: Akash Chaurasia Date: Fri, 9 Jun 2023 13:43:36 -0700 Subject: [PATCH 1/5] Switch to video dataloader --- .gitignore | 3 +- engine_pretrain.py | 16 ++- main_pretrain.py | 53 +++++---- utils/video_frame_dataset.py | 215 +++++++++++++++++++++++++++++++++++ 4 files changed, 257 insertions(+), 30 deletions(-) create mode 100644 utils/video_frame_dataset.py diff --git a/.gitignore b/.gitignore index 5fccba2..586a8a0 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ summary* run* *.pth *.png -*.sh \ No newline at end of file +*.sh +tags diff --git a/engine_pretrain.py b/engine_pretrain.py index d48fa41..675e140 100644 --- a/engine_pretrain.py +++ b/engine_pretrain.py @@ -19,11 +19,15 @@ import utils.lr_sched as lr_sched -def train_one_epoch(model: torch.nn.Module, - data_loader: Iterable, optimizer: torch.optim.Optimizer, - device: torch.device, epoch: int, loss_scaler, - log_writer=None, - args=None): +def train_one_epoch( + model: torch.nn.Module, + data_loader: Iterable, + optimizer: torch.optim.Optimizer, + device: torch.device, + epoch: int, loss_scaler, + log_writer=None, + args=None, +): model.train(True) metric_logger = misc.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) @@ -80,4 +84,4 @@ def train_one_epoch(model: torch.nn.Module, # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) - return {k: meter.global_avg for k, meter in metric_logger.meters.items()} \ No newline at end of file + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/main_pretrain.py b/main_pretrain.py index e0d38ab..ec16a90 100644 --- a/main_pretrain.py +++ b/main_pretrain.py @@ -25,7 +25,7 @@ import utils.misc as misc from utils.misc import NativeScalerWithGradNormCount as NativeScaler -from utils.datasets import VideoFrameDataset +from utils.video_frame_dataset import VideoFrameDataset import models_mae3d from engine_pretrain import train_one_epoch @@ -65,10 +65,10 @@ def get_args_parser(): help='epochs to warmup LR') # Dataset parameters - parser.add_argument('--data_path', default='/home/cyril/Datasets/MAE/', type=str, + parser.add_argument('--data_path', default='/scratch/users/akashc/steffner_echo_processed', type=str, help='dataset path') - parser.add_argument('--num_segments', default=4, type=int) - parser.add_argument('--frames_per_segment', default=4, type=int) + parser.add_argument('--num_frames', default=32, type=int) + parser.add_argument('--stride', default=8, type=int) parser.add_argument('--output_dir', default='./output_dir', help='path where to save, empty for no saving') parser.add_argument('--device', default='cuda', @@ -110,17 +110,14 @@ def main(args): cudnn.benchmark = True - # simple augmentation - transform_train = transforms.Compose([ - transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=Image.BICUBIC), # 3 is bicubic - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) - dataset_train = VideoFrameDataset(root_path=os.path.join(args.data_path, 'train'), annotationfile_path=os.path.join(args.data_path, 'ledger.csv'), - num_segments= args.num_segments, - frames_per_segment = args.frames_per_segment, - transform = transform_train, - test_mode = False) + dataset_train = VideoFrameDataset( + root_path=args.data_path, + split='train', + num_frames=args.num_frames, + stride=args.stride, + do_augmentation=True, + is_eval=False, + ) print(dataset_train) if True: # args.distributed: @@ -140,16 +137,19 @@ def main(args): pin_memory=args.pin_mem, drop_last=True, ) - + # define the model - model = models_mae3d.__dict__[args.model](num_frames=int(args.num_segments*args.frames_per_segment), norm_pix_loss=args.norm_pix_loss) + model = models_mae3d.__dict__[args.model]( + num_frames=args.num_frames, + norm_pix_loss=args.norm_pix_loss, + ) model.to(device) model_without_ddp = model print("Model = %s" % str(model_without_ddp)) eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() - + if args.lr is None: # only base_lr is specified args.lr = args.blr * eff_batch_size / 256 @@ -162,14 +162,19 @@ def main(args): if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module - + # following timm: set wd as 0 for bias and norm layers param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) print(optimizer) loss_scaler = NativeScaler() - misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) + misc.load_model( + args=args, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + ) print(f"Start training for {args.epochs} epochs") start_time = time.time() @@ -187,9 +192,11 @@ def main(args): args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch) - log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, - 'epoch': epoch,} - + log_stats = { + **{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch, + } + if misc.is_main_process(): wandb.log(log_stats) diff --git a/utils/video_frame_dataset.py b/utils/video_frame_dataset.py new file mode 100644 index 0000000..2db6e3b --- /dev/null +++ b/utils/video_frame_dataset.py @@ -0,0 +1,215 @@ +# Copyright (c) Akash Chaurasia +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- + +import itertools +import logging +from pathlib import Path +import random +from timeit import default_timer +from typing import Optional, Union + +import pandas as pd +import torch +from torch.utils.data import Dataset +from torchvision.io import VideoReader +from torchvision.transforms import Compose, InterpolationMode, Normalize, RandomResizedCrop + +from util.decoder.exceptions import InsufficientVideoLengthError +from util.decoder.transform import create_random_augment +from util.retry import DataloadFailure, retry_random_idx_on_err + + +def collate_batch(list_of_examples): + xs = torch.stack([l[0] for l in list_of_examples if not isinstance(l, DataloadFailure)]) + ys = torch.stack([l[1] for l in list_of_examples if not isinstance(l, DataloadFailure)]) + + return xs, ys + +class VideoFrameDataset(Dataset): + """ + Uniformly sample across videos, and within each video randomly sample a sequence of frames + according to the given number of frames and stride. + Args: + root_path: The root path in which video folders lie. + split: dataset split from the overall dataset (must be 'train', 'test', or 'val') + num_frames: number of frames from a video constituting one example + stride: number of frames between sampled frames. For example a stride of 4 starting at 0 + means we would take indices 0, 4, 8, etc. + transform: WIP + imagefile_template: The image filename template that video frame files + have inside of their video folders. + index_filename: name of index file in each video's folder. Something like 'index.pkl' + is_eval: If True, we drop augmentations (maybe?) + """ + + # Default parameters for preprocessing (augmentation, crop, etc.) + RAW_SIZE = (224, 224) + AUTOAUGMENT_TYPE = 'rand-m1-n1-mstd0.5-inc0' + INPUT_SIZE = 224 + CROP_SCALE = (0.55, 1) + MEAN = (0.45, 0.45, 0.45) + STD = (0.225, 0.225, 0.225) + + def __init__( + self, + root_path: str, + split: str, + num_frames: int = 8, + stride: int = 4, + do_augmentation=True, + is_eval: bool = False, + ): + super().__init__() + + self.root_path = Path(root_path) + assert split in {'train', 'test', 'val'}, f"Inalid split {split}!" + self.num_frames = num_frames + self.stride = stride + self.is_eval = is_eval or split != 'train' + + # Don't do augmentation for testing splits + self.do_augmentation = do_augmentation and not self.is_eval + self.rand_augment = create_random_augment( + input_size=self.RAW_SIZE, + auto_augment=self.AUTOAUGMENT_TYPE, + interpolation='bicubic', + ) + + xforms = [] + # Not a huge deal probably but for sanity don't to RRC on test examples + if not self.is_eval: + xforms.append( + RandomResizedCrop( + self.INPUT_SIZE, + scale=self.CROP_SCALE, + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ) + ) + + self.transforms = Compose(xforms + [Normalize(self.MEAN, self.STD)]) + + self.csv_path = self.root_path / f'{split}.csv' + self.video_metadata = pd.read_csv(self.csv_path) + logging.info(f"Instantiated {self.__class__.__name__} based on {str(self.csv_path)}") + logging.info(f"Number of examples: {len(self)}") + + def _get_start_index(self, num_frames) -> int: + return random.randint(0, num_frames - ((self.num_frames - 1) * self.stride)) + + @retry_random_idx_on_err(do_retry=True) + def __getitem__(self, idx): + """ + For video with id idx, loads self.NUM_SEGMENTS * self.FRAMES_PER_SEGMENT + frames from evenly chosen locations across the video. + Args: + idx: Video sample index. + Returns: + A tuple of (video, label). Label is either a single + integer or a list of integers in the case of multiple labels. + Video is either 1) a list of PIL images if no transform is used + 2) a batch of shape (NUM_IMAGES x CHANNELS x HEIGHT x WIDTH) in the range [0,1] + if the transform "ImglistToTensor" is used + 3) or anything else if a custom transform is used. + """ + + # Get the number of frames in sampled video + meta_row = self.video_metadata.iloc[idx] + video_path, afib_label = Path(meta_row['avi_path']), int(meta_row['postop_afib_label']) + + # Make inputs (C, T, H, W) for Conv3d + return self._get_frames(video_path).permute(1, 0, 2, 3), torch.tensor([afib_label]) + + def _get_frames(self, video_path: Union[Path, str], start_index: Optional[int] = None): + """ + Loads the frames of a video at the corresponding + indices. + Args: + video_path: Path to video for example + start_index: index to start sampling from + Returns: + A tuple of (video, label). Label is either a single + integer or a list of integers in the case of multiple labels. + Video is either 1) a list of PIL images if no transform is used + 2) a batch of shape (NUM_IMAGES x CHANNELS x HEIGHT x WIDTH) in the range [0,1] + if the transform "ImglistToTensor" is used + 3) or anything else if a custom transform is used. + """ + + reader = VideoReader(str(video_path)) + video_meta = reader.get_metadata()['video'] + + num_frames = int(video_meta['fps'][0] * video_meta['duration'][0]) + + if num_frames < self.stride * (self.num_frames - 1) + 1: + raise InsufficientVideoLengthError( + f"Video {str(video_path)} has {num_frames} frames, which is " + f"insufficient for parameters {self.num_frames=}, {self.stride=}" + ) + + start_index = self._get_start_index(num_frames) + start_s = start_index / video_meta['fps'][0] + reader.seek(start_s - 1e-5, keyframes_only=True) + + frames = [] + for frame_data in itertools.islice( + reader, + 0, + self.stride * (self.num_frames - 1) + 1, + self.stride, + ): + frames.append(frame_data['data']) + + num_missing_frames = self.num_frames - len(frames) + for _ in range(num_missing_frames): + frames.append(torch.zeros((3, 224, 224))) + + frames = torch.stack(frames).float() / 255.0 + frames = self.transforms(frames) # (T, C, H, W) + + return frames + + def __len__(self): + return len(self.video_metadata) + + def get_debug(self, index=None): + """ This is completely broken as of now """ + index = index if index is not None else random.randint(0, len(self)) + video = self.video_metadata[index] + + start_index = self._get_start_index(video) + + frames = [ + video.frames[idx].load() + for idx in range( + start_index, start_index + self.stride * (self.num_frames - 1) + 1, self.stride + ) + ] + + augmented_frames = self.rand_augment(frames) + + raw_frames = torch.stack([self.to_tensor_transform(f) for f in frames]).float() / 255.0 + augmented_frames = torch.stack([ + self.to_tensor_transform(f) for f in augmented_frames + ]).float() / 255.0 + + preprocessed_frames = self.transforms(raw_frames) # (T, C, H, W) + preprocessed_augmented_frames = self.transforms(augmented_frames) + + return raw_frames, preprocessed_frames, augmented_frames, preprocessed_augmented_frames + + +if __name__ == '__main__': + dataset = VideoFrameDataset('/scratch/users/akashc/test_dataset', split='train') + + start = default_timer() + for i in range(100): + frames, lbl = dataset[i] + print(frames.shape, lbl) + end = default_timer() + + print(f"Time for 100 iterations: {end - start:.3f} ({100 / (end - start):.3f} fps)") From d4992eb8b7bc299a5271715346057391ac21b5cf Mon Sep 17 00:00:00 2001 From: Akash Chaurasia Date: Fri, 9 Jun 2023 13:46:02 -0700 Subject: [PATCH 2/5] fix torch import --- utils/misc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/misc.py b/utils/misc.py index 19d15cf..c70f3aa 100644 --- a/utils/misc.py +++ b/utils/misc.py @@ -18,7 +18,7 @@ import torch import torch.distributed as dist -from torch._six import inf +from torch import inf class SmoothedValue(object): @@ -337,4 +337,4 @@ def all_reduce_mean(x): x_reduce /= world_size return x_reduce.item() else: - return x \ No newline at end of file + return x From 6af2a1418cb58612f8b22a761f5f5ff3a3fdce74 Mon Sep 17 00:00:00 2001 From: Akash Chaurasia Date: Fri, 9 Jun 2023 14:28:25 -0700 Subject: [PATCH 3/5] make training work --- constants.py | 3 + main_pretrain.py | 8 +- utils/misc.py | 2 + utils/rand_augment.py | 522 +++++++++++++++++++++++++++ utils/retry.py | 83 +++++ utils/transform.py | 673 +++++++++++++++++++++++++++++++++++ utils/video_frame_dataset.py | 9 +- 7 files changed, 1293 insertions(+), 7 deletions(-) create mode 100644 constants.py create mode 100644 utils/rand_augment.py create mode 100644 utils/retry.py create mode 100644 utils/transform.py diff --git a/constants.py b/constants.py new file mode 100644 index 0000000..6403e2c --- /dev/null +++ b/constants.py @@ -0,0 +1,3 @@ +import os + +DATASET_ERROR_VERBOSITY = int(os.getenv("DATASET_ERROR_VERBOSITY", "0")) diff --git a/main_pretrain.py b/main_pretrain.py index ec16a90..e4f65b5 100644 --- a/main_pretrain.py +++ b/main_pretrain.py @@ -31,7 +31,6 @@ from engine_pretrain import train_one_epoch import wandb -wandb.init(project="SuTr", entity="cyrilzakka") def get_args_parser(): parser = argparse.ArgumentParser('MAE pre-training', add_help=False) @@ -146,7 +145,7 @@ def main(args): model.to(device) model_without_ddp = model - print("Model = %s" % str(model_without_ddp)) + # print("Model = %s" % str(model_without_ddp)) eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() @@ -164,7 +163,7 @@ def main(args): model_without_ddp = model.module # following timm: set wd as 0 for bias and norm layers - param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) + param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay) optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) print(optimizer) loss_scaler = NativeScaler() @@ -176,6 +175,8 @@ def main(args): loss_scaler=loss_scaler, ) + wandb.init(project="SuTr", entity="akashc") + wandb.config.update(args) print(f"Start training for {args.epochs} epochs") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): @@ -208,7 +209,6 @@ def main(args): if __name__ == '__main__': args = get_args_parser() args = args.parse_args() - wandb.config.update(args) if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) diff --git a/utils/misc.py b/utils/misc.py index c70f3aa..27b6249 100644 --- a/utils/misc.py +++ b/utils/misc.py @@ -230,6 +230,8 @@ def init_distributed_mode(args): elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.rank % torch.cuda.device_count() + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29500' else: print('Not using distributed mode') setup_for_distributed(is_master=True) # hack diff --git a/utils/rand_augment.py b/utils/rand_augment.py new file mode 100644 index 0000000..be54c99 --- /dev/null +++ b/utils/rand_augment.py @@ -0,0 +1,522 @@ +# Copyright (c) Cyril Zakka. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- + + +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py +pulished under an Apache License 2.0. + +COMMENT FROM ORIGINAL: +AutoAugment, RandAugment, and AugMix for PyTorch +This code implements the searched ImageNet policies with various tweaks and +improvements and does not include any of the search code. AA and RA +Implementation adapted from: + https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py +AugMix adapted from: + https://github.com/google-research/augmix +Papers: + AutoAugment: Learning Augmentation Policies from Data + https://arxiv.org/abs/1805.09501 + Learning Data Augmentation Strategies for Object Detection + https://arxiv.org/abs/1906.11172 + RandAugment: Practical automated data augmentation... + https://arxiv.org/abs/1909.13719 + AugMix: A Simple Data Processing Method to Improve Robustness and + Uncertainty https://arxiv.org/abs/1912.02781 + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import math +import random +import re + +import PIL +from PIL import Image, ImageEnhance, ImageOps +import numpy as np + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10.0 + +_HPARAMS_DEFAULT = { + "translate_const": 250, + "img_mean": _FILL, +} + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop("resample", Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if "fillcolor" in kwargs and _PIL_VER < (5, 0): + kwargs.pop("fillcolor") + kwargs["resample"] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], + -rotn_center[1] - post_trans[1], + matrix, + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs["resample"]) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30.0 + level = _randomly_negate(level) + return (level,) + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + + +def _enhance_increasing_level_to_arg(level, _hparams): + # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend + # range [0.1, 1.9] + level = (level / _MAX_LEVEL) * 0.9 + level = 1.0 + _randomly_negate(level) + return (level,) + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams["translate_const"] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return (level,) + + +def _translate_rel_level_to_arg(level, hparams): + # default range [-0.45, 0.45] + translate_pct = hparams.get("translate_pct", 0.45) + level = (level / _MAX_LEVEL) * translate_pct + level = _randomly_negate(level) + return (level,) + + +def _posterize_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4),) + + +def _posterize_increasing_level_to_arg(level, hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image', + # intensity/severity of augmentation increases with level + return (4 - _posterize_level_to_arg(level, hparams)[0],) + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4) + 4,) + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 256),) + + +def _solarize_increasing_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation increases with level + return (256 - _solarize_level_to_arg(level, _hparams)[0],) + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return (int((level / _MAX_LEVEL) * 110),) + + +LEVEL_TO_ARG = { + "AutoContrast": None, + "Equalize": None, + "Invert": None, + "Rotate": _rotate_level_to_arg, + # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + "Posterize": _posterize_level_to_arg, + "PosterizeIncreasing": _posterize_increasing_level_to_arg, + "PosterizeOriginal": _posterize_original_level_to_arg, + "Solarize": _solarize_level_to_arg, + "SolarizeIncreasing": _solarize_increasing_level_to_arg, + "SolarizeAdd": _solarize_add_level_to_arg, + "Color": _enhance_level_to_arg, + "ColorIncreasing": _enhance_increasing_level_to_arg, + "Contrast": _enhance_level_to_arg, + "ContrastIncreasing": _enhance_increasing_level_to_arg, + "Brightness": _enhance_level_to_arg, + "BrightnessIncreasing": _enhance_increasing_level_to_arg, + "Sharpness": _enhance_level_to_arg, + "SharpnessIncreasing": _enhance_increasing_level_to_arg, + "ShearX": _shear_level_to_arg, + "ShearY": _shear_level_to_arg, + "TranslateX": _translate_abs_level_to_arg, + "TranslateY": _translate_abs_level_to_arg, + "TranslateXRel": _translate_rel_level_to_arg, + "TranslateYRel": _translate_rel_level_to_arg, +} + + +NAME_TO_OP = { + "AutoContrast": auto_contrast, + "Equalize": equalize, + "Invert": invert, + "Rotate": rotate, + "Posterize": posterize, + "PosterizeIncreasing": posterize, + "PosterizeOriginal": posterize, + "Solarize": solarize, + "SolarizeIncreasing": solarize, + "SolarizeAdd": solarize_add, + "Color": color, + "ColorIncreasing": color, + "Contrast": contrast, + "ContrastIncreasing": contrast, + "Brightness": brightness, + "BrightnessIncreasing": brightness, + "Sharpness": sharpness, + "SharpnessIncreasing": sharpness, + "ShearX": shear_x, + "ShearY": shear_y, + "TranslateX": translate_x_abs, + "TranslateY": translate_y_abs, + "TranslateXRel": translate_x_rel, + "TranslateYRel": translate_y_rel, +} + + +class AugmentOp: + """ + Apply for video. + """ + + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = { + "fillcolor": hparams["img_mean"] if "img_mean" in hparams else _FILL, + "resample": hparams["interpolation"] + if "interpolation" in hparams + else _RANDOM_INTERPOLATION, + } + + # If magnitude_std is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from a normal distribution + # with mean `magnitude` and std-dev of `magnitude_std`. + # NOTE This is my own hack, being tested, not in papers or reference impls. + self.magnitude_std = self.hparams.get("magnitude_std", 0) + + def __call__(self, img_list): + if self.prob < 1.0 and random.random() > self.prob: + return img_list + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + level_args = ( + self.level_fn(magnitude, self.hparams) if self.level_fn is not None else () + ) + + if isinstance(img_list, list): + return [self.aug_fn(img, *level_args, **self.kwargs) for img in img_list] + else: + return self.aug_fn(img_list, *level_args, **self.kwargs) + + +_RAND_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "Posterize", + "Solarize", + "SolarizeAdd", + "Color", + "Contrast", + "Brightness", + "Sharpness", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + + +_RAND_INCREASING_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "PosterizeIncreasing", + "SolarizeIncreasing", + "SolarizeAdd", + "ColorIncreasing", + "ContrastIncreasing", + "BrightnessIncreasing", + "SharpnessIncreasing", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + + +# These experimental weights are based loosely on the relative improvements mentioned in paper. +# They may not result in increased performance, but could likely be tuned to so. +_RAND_CHOICE_WEIGHTS_0 = { + "Rotate": 0.3, + "ShearX": 0.2, + "ShearY": 0.2, + "TranslateXRel": 0.1, + "TranslateYRel": 0.1, + "Color": 0.025, + "Sharpness": 0.025, + "AutoContrast": 0.025, + "Solarize": 0.005, + "SolarizeAdd": 0.005, + "Contrast": 0.005, + "Brightness": 0.005, + "Equalize": 0.005, + "Posterize": 0, + "Invert": 0, +} + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [ + AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) + for name in transforms + ] + + +class RandAugment: + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + # no replacement when using weighted choice + ops = np.random.choice( + self.ops, + self.num_layers, + replace=self.choice_weights is None, + p=self.choice_weights, + ) + for op in ops: + img = op(img) + return img + + +def rand_augment_transform(config_str, hparams): + """ + RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 + + Create a RandAugment transform + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) + 'mstd' - float std deviation of magnitude noise applied + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 + :param hparams: Other hparams (kwargs) for the RandAugmentation scheme + :return: A PyTorch compatible Transform + """ + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + transforms = _RAND_TRANSFORMS + config = config_str.split("-") + assert config[0] == "rand" + config = config[1:] + for c in config: + cs = re.split(r"(\d.*)", c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == "mstd": + # noise param injected via hparams for now + hparams.setdefault("magnitude_std", float(val)) + elif key == "inc": + if bool(val): + transforms = _RAND_INCREASING_TRANSFORMS + elif key == "m": + magnitude = int(val) + elif key == "n": + num_layers = int(val) + elif key == "w": + weight_idx = int(val) + else: + assert NotImplementedError + + ra_ops = rand_augment_ops( + magnitude=magnitude, hparams=hparams, transforms=transforms + ) + + choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) + return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) diff --git a/utils/retry.py b/utils/retry.py new file mode 100644 index 0000000..675fa3b --- /dev/null +++ b/utils/retry.py @@ -0,0 +1,83 @@ +import bdb +from dataclasses import dataclass +import functools +import logging +import time + +import numpy as np +import torch + +from constants import DATASET_ERROR_VERBOSITY + + +@dataclass +class DataloadFailure: + index: int + msg: str + err_type: str + + +def retry_random_idx_on_err(verbosity=DATASET_ERROR_VERBOSITY, do_retry=True): + + assert 0 <= verbosity <= 3 + + rng = None + subprocess_seed_set = False + + def decorator(getitem_fn): + MAX_RETRIES = 100 + BACKOFF_THRESHOLD = 10 + + @functools.wraps(getitem_fn) + def getitem_wrapper(self, index): + + nonlocal rng, do_retry, subprocess_seed_set + do_retry &= not getattr(self, 'is_eval', False) # Try to avoid retrying on eval splits + in_subprocess = torch.utils.data.get_worker_info() is not None + + if rng is None or (in_subprocess and not subprocess_seed_set): + rng = np.random.RandomState() + rng.set_state(np.random.get_state()) + subprocess_seed_set = in_subprocess + + for retry in range(MAX_RETRIES): + if retry > BACKOFF_THRESHOLD: + backoff_duration = (retry - BACKOFF_THRESHOLD) ** 2 + logging.info( + f"Dataload retry {retry}: backing off for {backoff_duration} seconds" + ) + time.sleep(backoff_duration) + + try: + return getitem_fn(self, index) + except bdb.BdbQuit: + raise + except Exception as e: + msg = str(e) + err_type = e.__class__.__qualname__ + + if verbosity == 0: + pass + elif verbosity == 1: + source = '.'.join([type(self).__module__, type(self).__qualname__]) + logging.warning( + f"[retry {retry}] Dataloading error: {err_type} - " + f"{msg} (from {source})", + exc_info=False + ) + else: + logging.warning( + f"[retry {retry}] {err_type}: {msg}, {do_retry=}", + exc_info=True, + ) + + if do_retry: + index = rng.randint(0, len(self)) + else: + return DataloadFailure(index, msg, err_type) + + raise Exception(f"Dataloading failed after {MAX_RETRIES} retries") + + return getitem_wrapper + + return decorator diff --git a/utils/transform.py b/utils/transform.py new file mode 100644 index 0000000..81d7a9a --- /dev/null +++ b/utils/transform.py @@ -0,0 +1,673 @@ +# Copyright (c) Cyril Zakka. +# All rights reserved. + + +import math +import random + +from PIL import Image +import numpy as np +import torch +from torchvision import transforms +import torchvision.transforms.functional as F + +from .rand_augment import rand_augment_transform + +_pil_interpolation_to_str = { + Image.NEAREST: "PIL.Image.NEAREST", + Image.BILINEAR: "PIL.Image.BILINEAR", + Image.BICUBIC: "PIL.Image.BICUBIC", + Image.LANCZOS: "PIL.Image.LANCZOS", + Image.HAMMING: "PIL.Image.HAMMING", + Image.BOX: "PIL.Image.BOX", +} + + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _pil_interp(method): + if method == "bicubic": + return Image.BICUBIC + elif method == "lanczos": + return Image.LANCZOS + elif method == "hamming": + return Image.HAMMING + else: + return Image.BILINEAR + + +def random_short_side_scale_jitter( + images, min_size, max_size, inverse_uniform_sampling=False +): + """ + Perform a spatial short scale jittering on the given images. + Args: + images (tensor): images to perform scale jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + min_size (int): the minimal size to scale the frames. + max_size (int): the maximal size to scale the frames. + inverse_uniform_sampling (bool): if True, sample uniformly in + [1 / max_scale, 1 / min_scale] and take a reciprocal to get the + scale. If False, take a uniform sample from [min_scale, max_scale]. + Returns: + (tensor): the scaled images with dimension of + `num frames` x `channel` x `new height` x `new width`. + """ + if inverse_uniform_sampling: + size = int(round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))) + else: + size = int(round(np.random.uniform(min_size, max_size))) + + height = images.shape[2] + width = images.shape[3] + if (width <= height and width == size) or (height <= width and height == size): + return images + new_width = size + new_height = size + if width < height: + new_height = int(math.floor((float(height) / width) * size)) + else: + new_width = int(math.floor((float(width) / height) * size)) + return torch.nn.functional.interpolate( + images, + size=(new_height, new_width), + mode="bilinear", + align_corners=False, + ) + + +def random_crop(images, size): + """ + Perform random spatial crop on the given images. + Args: + images (tensor): images to perform random crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): the size of height and width to crop on the image. + Returns: + cropped (tensor): cropped images with dimension of + `num frames` x `channel` x `size` x `size`. + """ + if images.shape[2] == size and images.shape[3] == size: + return images + height = images.shape[2] + width = images.shape[3] + y_offset = 0 + if height > size: + y_offset = int(np.random.randint(0, height - size)) + x_offset = 0 + if width > size: + x_offset = int(np.random.randint(0, width - size)) + cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] + return cropped + + +def horizontal_flip(prob, images): + """ + Perform horizontal flip on the given images. + Args: + prob (float): probility to flip the images. + images (tensor): images to perform horizontal flip, the dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): images with dimension of + `num frames` x `channel` x `height` x `width`. + """ + if np.random.uniform() < prob: + images = images.flip((-1)) + return images + + +def uniform_crop(images, size, spatial_idx, scale_size=None): + """ + Perform uniform spatial sampling on the images. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped + + +def blend(images1, images2, alpha): + """ + Blend two images with a given weight alpha. + Args: + images1 (tensor): the first images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + images2 (tensor): the second images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + alpha (float): the blending weight. + Returns: + (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + return images1 * alpha + images2 * (1 - alpha) + + +def grayscale(images): + """ + Get the grayscale for the input images. The channels of images should be + in order BGR. + Args: + images (tensor): the input images for getting grayscale. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + img_gray (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + # R -> 0.299, G -> 0.587, B -> 0.114. + img_gray = torch.tensor(images) + gray_channel = 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] + img_gray[:, 0] = gray_channel + img_gray[:, 1] = gray_channel + img_gray[:, 2] = gray_channel + return img_gray + + +def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): + """ + Perfrom a color jittering on the input images. The channels of images + should be in order BGR. + Args: + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + img_brightness (float): jitter ratio for brightness. + img_contrast (float): jitter ratio for contrast. + img_saturation (float): jitter ratio for saturation. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + + jitter = [] + if img_brightness != 0: + jitter.append("brightness") + if img_contrast != 0: + jitter.append("contrast") + if img_saturation != 0: + jitter.append("saturation") + + if len(jitter) > 0: + order = np.random.permutation(np.arange(len(jitter))) + for idx in range(0, len(jitter)): + if jitter[order[idx]] == "brightness": + images = brightness_jitter(img_brightness, images) + elif jitter[order[idx]] == "contrast": + images = contrast_jitter(img_contrast, images) + elif jitter[order[idx]] == "saturation": + images = saturation_jitter(img_saturation, images) + return images + + +def brightness_jitter(var, images): + """ + Perfrom brightness jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for brightness. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_bright = torch.zeros(images.shape) + images = blend(images, img_bright, alpha) + return images + + +def contrast_jitter(var, images): + """ + Perfrom contrast jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for contrast. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_gray = grayscale(images) + img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True) + images = blend(images, img_gray, alpha) + return images + + +def saturation_jitter(var, images): + """ + Perfrom saturation jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for saturation. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + img_gray = grayscale(images) + images = blend(images, img_gray, alpha) + + return images + + +def lighting_jitter(images, alphastd, eigval, eigvec): + """ + Perform AlexNet-style PCA jitter on the given images. + Args: + images (tensor): images to perform lighting jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + alphastd (float): jitter ratio for PCA jitter. + eigval (list): eigenvalues for PCA jitter. + eigvec (list[list]): eigenvectors for PCA jitter. + Returns: + out_images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if alphastd == 0: + return images + # generate alpha1, alpha2, alpha3. + alpha = np.random.normal(0, alphastd, size=(1, 3)) + eig_vec = np.array(eigvec) + eig_val = np.reshape(eigval, (1, 3)) + rgb = np.sum( + eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), + axis=1, + ) + out_images = torch.zeros_like(images) + if len(images.shape) == 3: + # C H W + channel_dim = 0 + elif len(images.shape) == 4: + # T C H W + channel_dim = 1 + else: + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") + + for idx in range(images.shape[channel_dim]): + # C H W + if len(images.shape) == 3: + out_images[idx] = images[idx] + rgb[2 - idx] + # T C H W + elif len(images.shape) == 4: + out_images[:, idx] = images[:, idx] + rgb[2 - idx] + else: + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") + + return out_images + + +def color_normalization(images, mean, stddev): + """ + Perform color nomration on the given images. + Args: + images (tensor): images to perform color normalization. Dimension is + `num frames` x `channel` x `height` x `width`. + mean (list): mean values for normalization. + stddev (list): standard deviations for normalization. + + Returns: + out_images (tensor): the noramlized images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if len(images.shape) == 3: + assert len(mean) == images.shape[0], "channel mean not computed properly" + assert len(stddev) == images.shape[0], "channel stddev not computed properly" + elif len(images.shape) == 4: + assert len(mean) == images.shape[1], "channel mean not computed properly" + assert len(stddev) == images.shape[1], "channel stddev not computed properly" + else: + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") + + out_images = torch.zeros_like(images) + for idx in range(len(mean)): + # C H W + if len(images.shape) == 3: + out_images[idx] = (images[idx] - mean[idx]) / stddev[idx] + elif len(images.shape) == 4: + out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] + else: + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") + return out_images + + +def _get_param_spatial_crop( + scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False +): + """ + Given scale, ratio, height and width, return sampled coordinates of the videos. + """ + for _ in range(num_repeat): + area = height * width + target_area = random.uniform(*scale) * area + if log_scale: + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + else: + aspect_ratio = random.uniform(*ratio) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if np.random.uniform() < 0.5 and switch_hw: + w, h = h, w + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(ratio): + w = width + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + +def random_resized_crop( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + Crop the given images to random size and aspect ratio. A crop of random + size (default: of 0.08 to 1.0) of the original size and a random aspect + ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This + crop is finally resized to given size. This is popularly used to train the + Inception networks. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + cropped = images[:, :, i : i + h, j : j + w] + return torch.nn.functional.interpolate( + cropped, + size=(target_height, target_width), + mode="bilinear", + align_corners=False, + ) + + +def random_resized_crop_with_shift( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + This is similar to random_resized_crop. However, it samples two different + boxes (for cropping) for the first and last frame. It then linearly + interpolates the two boxes for other frames. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + t = images.shape[1] + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width) + i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()] + j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()] + h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()] + w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] + out = torch.zeros((3, t, target_height, target_width)) + for ind in range(t): + out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate( + images[ + :, + ind : ind + 1, + i_s[ind] : i_s[ind] + h_s[ind], + j_s[ind] : j_s[ind] + w_s[ind], + ], + size=(target_height, target_width), + mode="bilinear", + align_corners=False, + ) + return out + + +def create_random_augment( + input_size, + auto_augment=None, + interpolation="bilinear", +): + """ + Get video randaug transform. + + Args: + input_size: The size of the input video in tuple. + auto_augment: Parameters for randaug. An example: + "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number + of operations to apply). + interpolation: Interpolation method. + """ + if isinstance(input_size, tuple): + img_size = input_size[-2:] + else: + img_size = input_size + + if auto_augment: + assert isinstance(auto_augment, str) + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size + aa_params = {"translate_const": int(img_size_min * 0.45)} + if interpolation and interpolation != "random": + aa_params["interpolation"] = _pil_interp(interpolation) + if auto_augment.startswith("rand"): + return transforms.Compose([rand_augment_transform(auto_augment, aa_params)]) + raise NotImplementedError + + +def random_sized_crop_img( + im, + size, + jitter_scale=(0.08, 1.0), + jitter_aspect=(3.0 / 4.0, 4.0 / 3.0), + max_iter=10, +): + """ + Performs Inception-style cropping (used for training). + """ + assert len(im.shape) == 3, "Currently only support image for random_sized_crop" + h, w = im.shape[1:3] + i, j, h, w = _get_param_spatial_crop( + scale=jitter_scale, + ratio=jitter_aspect, + height=h, + width=w, + num_repeat=max_iter, + log_scale=False, + switch_hw=True, + ) + cropped = im[:, i : i + h, j : j + w] + return torch.nn.functional.interpolate( + cropped.unsqueeze(0), + size=(size, size), + mode="bilinear", + align_corners=False, + ).squeeze(0) + + +# The following code are modified based on timm lib, we will replace the following +# contents with dependency from PyTorchVideo. +# https://github.com/facebookresearch/pytorchvideo +class RandomResizedCropAndInterpolation: + """Crop the given PIL Image to random size and aspect ratio with random interpolation. + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation="bilinear", + ): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + print("range should be of kind (min, max)") + + if interpolation == "random": + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = _pil_interp(interpolation) + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for _ in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + return F.resized_crop(img, i, j, h, w, self.size, interpolation) + + def __repr__(self): + if isinstance(self.interpolation, (tuple, list)): + interpolate_str = " ".join( + [_pil_interpolation_to_str[x] for x in self.interpolation] + ) + else: + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + "(size={0}".format(self.size) + format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale)) + format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio)) + format_string += ", interpolation={0})".format(interpolate_str) + return format_string diff --git a/utils/video_frame_dataset.py b/utils/video_frame_dataset.py index 2db6e3b..3846e44 100644 --- a/utils/video_frame_dataset.py +++ b/utils/video_frame_dataset.py @@ -18,9 +18,12 @@ from torchvision.io import VideoReader from torchvision.transforms import Compose, InterpolationMode, Normalize, RandomResizedCrop -from util.decoder.exceptions import InsufficientVideoLengthError -from util.decoder.transform import create_random_augment -from util.retry import DataloadFailure, retry_random_idx_on_err +from utils.transform import create_random_augment +from utils.retry import DataloadFailure, retry_random_idx_on_err + + +class InsufficientVideoLengthError(Exception): + pass def collate_batch(list_of_examples): From 2b6148f79fa5383c46afe87a08dd9a93c3ebe1f8 Mon Sep 17 00:00:00 2001 From: Akash Chaurasia Date: Fri, 9 Jun 2023 16:11:57 -0700 Subject: [PATCH 4/5] changes to make training work & better --- engine_pretrain.py | 30 +++++++++++++++++++++++++++++- main_pretrain.py | 14 ++++++++------ utils/video_frame_dataset.py | 29 ++++++++++++++--------------- 3 files changed, 51 insertions(+), 22 deletions(-) diff --git a/engine_pretrain.py b/engine_pretrain.py index 675e140..0c1c7c5 100644 --- a/engine_pretrain.py +++ b/engine_pretrain.py @@ -13,7 +13,9 @@ import sys from typing import Iterable +import wandb import torch +from timeit import default_timer import utils.misc as misc import utils.lr_sched as lr_sched @@ -41,7 +43,9 @@ def train_one_epoch( if log_writer is not None: print('log_dir: {}'.format(log_writer.log_dir)) - for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + epoch_start = data_start = default_timer() + for data_iter_step, samples in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + data_time = default_timer() - data_start # we use a per iteration (instead of per epoch) lr scheduler if data_iter_step % accum_iter == 0: @@ -49,6 +53,8 @@ def train_one_epoch( samples = samples.to(device, non_blocking=True) + fwd_bwd_start = default_timer() + with torch.cuda.amp.autocast(): loss, _, _ = model(samples, mask_ratio=args.mask_ratio) @@ -61,9 +67,28 @@ def train_one_epoch( loss /= accum_iter loss_scaler(loss, optimizer, parameters=model.parameters(), update_grad=(data_iter_step + 1) % accum_iter == 0) + fwd_bwd_time = default_timer() - fwd_bwd_start + if (data_iter_step + 1) % accum_iter == 0: optimizer.zero_grad() + num_frames = samples.size(0) * samples.size(2) + fps = num_frames / (fwd_bwd_time + data_time) + print( + f"step {data_iter_step} | loss {loss_value:.4f} | fwdbwd_t {fwd_bwd_time:.4f} | " + f"data_t {data_time:.4f} | fps {fps:.4f}" + ) + if data_iter_step % 10 == 0 and misc.is_main_process(): + wandb.log( + { + 'loss': loss_value, + 'fwdbwd_t': fwd_bwd_time, + 'data_t': data_time, + 'fps': fps, + }, + step=data_iter_step, + ) + torch.cuda.synchronize() metric_logger.update(loss=loss_value) @@ -80,7 +105,10 @@ def train_one_epoch( log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) log_writer.add_scalar('lr', lr, epoch_1000x) + data_start = default_timer() + epoch_time = default_timer() - epoch_start + pritn(f"Epoch time: {epoch_time}") # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) diff --git a/main_pretrain.py b/main_pretrain.py index e4f65b5..53cab39 100644 --- a/main_pretrain.py +++ b/main_pretrain.py @@ -64,8 +64,12 @@ def get_args_parser(): help='epochs to warmup LR') # Dataset parameters - parser.add_argument('--data_path', default='/scratch/users/akashc/steffner_echo_processed', type=str, - help='dataset path') + parser.add_argument( + '--data_path', + default='/scratch/groups/willhies/echo/echoai/combined.csv', + type=str, + help='dataset path', + ) parser.add_argument('--num_frames', default=32, type=int) parser.add_argument('--stride', default=8, type=int) parser.add_argument('--output_dir', default='./output_dir', @@ -110,12 +114,10 @@ def main(args): cudnn.benchmark = True dataset_train = VideoFrameDataset( - root_path=args.data_path, - split='train', + ledger_path=args.data_path, num_frames=args.num_frames, stride=args.stride, do_augmentation=True, - is_eval=False, ) print(dataset_train) @@ -188,7 +190,7 @@ def main(args): log_writer=None, args=args ) - if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs): + if args.output_dir: misc.save_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch) diff --git a/utils/video_frame_dataset.py b/utils/video_frame_dataset.py index 3846e44..b5e2235 100644 --- a/utils/video_frame_dataset.py +++ b/utils/video_frame_dataset.py @@ -59,8 +59,7 @@ class VideoFrameDataset(Dataset): def __init__( self, - root_path: str, - split: str, + ledger_path: str, num_frames: int = 8, stride: int = 4, do_augmentation=True, @@ -68,11 +67,10 @@ def __init__( ): super().__init__() - self.root_path = Path(root_path) - assert split in {'train', 'test', 'val'}, f"Inalid split {split}!" + self.ledger_path = Path(ledger_path) self.num_frames = num_frames self.stride = stride - self.is_eval = is_eval or split != 'train' + self.is_eval = is_eval # Don't do augmentation for testing splits self.do_augmentation = do_augmentation and not self.is_eval @@ -96,12 +94,13 @@ def __init__( self.transforms = Compose(xforms + [Normalize(self.MEAN, self.STD)]) - self.csv_path = self.root_path / f'{split}.csv' - self.video_metadata = pd.read_csv(self.csv_path) - logging.info(f"Instantiated {self.__class__.__name__} based on {str(self.csv_path)}") + self.video_metadata = pd.read_csv(self.ledger_path) + logging.info(f"Instantiated {self.__class__.__name__} based on {str(self.ledger_path)}") logging.info(f"Number of examples: {len(self)}") def _get_start_index(self, num_frames) -> int: + if num_frames <= (self.num_frames - 1) * self.stride + 1: + return 0 return random.randint(0, num_frames - ((self.num_frames - 1) * self.stride)) @retry_random_idx_on_err(do_retry=True) @@ -122,10 +121,10 @@ def __getitem__(self, idx): # Get the number of frames in sampled video meta_row = self.video_metadata.iloc[idx] - video_path, afib_label = Path(meta_row['avi_path']), int(meta_row['postop_afib_label']) + video_path = Path(meta_row['avi_path']) # Make inputs (C, T, H, W) for Conv3d - return self._get_frames(video_path).permute(1, 0, 2, 3), torch.tensor([afib_label]) + return self._get_frames(video_path).permute(1, 0, 2, 3) def _get_frames(self, video_path: Union[Path, str], start_index: Optional[int] = None): """ @@ -148,11 +147,11 @@ def _get_frames(self, video_path: Union[Path, str], start_index: Optional[int] = num_frames = int(video_meta['fps'][0] * video_meta['duration'][0]) - if num_frames < self.stride * (self.num_frames - 1) + 1: - raise InsufficientVideoLengthError( - f"Video {str(video_path)} has {num_frames} frames, which is " - f"insufficient for parameters {self.num_frames=}, {self.stride=}" - ) + # if num_frames < self.stride * (self.num_frames - 1) + 1: + # raise InsufficientVideoLengthError( + # f"Video {str(video_path)} has {num_frames} frames, which is " + # f"insufficient for parameters {self.num_frames=}, {self.stride=}" + # ) start_index = self._get_start_index(num_frames) start_s = start_index / video_meta['fps'][0] From 2765030496dff689eaf2b4c8a316b7f4543dcfc1 Mon Sep 17 00:00:00 2001 From: Akash Chaurasia Date: Fri, 9 Jun 2023 16:29:00 -0700 Subject: [PATCH 5/5] fix LR bug --- engine_pretrain.py | 10 +++++----- main_pretrain.py | 2 +- utils/lr_sched.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/engine_pretrain.py b/engine_pretrain.py index 0c1c7c5..a9730e2 100644 --- a/engine_pretrain.py +++ b/engine_pretrain.py @@ -74,11 +74,11 @@ def train_one_epoch( num_frames = samples.size(0) * samples.size(2) fps = num_frames / (fwd_bwd_time + data_time) - print( - f"step {data_iter_step} | loss {loss_value:.4f} | fwdbwd_t {fwd_bwd_time:.4f} | " - f"data_t {data_time:.4f} | fps {fps:.4f}" - ) if data_iter_step % 10 == 0 and misc.is_main_process(): + print( + f"step {data_iter_step} | loss {loss_value:.4f} | fwdbwd_t {fwd_bwd_time:.4f} | " + f"data_t {data_time:.4f} | fps {fps:.4f}" + ) wandb.log( { 'loss': loss_value, @@ -108,7 +108,7 @@ def train_one_epoch( data_start = default_timer() epoch_time = default_timer() - epoch_start - pritn(f"Epoch time: {epoch_time}") + print(f"Epoch time: {epoch_time}") # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) diff --git a/main_pretrain.py b/main_pretrain.py index 53cab39..1e0bd54 100644 --- a/main_pretrain.py +++ b/main_pretrain.py @@ -60,7 +60,7 @@ def get_args_parser(): help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') - parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', + parser.add_argument('--warmup_epochs', type=int, default=0, metavar='N', help='epochs to warmup LR') # Dataset parameters diff --git a/utils/lr_sched.py b/utils/lr_sched.py index 1544063..d0a5ce3 100644 --- a/utils/lr_sched.py +++ b/utils/lr_sched.py @@ -13,7 +13,7 @@ def adjust_learning_rate(optimizer, epoch, args): """Decay the learning rate with half-cycle cosine after warmup""" if epoch < args.warmup_epochs: - lr = args.lr * epoch / args.warmup_epochs + lr = args.lr * epoch / args.warmup_epochs else: lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))