Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye Gomez authored and Kye Gomez committed May 25, 2024
1 parent 4e6a194 commit 726468e
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 26 deletions.
24 changes: 7 additions & 17 deletions multi_head_latent_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,12 @@ def __init__(
# KV
self.latent_kv = nn.Parameter(torch.randn(batch_size, seqlen, dim))

def forward(self, x: Tensor) -> Tensor:
device = x.device
k_r_t, scale = self.rope(self.seqlen, device)
print(k_r_t)
x = k_r_t + x
# Output
self.to_out = nn.Linear(dim, dim)

def forward(
self, x: Tensor, mask: Tensor = None, *args, **kwargs
) -> Tensor:
b, s, d = x.shape

# # Example
# x = torch.randn(1, 100, 10)

# # Attention
# model = MultiHeadLatentAttention(
# 10,
# 8,
# )

# # Apply the model
# out = model(x)
# print(out.shape)
return x
File renamed without changes.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.5.1"
version = "2.5.2"
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
Expand Down
6 changes: 4 additions & 2 deletions zeta/nn/embeddings/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ def forward(self, seq_len, device):
return freqs, scale


def rotate_half(x):
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-1)
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(t, freqs, scale=1):
def apply_rotary_pos_emb(
t: torch.Tensor, freqs: torch.Tensor, scale: float = 1
) -> torch.Tensor:
seq_len = t.shape[-2]
freqs = freqs[-seq_len:, :]
return (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
10 changes: 4 additions & 6 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@
from zeta.nn.modules.pyro import hyper_optimize
from zeta.nn.modules.qformer import QFormer
from zeta.nn.modules.qkv_norm import qk_norm, qkv_norm

#######
from zeta.nn.modules.quantized_layernorm import QuantizedLN
from zeta.nn.modules.recursive_block import RecursiveBlock
from zeta.nn.modules.residual import Residual
Expand All @@ -134,14 +132,10 @@
from zeta.nn.modules.sig_lip import SigLipLoss
from zeta.nn.modules.simple_attention import simple_attention
from zeta.nn.modules.simple_feedforward import SimpleFeedForward

######
from zeta.nn.modules.simple_mamba import Mamba, MambaBlock
from zeta.nn.modules.simple_res_block import SimpleResBlock
from zeta.nn.modules.skipconnection import SkipConnection
from zeta.nn.modules.slerp_model_merger import SLERPModelMerger

####
from zeta.nn.modules.space_time_unet import (
ContinuousPositionBias,
Downsample,
Expand Down Expand Up @@ -223,6 +217,8 @@
SparseTokenIntegration,
SparseChannelIntegration,
)
from zeta.nn.modules.simple_lstm import SimpleLSTM
from zeta.nn.modules.simple_rnn import SimpleRNN

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -442,4 +438,6 @@
"SigLipSigmoidLoss",
"SparseTokenIntegration",
"SparseChannelIntegration",
"SimpleLSTM",
"SimpleRNN",
]
159 changes: 159 additions & 0 deletions zeta/nn/modules/simple_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import torch
from torch import nn, Tensor


