From 726468e7815ad2bb64c4df64ec4eff9991a90dc3 Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Sat, 25 May 2024 19:59:13 -0400 Subject: [PATCH] [CLEANUP] --- multi_head_latent_attention.py | 24 +-- .../modules/fractoral_norm.py | 0 pyproject.toml | 2 +- zeta/nn/embeddings/rope.py | 6 +- zeta/nn/modules/__init__.py | 10 +- zeta/nn/modules/simple_lstm.py | 159 ++++++++++++++++++ zeta/nn/modules/simple_rnn.py | 42 +++++ 7 files changed, 217 insertions(+), 26 deletions(-) rename fractoral_norm.py => playground/modules/fractoral_norm.py (100%) create mode 100644 zeta/nn/modules/simple_lstm.py create mode 100644 zeta/nn/modules/simple_rnn.py diff --git a/multi_head_latent_attention.py b/multi_head_latent_attention.py index 3c8745d1..889832e7 100644 --- a/multi_head_latent_attention.py +++ b/multi_head_latent_attention.py @@ -40,22 +40,12 @@ def __init__( # KV self.latent_kv = nn.Parameter(torch.randn(batch_size, seqlen, dim)) - def forward(self, x: Tensor) -> Tensor: - device = x.device - k_r_t, scale = self.rope(self.seqlen, device) - print(k_r_t) - x = k_r_t + x + # Output + self.to_out = nn.Linear(dim, dim) + def forward( + self, x: Tensor, mask: Tensor = None, *args, **kwargs + ) -> Tensor: + b, s, d = x.shape -# # Example -# x = torch.randn(1, 100, 10) - -# # Attention -# model = MultiHeadLatentAttention( -# 10, -# 8, -# ) - -# # Apply the model -# out = model(x) -# print(out.shape) + return x diff --git a/fractoral_norm.py b/playground/modules/fractoral_norm.py similarity index 100% rename from fractoral_norm.py rename to playground/modules/fractoral_norm.py diff --git a/pyproject.toml b/pyproject.toml index f177a42d..d1c8c557 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zetascale" -version = "2.5.1" +version = "2.5.2" description = "Rapidly Build, Optimize, and Deploy SOTA AI Models" authors = ["Zeta Team "] license = "MIT" diff --git a/zeta/nn/embeddings/rope.py b/zeta/nn/embeddings/rope.py index 579d94aa..10a0edfa 100644 --- a/zeta/nn/embeddings/rope.py +++ b/zeta/nn/embeddings/rope.py @@ -67,13 +67,15 @@ def forward(self, seq_len, device): return freqs, scale -def rotate_half(x): +def rotate_half(x: torch.Tensor) -> torch.Tensor: x = rearrange(x, "... (j d) -> ... j d", j=2) x1, x2 = x.unbind(dim=-1) return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(t, freqs, scale=1): +def apply_rotary_pos_emb( + t: torch.Tensor, freqs: torch.Tensor, scale: float = 1 +) -> torch.Tensor: seq_len = t.shape[-2] freqs = freqs[-seq_len:, :] return (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index 68cdd8e7..1b67c747 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -122,8 +122,6 @@ from zeta.nn.modules.pyro import hyper_optimize from zeta.nn.modules.qformer import QFormer from zeta.nn.modules.qkv_norm import qk_norm, qkv_norm - -####### from zeta.nn.modules.quantized_layernorm import QuantizedLN from zeta.nn.modules.recursive_block import RecursiveBlock from zeta.nn.modules.residual import Residual @@ -134,14 +132,10 @@ from zeta.nn.modules.sig_lip import SigLipLoss from zeta.nn.modules.simple_attention import simple_attention from zeta.nn.modules.simple_feedforward import SimpleFeedForward - -###### from zeta.nn.modules.simple_mamba import Mamba, MambaBlock from zeta.nn.modules.simple_res_block import SimpleResBlock from zeta.nn.modules.skipconnection import SkipConnection from zeta.nn.modules.slerp_model_merger import SLERPModelMerger - -#### from zeta.nn.modules.space_time_unet import ( ContinuousPositionBias, Downsample, @@ -223,6 +217,8 @@ SparseTokenIntegration, SparseChannelIntegration, ) +from zeta.nn.modules.simple_lstm import SimpleLSTM +from zeta.nn.modules.simple_rnn import SimpleRNN # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features @@ -442,4 +438,6 @@ "SigLipSigmoidLoss", "SparseTokenIntegration", "SparseChannelIntegration", + "SimpleLSTM", + "SimpleRNN", ] diff --git a/zeta/nn/modules/simple_lstm.py b/zeta/nn/modules/simple_lstm.py new file mode 100644 index 00000000..7d6e5e0e --- /dev/null +++ b/zeta/nn/modules/simple_lstm.py @@ -0,0 +1,159 @@ +import torch +from torch import nn, Tensor + + +class SimpleLSTMCell(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + """ + Simple LSTM cell implementation. + + Args: + dim (int): The input dimension. + hidden_dim (int): The hidden dimension. + """ + super(SimpleLSTMCell, self).__init__() + self.dim = dim + self.hidden_dim = hidden_dim + + # Linear layers for input gate, forget gate, output gate, and cell state + self.W_i = nn.Linear(dim, hidden_dim) + self.U_i = nn.Linear(hidden_dim, hidden_dim) + + self.W_f = nn.Linear(dim, hidden_dim) + self.U_f = nn.Linear(hidden_dim, hidden_dim) + + self.W_o = nn.Linear(dim, hidden_dim) + self.U_o = nn.Linear(hidden_dim, hidden_dim) + + self.W_c = nn.Linear(dim, hidden_dim) + self.U_c = nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x: Tensor, h: Tensor, c: Tensor) -> Tensor: + """ + Forward pass of the Simple LSTM cell. + + Args: + x (Tensor): The input tensor of shape (batch_size, input_dim). + h (Tensor): The previous hidden state tensor of shape (batch_size, hidden_dim). + c (Tensor): The previous cell state tensor of shape (batch_size, hidden_dim). + + Returns: + Tensor: The next hidden state tensor. + Tensor: The next cell state tensor. + """ + # Compute input gate + i = torch.sigmoid(self.W_i(x) + self.U_i(h)) + + # Compute forget gate + f = torch.sigmoid(self.W_f(x) + self.U_f(h)) + + # Compute output gate + o = torch.sigmoid(self.W_o(x) + self.U_o(h)) + + # Compute new cell candidate + c_tilde = torch.tanh(self.W_c(x) + self.U_c(h)) + + # Update cell state + c_next = f * c + i * c_tilde + + # Update hidden state + h_next = o * torch.tanh(c_next) + + return h_next, c_next + + +class SimpleLSTM(nn.Module): + """ + Simple LSTM implementation. + + Args: + dim (int): The input dimension. + hidden_dim (int): The hidden dimension. + depth (int): The number of LSTM layers. + output_dim (int): The output dimension. + """ + + def __init__(self, dim: int, hidden_dim: int, depth: int, output_dim: int): + super(SimpleLSTM, self).__init__() + self.dim = dim + self.hidden_dim = hidden_dim + self.depth = depth + + # LSTM cells + self.cells = nn.ModuleList( + [ + SimpleLSTMCell(dim if i == 0 else hidden_dim, hidden_dim) + for i in range(depth) + ] + ) + + # Final output layer + # self.fc = nn.Linear(hidden_dim, output_dim) + self.sequential = nn.Sequential( + nn.Linear(dim, dim), + nn.LayerNorm(dim), + nn.SiLU(), + nn.Linear(dim, output_dim), + nn.Softmax(dim=1), + ) + + def forward(self, x: Tensor) -> Tensor: + batch_size, seq_length, _ = x.shape + + # Init hidden and cell states with zeros + h = [ + torch.zeros(batch_size, self.hidden_dim).to(x.device) + for _ in range(self.depth) + ] + c = [ + torch.zeros(batch_size, self.hidden_dim).to(x.device) + for _ in range(self.depth) + ] + + # Collect outputs for each time step + outputs = [] + + # Iterate through each time step in the sequence + for t in range(seq_length): + # Extract the input for the current time step + x_t = x[:, t, :] + + # Pass through each LSTM cell + for layer in range(self.depth): + h[layer], c[layer] = self.cells[layer](x_t, h[layer], c[layer]) + x_t = h[layer] + + # Collect the output from the final LSTM layer + outputs.append(h[-1].unsqueeze(1)) + + # Concatenate the outputs along the time dimension + outputs = torch.cat(outputs, dim=1) + print(outputs.shape) + b, s, d = outputs.shape + + # Apply the fully connected layer + # outputs = self.sequential(outputs) + outputs = nn.Sequential( + nn.Linear(d, self.dim), + nn.LayerNorm(self.dim), + nn.SiLU(), + nn.Linear(self.dim, self.dim), + # nn.Softmax(dim=1), + )(outputs) + + return outputs + + +# # Example usage: +# if __name__ == "__main__": +# batch_size = 32 +# seq_length = 10 +# input_dim = 50 +# hidden_dim = 100 +# num_layers = 2 +# output_dim = 30 + +# model = SimpleLSTM(input_dim, hidden_dim, num_layers, output_dim) +# inputs = torch.randn(batch_size, seq_length, input_dim) +# outputs = model(inputs) +# print(outputs) # Expected output shape: (batch_size, seq_length, output_dim) diff --git a/zeta/nn/modules/simple_rnn.py b/zeta/nn/modules/simple_rnn.py new file mode 100644 index 00000000..c6da2de6 --- /dev/null +++ b/zeta/nn/modules/simple_rnn.py @@ -0,0 +1,42 @@ +# replace some of the activation functions from sigmoid to exponential function - e ^ x +# Memory saving: make the memory larger --> associate memory --> increase + + +from torch import nn, Tensor + + +class SimpleRNN(nn.Module): + """ + A simple recurrent neural network module. + + Args: + dim (int): The input dimension. + hidden_dim (int): The dimension of the hidden state. + """ + + def __init__( + self, + dim: int = None, + hidden_dim: int = None, + ): + super().__init__() + self.dim = dim + self.hidden_dim = hidden_dim + + self.act = nn.Tanh() + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the simple RNN module. + + Args: + x (Tensor): The input tensor of shape (batch_size, sequence_length, input_dim). + + Returns: + Tensor: The output tensor of shape (batch_size, sequence_length, hidden_dim). + """ + b, s, d = x.shape + + h = self.act(x) + + return h