Skip to content

Commit

Permalink
yolo
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Oct 11, 2023
1 parent 8ffaccb commit ebcd79a
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 58 deletions.
2 changes: 2 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@
from zeta.nn.modules.cnn_text import CNNNew
from zeta.nn.modules.fast_text import FastTextNew
from zeta.nn.modules.simple_attention import simple_attention
from zeta.nn.modules.spacial_transformer import SpacialTransformer
from zeta.nn.modules.yolo import yolo
37 changes: 16 additions & 21 deletions zeta/nn/modules/cnn_text.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from einops import rearrange, reduce
from torch import nn


class CNNNew(nn.Module):
"""
CNN for language
Expand All @@ -23,36 +24,30 @@ class CNNNew(nn.Module):
dropout=0.5,
)
net(x)
"""

def __init__(
self,
vocab_size,
embedding_dim,
n_filters,
filter_sizes,
output_dim,
dropout
self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.convs = nn.ModuleList([
nn.Conv2d(
embedding_dim,
n_filters,
kernel_size=size
) for size in filter_sizes
])
self.convs = nn.ModuleList(
[
nn.Conv2d(embedding_dim, n_filters, kernel_size=size)
for size in filter_sizes
]
)
self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
"""
Forward pass of CNNNew
"""
x = rearrange(x, 'b t -> b t')
emb = rearrange(self.embedding(x), 't b c -> b c t')
pooled = [reduce(conv(emb), 'b c t -> b c', 'max') for conv in self.convs]
concatenated = rearrange(pooled, 'filter b c -> b (filter c)')
return self.fc(self.dropout(concatenated))
x = rearrange(x, "b t -> b t")
emb = rearrange(self.embedding(x), "t b c -> b c t")
pooled = [reduce(conv(emb), "b c t -> b c", "max") for conv in self.convs]
concatenated = rearrange(pooled, "filter b c -> b (filter c)")
return self.fc(self.dropout(concatenated))
17 changes: 7 additions & 10 deletions zeta/nn/modules/fast_text.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from torch import nn
from einops.layers.torch import Rearrange, Reduce

def FastTextNew(
vocab_size,
embedding_dim,
output_dim
):

def FastTextNew(vocab_size, embedding_dim, output_dim):
"""
FastText for language
Expand All @@ -21,12 +18,12 @@ def FastTextNew(
output_dim=10,
)
net(x)
"""
return nn.Sequential(
Rearrange('t b -> t b '),
Rearrange("t b -> t b "),
nn.Embedding(vocab_size, embedding_dim),
Reduce('t b c -> b c', 'mean'),
Reduce("t b c -> b c", "mean"),
nn.Linear(embedding_dim, output_dim),
Rearrange('b c -> b c'),
)
Rearrange("b c -> b c"),
)
3 changes: 2 additions & 1 deletion zeta/nn/modules/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ class ResNet(nn.Module):
net(x)
"""

def __init__(self, block, layers, num_classes=1000):
super(ResNet, self).__init__()

Expand Down
15 changes: 8 additions & 7 deletions zeta/nn/modules/rnn_nlp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import torch
from torch import nn
from einops import rearrange


class RNNL(nn.Module):
"""
RNN for language
Expand All @@ -27,6 +28,7 @@ class RNNL(nn.Module):
net(x)
"""

def __init__(
self,
vocab_size,
Expand All @@ -38,19 +40,19 @@ def __init__(
dropout,
):
super().__init__()

self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.rnn = nn.LSTM(
embedding_dim,
hidden_dim,
num_layers=n_layers,
bidirectional=bidirectional,
dropout=dropout
dropout=dropout,
)
self.dropout = nn.Dropout(dropout)
self.directions = 2 if bidirectional else 1
self.fc = nn.Linear(hidden_dim * self.directions, output_dim)

def forward(self, x):
"""
Forward pass of the network.
Expand All @@ -64,7 +66,6 @@ def forward(self, x):
"(layer dir) b c -> layer b (dir c)",
dir=self.directions,
)
#take the final layers hidden

# take the final layers hidden
return self.fn(self.dropout(hidden[-1]))

2 changes: 1 addition & 1 deletion zeta/nn/modules/simple_attention.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch
import torch.nn.functional as F