class SimpleLSTMCell(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
"""
Simple LSTM cell implementation.
Args:
dim (int): The input dimension.
hidden_dim (int): The hidden dimension.
"""
super(SimpleLSTMCell, self).__init__()
self.dim = dim
self.hidden_dim = hidden_dim

# Linear layers for input gate, forget gate, output gate, and cell state
self.W_i = nn.Linear(dim, hidden_dim)
self.U_i = nn.Linear(hidden_dim, hidden_dim)

self.W_f = nn.Linear(dim, hidden_dim)
self.U_f = nn.Linear(hidden_dim, hidden_dim)

self.W_o = nn.Linear(dim, hidden_dim)
self.U_o = nn.Linear(hidden_dim, hidden_dim)

self.W_c = nn.Linear(dim, hidden_dim)
self.U_c = nn.Linear(hidden_dim, hidden_dim)

def forward(self, x: Tensor, h: Tensor, c: Tensor) -> Tensor:
"""
Forward pass of the Simple LSTM cell.
Args:
x (Tensor): The input tensor of shape (batch_size, input_dim).
h (Tensor): The previous hidden state tensor of shape (batch_size, hidden_dim).
c (Tensor): The previous cell state tensor of shape (batch_size, hidden_dim).
Returns:
Tensor: The next hidden state tensor.
Tensor: The next cell state tensor.
"""
# Compute input gate
i = torch.sigmoid(self.W_i(x) + self.U_i(h))

# Compute forget gate
f = torch.sigmoid(self.W_f(x) + self.U_f(h))

# Compute output gate
o = torch.sigmoid(self.W_o(x) + self.U_o(h))

# Compute new cell candidate
c_tilde = torch.tanh(self.W_c(x) + self.U_c(h))

# Update cell state
c_next = f * c + i * c_tilde

# Update hidden state
h_next = o * torch.tanh(c_next)

return h_next, c_next


class SimpleLSTM(nn.Module):
"""
Simple LSTM implementation.
Args:
dim (int): The input dimension.
hidden_dim (int): The hidden dimension.
depth (int): The number of LSTM layers.
output_dim (int): The output dimension.
"""

def __init__(self, dim: int, hidden_dim: int, depth: int, output_dim: int):
super(SimpleLSTM, self).__init__()
self.dim = dim
self.hidden_dim = hidden_dim
self.depth = depth

# LSTM cells
self.cells = nn.ModuleList(
[
SimpleLSTMCell(dim if i == 0 else hidden_dim, hidden_dim)
for i in range(depth)
]
)

# Final output layer
# self.fc = nn.Linear(hidden_dim, output_dim)
self.sequential = nn.Sequential(
nn.Linear(dim, dim),
nn.LayerNorm(dim),
nn.SiLU(),
nn.Linear(dim, output_dim),
nn.Softmax(dim=1),
)

def forward(self, x: Tensor) -> Tensor:
batch_size, seq_length, _ = x.shape

# Init hidden and cell states with zeros
h = [
torch.zeros(batch_size, self.hidden_dim).to(x.device)
for _ in range(self.depth)
]
c = [
torch.zeros(batch_size, self.hidden_dim).to(x.device)
for _ in range(self.depth)
]

# Collect outputs for each time step
outputs = []

# Iterate through each time step in the sequence
for t in range(seq_length):
# Extract the input for the current time step
x_t = x[:, t, :]

# Pass through each LSTM cell
for layer in range(self.depth):
h[layer], c[layer] = self.cells[layer](x_t, h[layer], c[layer])
x_t = h[layer]

# Collect the output from the final LSTM layer
outputs.append(h[-1].unsqueeze(1))

# Concatenate the outputs along the time dimension
outputs = torch.cat(outputs, dim=1)
print(outputs.shape)
b, s, d = outputs.shape

# Apply the fully connected layer
# outputs = self.sequential(outputs)
outputs = nn.Sequential(
nn.Linear(d, self.dim),
nn.LayerNorm(self.dim),
nn.SiLU(),
nn.Linear(self.dim, self.dim),
# nn.Softmax(dim=1),
)(outputs)

return outputs


# # Example usage:
# if __name__ == "__main__":
# batch_size = 32
# seq_length = 10
# input_dim = 50
# hidden_dim = 100
# num_layers = 2
# output_dim = 30

# model = SimpleLSTM(input_dim, hidden_dim, num_layers, output_dim)
# inputs = torch.randn(batch_size, seq_length, input_dim)
# outputs = model(inputs)
# print(outputs) # Expected output shape: (batch_size, seq_length, output_dim)
42 changes: 42 additions & 0 deletions zeta/nn/modules/simple_rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# replace some of the activation functions from sigmoid to exponential function - e ^ x
# Memory saving: make the memory larger --> associate memory --> increase


from torch import nn, Tensor


class SimpleRNN(nn.Module):
"""
A simple recurrent neural network module.
Args:
dim (int): The input dimension.
hidden_dim (int): The dimension of the hidden state.
"""

def __init__(
self,
dim: int = None,
hidden_dim: int = None,
):
super().__init__()
self.dim = dim
self.hidden_dim = hidden_dim

self.act = nn.Tanh()

def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the simple RNN module.
Args:
x (Tensor): The input tensor of shape (batch_size, sequence_length, input_dim).
Returns:
Tensor: The output tensor of shape (batch_size, sequence_length, hidden_dim).
"""
b, s, d = x.shape

h = self.act(x)

return h

0 comments on commit 726468e

Please sign in to comment.