Skip to content

Commit

Permalink
OUTPUTHEAD]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Apr 30, 2024
1 parent cfef940 commit 21c3163
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 29 deletions.
56 changes: 30 additions & 26 deletions playground/models/toka_master_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from zeta.nn.attention.multiquery_attention import MultiQueryAttention
from zeta.nn import OutputHead


class TokaTransformerBlock(nn.Module):
"""
Transformer block used in the Toka model.
Expand Down Expand Up @@ -35,21 +36,21 @@ def __init__(
ff_mult: int,
dropout: float = 0.1,
*args,
**kwargs
**kwargs,
):
super().__init__()
self.dim = dim
self.dim_head = dim_head
self.heads = heads
self.ff_mult = ff_mult
self.dropout = dropout

# Attention
self.attn = MultiQueryAttention(
dim,
heads,
)

# FFn
self.mlp = nn.Sequential(
nn.Linear(dim, dim * ff_mult),
Expand All @@ -60,10 +61,10 @@ def __init__(
nn.LayerNorm(dim),
nn.Linear(dim, dim),
)

# LayerNorm
self.norm = nn.LayerNorm(dim)

def forward(self, x: Tensor):
"""
Forward pass of the TokaTransformerBlock.
Expand All @@ -77,18 +78,18 @@ def forward(self, x: Tensor):
"""
skip = x
x, _, _ = self.attn(x)

# Add with the skip connection
x = x + skip
x = self.norm(x)
skip_two = x

# MLP
x = self.mlp(x)
x = x + skip_two
return self.norm(x)


class TokaTransformer(nn.Module):
"""
A transformer model based on the Toka architecture.
Expand Down Expand Up @@ -121,23 +122,26 @@ def __init__(
dropout: float = 0.1,
depth: int = 6,
*args,
**kwargs
**kwargs,
):
super().__init__()
self.dim = dim
self.dim_head = dim_head
self.heads = heads
self.ff_mult = ff_mult
self.dropout = dropout

# Transformer layer
self.layers = nn.ModuleList([
TokaTransformerBlock(dim, dim_head, heads, ff_mult, dropout) for _ in range(depth)
])

self.layers = nn.ModuleList(
[
TokaTransformerBlock(dim, dim_head, heads, ff_mult, dropout)
for _ in range(depth)
]
)

# Norm
self.norm = nn.LayerNorm(dim)

def forward(self, x: Tensor):
"""
Forward pass of the TokaTransformer.
Expand All @@ -150,10 +154,10 @@ def forward(self, x: Tensor):
"""
x = self.norm(x)

for layer in self.layers:
x = layer(x)

return OutputHead(self.dim, 1)(x)