def simple_attention(K, V, Q):
_, n_channels, _ = K.shape
A = torch.einsum("bct,bc1->bt1", [K, Q])
A = F.softmax(A * n_channels ** (-0.5), 1)
R = torch.einsum("bct, bt1->bc1", [V, A])
return torch.cat((R, Q), dim=1)

26 changes: 12 additions & 14 deletions zeta/nn/modules/spacial_transformer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
import torch
from torch import nn
from einops.layers.torch import Rearrange
import torch.nn.functional as F


class SpacialTransformer(nn.Module):
"""
Spacial Transformer Network
Expand All @@ -12,22 +13,19 @@ class SpacialTransformer(nn.Module):
Usage:
>>> stn = SpacialTransformer()
>>> stn.stn(x)
"""
def __init__(
self
):

def __init__(self):
super(SpacialTransformer, self).__init__()

#spatial transformer localization-network
# spatial transformer localization-network
linear = nn.Linear(32, 3 * 2)

#initialize the weights/bias with identity transformation
# initialize the weights/bias with identity transformation
linear.weight.data.zero_()

linear.bias.data.copy_(
torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)
)
linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

self.compute_theta = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
Expand All @@ -36,16 +34,16 @@ def __init__(
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
Rearrange('b c h w -> b (c h w)', h=3, w=3),
Rearrange("b c h w -> b (c h w)", h=3, w=3),
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
linear,
Rearrange('b (row col) -> b row col', row=2, col=3),
Rearrange("b (row col) -> b row col", row=2, col=3),
)

def stn(self, x):
"""
stn module
"""
grid = F.affine_grid(self.compute_theta(x), x.size())
return F.grid_sample(x, grid)
return F.grid_sample(x, grid)
32 changes: 32 additions & 0 deletions zeta/nn/modules/yolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
from torch import nn
from einops.layers.torch import Rearrange
from einops import rearrange


def yolo(input, num_classes, num_anchors, anchors, stride_h, stride_w):
raw_predictions = rearrange(
input,
"b (anchor prediction) h w -> prediction b anchor h w",
anchor=num_anchors,
prediction=5 + num_classes,
)
anchors = torch.FloatTensor(anchors).to(input.device)
anchor_sizes = rearrange(anchors, "anchor dim -> dim () anchor () ()")

_, _, _, in_h, in_w = raw_predictions.shape
grid_h = rearrange(torch.arange(in_h).float(), "h -> () () h ()").to(input.device)
grid_w = rearrange(torch.arange(in_w).float(), "w -> () () () w").to(input.device)

predicted_bboxes = torch.zeros_like(raw_predictions)
predicted_bboxes[0] = (raw_predictions[0].sigmoid() + grid_w) * stride_w # center x
predicted_bboxes[1] = (raw_predictions[1].sigmoid() + grid_h) * stride_h # center y
predicted_bboxes[2:4] = (
raw_predictions[2:4].exp()
) * anchor_sizes # bbox width and height
predicted_bboxes[4] = raw_predictions[4].sigmoid() # confidence
predicted_bboxes[5:] = raw_predictions[5:].sigmoid() # class predictions
# merging all predicted bboxes for each image
return rearrange(
predicted_bboxes, "prediction b anchor h w -> b (anchor h w) prediction"
)
11 changes: 7 additions & 4 deletions zeta/ops/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,11 +497,14 @@ def channel_shuffle_new(x, groups):
)



#GLOW depth to space
# GLOW depth to space
def unsqueeze_2d_new(input, factor=2):
return rearrange(input, "b (c h2 w2) h w -> b c (h h2) (w w2)", h2=factor, w2=factor)
return rearrange(
input, "b (c h2 w2) h w -> b c (h h2) (w w2)", h2=factor, w2=factor
)


def squeeze_2d_new(input, factor=2):
return rearrange(input, "b c (h h2) (w w2) -> b (c h2 w2) h w", h2=factor, w2=factor)
return rearrange(
input, "b c (h h2) (w w2) -> b (c h2 w2) h w", h2=factor, w2=factor
)

0 comments on commit ebcd79a

Please sign in to comment.