diff --git a/models/modules/image_bind/__init__.py b/models/modules/image_bind/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/models/modules/image_bind/helpers.py b/models/modules/image_bind/helpers.py new file mode 100644 index 000000000..d1ae971cf --- /dev/null +++ b/models/modules/image_bind/helpers.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import einops +import numpy as np +import torch +import torch.nn as nn + + +class Normalize(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.nn.functional.normalize(x, dim=self.dim, p=2) + + +class LearnableLogitScaling(nn.Module): + def __init__( + self, + logit_scale_init: float = 1 / 0.07, + learnable: bool = True, + max_logit_scale: float = 100, + ) -> None: + super().__init__() + self.max_logit_scale = max_logit_scale + self.logit_scale_init = logit_scale_init + self.learnable = learnable + log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) + if learnable: + self.log_logit_scale = nn.Parameter(log_logit_scale) + else: + self.register_buffer("log_logit_scale", log_logit_scale) + + def forward(self, x): + return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x + + def extra_repr(self): + st = ( + f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}," + f" max_logit_scale={self.max_logit_scale}" + ) + return st + + +class EinOpsRearrange(nn.Module): + def __init__(self, rearrange_expr: str, **kwargs) -> None: + super().__init__() + self.rearrange_expr = rearrange_expr + self.kwargs = kwargs + + def forward(self, x): + assert isinstance(x, torch.Tensor) + return einops.rearrange(x, self.rearrange_expr, **self.kwargs) + + +class VerboseNNModule(nn.Module): + """ + Wrapper around nn.Module that prints registered buffers and parameter names. + """ + + @staticmethod + def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: + st = ( + "(" + + name + + "): " + + "tensor(" + + str(tuple(tensor[1].shape)) + + ", requires_grad=" + + str(tensor[1].requires_grad) + + ")\n" + ) + return st + + def extra_repr(self) -> str: + named_modules = set() + for p in self.named_modules(): + named_modules.update([p[0]]) + named_modules = list(named_modules) + + string_repr = "" + for p in self.named_parameters(): + name = p[0].split(".")[0] + if name not in named_modules: + string_repr += self.get_readable_tensor_repr(name, p) + + for p in self.named_buffers(): + name = p[0].split(".")[0] + string_repr += self.get_readable_tensor_repr(name, p) + + return string_repr + + +def cast_if_src_dtype( + tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype +): + updated = False + if tensor.dtype == src_dtype: + tensor = tensor.to(dtype=tgt_dtype) + updated = True + return tensor, updated + + +class QuickGELU(nn.Module): + # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166 + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class SelectElement(nn.Module): + def __init__(self, index) -> None: + super().__init__() + self.index = index + + def forward(self, x): + assert x.ndim >= 3 + return x[:, self.index, ...] + + +class SelectEOSAndProject(nn.Module): + """ + Text Pooling used in OpenCLIP + """ + + def __init__(self, proj: nn.Module) -> None: + super().__init__() + self.proj = proj + + def forward(self, x, seq_len): + assert x.ndim == 3 + # x is of shape B x L x D + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), seq_len] + x = self.proj(x) + return x diff --git a/models/modules/image_bind/imagebind_model.py b/models/modules/image_bind/imagebind_model.py new file mode 100644 index 000000000..f0d3fd146 --- /dev/null +++ b/models/modules/image_bind/imagebind_model.py @@ -0,0 +1,520 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import os +from functools import partial +from types import SimpleNamespace + +import torch +import torch.nn as nn + +from .helpers import ( + EinOpsRearrange, + LearnableLogitScaling, + Normalize, + SelectElement, + SelectEOSAndProject, +) +from .multimodal_preprocessors import ( + AudioPreprocessor, + IMUPreprocessor, + PadIm2Video, + PatchEmbedGeneric, + RGBDTPreprocessor, + SpatioTemporalPosEmbeddingHelper, + TextPreprocessor, + ThermalPreprocessor, +) +from .transformer import MultiheadAttention, SimpleTransformer + +ModalityType = SimpleNamespace( + VISION="vision", + TEXT="text", + AUDIO="audio", + THERMAL="thermal", + DEPTH="depth", + IMU="imu", +) + + +class ImageBindModel(nn.Module): + def __init__( + self, + video_frames=2, + kernel_size=(2, 14, 14), + audio_kernel_size=16, + audio_stride=10, + out_embed_dim=768, + vision_embed_dim=1024, + vision_num_blocks=24, + vision_num_heads=16, + audio_embed_dim=768, + audio_num_blocks=12, + audio_num_heads=12, + audio_num_mel_bins=128, + audio_target_len=204, + audio_drop_path=0.1, + text_embed_dim=768, + text_num_blocks=12, + text_num_heads=12, + depth_embed_dim=384, + depth_kernel_size=16, + depth_num_blocks=12, + depth_num_heads=8, + depth_drop_path=0.0, + thermal_embed_dim=768, + thermal_kernel_size=16, + thermal_num_blocks=12, + thermal_num_heads=12, + thermal_drop_path=0.0, + imu_embed_dim=512, + imu_kernel_size=8, + imu_num_blocks=6, + imu_num_heads=8, + imu_drop_path=0.7, + ): + super().__init__() + + self.modality_preprocessors = self._create_modality_preprocessors( + video_frames, + vision_embed_dim, + kernel_size, + text_embed_dim, + audio_embed_dim, + audio_kernel_size, + audio_stride, + audio_num_mel_bins, + audio_target_len, + depth_embed_dim, + depth_kernel_size, + thermal_embed_dim, + thermal_kernel_size, + imu_embed_dim, + ) + + self.modality_trunks = self._create_modality_trunks( + vision_embed_dim, + vision_num_blocks, + vision_num_heads, + text_embed_dim, + text_num_blocks, + text_num_heads, + audio_embed_dim, + audio_num_blocks, + audio_num_heads, + audio_drop_path, + depth_embed_dim, + depth_num_blocks, + depth_num_heads, + depth_drop_path, + thermal_embed_dim, + thermal_num_blocks, + thermal_num_heads, + thermal_drop_path, + imu_embed_dim, + imu_num_blocks, + imu_num_heads, + imu_drop_path, + ) + + self.modality_heads = self._create_modality_heads( + out_embed_dim, + vision_embed_dim, + text_embed_dim, + audio_embed_dim, + depth_embed_dim, + thermal_embed_dim, + imu_embed_dim, + ) + + self.modality_postprocessors = self._create_modality_postprocessors( + out_embed_dim + ) + + def _create_modality_preprocessors( + self, + video_frames=2, + vision_embed_dim=1024, + kernel_size=(2, 14, 14), + text_embed_dim=768, + audio_embed_dim=768, + audio_kernel_size=16, + audio_stride=10, + audio_num_mel_bins=128, + audio_target_len=204, + depth_embed_dim=768, + depth_kernel_size=16, + thermal_embed_dim=768, + thermal_kernel_size=16, + imu_embed_dim=512, + ): + rgbt_stem = PatchEmbedGeneric( + proj_stem=[ + PadIm2Video(pad_type="repeat", ntimes=2), + nn.Conv3d( + in_channels=3, + kernel_size=kernel_size, + out_channels=vision_embed_dim, + stride=kernel_size, + bias=False, + ), + ] + ) + rgbt_preprocessor = RGBDTPreprocessor( + img_size=[3, video_frames, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=rgbt_stem, + depth_stem=None, + ) + + text_preprocessor = TextPreprocessor( + context_length=77, + vocab_size=49408, + embed_dim=text_embed_dim, + causal_masking=True, + ) + + audio_stem = PatchEmbedGeneric( + proj_stem=[ + nn.Conv2d( + in_channels=1, + kernel_size=audio_kernel_size, + stride=audio_stride, + out_channels=audio_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), + ) + audio_preprocessor = AudioPreprocessor( + img_size=[1, audio_num_mel_bins, audio_target_len], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + audio_stem=audio_stem, + ) + + depth_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=depth_kernel_size, + in_channels=1, + out_channels=depth_embed_dim, + stride=depth_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), + ) + + depth_preprocessor = RGBDTPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + rgbt_stem=None, + depth_stem=depth_stem, + ) + + thermal_stem = PatchEmbedGeneric( + [ + nn.Conv2d( + kernel_size=thermal_kernel_size, + in_channels=1, + out_channels=thermal_embed_dim, + stride=thermal_kernel_size, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), + ) + thermal_preprocessor = ThermalPreprocessor( + img_size=[1, 224, 224], + num_cls_tokens=1, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + thermal_stem=thermal_stem, + ) + + imu_stem = PatchEmbedGeneric( + [ + nn.Linear( + in_features=48, + out_features=imu_embed_dim, + bias=False, + ), + ], + norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), + ) + + imu_preprocessor = IMUPreprocessor( + img_size=[6, 2000], + num_cls_tokens=1, + kernel_size=8, + embed_dim=imu_embed_dim, + pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), + imu_stem=imu_stem, + ) + + modality_preprocessors = { + ModalityType.VISION: rgbt_preprocessor, + ModalityType.TEXT: text_preprocessor, + ModalityType.AUDIO: audio_preprocessor, + ModalityType.DEPTH: depth_preprocessor, + ModalityType.THERMAL: thermal_preprocessor, + ModalityType.IMU: imu_preprocessor, + } + + return nn.ModuleDict(modality_preprocessors) + + def _create_modality_trunks( + self, + vision_embed_dim=1024, + vision_num_blocks=24, + vision_num_heads=16, + text_embed_dim=768, + text_num_blocks=12, + text_num_heads=12, + audio_embed_dim=768, + audio_num_blocks=12, + audio_num_heads=12, + audio_drop_path=0.0, + depth_embed_dim=768, + depth_num_blocks=12, + depth_num_heads=12, + depth_drop_path=0.0, + thermal_embed_dim=768, + thermal_num_blocks=12, + thermal_num_heads=12, + thermal_drop_path=0.0, + imu_embed_dim=512, + imu_num_blocks=6, + imu_num_heads=8, + imu_drop_path=0.7, + ): + def instantiate_trunk( + embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path + ): + return SimpleTransformer( + embed_dim=embed_dim, + num_blocks=num_blocks, + ffn_dropout_rate=0.0, + drop_path_rate=drop_path, + attn_target=partial( + MultiheadAttention, + embed_dim=embed_dim, + num_heads=num_heads, + bias=True, + add_bias_kv=add_bias_kv, + ), + pre_transformer_layer=nn.Sequential( + nn.LayerNorm(embed_dim, eps=1e-6) + if pre_transformer_ln + else nn.Identity(), + EinOpsRearrange("b l d -> l b d"), + ), + post_transformer_layer=EinOpsRearrange("l b d -> b l d"), + ) + + modality_trunks = {} + modality_trunks[ModalityType.VISION] = instantiate_trunk( + vision_embed_dim, + vision_num_blocks, + vision_num_heads, + pre_transformer_ln=True, + add_bias_kv=False, + drop_path=0.0, + ) + modality_trunks[ModalityType.TEXT] = instantiate_trunk( + text_embed_dim, + text_num_blocks, + text_num_heads, + pre_transformer_ln=False, + add_bias_kv=False, + drop_path=0.0, + ) + modality_trunks[ModalityType.AUDIO] = instantiate_trunk( + audio_embed_dim, + audio_num_blocks, + audio_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=audio_drop_path, + ) + modality_trunks[ModalityType.DEPTH] = instantiate_trunk( + depth_embed_dim, + depth_num_blocks, + depth_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=depth_drop_path, + ) + modality_trunks[ModalityType.THERMAL] = instantiate_trunk( + thermal_embed_dim, + thermal_num_blocks, + thermal_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=thermal_drop_path, + ) + modality_trunks[ModalityType.IMU] = instantiate_trunk( + imu_embed_dim, + imu_num_blocks, + imu_num_heads, + pre_transformer_ln=False, + add_bias_kv=True, + drop_path=imu_drop_path, + ) + + return nn.ModuleDict(modality_trunks) + + def _create_modality_heads( + self, + out_embed_dim, + vision_embed_dim, + text_embed_dim, + audio_embed_dim, + depth_embed_dim, + thermal_embed_dim, + imu_embed_dim, + ): + modality_heads = {} + + modality_heads[ModalityType.VISION] = nn.Sequential( + nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(vision_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.TEXT] = SelectEOSAndProject( + proj=nn.Sequential( + nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), + nn.Linear(text_embed_dim, out_embed_dim, bias=False), + ) + ) + + modality_heads[ModalityType.AUDIO] = nn.Sequential( + nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(audio_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.DEPTH] = nn.Sequential( + nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(depth_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.THERMAL] = nn.Sequential( + nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), + ) + + modality_heads[ModalityType.IMU] = nn.Sequential( + nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), + SelectElement(index=0), + nn.Dropout(p=0.5), + nn.Linear(imu_embed_dim, out_embed_dim, bias=False), + ) + + return nn.ModuleDict(modality_heads) + + def _create_modality_postprocessors(self, out_embed_dim): + modality_postprocessors = {} + + modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) + modality_postprocessors[ModalityType.TEXT] = nn.Sequential( + Normalize(dim=-1), LearnableLogitScaling(learnable=True) + ) + modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=20.0, learnable=False), + ) + modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=10.0, learnable=False), + ) + modality_postprocessors[ModalityType.IMU] = nn.Sequential( + Normalize(dim=-1), + LearnableLogitScaling(logit_scale_init=5.0, learnable=False), + ) + + return nn.ModuleDict(modality_postprocessors) + + def forward(self, inputs): + outputs = {} + for modality_key, modality_value in inputs.items(): + reduce_list = ( + modality_value.ndim >= 5 + ) # Audio and Video inputs consist of multiple clips + if reduce_list: + B, S = modality_value.shape[:2] + modality_value = modality_value.reshape( + B * S, *modality_value.shape[2:] + ) + + if modality_value is not None: + modality_value = self.modality_preprocessors[modality_key]( + **{modality_key: modality_value} + ) + trunk_inputs = modality_value["trunk"] + head_inputs = modality_value["head"] + modality_value = self.modality_trunks[modality_key](**trunk_inputs) + modality_value = self.modality_heads[modality_key]( + modality_value, **head_inputs + ) + modality_value = self.modality_postprocessors[modality_key]( + modality_value + ) + + if reduce_list: + modality_value = modality_value.reshape(B, S, -1) + modality_value = modality_value.mean(dim=1) + + outputs[modality_key] = modality_value + + return outputs + + +def imagebind_huge(pretrained=False): + model = ImageBindModel( + vision_embed_dim=1280, + vision_num_blocks=32, + vision_num_heads=16, + text_embed_dim=1024, + text_num_blocks=24, + text_num_heads=16, + out_embed_dim=1024, + audio_drop_path=0.1, + imu_drop_path=0.7, + ) + + if pretrained: + + path = ".models/configs/bind/pretrain" + + file_name = "imagebind_huge.pth" + + if not os.path.exists(os.path.join(path, file_name)): + print( + "Downloading imagebind weights to %s ..." + % os.path.join(path, file_name) + ) + os.makedirs(path, exist_ok=True) + torch.hub.download_url_to_file( + "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", + os.path.join(path, file_name), + progress=True, + ) + + model.load_state_dict(torch.load(os.path.join(path, file_name))) + + return model diff --git a/models/modules/image_bind/multimodal_preprocessors.py b/models/modules/image_bind/multimodal_preprocessors.py new file mode 100644 index 000000000..c89777a4a --- /dev/null +++ b/models/modules/image_bind/multimodal_preprocessors.py @@ -0,0 +1,686 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import gzip +import html +import io +import math +from functools import lru_cache +from typing import Callable, List, Optional, Tuple + +import ftfy +import numpy as np +import regex as re +import torch +import torch.nn as nn + +# from iopath.common.file_io import g_pathmgr +from timm.models.layers import trunc_normal_ + +from .helpers import VerboseNNModule, cast_if_src_dtype + + +def get_sinusoid_encoding_table(n_position, d_hid): + """Sinusoid position encoding table""" + + # TODO: make it with torch instead of numpy + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos_i) for pos_i in range(n_position)] + ) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +def interpolate_pos_encoding_2d(target_spatial_size, pos_embed): + N = pos_embed.shape[1] + if N == target_spatial_size: + return pos_embed + dim = pos_embed.shape[-1] + # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32 + pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32) + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), + scale_factor=math.sqrt(target_spatial_size / N), + mode="bicubic", + ) + if updated: + pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16) + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return pos_embed + + +def interpolate_pos_encoding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape=None, + first_patch_idx=1, +): + assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none" + N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists + if npatch_per_img == N: + return pos_embed + + assert ( + patches_layout[-1] == patches_layout[-2] + ), "Interpolation of pos embed not supported for non-square layouts" + + class_emb = pos_embed[:, :first_patch_idx] + pos_embed = pos_embed[:, first_patch_idx:] + + if input_shape is None or patches_layout[0] == 1: + # simple 2D pos embedding, no temporal component + pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed) + elif patches_layout[0] > 1: + # pos embed has a temporal component + assert len(input_shape) == 4, "temporal interpolation not supported" + # we only support 2D interpolation in this case + num_frames = patches_layout[0] + num_spatial_tokens = patches_layout[1] * patches_layout[2] + pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1) + # interpolate embedding for zeroth frame + pos_embed = interpolate_pos_encoding_2d( + npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0) + ) + else: + raise ValueError("This type of interpolation isn't implemented") + + return torch.cat((class_emb, pos_embed), dim=1) + + +def _get_pos_embedding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape, + first_patch_idx=1, +): + pos_embed = interpolate_pos_encoding( + npatch_per_img, + pos_embed, + patches_layout, + input_shape=input_shape, + first_patch_idx=first_patch_idx, + ) + return pos_embed + + +class PatchEmbedGeneric(nn.Module): + """ + PatchEmbed from Hydra + """ + + def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None): + super().__init__() + + if len(proj_stem) > 1: + self.proj = nn.Sequential(*proj_stem) + else: + # Special case to be able to load pre-trained models that were + # trained with a standard stem + self.proj = proj_stem[0] + self.norm_layer = norm_layer + + def get_patch_layout(self, img_size): + with torch.no_grad(): + dummy_img = torch.zeros( + [ + 1, + ] + + img_size + ) + dummy_out = self.proj(dummy_img) + embed_dim = dummy_out.shape[1] + patches_layout = tuple(dummy_out.shape[2:]) + num_patches = np.prod(patches_layout) + return patches_layout, num_patches, embed_dim + + def forward(self, x): + x = self.proj(x) + # B C (T) H W -> B (T)HW C + x = x.flatten(2).transpose(1, 2) + if self.norm_layer is not None: + x = self.norm_layer(x) + return x + + +class SpatioTemporalPosEmbeddingHelper(VerboseNNModule): + def __init__( + self, + patches_layout: List, + num_patches: int, + num_cls_tokens: int, + embed_dim: int, + learnable: bool, + ) -> None: + super().__init__() + self.num_cls_tokens = num_cls_tokens + self.patches_layout = patches_layout + self.num_patches = num_patches + self.num_tokens = num_cls_tokens + num_patches + self.learnable = learnable + if self.learnable: + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) + trunc_normal_(self.pos_embed, std=0.02) + else: + self.register_buffer( + "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim) + ) + + def get_pos_embedding(self, vision_input, all_vision_tokens): + input_shape = vision_input.shape + pos_embed = _get_pos_embedding( + all_vision_tokens.size(1) - self.num_cls_tokens, + pos_embed=self.pos_embed, + patches_layout=self.patches_layout, + input_shape=input_shape, + first_patch_idx=self.num_cls_tokens, + ) + return pos_embed + + +class RGBDTPreprocessor(VerboseNNModule): + def __init__( + self, + rgbt_stem: PatchEmbedGeneric, + depth_stem: Optional[PatchEmbedGeneric], + img_size: Tuple = (3, 224, 224), + num_cls_tokens: int = 1, + pos_embed_fn: Optional[Callable] = None, + use_type_embed: bool = False, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + stem = rgbt_stem if rgbt_stem is not None else depth_stem + ( + self.patches_layout, + self.num_patches, + self.embed_dim, + ) = stem.get_patch_layout(img_size) + self.rgbt_stem = rgbt_stem + self.depth_stem = depth_stem + self.use_pos_embed = pos_embed_fn is not None + self.use_type_embed = use_type_embed + self.num_cls_tokens = num_cls_tokens + + if self.use_pos_embed: + self.pos_embedding_helper = pos_embed_fn( + patches_layout=self.patches_layout, + num_cls_tokens=num_cls_tokens, + num_patches=self.num_patches, + embed_dim=self.embed_dim, + ) + if self.num_cls_tokens > 0: + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, self.embed_dim) + ) + if self.use_type_embed: + self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style): + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + if self.use_pos_embed: + nn.init.normal_(self.pos_embedding_helper.pos_embed) + self.pos_embedding_helper.pos_embed *= scale + + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + if self.use_type_embed: + nn.init.normal_(self.type_embed) + + def tokenize_input_and_cls_pos(self, input, stem, mask): + # tokens is of shape B x L x D + tokens = stem(input) + assert tokens.ndim == 3 + assert tokens.shape[2] == self.embed_dim + B = tokens.shape[0] + if self.num_cls_tokens > 0: + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + tokens = torch.cat((class_tokens, tokens), dim=1) + if self.use_pos_embed: + pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens) + tokens = tokens + pos_embed + if self.use_type_embed: + tokens = tokens + self.type_embed.expand(B, -1, -1) + return tokens + + def forward(self, vision=None, depth=None, patch_mask=None): + if patch_mask is not None: + raise NotImplementedError() + + if vision is not None: + vision_tokens = self.tokenize_input_and_cls_pos( + vision, self.rgbt_stem, patch_mask + ) + + if depth is not None: + depth_tokens = self.tokenize_input_and_cls_pos( + depth, self.depth_stem, patch_mask + ) + + # aggregate tokens + if vision is not None and depth is not None: + final_tokens = vision_tokens + depth_tokens + else: + final_tokens = vision_tokens if vision is not None else depth_tokens + return_dict = { + "trunk": { + "tokens": final_tokens, + }, + "head": {}, + } + return return_dict + + +class AudioPreprocessor(RGBDTPreprocessor): + def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None: + super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs) + + def forward(self, audio=None): + return super().forward(vision=audio) + + +class ThermalPreprocessor(RGBDTPreprocessor): + def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None: + super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs) + + def forward(self, thermal=None): + return super().forward(vision=thermal) + + +def build_causal_attention_mask(context_length): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(context_length, context_length, requires_grad=False) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + +class TextPreprocessor(VerboseNNModule): + def __init__( + self, + vocab_size: int, + context_length: int, + embed_dim: int, + causal_masking: bool, + supply_seq_len_to_head: bool = True, + num_cls_tokens: int = 0, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + self.vocab_size = vocab_size + self.context_length = context_length + self.token_embedding = nn.Embedding(vocab_size, embed_dim) + self.pos_embed = nn.Parameter( + torch.empty(1, self.context_length + num_cls_tokens, embed_dim) + ) + self.causal_masking = causal_masking + if self.causal_masking: + mask = build_causal_attention_mask(self.context_length) + # register the mask as a buffer so it can be moved to the right device + self.register_buffer("mask", mask) + + self.supply_seq_len_to_head = supply_seq_len_to_head + self.num_cls_tokens = num_cls_tokens + self.embed_dim = embed_dim + if num_cls_tokens > 0: + assert self.causal_masking is False, "Masking + CLS token isn't implemented" + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, embed_dim) + ) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style="openclip"): + # OpenCLIP style initialization + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.pos_embed, std=0.01) + + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + def forward(self, text): + # text tokens are of shape B x L x D + text_tokens = self.token_embedding(text) + # concat CLS tokens if any + if self.num_cls_tokens > 0: + B = text_tokens.shape[0] + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + text_tokens = torch.cat((class_tokens, text_tokens), dim=1) + text_tokens = text_tokens + self.pos_embed + return_dict = { + "trunk": { + "tokens": text_tokens, + }, + "head": {}, + } + # Compute sequence length after adding CLS tokens + if self.supply_seq_len_to_head: + text_lengths = text.argmax(dim=-1) + return_dict["head"] = { + "seq_len": text_lengths, + } + if self.causal_masking: + return_dict["trunk"].update({"attn_mask": self.mask}) + return return_dict + + +class Im2Video(nn.Module): + """Convert an image into a trivial video.""" + + def __init__(self, time_dim=2): + super().__init__() + self.time_dim = time_dim + + def forward(self, x): + if x.ndim == 4: + # B, C, H, W -> B, C, T, H, W + return x.unsqueeze(self.time_dim) + elif x.ndim == 5: + return x + else: + raise ValueError(f"Dimension incorrect {x.shape}") + + +class PadIm2Video(Im2Video): + def __init__(self, ntimes, pad_type, time_dim=2): + super().__init__(time_dim=time_dim) + assert ntimes > 0 + assert pad_type in ["zero", "repeat"] + self.ntimes = ntimes + self.pad_type = pad_type + + def forward(self, x): + x = super().forward(x) + if x.shape[self.time_dim] == 1: + if self.pad_type == "repeat": + new_shape = [1] * len(x.shape) + new_shape[self.time_dim] = self.ntimes + x = x.repeat(new_shape) + elif self.pad_type == "zero": + padarg = [0, 0] * len(x.shape) + padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim] + x = nn.functional.pad(x, padarg) + return x + + +# Modified from github.com/openai/CLIP +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str, context_length=77): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + with open(bpe_path, "rb") as fh: # g_pathmgr. + bpe_bytes = io.BytesIO(fh.read()) + merges: List[str] = gzip.open(bpe_bytes).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges: List[Tuple[str, ...]] = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + self.context_length = context_length + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + def __call__(self, texts, context_length=None): + if not context_length: + context_length = self.context_length + + if isinstance(texts, str): + texts = [texts] + + sot_token = self.encoder["<|startoftext|>"] + eot_token = self.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + tokens = tokens[:context_length] + result[i, : len(tokens)] = torch.tensor(tokens) + + if len(result) == 1: + return result[0] + return result + + +class IMUPreprocessor(VerboseNNModule): + def __init__( + self, + kernel_size: int, + imu_stem: PatchEmbedGeneric, + embed_dim: int, + img_size: Tuple = (6, 2000), + num_cls_tokens: int = 1, + pos_embed_fn: Optional[Callable] = None, + init_param_style: str = "openclip", + ) -> None: + super().__init__() + self.imu_stem = imu_stem + self.embed_dim = embed_dim + self.use_pos_embed = pos_embed_fn is not None + self.num_cls_tokens = num_cls_tokens + self.kernel_size = kernel_size + self.pos_embed = nn.Parameter( + torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim) + ) + + if self.num_cls_tokens > 0: + self.cls_token = nn.Parameter( + torch.zeros(1, self.num_cls_tokens, self.embed_dim) + ) + + self.init_parameters(init_param_style) + + @torch.no_grad() + def init_parameters(self, init_param_style): + nn.init.normal_(self.pos_embed, std=0.01) + + if init_param_style == "openclip": + # OpenCLIP style initialization + scale = self.embed_dim**-0.5 + + if self.num_cls_tokens > 0: + nn.init.normal_(self.cls_token) + self.cls_token *= scale + elif init_param_style == "vit": + self.cls_token.data.fill_(0) + else: + raise ValueError(f"Unknown init {init_param_style}") + + def tokenize_input_and_cls_pos(self, input, stem): + # tokens is of shape B x L x D + tokens = stem.norm_layer(stem.proj(input)) + assert tokens.ndim == 3 + assert tokens.shape[2] == self.embed_dim + B = tokens.shape[0] + if self.num_cls_tokens > 0: + class_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole class_tokens impl from Phil Wang, thanks + tokens = torch.cat((class_tokens, tokens), dim=1) + if self.use_pos_embed: + tokens = tokens + self.pos_embed + return tokens + + def forward(self, imu): + # Patchify + imu = imu.unfold( + -1, + self.kernel_size, + self.kernel_size, + ).permute(0, 2, 1, 3) + imu = imu.reshape(imu.size(0), imu.size(1), -1) + + imu_tokens = self.tokenize_input_and_cls_pos( + imu, + self.imu_stem, + ) + + return_dict = { + "trunk": { + "tokens": imu_tokens, + }, + "head": {}, + } + return return_dict diff --git a/models/modules/image_bind/transformer.py b/models/modules/image_bind/transformer.py new file mode 100644 index 000000000..4950f8a57 --- /dev/null +++ b/models/modules/image_bind/transformer.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# Portions Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Code modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ; +# https://github.com/facebookresearch/deit/blob/main/models.py +# and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py + + +from functools import partial +from typing import Callable, List, Optional + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, trunc_normal_ + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, + # can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class MultiheadAttention(nn.MultiheadAttention): + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + +class ViTAttention(Attention): + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + assert attn_mask is None + return super().forward(x) + + +class BlockWithMasking(nn.Module): + def __init__( + self, + dim: int, + attn_target: Callable, + mlp_ratio: int = 4, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ffn_dropout_rate: float = 0.0, + drop_path: float = 0.0, + layer_scale_type: Optional[str] = None, + layer_scale_init_value: float = 1e-4, + ): + super().__init__() + + assert not isinstance( + attn_target, nn.Module + ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" + self.attn = attn_target() + if drop_path > 0.0: + self.drop_path = DropPath(drop_path) + else: + self.drop_path = nn.Identity() + self.norm_1 = norm_layer(dim) + mlp_hidden_dim = int(mlp_ratio * dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=ffn_dropout_rate, + ) + self.norm_2 = norm_layer(dim) + self.layer_scale_type = layer_scale_type + if self.layer_scale_type is not None: + assert self.layer_scale_type in [ + "per_channel", + "scalar", + ], f"Found Layer scale type {self.layer_scale_type}" + if self.layer_scale_type == "per_channel": + # one gamma value per channel + gamma_shape = [1, 1, dim] + elif self.layer_scale_type == "scalar": + # single gamma value for all channels + gamma_shape = [1, 1, 1] + # two gammas: for each part of the fwd in the encoder + self.layer_scale_gamma1 = nn.Parameter( + torch.ones(size=gamma_shape) * layer_scale_init_value, + requires_grad=True, + ) + self.layer_scale_gamma2 = nn.Parameter( + torch.ones(size=gamma_shape) * layer_scale_init_value, + requires_grad=True, + ) + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + if self.layer_scale_type is None: + x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + x = x + self.drop_path(self.mlp(self.norm_2(x))) + else: + x = ( + x + + self.drop_path(self.attn(self.norm_1(x), attn_mask)) + * self.layer_scale_gamma1 + ) + x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 + return x + + +_LAYER_NORM = partial(nn.LayerNorm, eps=1e-6) + + +class SimpleTransformer(nn.Module): + def __init__( + self, + attn_target: Callable, + embed_dim: int, + num_blocks: int, + block: Callable = BlockWithMasking, + pre_transformer_layer: Optional[Callable] = None, + post_transformer_layer: Optional[Callable] = None, + drop_path_rate: float = 0.0, + drop_path_type: str = "progressive", + norm_layer: Callable = _LAYER_NORM, + mlp_ratio: int = 4, + ffn_dropout_rate: float = 0.0, + layer_scale_type: Optional[ + str + ] = None, # from cait; possible values are None, "per_channel", "scalar" + layer_scale_init_value: float = 1e-4, # from cait; float + weight_init_style: str = "jax", # possible values jax or pytorch + ): + """ + Simple Transformer with the following features + 1. Supports masked attention + 2. Supports DropPath + 3. Supports LayerScale + 4. Supports Dropout in Attention and FFN + 5. Makes few assumptions about the input except that it is a Tensor + """ + super().__init__() + self.pre_transformer_layer = pre_transformer_layer + if drop_path_type == "progressive": + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] + elif drop_path_type == "uniform": + dpr = [drop_path_rate for i in range(num_blocks)] + else: + raise ValueError(f"Unknown drop_path_type: {drop_path_type}") + + self.blocks = nn.Sequential( + *[ + block( + dim=embed_dim, + attn_target=attn_target, + mlp_ratio=mlp_ratio, + ffn_dropout_rate=ffn_dropout_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + layer_scale_type=layer_scale_type, + layer_scale_init_value=layer_scale_init_value, + ) + for i in range(num_blocks) + ] + ) + self.post_transformer_layer = post_transformer_layer + self.weight_init_style = weight_init_style + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + if self.weight_init_style == "jax": + # Based on MAE and official Jax ViT implementation + torch.nn.init.xavier_uniform_(m.weight) + elif self.weight_init_style == "pytorch": + # PyTorch ViT uses trunc_normal_ + trunc_normal_(m.weight, std=0.02) + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + tokens: torch.Tensor, + attn_mask: torch.Tensor = None, + use_checkpoint: bool = False, + checkpoint_every_n: int = 1, + checkpoint_blk_ids: Optional[List[int]] = None, + ): + """ + Inputs + - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation) + - attn: mask of shape L x L + + Output + - x: data of shape N x L x D (or L x N x D depending on the attention implementation) + """ + if self.pre_transformer_layer: + tokens = self.pre_transformer_layer(tokens) + if use_checkpoint and checkpoint_blk_ids is None: + checkpoint_blk_ids = [ + blk_id + for blk_id in range(len(self.blocks)) + if blk_id % checkpoint_every_n == 0 + ] + if checkpoint_blk_ids: + checkpoint_blk_ids = set(checkpoint_blk_ids) + for blk_id, blk in enumerate(self.blocks): + if use_checkpoint and blk_id in checkpoint_blk_ids: + tokens = checkpoint.checkpoint( + blk, tokens, attn_mask, use_reentrant=False + ) + else: + tokens = blk(tokens, attn_mask=attn_mask) + if self.post_transformer_layer: + tokens = self.post_transformer_layer(tokens) + return tokens