Skip to content

Commit

Permalink
feat: add bind files
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau committed Jul 4, 2023
1 parent 963b5de commit 350578e
Show file tree
Hide file tree
Showing 5 changed files with 1,630 additions and 0 deletions.
Empty file.
142 changes: 142 additions & 0 deletions models/modules/image_bind/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import einops
import numpy as np
import torch
import torch.nn as nn


class Normalize(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim

def forward(self, x):
return torch.nn.functional.normalize(x, dim=self.dim, p=2)


class LearnableLogitScaling(nn.Module):
def __init__(
self,
logit_scale_init: float = 1 / 0.07,
learnable: bool = True,
max_logit_scale: float = 100,
) -> None:
super().__init__()
self.max_logit_scale = max_logit_scale
self.logit_scale_init = logit_scale_init
self.learnable = learnable
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
if learnable:
self.log_logit_scale = nn.Parameter(log_logit_scale)
else:
self.register_buffer("log_logit_scale", log_logit_scale)

def forward(self, x):
return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x

def extra_repr(self):
st = (
f"logit_scale_init={self.logit_scale_init},learnable={self.learnable},"
f" max_logit_scale={self.max_logit_scale}"
)
return st


class EinOpsRearrange(nn.Module):
def __init__(self, rearrange_expr: str, **kwargs) -> None:
super().__init__()
self.rearrange_expr = rearrange_expr
self.kwargs = kwargs

def forward(self, x):
assert isinstance(x, torch.Tensor)
return einops.rearrange(x, self.rearrange_expr, **self.kwargs)


class VerboseNNModule(nn.Module):
"""
Wrapper around nn.Module that prints registered buffers and parameter names.
"""

@staticmethod
def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
st = (
"("
+ name
+ "): "
+ "tensor("
+ str(tuple(tensor[1].shape))
+ ", requires_grad="
+ str(tensor[1].requires_grad)
+ ")\n"
)
return st

def extra_repr(self) -> str:
named_modules = set()
for p in self.named_modules():
named_modules.update([p[0]])
named_modules = list(named_modules)

string_repr = ""
for p in self.named_parameters():
name = p[0].split(".")[0]
if name not in named_modules:
string_repr += self.get_readable_tensor_repr(name, p)

for p in self.named_buffers():
name = p[0].split(".")[0]
string_repr += self.get_readable_tensor_repr(name, p)

return string_repr


def cast_if_src_dtype(
tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
):
updated = False
if tensor.dtype == src_dtype:
tensor = tensor.to(dtype=tgt_dtype)
updated = True
return tensor, updated


class QuickGELU(nn.Module):
# From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)


class SelectElement(nn.Module):
def __init__(self, index) -> None:
super().__init__()
self.index = index

def forward(self, x):
assert x.ndim >= 3
return x[:, self.index, ...]


class SelectEOSAndProject(nn.Module):
"""
Text Pooling used in OpenCLIP
"""

def __init__(self, proj: nn.Module) -> None:
super().__init__()
self.proj = proj

def forward(self, x, seq_len):
assert x.ndim == 3
# x is of shape B x L x D
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), seq_len]
x = self.proj(x)
return x
Loading

0 comments on commit 350578e

Please sign in to comment.