Expand Down Expand Up @@ -190,7 +194,9 @@ def __init__(

self.act = nn.Tanh()

self.lstm_head = nn.LSTM(dim, dim, num_layers=num_layers, dropout=dropout)
self.lstm_head = nn.LSTM(
dim, dim, num_layers=num_layers, dropout=dropout
)
self.transformer = TokaTransformer(
dim,
dropout=dropout,
Expand All @@ -203,7 +209,7 @@ def __init__(
nn.ELU(),
nn.Linear(dim * ff_mult, dim),
nn.LayerNorm(dim),
)q
)

def forward(self, x: Tensor) -> Tensor:
"""
Expand All @@ -223,9 +229,9 @@ def forward(self, x: Tensor) -> Tensor:
# LSTM
if self.transformer is True:
x = self.transformer(x)
else:
else:
x, _ = self.lstm_head(x)

print(x.shape)

# Concatenate
Expand Down Expand Up @@ -268,7 +274,7 @@ class TokaPolicyBlock(nn.Module):
Attributes:
dim (int): The dimension of the input and output tensors.
dropout (float): The dropout probability.
ff_mult (int): The multiplier for the dimension of the hidden layer in the MLP.
e ff_mult (int): The multiplier for the dimension of the hidden layer in the MLP.
actions (int): The number of output actions.
proj (nn.Linear): The linear projection layer.
norm (nn.LayerNorm): The layer normalization layer.
Expand Down Expand Up @@ -319,11 +325,10 @@ def __init__(

# Softplus
self.soft = nn.Softplus()

# Final proj
self.final_proj = nn.Linear(dim, actions)


# Initialize weights using truncated normal distribution
nn.init.trunc_normal_(self.proj.weight, std=1 / (dim**0.5))
nn.init.trunc_normal_(self.mlp[0].weight, std=1 / (dim**0.5))
Expand All @@ -338,7 +343,6 @@ def __init__(
self.mlp[4].bias.data.zero_()
self.final_proj.bias.data.zero_()


def forward(self, x: Tensor) -> Tensor:
"""
Performs the forward pass of the policy block.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.4.2"
version = "2.4.3"
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
Expand Down
95 changes: 95 additions & 0 deletions zeta/nn/modules/mixtape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class Mixtape(nn.Module):
def __init__(self, vocab_size, d_model, d1, d2, num_gates=4):
super(Mixtape, self).__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.d1 = d1
self.d2 = d2
self.num_gates = num_gates

# Parameters for computing pre-activation gate priors
self.U = nn.Parameter(torch.randn(self.num_gates, self.d2, self.d1))
self.v = nn.Parameter(torch.randn(self.vocab_size, self.d2))
self.u = nn.Parameter(torch.randn(self.num_gates, self.d1))
self.b = nn.Parameter(torch.randn(self.vocab_size, self.num_gates))

# Parameters for context embeddings
self.H = nn.Parameter(
torch.randn(self.num_gates, self.d_model, self.d1)
)

# Token embeddings (not specified in the abstract, assuming needed)
self.token_embeddings = nn.Parameter(
torch.randn(self.vocab_size, self.d_model)
)

def forward(self, gc):
batch_size, seq_length, _ = gc.shape

# Compute context embeddings for each gate
# Expanded gc to [batch_size, seq_length, 1, d1] for broadcasting
hc = torch.tanh(
torch.einsum("kij,btj->btki", self.H, gc)
) # (batch_size, seq_length, num_gates, d_model)

# Compute pre-activation gate priors for each token and gate
# Expanded gc for broadcasting with different parameters
lc = (
torch.einsum(
"ij,btj->bti",
self.v,
torch.tanh(torch.einsum("kij,btj->btki", self.U, gc)),
)
+ torch.einsum("ij,btj->bti", self.u, gc)
+ self.b[None, None, :, :]
) # (batch_size, seq_length, vocab_size, num_gates)

# Sigmoid tree decomposition
gamma = torch.sigmoid(
lc[..., :-1]
) # (batch_size, seq_length, vocab_size, num_gates-1)
pis = [None] * self.num_gates
pis[0] = gamma[..., 0] * gamma[..., 1]
pis[1] = gamma[..., 0] * (1 - gamma[..., 1])
pis[2] = (1 - gamma[..., 0]) * gamma[..., 2]
pis[3] = (1 - gamma[..., 0]) * (1 - gamma[..., 2])

# Convert list to tensor
pi = torch.stack(
pis, dim=-1
) # (batch_size, seq_length, vocab_size, num_gates)
print(pi.shape)

# Compute the logit sum for each token using vector gating
logits = torch.einsum(
"btki,btik->bti",
hc,
torch.einsum("btik,bjk->btikj", pi, self.token_embeddings),
)
print(logits.shape)
probs = F.softmax(
logits, dim=-1
) # (batch_size, seq_length, vocab_size)

return probs


# Example usage
d_model = 512
d1 = 256
d2 = 128
vocab_size = 10000
seq_length = 20

model = Mixtape(vocab_size=vocab_size, d_model=d_model, d1=d1, d2=d2)
gc = torch.randn(
10, seq_length, d1
) # Simulated last-layer hidden states for a batch of 10 with sequence length 20
print(gc.shape)
output = model(gc)
print(output)
13 changes: 11 additions & 2 deletions zeta/nn/modules/multi_input_multi_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,14 @@ def forward(self, x: Tensor):


class OutputHead(nn.Module):
def __init__(self, dim: int, dim_range: int, *args, **kwargs):
def __init__(
self,
dim: int,
dim_range: int = 1,
vocab_size: int = 20000,
*args,
**kwargs,
):
"""
Initializes an OutputHead module.
Expand All @@ -123,8 +130,10 @@ def __init__(self, dim: int, dim_range: int, *args, **kwargs):
# Linear layer for each output
self.output_layers = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim),
nn.Linear(dim, vocab_size),
nn.Softmax(dim_range),
*args,
**kwargs,
)

def forward(self, x: Tensor):
Expand Down

0 comments on commit 21c3163

Please sign in to comment.