Skip to content

Commit

Permalink
[video_patch_linear_flatten]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Apr 6, 2024
1 parent 6d2b9fe commit e1afe6c
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.3.5"
version = "2.3.7"
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
Expand Down
1 change: 0 additions & 1 deletion zeta/models/andromeda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# the best llm ever made
from torch.nn import Module

from zeta.structs.auto_regressive_wrapper import AutoRegressiveWrapper
from zeta.structs.transformer import Decoder, Transformer


Expand Down
4 changes: 4 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@
from zeta.nn.modules.patch_linear_flatten import (
vit_output_head,
patch_linear_flatten,
cls_tokens,
video_patch_linear_flatten,
)
from zeta.nn.modules.chan_layer_norm import ChanLayerNorm

Expand Down Expand Up @@ -416,4 +418,6 @@
"vit_output_head",
"posemb_sincos_2d",
"ChanLayerNorm",
"cls_tokens",
"video_patch_linear_flatten",
]
130 changes: 129 additions & 1 deletion zeta/nn/modules/patch_linear_flatten.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch import nn, Tensor
from einops.layers.torch import Rearrange
from einops import repeat


def posemb_sincos_2d(patches, temperature=10000, dtype=torch.float32):
Expand All @@ -23,7 +24,9 @@ def posemb_sincos_2d(patches, temperature=10000, dtype=torch.float32):
return pe.type(dtype)


def vit_output_head(x: Tensor, dim: int, num_classes: int = None):
def vit_output_head(
x: Tensor, dim: int, num_classes: int = None, pooling: str = "mean"
):
"""
Applies a Vision Transformer (ViT) output head to the input tensor.
Expand All @@ -35,6 +38,15 @@ def vit_output_head(x: Tensor, dim: int, num_classes: int = None):
Returns:
Tensor: The output tensor after applying the ViT output head.
"""
if pooling == "mean":
x = x.mean(dim=1)
elif pooling == "cls":
x = x[:, 0]
elif pooling == "max":
x = x.max(dim=1).values
elif pooling == "none":
x = x
x = nn.Identity()(x) # Identity layer to avoid error in nn.Sequential
return nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))(x)


Expand Down Expand Up @@ -86,3 +98,119 @@ def patch_linear_flatten(
to_patch_embeddings + +pos_embeddings

return to_patch_embeddings


def video_patch_linear_flatten(
x: Tensor,
patch_size: int,
dim: int,
image_size: int,
channels: int = 3,
add_pos_embeddings: bool = False,
frame_patch_size: int = 1,
frames: int = None,
seqlen: int = None,
*args,
**kwargs,
):
"""
Applies patch embedding to the input tensor and flattens it.
Args:
x (Tensor): Input tensor of shape (batch_size, channels, image_height, image_width).
patch_size (int): Size of the square patch.
dim (int): Dimension of the output tensor.
image_size (int): Size of the input image (assumed to be square).
channels (int, optional): Number of input channels. Defaults to 3.
add_pos_embeddings (bool, optional): Whether to add positional embeddings. Defaults to False.
Returns:
Tensor: Flattened tensor of shape (batch_size, num_patches, dim).
"""
image_height, image_width = image_size, image_size
patch_height, patch_width = patch_size, patch_size

assert (
image_height % patch_height == 0 and image_width % patch_width == 0
), "Image dimensions must be divisible by the patch size."
assert (
frames % frame_patch_size == 0
), "Frames must be divisible by frame patch size"

# calculate number of patches
num_patches = (
(image_height // patch_height)
* (image_width // patch_width)
* (frames // frame_patch_size)
)
patch_dim = channels * patch_height * patch_width * frame_patch_size

# Patch Embedding layer
to_patch_embeddings = nn.Sequential(
Rearrange(
"b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)",
p1=patch_height,
p2=patch_width,
pf=frame_patch_size,
),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)(x)

if add_pos_embeddings is not False:
pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
to_patch_embeddings += pos_embedding[:, : (seqlen + 1)]

return to_patch_embeddings


def cls_tokens(
x: Tensor,
dropout: float = 0.0,
num_patches: int = None,
pos_emb: bool = False,
):
"""
Adds class tokens to the input tensor and applies dropout and positional embeddings if specified.
Args:
x (Tensor): The input tensor of shape (batch_size, sequence_length, hidden_dim).
dropout (float, optional): The dropout probability. Defaults to 0.0.
num_patches (int, optional): The number of patches. Defaults to None.
pos_emb (bool, optional): Whether to apply positional embeddings. Defaults to False.
Returns:
Tensor: The modified input tensor with class tokens added.
"""
b, s, d = x.shape

cls_tokens = repeat(x, "1 1 d -> b 1 d", b=b)
x = torch.cat((cls_tokens, x), dim=1)

if dropout is not None:
x = nn.Dropout(dropout)(x)

if pos_emb:
pos_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, d))
x += pos_embeddings[:, : (s + 1)]

return x


# # video: b, c, f, h, w
# x = torch.randn(1, 3, 16, 224, 224)

# # patch size
# patch_size = 16
# frames = 16
# frame_patch_size = 1
# dim = 512
# image_size = 224
# channels = 3
# model = video_patch_linear_flatten(
# x, patch_size, dim, image_size, channels, frames=frames, frame_patch_size=frame_patch_size
# )

# print(model.shape)

0 comments on commit e1afe6c

Please sign in to comment.