From 295c4f148f225c7a5c21035295d67295b5584504 Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Mon, 22 Jul 2024 17:00:57 -0700 Subject: [PATCH] [FEAT][Snake] --- zeta/nn/modules/__init__.py | 2 ++ zeta/nn/modules/pretrained_t_five.py | 38 ++++++++++++++++++++++++++++ zeta/nn/modules/snake_act.py | 18 +++++++++++++ 3 files changed, 58 insertions(+) create mode 100644 zeta/nn/modules/pretrained_t_five.py create mode 100644 zeta/nn/modules/snake_act.py diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index 442bab74..727afdd8 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -222,6 +222,7 @@ from zeta.nn.modules.cope import CoPE from zeta.nn.modules.multi_layer_key_cache import MultiLayerKeyValueAttention from zeta.nn.modules.evlm_xattn import GatedMoECrossAttn, GatedXAttention +from zeta.nn.modules.snake_act import Snake # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features @@ -447,4 +448,5 @@ "MultiLayerKeyValueAttention", "GatedMoECrossAttn", "GatedXAttention", + "Snake", ] diff --git a/zeta/nn/modules/pretrained_t_five.py b/zeta/nn/modules/pretrained_t_five.py new file mode 100644 index 00000000..aabba931 --- /dev/null +++ b/zeta/nn/modules/pretrained_t_five.py @@ -0,0 +1,38 @@ +import torch +from transformers import T5Tokenizer, T5EncoderModel +from loguru import logger + + +class PretrainedT5Embedder: + def __init__(self, model_name: str = "t5-small", *args, **kwargs): + """ + Initializes the PretrainedT5Embedder with a specified T5 model. + + Args: + model_name (str): The name of the pre-trained T5 model to use. + """ + logger.info( + f"Initializing the T5 tokenizer and model with {model_name}." + ) + self.tokenizer = T5Tokenizer.from_pretrained(model_name) + self.model = T5EncoderModel.from_pretrained(model_name, *args, **kwargs) + + def run(self, text: str, *args, **kwargs) -> torch.Tensor: + """ + Encodes the input text using the T5 model and returns the embeddings. + + Args: + text (str): The input text to be embedded. + + Returns: + torch.Tensor: The embedded representation of the input text. + """ + logger.info(f"Encoding the text: {text}") + inputs = self.tokenizer( + text, return_tensors="pt", padding=True, truncation=True + ) + with torch.no_grad(): + outputs = self.model(**inputs) + embeddings = outputs.last_hidden_state.mean(dim=1) + logger.info("Text successfully embedded.") + return embeddings diff --git a/zeta/nn/modules/snake_act.py b/zeta/nn/modules/snake_act.py new file mode 100644 index 00000000..6c1ea02d --- /dev/null +++ b/zeta/nn/modules/snake_act.py @@ -0,0 +1,18 @@ +import torch +import torch.nn as nn + + +class Snake(nn.Module): + def __init__(self, alpha: float = 1.0): + super(Snake, self).__init__() + self.alpha = nn.Parameter(torch.tensor(alpha)) + + def forward(self, x): + return x + (1 / self.alpha) * torch.sin(self.alpha * x) ** 2 + + +# # Example usage +# snake = Snake() +# x = torch.randn(10, 100, 100) # Example input tensor +# output = snake(x) +# print(output)