diff --git a/data/base_dataset.py b/data/base_dataset.py index bec36ff30..57d1c6e1e 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -50,6 +50,10 @@ def __init__(self, opt, phase): self.use_domain_B = not "self_supervised" in self.opt.data_dataset_mode self.root = opt.dataroot + + if not self.root.endswith("/"): + self.root += "/" + self.sv_dir = os.path.join(opt.checkpoints_dir, opt.name) self.warning_mode = self.opt.warning_mode self.set_dataset_dirs_and_dims() @@ -342,6 +346,37 @@ def get_transform( return transforms.Compose(transform_list) +def get_transform_ref( + opt, + params=None, + grayscale=False, + method=InterpolationMode.BICUBIC, + convert=True, + crop=True, +): + + transform_list = [] + + if grayscale: + transform_list.append(transforms.Grayscale(1)) + + osize = [opt.data_crop_size, opt.data_crop_size] + transform_list.append(transforms.Resize(osize, interpolation=method)) + + if convert: + transform_list += [transforms.ToTensor()] + """if grayscale: + transform_list += [transforms.Normalize((0.5,), (0.5,))] + else: + transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]""" + + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ) + return transforms.Compose(transform_list) + + def __make_power_2(img, base, method=InterpolationMode.BICUBIC): ow, oh = img.size h = int(round(oh / base) * base) diff --git a/data/image_folder.py b/data/image_folder.py index 178120c92..972afbbc9 100644 --- a/data/image_folder.py +++ b/data/image_folder.py @@ -89,18 +89,24 @@ def make_labeled_path_dataset(dir, paths, max_dataset_size=float("inf")): for line in paths_list: line_split = line.split(" ") - if len(line_split) == 2: + if ( + len(line_split) == 1 and len(line_split[0]) > 0 + ): # we allow B not having a label images.append(line_split[0]) - labels.append(line_split[1]) - if len(line_split) == 3: + + elif len(line_split) == 2: images.append(line_split[0]) - labels.append(line_split[1] + " " + line_split[2]) + labels.append(line_split[1]) - elif ( - len(line_split) == 1 and len(line_split[0]) > 0 - ): # we allow B not having a label + elif len(line_split) > 2: images.append(line_split[0]) + label_line = line_split[1] + for i in range(2, len(line_split)): + label_line += " " + line_split[i] + + labels.append(label_line) + return ( images[: min(max_dataset_size, len(images))], labels[: min(max_dataset_size, len(images))], @@ -123,6 +129,49 @@ def make_dataset_path(dir, paths, max_dataset_size=float("inf")): return images[: min(max_dataset_size, len(images))] +def make_ref_path(dir, paths, max_dataset_size=float("inf")): + ref = {} + assert os.path.isdir(dir), "%s is not a valid directory" % dir + + with open(dir + paths, "r") as f: + paths_list = f.read().split("\n") + + for line in paths_list: + line_split = line.split(" ") + + if len(line_split) == 2: + ref[line_split[0]] = line_split[1] + + return ref + + +def make_ref_path_list(dir, paths, max_dataset_size=float("inf")): + ref = {} + assert os.path.isdir(dir), "%s is not a valid directory" % dir + + with open(dir + paths, "r") as f: + paths_list = f.read().split("\n") + + root = "/".join(dir.split("/")[:-1]) + + for line in paths_list: + line_split = line.split(" ") + + if len(line_split) == 2: + path_to_ref = line_split[1] + + path = os.path.join(root, path_to_ref) + + with open(path, "r") as f: + paths_ref_list = f.read().split("\n") + + paths_ref_list.remove("") + + ref[line_split[0]] = paths_ref_list + + return ref + + def default_loader(path): return Image.open(path).convert("RGB") diff --git a/data/online_creation.py b/data/online_creation.py index bb4375477..9a451b0ce 100644 --- a/data/online_creation.py +++ b/data/online_creation.py @@ -402,7 +402,7 @@ def crop_image( int(ref_bbox[3] * (output_dim + margin) / crop_size), ] - return img, mask, ref_bbox + return img, mask, ref_bbox, idx_bbox_ref def fill_mask_with_random(img, mask, cls): diff --git a/data/self_supervised_labeled_mask_online_ref_dataset.py b/data/self_supervised_labeled_mask_online_ref_dataset.py new file mode 100644 index 000000000..03afb8a07 --- /dev/null +++ b/data/self_supervised_labeled_mask_online_ref_dataset.py @@ -0,0 +1,69 @@ +import os.path +from data.unaligned_labeled_mask_online_ref_dataset import ( + UnalignedLabeledMaskOnlineRefDataset, +) +from data.online_creation import fill_mask_with_random, fill_mask_with_color +from PIL import Image +import numpy as np +import torch +import warnings + + +class SelfSupervisedLabeledMaskOnlineRefDataset(UnalignedLabeledMaskOnlineRefDataset): + """ + This dataset class can create paired datasets with mask labels from only one domain. + """ + + def __init__(self, opt, phase): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + super().__init__(opt, phase) + + def get_img( + self, + A_img_path, + A_label_mask_path, + A_label_cls, + B_img_path=None, + B_label_mask_path=None, + B_label_cls=None, + index=None, + ): + result = super().get_img( + A_img_path, + A_label_mask_path, + A_label_cls, + B_img_path, + B_label_mask_path, + B_label_cls, + index, + clamp_semantics=False, + ) + + try: + + if self.opt.data_online_creation_rand_mask_A: + A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1) + elif self.opt.data_online_creation_color_mask_A: + A_img = fill_mask_with_color(result["A"], result["A_label_mask"], {}) + else: + raise Exception( + "self supervised dataset: no self supervised method specified" + ) + + result.update( + { + "A": A_img, + "B": result["A"], + "B_img_paths": result["A_img_paths"], + "B_label_mask": result["A_label_mask"].clone(), + } + ) + except Exception as e: + print(e, "self supervised data loading") + return None + + return result diff --git a/data/self_supervised_labeled_mask_ref_dataset.py b/data/self_supervised_labeled_mask_ref_dataset.py new file mode 100644 index 000000000..0c5a52d60 --- /dev/null +++ b/data/self_supervised_labeled_mask_ref_dataset.py @@ -0,0 +1,67 @@ +import os.path +from data.unaligned_labeled_mask_ref_dataset import UnalignedLabeledMaskRefDataset +from data.online_creation import fill_mask_with_random, fill_mask_with_color +from PIL import Image +import numpy as np +import torch +import warnings + + +class SelfSupervisedLabeledMaskRefDataset(UnalignedLabeledMaskRefDataset): + """ + This dataset class can create paired datasets with mask labels from only one domain. + """ + + def __init__(self, opt, phase): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + super().__init__(opt, phase) + + def get_img( + self, + A_img_path, + A_label_mask_path, + A_label_cls, + B_img_path=None, + B_label_mask_path=None, + B_label_cls=None, + index=None, + ): + result = super().get_img( + A_img_path, + A_label_mask_path, + A_label_cls, + B_img_path, + B_label_mask_path, + B_label_cls, + index, + clamp_semantics=False, + ) + + try: + + if self.opt.data_online_creation_rand_mask_A: + A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1) + elif self.opt.data_online_creation_color_mask_A: + A_img = fill_mask_with_color(result["A"], result["A_label_mask"], {}) + else: + raise Exception( + "self supervised dataset: no self supervised method specified" + ) + + result.update( + { + "A": A_img, + "B": result["A"], + "B_img_paths": result["A_img_paths"], + "B_label_mask": result["A_label_mask"].clone(), + } + ) + except Exception as e: + print(e, "self supervised data loading") + return None + + return result diff --git a/data/unaligned_labeled_mask_online_dataset.py b/data/unaligned_labeled_mask_online_dataset.py index 77510ff46..99aa6e4b0 100644 --- a/data/unaligned_labeled_mask_online_dataset.py +++ b/data/unaligned_labeled_mask_online_dataset.py @@ -174,7 +174,7 @@ def get_img( else: mask_delta_A = self.opt.data_online_creation_mask_delta_A_ratio - A_img, A_label_mask, A_ref_bbox = crop_image( + A_img, A_label_mask, A_ref_bbox, A_ref_bbox_id = crop_image( A_img_path, A_label_mask_path, mask_delta=mask_delta_A, @@ -190,6 +190,7 @@ def get_img( inverted_mask=self.opt.data_inverted_mask, single_bbox=self.opt.data_online_single_bbox, ) + self.cat_A_ref_bbox = torch.tensor(A_ref_bbox[0]) A_ref_bbox = A_ref_bbox[1:] @@ -213,6 +214,7 @@ def get_img( "A_img_paths": A_img_path, "A_label_mask": A_label_mask, "A_ref_bbox": A_ref_bbox, + "A_ref_bbox_id": A_ref_bbox_id, } # Domain B @@ -227,7 +229,7 @@ def get_img( mask_delta_B = self.opt.data_online_creation_mask_delta_B_ratio if B_label_mask_path is not None: - B_img, B_label_mask, B_ref_bbox = crop_image( + B_img, B_label_mask, B_ref_bbox, B_ref_bbox_id = crop_image( B_img_path, B_label_mask_path, mask_delta=mask_delta_B, @@ -283,6 +285,7 @@ def get_img( { "B_label_mask": B_label_mask, "B_ref_bbox": B_ref_bbox, + "B_ref_bbox_id": B_ref_bbox_id, } ) diff --git a/data/unaligned_labeled_mask_online_ref_dataset.py b/data/unaligned_labeled_mask_online_ref_dataset.py new file mode 100644 index 000000000..4d240af3e --- /dev/null +++ b/data/unaligned_labeled_mask_online_ref_dataset.py @@ -0,0 +1,66 @@ +import os +from PIL import Image + +from data.base_dataset import get_transform_ref +from data.unaligned_labeled_mask_online_dataset import UnalignedLabeledMaskOnlineDataset +from data.image_folder import make_ref_path_list + + +class UnalignedLabeledMaskOnlineRefDataset(UnalignedLabeledMaskOnlineDataset): + def __init__(self, opt, phase): + super().__init__(opt, phase) + + self.A_img_ref = make_ref_path_list(self.dir_A, "/conditions.txt") + + self.transform_ref = get_transform_ref(opt) + + def get_img( + self, + A_img_path, + A_label_mask_path, + A_label_cls, + B_img_path=None, + B_label_mask_path=None, + B_label_cls=None, + index=None, + clamp_semantics=True, + ): + result = super().get_img( + A_img_path, + A_label_mask_path, + A_label_cls, + B_img_path, + B_label_mask_path, + B_label_cls, + index, + clamp_semantics, + ) + + img_path = result["A_img_paths"] + + if self.opt.data_relative_paths: + img_path = img_path.replace(self.root, "") + + A_ref_bbox_id = result["A_ref_bbox_id"] + + ref_A_path = self.A_img_ref[img_path][A_ref_bbox_id] + + if self.opt.data_relative_paths: + ref_A_path = os.path.join(self.root, ref_A_path) + + try: + ref_A = Image.open(ref_A_path).convert("RGB") + + except Exception as e: + print( + "failure with reading A domain image ref ", + ref_A_path, + ) + print(e) + return None + + ref_A = self.transform_ref(ref_A) + + result.update({"ref_A": ref_A}) + + return result diff --git a/data/unaligned_labeled_mask_ref_dataset.py b/data/unaligned_labeled_mask_ref_dataset.py new file mode 100644 index 000000000..6917edb4e --- /dev/null +++ b/data/unaligned_labeled_mask_ref_dataset.py @@ -0,0 +1,67 @@ +import os + +from torchvision.transforms.functional import resize +from PIL import Image + +from data.base_dataset import get_transform_ref +from data.unaligned_labeled_mask_dataset import UnalignedLabeledMaskDataset +from data.image_folder import make_ref_path + + +class UnalignedLabeledMaskRefDataset(UnalignedLabeledMaskDataset): + def __init__(self, opt, phase): + super().__init__(opt, phase) + + self.A_img_ref = make_ref_path(self.dir_A, "/conditions.txt") + + self.transform_ref = get_transform_ref(opt) + + def get_img( + self, + A_img_path, + A_label_mask_path, + A_label_cls, + B_img_path=None, + B_label_mask_path=None, + B_label_cls=None, + index=None, + clamp_semantics=True, + ): + + result = super().get_img( + A_img_path, + A_label_mask_path, + A_label_cls, + B_img_path, + B_label_mask_path, + B_label_cls, + index, + clamp_semantics, + ) + + img_path = result["A_img_paths"] + + if self.opt.data_relative_paths: + img_path = img_path.replace(self.root, "") + + ref_A_path = self.A_img_ref[img_path] + + if self.opt.data_relative_paths: + ref_A_path = os.path.join(self.root, ref_A_path) + + try: + ref_A = Image.open(ref_A_path).convert("RGB") + + except Exception as e: + print( + "failure with reading A domain image ref ", + ref_A_path, + ) + print(e) + return None + + ref_A = self.transform_ref(ref_A) + + result.update({"ref_A": ref_A}) + + return result diff --git a/docs/source/dataloaders.rst b/docs/source/dataloaders.rst index 3cb326ec8..7739e828f 100644 --- a/docs/source/dataloaders.rst +++ b/docs/source/dataloaders.rst @@ -44,15 +44,20 @@ List of dataloaders - unaligned_labeled_mask: unaligned with masks - unaligned_labeled_mask_online: unaligned with masks with online croping around masks +- unaligned_labeled_mask_cls: unaligned with masks and classes - unaligned_labeled_mask_cls_online: unaligned with masks and classes with online croping around masks +- unaligned_labeled_mask_ref: unaligned with masks and reference image +- unaligned_labeled_mask_ref_online: unaligned with masks and reference image with online croping around masks -- self_supervised_labeled_cls: with class labels - self_supervised_labeled_mask: with mask labels - self_supervised_labeled_mask_online: with mask labels and online croping around masks +- self_supervised_labeled_mask_cls: with class and mask labels - self_supervised_labeled_mask_cls_online: with class and mask labels, and online croping around masks +- self_supervised_labeled_mask_ref: with a reference image and mask labels +- self_supervised_labeled_mask_online_ref: with a reference image and mask labels, and online croping around masks - temporal: basic temporal (sequential) loader - self_supervised_temporal: self-supervised version of the temporal diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 6df573e4d..d4b1ce014 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -36,6 +36,7 @@ respectively. Subdirectories ``testA`` and ``testB`` can be added for test data. Example: horse to zebra from two sets of images + Dataset: https://joligen.com/datasets/horse2zebra.zip .. code-block:: bash @@ -59,6 +60,7 @@ Dataset with class label has ``trainA`` and ``trainB`` directories. In images for this class. Example: font number conversion + Dataset: https://joligen.com/datasets/mnist2USPS.zip .. code-block:: bash @@ -223,3 +225,101 @@ https://joligen.com/datasets/daytime2dawn_dusk_lite.zip in this order: ``source image path``, ``image class``, ``image mask``, where ``image class`` in this dataset represents the weather class. + +***************************************************** + Datasets with mask and reference image conditioning +***************************************************** + +Example: inpaint a garment from a catalog image onto a person + +Dataset: +https://www.joligen.com/datasets/viton_mask_ref_mini.zip + +.. code:: + + viton_mask_ref_mini + viton_mask_ref_mini/trainA + viton_mask_ref_mini/trainA/imgs + viton_mask_ref_mini/trainA/imgs/00000_00.jpg # source image, e.g. person with original garment + ... + viton_mask_ref_mini/trainA/mask + viton_mask_ref_mini/trainA/mask/00000_00.png # mask for inpainting zone, e.g. original garment to remove + ... + viton_mask_ref_mini/trainA/ref + viton_mask_ref_mini/trainA/ref/00000_00.jpg # reference image to inpaint, e.g. catalog image + ... + viton_mask_ref_mini/trainA/paths.txt # list of associated source / mask image + viton_mask_ref_mini/trainA/conditions.txt # list of associated source / reference image + +``paths.txt`` format: + +.. code:: + + cat trainA/paths.txt + trainA/imgs/00000_00.jpg trainA/mask/00000_00.png + +in this order: ``source image path``, ``image mask``. + +``conditions.txt`` format: + +.. code:: + + cat trainA/conditions.txt + trainA/imgs/00000_00.jpg trainA/ref/00000_00.jpg + +in this order: ``source image path`` (same as ``paths.txt``), ``reference image``. + +***************************************************** + Datasets with bbox and reference image conditioning +***************************************************** + +Example: inpaint garments from a catalog image onto a person + +Dataset: +https://www.joligen.com/datasets/viton_bbox_ref_mini.zip + +.. code:: + + viton_bbox_ref_mini + viton_bbox_ref_mini/trainA + viton_bbox_ref_mini/trainA/imgs + viton_bbox_ref_mini/trainA/imgs/00000_00.jpg # source image, e.g. person with original garments + ... + viton_bbox_ref_mini/trainA/bbox + viton_bbox_ref_mini/trainA/bbox/00000_00.txt # list of bboxes for inpainting zone, e.g. original garments to remove + ... + viton_bbox_ref_mini/trainA/cond + viton_bbox_ref_mini/trainA/cond/00000_00.txt # list of reference images to inpaint for each bbox + ... + viton_bbox_ref_mini/trainA/ref + viton_bbox_ref_mini/trainA/ref/00000_00.jpg # reference image to inpaint, e.g. catalog image + ... + viton_bbox_ref_mini/trainA/paths.txt # list of associated source / bboxes + viton_bbox_ref_mini/trainA/conditions.txt # list of associated source / reference images + +``paths.txt`` format: + +.. code:: + + cat trainA/paths.txt + trainA/imgs/00000_00.jpg trainA/bbox/00000_00.txt + +in this order: ``source image path``, ``bboxes file``. + +Bounding box format is the :ref:`same as above`. + +``conditions.txt`` format: + +.. code:: + + cat trainA/conditions.txt + trainA/imgs/00000_00.jpg trainA/cond/00000_00.txt + +in this order: ``source image path`` (same as ``paths.txt``), ``file containing list of reference images``. + +List of reference images file format: + +.. code:: + + cat trainA/cond/00000_00.txt + trainA/ref/00000_00.jpg # path to reference image (same number of lines and order as corresponding bbox file) diff --git a/models/diffusion_networks.py b/models/diffusion_networks.py index bea5d8b4b..23cfc7872 100644 --- a/models/diffusion_networks.py +++ b/models/diffusion_networks.py @@ -36,6 +36,7 @@ def define_G( alg_palette_sampling_method, alg_palette_conditioning, alg_palette_cond_embed_dim, + alg_palette_ref_embed_net, model_prior_321_backwardcompatibility, dropout=0, channel_mults=(1, 2, 4, 8), @@ -154,6 +155,7 @@ def define_G( denoise_fn = PaletteDenoiseFn( model=model, cond_embed_dim=cond_embed_dim, + ref_embed_net=alg_palette_ref_embed_net, conditioning=alg_palette_conditioning, nclasses=f_s_semantic_nclasses, ) diff --git a/models/modules/diffusion_generator.py b/models/modules/diffusion_generator.py index f90aa5306..e7b86d765 100644 --- a/models/modules/diffusion_generator.py +++ b/models/modules/diffusion_generator.py @@ -64,7 +64,7 @@ def __init__( else: self.cond_embed_dim = cond_embed_dim - if "class" in self.denoise_fn.conditioning: + if any(cond in self.denoise_fn.conditioning for cond in ["class", "ref"]): self.cond_embed_gammas = self.cond_embed_dim // 2 else: self.cond_embed_gammas = self.cond_embed_dim @@ -86,6 +86,7 @@ def restoration( mask=None, sample_num=8, cls=None, + ref=None, guidance_scale=0.0, ddim_num_steps=10, ddim_eta=0.5, @@ -99,6 +100,7 @@ def restoration( sample_num=sample_num, cls=cls, guidance_scale=guidance_scale, + ref=ref, ) elif self.sampling_method == "ddim": return self.restoration_ddim( @@ -117,12 +119,13 @@ def restoration( def restoration_ddpm( self, y_cond, - y_t=None, - y_0=None, - mask=None, - sample_num=8, - cls=None, - guidance_scale=0.0, + y_t, + y_0, + mask, + sample_num, + cls, + guidance_scale, + ref, ): phase = "test" @@ -149,6 +152,7 @@ def restoration_ddpm( phase=phase, cls=cls, mask=mask, + ref=ref, guidance_scale=guidance_scale, ) @@ -180,6 +184,7 @@ def p_mean_variance( clip_denoised: bool, cls, mask, + ref, y_cond=None, guidance_scale=0.0, ): @@ -197,7 +202,11 @@ def p_mean_variance( y_t, t=t, noise=self.denoise_fn( - input, torch.zeros_like(embed_noise_level), cls=None, mask=None + input, + torch.zeros_like(embed_noise_level), + cls=None, + mask=None, + ref=ref, ), phase=phase, ) @@ -206,7 +215,9 @@ def p_mean_variance( self.denoise_fn.model, y_t, t=t, - noise=self.denoise_fn(input, embed_noise_level, cls=cls, mask=mask), + noise=self.denoise_fn( + input, embed_noise_level, cls=cls, mask=mask, ref=ref + ), phase=phase, ) @@ -232,6 +243,7 @@ def p_sample( phase, cls, mask, + ref, clip_denoised=True, y_cond=None, guidance_scale=0.0, @@ -245,6 +257,7 @@ def p_sample( phase=phase, cls=cls, mask=mask, + ref=ref, guidance_scale=guidance_scale, ) @@ -409,7 +422,7 @@ def ddim_p_mean_variance( return model_mean, posterior_log_variance - def forward(self, y_0, y_cond, mask, noise, cls, dropout_prob=0.0): + def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0): b, *_ = y_0.shape t = torch.randint( @@ -438,7 +451,9 @@ def forward(self, y_0, y_cond, mask, noise, cls, dropout_prob=0.0): input = torch.cat([y_cond, y_noisy], dim=1) - noise_hat = self.denoise_fn(input, embed_sample_gammas, cls=cls, mask=mask) + noise_hat = self.denoise_fn( + input, embed_sample_gammas, cls=cls, mask=mask, ref=ref + ) return noise, noise_hat diff --git a/models/modules/image_bind/__init__.py b/models/modules/image_bind/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/models/modules/image_bind/helpers.py b/models/modules/image_bind/helpers.py new file mode 100644 index 000000000..d1ae971cf --- /dev/null +++ b/models/modules/image_bind/helpers.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import einops +import numpy as np +import torch +import torch.nn as nn + + +class Normalize(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.nn.functional.normalize(x, dim=self.dim, p=2) + + +class LearnableLogitScaling(nn.Module): + def __init__( + self, + logit_scale_init: float = 1 / 0.07, + learnable: bool = True, + max_logit_scale: float = 100, + ) -> None: + super().__init__() + self.max_logit_scale = max_logit_scale + self.logit_scale_init = logit_scale_init + self.learnable = learnable + log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) + if learnable: + self.log_logit_scale = nn.Parameter(log_logit_scale) + else: + self.register_buffer("log_logit_scale", log_logit_scale) + + def forward(self, x): + return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x + + def extra_repr(self): + st = ( + f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}," + f" max_logit_scale={self.max_logit_scale}" + ) + return st + + +class EinOpsRearrange(nn.Module): + def __init__(self, rearrange_expr: str, **kwargs) -> None: + super().__init__() + self.rearrange_expr = rearrange_expr + self.kwargs = kwargs + + def forward(self, x): + assert isinstance(x, torch.Tensor) + return einops.rearrange(x, self.rearrange_expr, **self.kwargs) + + +class VerboseNNModule(nn.Module): + """ + Wrapper around nn.Module that prints registered buffers and parameter names. + """ + + @staticmethod + def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: + st = ( + "(" + + name + + "): " + + "tensor(" + + str(tuple(tensor[1].shape)) + + ", requires_grad=" + + str(tensor[1].requires_grad) + + ")\n" + ) + return st + + def extra_repr(self) -> str: + named_modules = set() + for p in self.named_modules(): + named_modules.update([p[0]]) + named_modules = list(named_modules) + + string_repr = "" + for p in self.named_parameters(): + name = p[0].split(".")[0] + if name not in named_modules: + string_repr += self.get_readable_tensor_repr(name, p) + + for p in self.named_buffers(): + name = p[0].split(".")[0] + string_repr += self.get_readable_tensor_repr(name, p) + + return string_repr + + +def cast_if_src_dtype( + tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype +): + updated = False + if tensor.dtype == src_dtype: + tensor = tensor.to(dtype=tgt_dtype) + updated = True + return tensor, updated + + +class QuickGELU(nn.Module): + # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166 + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class SelectElement(nn.Module): + def __init__(self, index) -> None: + super().__init__() + self.index = index + + def forward(self, x): + assert x.ndim >= 3 + return x[:, self.index, ...] + + +class SelectEOSAndProject(nn.Module): + """ + Text Pooling used in OpenCLIP + """ + + def __init__(self, proj: nn.Module) -> None: + super().__init__() + self.proj = proj + + def forward(self, x, seq_len): + assert x.ndim == 3 + # x is of shape B x L x D + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), seq_len] + x = self.proj(x) + return x diff --git a/models/modules/image_bind/imagebind_model.py b/models/modules/image_bind/imagebind_model.py new file mode 100644 index 000000000..f0d3fd146 --- /dev/null +++ b/models/modules/image_bind/imagebind_model.py @@ -0,0 +1,520 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import os +from functools import partial +from types import SimpleNamespace + +import torch +import torch.nn as nn + +from .helpers import ( + EinOpsRearrange, + LearnableLogitScaling, + Normalize, + SelectElement, + SelectEOSAndProject, +) +from .multimodal_preprocessors import ( + AudioPreprocessor, + IMUPreprocessor, + PadIm2Video, + PatchEmbedGeneric, + RGBDTPreprocessor, + SpatioTemporalPosEmbeddingHelper, + TextPreprocessor, + ThermalPreprocessor, +) +from .transformer import MultiheadAttention, SimpleTransformer + +ModalityType = SimpleNamespace( + VISION="vision", + TEXT="text", + AUDIO="audio", + THERMAL="thermal", + DEPTH="depth", + IMU="imu", +) + + +class ImageBindModel(nn.Module): + def __init__( + self, + video_frames=2, + kernel_size=(2, 14, 14), + audio_kernel_size=16, + audio_stride=10, + out_embed_dim=768, + vision_embed_dim=1024, + vision_num_blocks=24, + vision_num_heads=16, + audio_embed_dim=768, + audio_num_blocks=12, + audio_num_heads=12, + audio_num_mel_bins=128, + audio_target_len=204, + audio_drop_path=0.1, + text_embed_dim=768, + text_num_blocks=12, + text_num_heads=12, + depth_embed_dim=384, + depth_kernel_size=16, + depth_num_blocks=12, + depth_num_heads=8, + depth_drop_path=0.0, + thermal_embed_dim=768, + thermal_kernel_size=16, + thermal_num_blocks=12, + thermal_num_heads=12, + thermal_drop_path=0.0, + imu_embed_dim=512, + imu_kernel_size=8, + imu_num_blocks=6, + imu_num_heads=8, + imu_drop_path=0.7, + ): + super().__init__() + + self.modality_preprocessors = self._create_modality_preprocessors( + video_frames, + vision_embed_dim, + kernel_size, + text_embed_dim, + audio_embed_dim, + audio_kernel_size, + audio_stride, + audio_num_mel_bins, + audio_target_len, + depth_embed_dim, + depth_kernel_size, + thermal_embed_dim, + thermal_kernel_size, + imu_embed_dim, + ) + + self.modality_trunks = self._create_modality_trunks( + vision_embed_dim, + vision_num_blocks, + vision_num_heads, + text_embed_dim, + text_num_blocks, + text_num_heads, + audio_embed_dim, + audio_num_blocks, + audio_num_heads, + audio_drop_path, + depth_embed_dim, + depth_num_blocks, + depth_num_heads, + depth_drop_path, + thermal_embed_dim, + thermal_num_blocks, + thermal_num_heads, + thermal_drop_path, + imu_embed_dim, + imu_num_blocks, + imu_num_heads, + imu_drop_path, + ) + + self.modality_heads = self._create_modality_heads( + out_embed_dim, + vision_embed_dim, + text_embed_dim, + audio_embed_dim, + depth_embed_dim, + thermal_embed_dim, + imu_embed_dim, + ) + + self.modality_postprocessors = self._create_modality_postprocessors( + out_embed_dim + ) + + def _create_modality_preprocessors( + self, + video_frames=2, + vision_embed_dim=1024, + kernel_size=(2, 14, 14), + text_embed_dim=768, + audio_embed_dim=768, + audio_kernel_size=16, + audio_stride=10, + audio_num_mel_bins=128, + audio_target_len=204, + depth_embed_dim=768, + depth_kernel_size=16, + thermal_embed_dim=768, + thermal_kernel_size=16, + imu_embed_dim=512, + ): + rgbt_stem = PatchEmbedGeneric( + proj_stem=[ + PadIm2Video(pad_type="repeat", ntimes=2), + nn.Conv3d( + in_channels=3, + kernel_size=kernel_size, + out_channels=vision_embed_dim, + stride=kernel_size, + bias=False, + ), + ] + ) + rgbt_preprocessor = RGBDTPreprocessor( + img_size=[3, video_frames, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=rgbt_stem, + depth_stem=None, + ) + + text_preprocessor = TextPreprocessor( + context_length=77, + vocab_size=49408, + embed_dim=text_embed_dim, + causal_masking=True, + ) + + audio_stem = PatchEmbedGeneric( + proj_stem=[ + nn.Conv2d( + in_channels=1, + kernel_size=audio_kernel_size, + stride=audio_stride, + out_channels=audio_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), + ) + audio_preprocessor = AudioPreprocessor( + img_size=[1, audio_num_mel_bins, audio_target_len], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + audio_stem=audio_stem, + ) + + depth_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=depth_kernel_size, + in_channels=1, + out_channels=depth_embed_dim, + stride=depth_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), + ) + + depth_preprocessor = RGBDTPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=None, + depth_stem=depth_stem, + ) + + thermal_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=thermal_kernel_size, + in_channels=1, + out_channels=thermal_embed_dim, + stride=thermal_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), + ) + thermal_preprocessor = ThermalPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + thermal_stem=thermal_stem, + ) + + imu_stem = PatchEmbedGeneric( + [ + nn.Linear( + in_features=48, + out_features=imu_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), + ) + + imu_preprocessor = IMUPreprocessor( + img_size=[6, 2000], + num_cls_tokens=1, + kernel_size=8, + embed_dim=imu_embed_dim, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + imu_stem=imu_stem, + ) + + modality_preprocessors = { + ModalityType.VISION: rgbt_preprocessor, + ModalityType.TEXT: text_preprocessor, + ModalityType.AUDIO: audio_preprocessor, + ModalityType.DEPTH: depth_preprocessor, + ModalityType.THERMAL: thermal_preprocessor, + ModalityType.IMU: imu_preprocessor, + } + + return nn.ModuleDict(modality_preprocessors) + + def _create_modality_trunks( + self, + vision_embed_dim=1024, + vision_num_blocks=24, + vision_num_heads=16, + text_embed_dim=768, + text_num_blocks=12, + text_num_heads=12, + audio_embed_dim=768, + audio_num_blocks=12, + audio_num_heads=12, + audio_drop_path=0.0, + depth_embed_dim=768, + depth_num_blocks=12, + depth_num_heads=12, + depth_drop_path=0.0, + thermal_embed_dim=768, + thermal_num_blocks=12, + thermal_num_heads=12, + thermal_drop_path=0.0, + imu_embed_dim=512, + imu_num_blocks=6, + imu_num_heads=8, + imu_drop_path=0.7, + ): + def instantiate_trunk( + embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path + ): + return SimpleTransformer( + embed_dim=embed_dim, + num_blocks=num_blocks, + ffn_dropout_rate=0.0, + drop_path_rate=drop_path, + attn_target=partial( + MultiheadAttention, + embed_dim=embed_dim, + num_heads=num_heads, + bias=True, + add_bias_kv=add_bias_kv, + ), + pre_transformer_layer=nn.Sequential( + nn.LayerNorm(embed_dim, eps=1e-6) + if pre_transformer_ln + else nn.Identity(), + EinOpsRearrange("b l d -> l b d"), + ), + post_transformer_layer=EinOpsRearrange("l b d -> b l d"), + ) + + modality_trunks = {} + modality_trunks[ModalityType.VISION] = instantiate_trunk( + vision_embed_dim, + vision_num_blocks, + vision_num_heads, + pre_transformer_ln=True, + add_bias_kv=False, + drop_path=0.0, + ) + modality_trunks[ModalityType.TEXT] = instantiate_trunk( + text_embed_dim, + text_num_blocks, + text_num_heads, + pre_transformer_ln=False, + add_bias_kv=False, + drop_path=0.0, + ) + modality_trunks[ModalityType.AUDIO] = instantiate_trunk( + audio_embed_dim, + audio_num_blocks, + audio_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=audio_drop_path, + ) + modality_trunks[ModalityType.DEPTH] = instantiate_trunk( + depth_embed_dim, + depth_num_blocks, + depth_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=depth_drop_path, + ) + modality_trunks[ModalityType.THERMAL] = instantiate_trunk( + thermal_embed_dim, + thermal_num_blocks, + thermal_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=thermal_drop_path, + ) + modality_trunks[ModalityType.IMU] = instantiate_trunk( + imu_embed_dim, + imu_num_blocks, + imu_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=imu_drop_path, + ) + + return nn.ModuleDict(modality_trunks) + + def _create_modality_heads( + self, + out_embed_dim, + vision_embed_dim, + text_embed_dim, + audio_embed_dim, + depth_embed_dim, + thermal_embed_dim, + imu_embed_dim, + ): + modality_heads = {} + + modality_heads[ModalityType.VISION] = nn.Sequential( + nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(vision_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.TEXT] = SelectEOSAndProject( + proj=nn.Sequential( + nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), + nn.Linear(text_embed_dim, out_embed_dim, bias=False), + ) + ) + + modality_heads[ModalityType.AUDIO] = nn.Sequential( + nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(audio_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.DEPTH] = nn.Sequential( + nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(depth_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.THERMAL] = nn.Sequential( + nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.IMU] = nn.Sequential( + nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Dropout(p=0.5), + nn.Linear(imu_embed_dim, out_embed_dim, bias=False), + ) + + return nn.ModuleDict(modality_heads) + + def _create_modality_postprocessors(self, out_embed_dim): + modality_postprocessors = {} + + modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) + modality_postprocessors[ModalityType.TEXT] = nn.Sequential( + Normalize(dim=-1), LearnableLogitScaling(learnable=True) + ) + modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=20.0, learnable=False), + ) + modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=10.0, learnable=False), + ) + modality_postprocessors[ModalityType.IMU] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + + return nn.ModuleDict(modality_postprocessors) + + def forward(self, inputs): + outputs = {} + for modality_key, modality_value in inputs.items(): + reduce_list = ( + modality_value.ndim >= 5 + ) # Audio and Video inputs consist of multiple clips + if reduce_list: + B, S = modality_value.shape[:2] + modality_value = modality_value.reshape( + B * S, *modality_value.shape[2:] + ) + + if modality_value is not None: + modality_value = self.modality_preprocessors[modality_key]( + **{modality_key: modality_value} + ) + trunk_inputs = modality_value["trunk"] + head_inputs = modality_value["head"] + modality_value = self.modality_trunks[modality_key](**trunk_inputs) + modality_value = self.modality_heads[modality_key]( + modality_value, **head_inputs + ) + modality_value = self.modality_postprocessors[modality_key]( + modality_value + ) + + if reduce_list: + modality_value = modality_value.reshape(B, S, -1) + modality_value = modality_value.mean(dim=1) + + outputs[modality_key] = modality_value + + return outputs + + +def imagebind_huge(pretrained=False): + model = ImageBindModel( + vision_embed_dim=1280, + vision_num_blocks=32, + vision_num_heads=16, + text_embed_dim=1024, + text_num_blocks=24, + text_num_heads=16, + out_embed_dim=1024, + audio_drop_path=0.1, + imu_drop_path=0.7, + ) + + if pretrained: + + path = ".models/configs/bind/pretrain" + + file_name = "imagebind_huge.pth" + + if not os.path.exists(os.path.join(path, file_name)): + print( + "Downloading imagebind weights to %s ..." + % os.path.join(path, file_name) + ) + os.makedirs(path, exist_ok=True) + torch.hub.download_url_to_file( + "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", + os.path.join(path, file_name), + progress=True, + ) + + model.load_state_dict(torch.load(os.path.join(path, file_name))) + + return model diff --git a/models/modules/image_bind/multimodal_preprocessors.py b/models/modules/image_bind/multimodal_preprocessors.py new file mode 100644 index 000000000..c89777a4a --- /dev/null +++ b/models/modules/image_bind/multimodal_preprocessors.py @@ -0,0 +1,686 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import gzip +import html +import io +import math +from functools import lru_cache +from typing import Callable, List, Optional, Tuple + +import ftfy +import numpy as np +import regex as re +import torch +import torch.nn as nn + +# from iopath.common.file_io import g_pathmgr +from timm.models.layers import trunc_normal_ + +from .helpers import VerboseNNModule, cast_if_src_dtype + + +def get_sinusoid_encoding_table(n_position, d_hid): + """Sinusoid position encoding table""" + + # TODO: make it with torch instead of numpy + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos_i) for pos_i in range(n_position)] + ) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +def interpolate_pos_encoding_2d(target_spatial_size, pos_embed): + N = pos_embed.shape[1] + if N == target_spatial_size: + return pos_embed + dim = pos_embed.shape[-1] + # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32 + pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32) + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), + scale_factor=math.sqrt(target_spatial_size / N), + mode="bicubic", + ) + if updated: + pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16) + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return pos_embed + + +def interpolate_pos_encoding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape=None, + first_patch_idx=1, +): + assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none" + N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists + if npatch_per_img == N: + return pos_embed + + assert ( + patches_layout[-1] == patches_layout[-2] + ), "Interpolation of pos embed not supported for non-square layouts" + + class_emb = pos_embed[:, :first_patch_idx] + pos_embed = pos_embed[:, first_patch_idx:] + + if input_shape is None or patches_layout[0] == 1: + # simple 2D pos embedding, no temporal component + pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed) + elif patches_layout[0] > 1: + # pos embed has a temporal component + assert len(input_shape) == 4, "temporal interpolation not supported" + # we only support 2D interpolation in this case + num_frames = patches_layout[0] + num_spatial_tokens = patches_layout[1] * patches_layout[2] + pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1) + # interpolate embedding for zeroth frame + pos_embed = interpolate_pos_encoding_2d( + npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0) + ) + else: + raise ValueError("This type of interpolation isn't implemented") + + return torch.cat((class_emb, pos_embed), dim=1) + + +def _get_pos_embedding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape, + first_patch_idx=1, +): + pos_embed = interpolate_pos_encoding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape=input_shape, + first_patch_idx=first_patch_idx, + ) + return pos_embed + + +class PatchEmbedGeneric(nn.Module): + """ + PatchEmbed from Hydra + """ + + def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None): + super().__init__() + + if len(proj_stem) > 1: + self.proj = nn.Sequential(*proj_stem) + else: + # Special case to be able to load pre-trained models that were + # trained with a standard stem + self.proj = proj_stem[0] + self.norm_layer = norm_layer + + def get_patch_layout(self, img_size): + with torch.no_grad(): + dummy_img = torch.zeros( + [ + 1, + ] + + img_size + ) + dummy_out = self.proj(dummy_img) + embed_dim = dummy_out.shape[1] + patches_layout = tuple(dummy_out.shape[2:]) + num_patches = np.prod(patches_layout) + return patches_layout, num_patches, embed_dim + + def forward(self, x): + x = self.proj(x) + # B C (T) H W -> B (T)HW C + x = x.flatten(2).transpose(1, 2) + if self.norm_layer is not None: + x = self.norm_layer(x) + return x + + +class SpatioTemporalPosEmbeddingHelper(VerboseNNModule): + def __init__( + self, + patches_layout: List, + num_patches: int, + num_cls_tokens: int, + embed_dim: int, + learnable: bool, + ) -> None: + super().__init__() + self.num_cls_tokens = num_cls_tokens + self.patches_layout = patches_layout + self.num_patches = num_patches + self.num_tokens = num_cls_tokens + num_patches + self.learnable = learnable + if self.learnable: + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) + trunc_normal_(self.pos_embed, std=0.02) + else: + self.register_buffer( + "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim) + ) + + def get_pos_embedding(self, vision_input, all_vision_tokens): + input_shape = vision_input.shape + pos_embed = _get_pos_embedding( + all_vision_tokens.size(1) - self.num_cls_tokens, + pos_embed=self.pos_embed, + patches_layout=self.patches_layout, + input_shape=input_shape, + first_patch_idx=self.num_cls_tokens, + ) + return pos_embed + + +class RGBDTPreprocessor(VerboseNNModule): + def __init__( + self, + rgbt_stem: PatchEmbedGeneric, + depth_stem: Optional[PatchEmbedGeneric], + img_size: Tuple = (3, 224, 224), + num_cls_tokens: int = 1, + pos_embed_fn: Optional[Callable] = None, + use_type_embed: bool = False, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + stem = rgbt_stem if rgbt_stem is not None else depth_stem + ( + self.patches_layout, + self.num_patches, + self.embed_dim, + ) = stem.get_patch_layout(img_size) + self.rgbt_stem = rgbt_stem + self.depth_stem = depth_stem + self.use_pos_embed = pos_embed_fn is not None + self.use_type_embed = use_type_embed + self.num_cls_tokens = num_cls_tokens + + if self.use_pos_embed: + self.pos_embedding_helper = pos_embed_fn( + patches_layout=self.patches_layout, + num_cls_tokens=num_cls_tokens, + num_patches=self.num_patches, + embed_dim=self.embed_dim, + ) + if self.num_cls_tokens > 0: + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, self.embed_dim) + ) + if self.use_type_embed: + self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style): + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + if self.use_pos_embed: + nn.init.normal_(self.pos_embedding_helper.pos_embed) + self.pos_embedding_helper.pos_embed *= scale + + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + if self.use_type_embed: + nn.init.normal_(self.type_embed) + + def tokenize_input_and_cls_pos(self, input, stem, mask): + # tokens is of shape B x L x D + tokens = stem(input) + assert tokens.ndim == 3 + assert tokens.shape[2] == self.embed_dim + B = tokens.shape[0] + if self.num_cls_tokens > 0: + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + tokens = torch.cat((class_tokens, tokens), dim=1) + if self.use_pos_embed: + pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens) + tokens = tokens + pos_embed + if self.use_type_embed: + tokens = tokens + self.type_embed.expand(B, -1, -1) + return tokens + + def forward(self, vision=None, depth=None, patch_mask=None): + if patch_mask is not None: + raise NotImplementedError() + + if vision is not None: + vision_tokens = self.tokenize_input_and_cls_pos( + vision, self.rgbt_stem, patch_mask + ) + + if depth is not None: + depth_tokens = self.tokenize_input_and_cls_pos( + depth, self.depth_stem, patch_mask + ) + + # aggregate tokens + if vision is not None and depth is not None: + final_tokens = vision_tokens + depth_tokens + else: + final_tokens = vision_tokens if vision is not None else depth_tokens + return_dict = { + "trunk": { + "tokens": final_tokens, + }, + "head": {}, + } + return return_dict + + +class AudioPreprocessor(RGBDTPreprocessor): + def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None: + super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs) + + def forward(self, audio=None): + return super().forward(vision=audio) + + +class ThermalPreprocessor(RGBDTPreprocessor): + def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None: + super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs) + + def forward(self, thermal=None): + return super().forward(vision=thermal) + + +def build_causal_attention_mask(context_length): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(context_length, context_length, requires_grad=False) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + +class TextPreprocessor(VerboseNNModule): + def __init__( + self, + vocab_size: int, + context_length: int, + embed_dim: int, + causal_masking: bool, + supply_seq_len_to_head: bool = True, + num_cls_tokens: int = 0, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + self.vocab_size = vocab_size + self.context_length = context_length + self.token_embedding = nn.Embedding(vocab_size, embed_dim) + self.pos_embed = nn.Parameter( + torch.empty(1, self.context_length + num_cls_tokens, embed_dim) + ) + self.causal_masking = causal_masking + if self.causal_masking: + mask = build_causal_attention_mask(self.context_length) + # register the mask as a buffer so it can be moved to the right device + self.register_buffer("mask", mask) + + self.supply_seq_len_to_head = supply_seq_len_to_head + self.num_cls_tokens = num_cls_tokens + self.embed_dim = embed_dim + if num_cls_tokens > 0: + assert self.causal_masking is False, "Masking + CLS token isn't implemented" + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, embed_dim) + ) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style="openclip"): + # OpenCLIP style initialization + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.pos_embed, std=0.01) + + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + def forward(self, text): + # text tokens are of shape B x L x D + text_tokens = self.token_embedding(text) + # concat CLS tokens if any + if self.num_cls_tokens > 0: + B = text_tokens.shape[0] + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + text_tokens = torch.cat((class_tokens, text_tokens), dim=1) + text_tokens = text_tokens + self.pos_embed + return_dict = { + "trunk": { + "tokens": text_tokens, + }, + "head": {}, + } + # Compute sequence length after adding CLS tokens + if self.supply_seq_len_to_head: + text_lengths = text.argmax(dim=-1) + return_dict["head"] = { + "seq_len": text_lengths, + } + if self.causal_masking: + return_dict["trunk"].update({"attn_mask": self.mask}) + return return_dict + + +class Im2Video(nn.Module): + """Convert an image into a trivial video.""" + + def __init__(self, time_dim=2): + super().__init__() + self.time_dim = time_dim + + def forward(self, x): + if x.ndim == 4: + # B, C, H, W -> B, C, T, H, W + return x.unsqueeze(self.time_dim) + elif x.ndim == 5: + return x + else: + raise ValueError(f"Dimension incorrect {x.shape}") + + +class PadIm2Video(Im2Video): + def __init__(self, ntimes, pad_type, time_dim=2): + super().__init__(time_dim=time_dim) + assert ntimes > 0 + assert pad_type in ["zero", "repeat"] + self.ntimes = ntimes + self.pad_type = pad_type + + def forward(self, x): + x = super().forward(x) + if x.shape[self.time_dim] == 1: + if self.pad_type == "repeat": + new_shape = [1] * len(x.shape) + new_shape[self.time_dim] = self.ntimes + x = x.repeat(new_shape) + elif self.pad_type == "zero": + padarg = [0, 0] * len(x.shape) + padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim] + x = nn.functional.pad(x, padarg) + return x + + +# Modified from github.com/openai/CLIP +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str, context_length=77): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + with open(bpe_path, "rb") as fh: # g_pathmgr. + bpe_bytes = io.BytesIO(fh.read()) + merges: List[str] = gzip.open(bpe_bytes).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges: List[Tuple[str, ...]] = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + self.context_length = context_length + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + def __call__(self, texts, context_length=None): + if not context_length: + context_length = self.context_length + + if isinstance(texts, str): + texts = [texts] + + sot_token = self.encoder["<|startoftext|>"] + eot_token = self.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + tokens = tokens[:context_length] + result[i, : len(tokens)] = torch.tensor(tokens) + + if len(result) == 1: + return result[0] + return result + + +class IMUPreprocessor(VerboseNNModule): + def __init__( + self, + kernel_size: int, + imu_stem: PatchEmbedGeneric, + embed_dim: int, + img_size: Tuple = (6, 2000), + num_cls_tokens: int = 1, + pos_embed_fn: Optional[Callable] = None, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + self.imu_stem = imu_stem + self.embed_dim = embed_dim + self.use_pos_embed = pos_embed_fn is not None + self.num_cls_tokens = num_cls_tokens + self.kernel_size = kernel_size + self.pos_embed = nn.Parameter( + torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim) + ) + + if self.num_cls_tokens > 0: + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, self.embed_dim) + ) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style): + nn.init.normal_(self.pos_embed, std=0.01) + + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + def tokenize_input_and_cls_pos(self, input, stem): + # tokens is of shape B x L x D + tokens = stem.norm_layer(stem.proj(input)) + assert tokens.ndim == 3 + assert tokens.shape[2] == self.embed_dim + B = tokens.shape[0] + if self.num_cls_tokens > 0: + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + tokens = torch.cat((class_tokens, tokens), dim=1) + if self.use_pos_embed: + tokens = tokens + self.pos_embed + return tokens + + def forward(self, imu): + # Patchify + imu = imu.unfold( + -1, + self.kernel_size, + self.kernel_size, + ).permute(0, 2, 1, 3) + imu = imu.reshape(imu.size(0), imu.size(1), -1) + + imu_tokens = self.tokenize_input_and_cls_pos( + imu, + self.imu_stem, + ) + + return_dict = { + "trunk": { + "tokens": imu_tokens, + }, + "head": {}, + } + return return_dict diff --git a/models/modules/image_bind/transformer.py b/models/modules/image_bind/transformer.py new file mode 100644 index 000000000..4950f8a57 --- /dev/null +++ b/models/modules/image_bind/transformer.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Code modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ; +# https://github.com/facebookresearch/deit/blob/main/models.py +# and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py + + +from functools import partial +from typing import Callable, List, Optional + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, trunc_normal_ + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, + # can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class MultiheadAttention(nn.MultiheadAttention): + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + +class ViTAttention(Attention): + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + assert attn_mask is None + return super().forward(x) + + +class BlockWithMasking(nn.Module): + def __init__( + self, + dim: int, + attn_target: Callable, + mlp_ratio: int = 4, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ffn_dropout_rate: float = 0.0, + drop_path: float = 0.0, + layer_scale_type: Optional[str] = None, + layer_scale_init_value: float = 1e-4, + ): + super().__init__() + + assert not isinstance( + attn_target, nn.Module + ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" + self.attn = attn_target() + if drop_path > 0.0: + self.drop_path = DropPath(drop_path) + else: + self.drop_path = nn.Identity() + self.norm_1 = norm_layer(dim) + mlp_hidden_dim = int(mlp_ratio * dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=ffn_dropout_rate, + ) + self.norm_2 = norm_layer(dim) + self.layer_scale_type = layer_scale_type + if self.layer_scale_type is not None: + assert self.layer_scale_type in [ + "per_channel", + "scalar", + ], f"Found Layer scale type {self.layer_scale_type}" + if self.layer_scale_type == "per_channel": + # one gamma value per channel + gamma_shape = [1, 1, dim] + elif self.layer_scale_type == "scalar": + # single gamma value for all channels + gamma_shape = [1, 1, 1] + # two gammas: for each part of the fwd in the encoder + self.layer_scale_gamma1 = nn.Parameter( + torch.ones(size=gamma_shape) * layer_scale_init_value, + requires_grad=True, + ) + self.layer_scale_gamma2 = nn.Parameter( + torch.ones(size=gamma_shape) * layer_scale_init_value, + requires_grad=True, + ) + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + if self.layer_scale_type is None: + x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + x = x + self.drop_path(self.mlp(self.norm_2(x))) + else: + x = ( + x + + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + * self.layer_scale_gamma1 + ) + x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 + return x + + +_LAYER_NORM = partial(nn.LayerNorm, eps=1e-6) + + +class SimpleTransformer(nn.Module): + def __init__( + self, + attn_target: Callable, + embed_dim: int, + num_blocks: int, + block: Callable = BlockWithMasking, + pre_transformer_layer: Optional[Callable] = None, + post_transformer_layer: Optional[Callable] = None, + drop_path_rate: float = 0.0, + drop_path_type: str = "progressive", + norm_layer: Callable = _LAYER_NORM, + mlp_ratio: int = 4, + ffn_dropout_rate: float = 0.0, + layer_scale_type: Optional[ + str + ] = None, # from cait; possible values are None, "per_channel", "scalar" + layer_scale_init_value: float = 1e-4, # from cait; float + weight_init_style: str = "jax", # possible values jax or pytorch + ): + """ + Simple Transformer with the following features + 1. Supports masked attention + 2. Supports DropPath + 3. Supports LayerScale + 4. Supports Dropout in Attention and FFN + 5. Makes few assumptions about the input except that it is a Tensor + """ + super().__init__() + self.pre_transformer_layer = pre_transformer_layer + if drop_path_type == "progressive": + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] + elif drop_path_type == "uniform": + dpr = [drop_path_rate for i in range(num_blocks)] + else: + raise ValueError(f"Unknown drop_path_type: {drop_path_type}") + + self.blocks = nn.Sequential( + *[ + block( + dim=embed_dim, + attn_target=attn_target, + mlp_ratio=mlp_ratio, + ffn_dropout_rate=ffn_dropout_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + layer_scale_type=layer_scale_type, + layer_scale_init_value=layer_scale_init_value, + ) + for i in range(num_blocks) + ] + ) + self.post_transformer_layer = post_transformer_layer + self.weight_init_style = weight_init_style + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + if self.weight_init_style == "jax": + # Based on MAE and official Jax ViT implementation + torch.nn.init.xavier_uniform_(m.weight) + elif self.weight_init_style == "pytorch": + # PyTorch ViT uses trunc_normal_ + trunc_normal_(m.weight, std=0.02) + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + tokens: torch.Tensor, + attn_mask: torch.Tensor = None, + use_checkpoint: bool = False, + checkpoint_every_n: int = 1, + checkpoint_blk_ids: Optional[List[int]] = None, + ): + """ + Inputs + - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation) + - attn: mask of shape L x L + + Output + - x: data of shape N x L x D (or L x N x D depending on the attention implementation) + """ + if self.pre_transformer_layer: + tokens = self.pre_transformer_layer(tokens) + if use_checkpoint and checkpoint_blk_ids is None: + checkpoint_blk_ids = [ + blk_id + for blk_id in range(len(self.blocks)) + if blk_id % checkpoint_every_n == 0 + ] + if checkpoint_blk_ids: + checkpoint_blk_ids = set(checkpoint_blk_ids) + for blk_id, blk in enumerate(self.blocks): + if use_checkpoint and blk_id in checkpoint_blk_ids: + tokens = checkpoint.checkpoint( + blk, tokens, attn_mask, use_reentrant=False + ) + else: + tokens = blk(tokens, attn_mask=attn_mask) + if self.post_transformer_layer: + tokens = self.post_transformer_layer(tokens) + return tokens diff --git a/models/modules/palette_denoise_fn.py b/models/modules/palette_denoise_fn.py index 3655a7218..b2b8b3d66 100644 --- a/models/modules/palette_denoise_fn.py +++ b/models/modules/palette_denoise_fn.py @@ -1,8 +1,12 @@ import torch from torch import nn - +from torchvision import transforms from einops import rearrange +from .image_bind import imagebind_model +from .image_bind.imagebind_model import ModalityType +import clip + class LabelEmbedder(nn.Module): """ @@ -26,12 +30,13 @@ def forward(self, labels): class PaletteDenoiseFn(nn.Module): - def __init__(self, model, cond_embed_dim, conditioning, nclasses): + def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses): super().__init__() self.model = model self.conditioning = conditioning self.cond_embed_dim = cond_embed_dim + self.ref_embed_net = ref_embed_net # Label embedding if "class" in conditioning: @@ -51,20 +56,54 @@ def __init__(self, model, cond_embed_dim, conditioning, nclasses): ) nn.init.normal_(self.netl_embedder_mask.embedding_table.weight, std=0.02) - def forward(self, input, embed_noise_level, cls, mask): - cls_embed, mask_embed = self.compute_cond(input, cls, mask) + # Instantiate model + if "ref" in conditioning: + cond_embed_class = cond_embed_dim // 2 + + self.ref_transform = transforms.Compose( + [ + transforms.Resize( + 224, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.CenterCrop(224), + ] + ) + + if ref_embed_net == "clip": + model_name = "ViT-B/16" + self.freezenetClip, _ = clip.load(model_name) + self.freezenetClip = self.freezenetClip.visual.float() + ref_embed_dim = 512 + + elif ref_embed_net == "imagebind": + self.freezenetImageBin = imagebind_model.imagebind_huge(pretrained=True) + self.freezenetImageBin.eval() + ref_embed_dim = 1024 + + else: + raise NotImplementedError(ref_embed_net) + + self.emb_layers = nn.Sequential( + torch.nn.SiLU(), nn.Linear(ref_embed_dim, cond_embed_class) + ) + + def forward(self, input, embed_noise_level, cls, mask, ref): + cls_embed, mask_embed, ref_embed = self.compute_cond(input, cls, mask, ref) if "class" in self.conditioning: embedding = torch.cat((embed_noise_level, cls_embed), dim=1) else: embedding = embed_noise_level + if "ref" in self.conditioning: + embedding = torch.cat((embedding, ref_embed), dim=1) + if "mask" in self.conditioning: input = torch.cat([input, mask_embed], dim=1) return self.model(input, embedding) - def compute_cond(self, input, cls, mask): + def compute_cond(self, input, cls, mask, ref): if "class" in self.conditioning and cls is not None: cls_embed = self.netl_embedder_class(cls) else: @@ -80,4 +119,22 @@ def compute_cond(self, input, cls, mask): else: mask_embed = None - return cls_embed, mask_embed + if "ref" in self.conditioning: + ref = self.ref_transform(ref) + + if self.ref_embed_net == "clip": + ref_embed = self.freezenetClip(ref) + + elif self.ref_embed_net == "imagebind": + input_ref = {ModalityType.VISION: ref} + ref_embed = self.freezenetImageBin(input_ref)["vision"] + + else: + raise NotImplementedError(ref_embed_net) + + ref_embed = self.emb_layers(ref_embed) + + else: + ref_embed = None + + return cls_embed, mask_embed, ref_embed diff --git a/models/palette_model.py b/models/palette_model.py index 0dfed42d7..dff765119 100644 --- a/models/palette_model.py +++ b/models/palette_model.py @@ -72,6 +72,7 @@ def modify_commandline_options(parser, is_train=True): "previous_frame", "computed_sketch", "low_res", + "ref", ], help="how cond_image is created", ) @@ -210,7 +211,7 @@ def modify_commandline_options(parser, is_train=True): "--alg_palette_conditioning", type=str, default="", - choices=["", "mask", "class", "mask_and_class"], + choices=["", "mask", "class", "mask_and_class", "ref"], help="whether to use conditioning or not", ) @@ -227,6 +228,14 @@ def modify_commandline_options(parser, is_train=True): help="whether to generate samples of each images", ) + parser.add_argument( + "--alg_palette_ref_embed_net", + type=str, + default="clip", + choices=["clip", "imagebind"], + help="embedding network to use for ref conditioning", + ) + return parser def __init__(self, opt, rank): @@ -284,6 +293,15 @@ def __init__(self, opt, rank): for i in range(self.nb_classes_inference): self.gen_visual_names.append("output_" + str(i + 1) + "_") + + elif ( + self.opt.alg_palette_cond_image_creation == "ref" + or "ref" in self.opt.alg_palette_conditioning + ): + for i in range(self.inference_num): + self.gen_visual_names.append("cond_ref_" + str(i + 1) + "_") + self.gen_visual_names.append("output_" + str(i + 1) + "_") + else: self.gen_visual_names.append("output_") @@ -446,6 +464,12 @@ def set_input(self, data): else: self.cls = None + if ( + "ref" in self.opt.alg_palette_conditioning + or self.opt.alg_palette_cond_image_creation == "ref" + ): + self.ref_A = data["ref_A"].to(self.device) + if self.opt.alg_palette_cond_image_creation == "y_t": self.cond_image = self.y_t elif self.opt.alg_palette_cond_image_creation == "previous_frame": @@ -526,6 +550,9 @@ def set_input(self, data): self.cond_image = self.transform_lr(self.gt_image) # bilinear interpolation self.cond_image = self.transform_hr(self.cond_image) # let's get it back + elif self.opt.alg_palette_cond_image_creation == "ref": + self.cond_image = self.ref_A + self.batch_size = self.cond_image.shape[0] self.real_A = self.cond_image @@ -559,8 +586,16 @@ def compute_palette_loss(self): # the highest class is the unconditionned one. cls = torch.where(drop_ids, self.num_classes - 1, cls) + if ( + self.opt.alg_palette_cond_image_creation == "ref" + or "ref" in self.opt.alg_palette_conditioning + ): + ref = self.ref_A + else: + ref = None + noise, noise_hat = self.netG_A( - y_0=y_0, y_cond=y_cond, noise=noise, mask=mask, cls=cls + y_0=y_0, y_cond=y_cond, noise=noise, mask=mask, cls=cls, ref=ref ) if mask is not None: @@ -619,8 +654,57 @@ def inference(self): cls=cur_class, ddim_num_steps=self.ddim_num_steps, ddim_eta=self.ddim_eta, + ref=self.ref_A, + ) + + name = "output_" + str(i + 1) + setattr(self, name, output) + + name = "visuals_" + str(i + 1) + setattr(self, name, visuals) + + self.fake_B = self.output_1 + self.visuals = self.visuals_1 + + elif ( + self.opt.alg_palette_cond_image_creation == "ref" + or self.opt.alg_palette_conditioning == "ref" + ): + for i in range(self.inference_num): + if self.cls is not None: + cls = self.cls[: self.inference_num] + else: + cls = self.cls + + if self.mask is not None: + mask = self.mask[: self.inference_num] + else: + mask = self.mask + + cur_ref = self.cond_image[i : i + 1].expand( + self.inference_num, -1, -1, -1 ) + if self.opt.alg_palette_cond_image_creation == "ref": + + y_cond = cur_ref + + else: + y_cond = self.cond_image[: self.inference_num] + + output, visuals = netG.restoration( + y_cond=y_cond, + y_t=self.y_t[: self.inference_num], + y_0=self.gt_image[: self.inference_num], + mask=mask, + sample_num=self.sample_num, + cls=cls, + ref=cur_ref, + ) + + name = "cond_ref_" + str(i + 1) + setattr(self, name, y_cond) + name = "output_" + str(i + 1) setattr(self, name, output) @@ -717,12 +801,17 @@ def get_dummy_input(self, device=None): else: dummy_cls = None + dummy_ref = torch.ones( + 1, input_nc, self.opt.data_crop_size, self.opt.data_crop_size, device=device + ) + dummy_input = ( dummy_y_0, dummy_y_cond, dummy_mask, dummy_noise, dummy_cls, + dummy_ref, ) return dummy_input diff --git a/options/base_options.py b/options/base_options.py index 6c893f57e..553b4cc24 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -603,6 +603,10 @@ def initialize(self, parser): "temporal_labeled_mask_online", "self_supervised_temporal", "single", + "unaligned_labeled_mask_ref", + "self_supervised_labeled_mask_ref", + "unaligned_labeled_mask_online_ref", + "self_supervised_labeled_mask_online_ref", ], help="chooses how datasets are loaded.", ) diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index 7b40d9b24..054c25816 100644 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -145,6 +145,40 @@ if [ $OUT != 0 ]; then exit 1 fi +####### mask ref test +echo "Running mask ref training tests" +URL=https://joligen.com/datasets/viton_mask_ref_mini.zip +ZIP_FILE=$DIR/viton_mask_ref_mini.zip +TARGET_MASK_REF_DIR=$DIR/viton_mask_ref_mini +wget -N $URL -O $ZIP_FILE +mkdir $TARGET_MASK_REF_DIR +unzip $ZIP_FILE -d $DIR +rm $ZIP_FILE + +python3 -m pytest -p no:cacheprovider -s "${current_dir}/../tests/test_run_mask_ref.py" --dataroot "$TARGET_MASK_REF_DIR" +OUT=$? + +if [ $OUT != 0 ]; then + exit 1 +fi + +####### mask ref online test +echo "Running mask ref online training tests" +URL=https://joligen.com/datasets/viton_bbox_ref_mini.zip +ZIP_FILE=$DIR/viton_bbox_ref_mini.zip +TARGET_MASK_ONLINE_REF_DIR=$DIR/viton_bbox_ref_mini +wget -N $URL -O $ZIP_FILE +mkdir $TARGET_MASK_ONLINE_REF_DIR +unzip $ZIP_FILE -d $DIR +rm $ZIP_FILE + +python3 -m pytest -p no:cacheprovider -s "${current_dir}/../tests/test_run_mask_online_ref.py" --dataroot "$TARGET_MASK_ONLINE_REF_DIR" +OUT=$? + +if [ $OUT != 0 ]; then + exit 1 +fi + echo "Deleting target dir $DIR" rm -rf $DIR/* diff --git a/tests/test_run_mask_online_ref.py b/tests/test_run_mask_online_ref.py new file mode 100644 index 000000000..4d9b916a1 --- /dev/null +++ b/tests/test_run_mask_online_ref.py @@ -0,0 +1,59 @@ +import pytest +import torch.multiprocessing as mp +import sys +from itertools import product + +sys.path.append(sys.path[0] + "/..") +import train +from options.train_options import TrainOptions +from data import create_dataset + +json_like_dict = { + "name": "joligen_utest_mask_online_ref", + "output_display_env": "joligen_utest_mask_online_ref", + "output_display_id": 0, + "gpu_ids": "0", + "data_load_size": 128, + "data_crop_size": 128, + "train_n_epochs": 1, + "train_n_epochs_decay": 0, + "data_max_dataset_size": 10, + "data_relative_paths": True, + "train_G_ema": True, + "dataaug_no_rotate": True, + "G_unet_mha_num_head_channels": 16, + "G_unet_mha_channel_mults": [1, 2], + "G_nblocks": 1, + "G_padding_type": "reflect", + "data_online_creation_rand_mask_A": True, + "f_s_semantic_nclasses": 100, + "model_type": "palette", + "G_netG": "unet_mha", +} + +models_datasets = [ + ["palette", "self_supervised_labeled_mask_online_ref"], + ["cut", "unaligned_labeled_mask_online_ref"], +] +conditionings = [ + "alg_palette_conditioning", + "alg_palette_cond_image_creation", +] + +product_list = product( + models_datasets, + conditionings, +) + + +def test_mask_online_ref(dataroot): + json_like_dict["dataroot"] = dataroot + json_like_dict["checkpoints_dir"] = "/".join(dataroot.split("/")[:-1]) + + for (model, dataset), conditioning in product_list: + json_like_dict_c = json_like_dict.copy() + json_like_dict_c["data_dataset_mode"] = dataset + json_like_dict_c["model_type"] = model + json_like_dict_c[conditioning] = "ref" + opt = TrainOptions().parse_json(json_like_dict_c, save_config=True) + train.launch_training(opt) diff --git a/tests/test_run_mask_ref.py b/tests/test_run_mask_ref.py new file mode 100644 index 000000000..ebb0568f9 --- /dev/null +++ b/tests/test_run_mask_ref.py @@ -0,0 +1,59 @@ +import pytest +import torch.multiprocessing as mp +import sys +from itertools import product + +sys.path.append(sys.path[0] + "/..") +import train +from options.train_options import TrainOptions +from data import create_dataset + +json_like_dict = { + "name": "joligen_utest_mask_ref", + "output_display_env": "joligen_utest_mask_ref", + "output_display_id": 0, + "gpu_ids": "0", + "data_load_size": 128, + "data_crop_size": 128, + "train_n_epochs": 1, + "train_n_epochs_decay": 0, + "data_max_dataset_size": 10, + "data_relative_paths": True, + "train_G_ema": True, + "dataaug_no_rotate": True, + "G_unet_mha_num_head_channels": 16, + "G_unet_mha_channel_mults": [1, 2], + "G_nblocks": 1, + "G_padding_type": "reflect", + "data_online_creation_rand_mask_A": True, + "f_s_semantic_nclasses": 100, + "model_type": "palette", + "G_netG": "unet_mha", +} + +models_datasets = [ + ["palette", "self_supervised_labeled_mask_ref"], + ["cut", "unaligned_labeled_mask_ref"], +] +conditionings = [ + "alg_palette_conditioning", + "alg_palette_cond_image_creation", +] + +product_list = product( + models_datasets, + conditionings, +) + + +def test_mask_ref(dataroot): + json_like_dict["dataroot"] = dataroot + json_like_dict["checkpoints_dir"] = "/".join(dataroot.split("/")[:-1]) + + for (model, dataset), conditioning in product_list: + json_like_dict_c = json_like_dict.copy() + json_like_dict_c["data_dataset_mode"] = dataset + json_like_dict_c["model_type"] = model + json_like_dict_c[conditioning] = "ref" + opt = TrainOptions().parse_json(json_like_dict_c, save_config=True) + train.launch_training(opt)