Skip to content

Commit

Permalink
[FEAT][Snake]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye Gomez authored and Kye Gomez committed Jul 23, 2024
1 parent 4ff5d90 commit 295c4f1
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
2 changes: 2 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@
from zeta.nn.modules.cope import CoPE
from zeta.nn.modules.multi_layer_key_cache import MultiLayerKeyValueAttention
from zeta.nn.modules.evlm_xattn import GatedMoECrossAttn, GatedXAttention
from zeta.nn.modules.snake_act import Snake

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -447,4 +448,5 @@
"MultiLayerKeyValueAttention",
"GatedMoECrossAttn",
"GatedXAttention",
"Snake",
]
38 changes: 38 additions & 0 deletions zeta/nn/modules/pretrained_t_five.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from transformers import T5Tokenizer, T5EncoderModel
from loguru import logger


class PretrainedT5Embedder:
def __init__(self, model_name: str = "t5-small", *args, **kwargs):
"""
Initializes the PretrainedT5Embedder with a specified T5 model.
Args:
model_name (str): The name of the pre-trained T5 model to use.
"""
logger.info(
f"Initializing the T5 tokenizer and model with {model_name}."
)
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
self.model = T5EncoderModel.from_pretrained(model_name, *args, **kwargs)

def run(self, text: str, *args, **kwargs) -> torch.Tensor:
"""
Encodes the input text using the T5 model and returns the embeddings.
Args:
text (str): The input text to be embedded.
Returns:
torch.Tensor: The embedded representation of the input text.
"""
logger.info(f"Encoding the text: {text}")
inputs = self.tokenizer(
text, return_tensors="pt", padding=True, truncation=True
)
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
logger.info("Text successfully embedded.")
return embeddings
18 changes: 18 additions & 0 deletions zeta/nn/modules/snake_act.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import torch.nn as nn


class Snake(nn.Module):
def __init__(self, alpha: float = 1.0):
super(Snake, self).__init__()
self.alpha = nn.Parameter(torch.tensor(alpha))

def forward(self, x):
return x + (1 / self.alpha) * torch.sin(self.alpha * x) ** 2


# # Example usage
# snake = Snake()
# x = torch.randn(10, 100, 100) # Example input tensor
# output = snake(x)
# print(output)

0 comments on commit 295c4f1

Please sign in to comment.