Skip to content

Commit

Permalink
Iteraitve self attn with prenorm
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Nov 29, 2023
1 parent 6e66149 commit 6f029ba
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 1 deletion.
4 changes: 4 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
from zeta.nn.modules.flexible_mlp import CustomMLP
from zeta.nn.modules.fractorial_net import FractalBlock, FractalNetwork
from zeta.nn.modules.polymorphic_activation import PolymorphicActivation
from zeta.nn.modules.prenorm import PreNorm
from zeta.nn.modules.itca import IterativeCrossSelfAttention

__all__ = [
"CNNNew",
Expand Down Expand Up @@ -96,4 +98,6 @@
"PolymorphicNeuronLayer",
"CustomMLP",
"PolymorphicActivation",
"PreNorm",
"IterativeCrossSelfAttention",
]
4 changes: 3 additions & 1 deletion zeta/nn/modules/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class SimpleDecisionTree(nn.Module):
grad_fn=<AddmmBackward>)]
"""

def __init__(self, input_size, output_size, depth, heads):
def __init__(
self, input_size: int, output_size: int, depth: int, heads: int
):
super(SimpleDecisionTree, self).__init__()
self.input_size = input_size
self.output_size = output_size
Expand Down
145 changes: 145 additions & 0 deletions zeta/nn/modules/itca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import torch
from torch import nn


# Example usage of the IterativeCrossSelfAttention class
class PreNorm(nn.Module):
"""Prenorm
Args:
dim (_type_): _description_
fn (_type_): _description_
"""

def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn

def forward(self, x, context=None):
"""Forward pass of prenorm
Args:
x (_type_): _description_
"""
return self.fn(self.norm(x), context=context)


class CrossAttention(nn.Module):
def __init__(
self,
dim,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
qk_norm: bool = True,
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head**-0.5

self.attend = nn.Softmax(dim=-1)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim), nn.Dropout(dropout)
)

self._qk_norm = nn.LayerNorm(dim)

def forward(self, x, context=None):
if context is None:
context = x

q = self.to_q(x)
kv = self.to_kv(context).chunk(2, dim=-1)
k, v = kv[0], kv[1]

if self.qk_norm:
q, k = self._qk_norm(q), self._qk_norm(k)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)
out = torch.matmul(attn, v)
out = self.to_out(out)
return out


class IterativeCrossSelfAttention(nn.Module):
"""Iterative
Args:
dim (_type_): _description_
depth (_type_): _description_
heads (_type_): _description_
dim_head (_type_): _description_
dropout (float, optional): _description_. Defaults to 0.1.
Methods:
forward(x, context=None): _description_
Examples:
"""
def __init__(
self,
dim,
depth,
heads,
dim_head,
dropout=0.1,
):
super().__init__()
self.layers = nn.ModuleList(
[
PreNorm(
dim,
CrossAttention(
dim, heads=heads, dim_head=dim_head, dropout=dropout
),
)
for _ in range(depth)
]
)

def forward(self, x: torch.Tensor, context: torch.Tensor = None):
"""Forward pass of IterativeCrossSelfAttention
Args:
x (torch.Tensor): _description_
context (_type_, optional): _description_. Defaults to None.
Returns:
_type_: _description_
"""
for layer in self.layers:
x = layer(x, context=context) + x
return x


# import torch

# # Example usage of the IterativeCrossSelfAttention class
# if __name__ == "__main__":
# batch_size = 8
# seq_len = 16 # Sequence length of the input embeddings
# latent_seq_len = 16 # Sequence length of the latent array (could be different from input sequence length)
# dim = 512 # Dimensionality of the input embeddings and latent array
# heads = 8 # Number of attention heads
# dim_head = 64 # Dimensionality of each attention head
# depth = 6 # Number of cross-attention layers

# # Initialize the IterativeCrossSelfAttention module
# iter_cs_attn = IterativeCrossSelfAttention(dim, depth, heads, dim_head)

# # Create random tensors for the input embeddings and the latent array
# input_embeddings = torch.rand(batch_size, seq_len, dim)
# latent_array = torch.rand(batch_size, latent_seq_len, dim)

# # Pass the input embeddings and the latent array through the IterativeCrossSelfAttention module
# output_embeddings = iter_cs_attn(input_embeddings, latent_array)

# print("Output embeddings shape:", output_embeddings.shape)
26 changes: 26 additions & 0 deletions zeta/nn/modules/prenorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

from torch import nn


# Example usage of the IterativeCrossSelfAttention class
class PreNorm(nn.Module):
"""Prenorm
Args:
dim (_type_): _description_
fn (_type_): _description_
"""

def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn

def forward(self, x, context=None):
"""Forward pass of prenorm
Args:
x (_type_): _description_
"""
return self.fn(self.norm(x), context=context)

0 comments on commit 6f029ba

Please sign in to comment.