Skip to content

Commit

Permalink
SpacialTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Oct 11, 2023
1 parent 03c561b commit 8ffaccb
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 1 deletion.
2 changes: 1 addition & 1 deletion zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
from zeta.nn.modules.rnn_nlp import RNNL
from zeta.nn.modules.cnn_text import CNNNew
from zeta.nn.modules.fast_text import FastTextNew

from zeta.nn.modules.simple_attention import simple_attention
10 changes: 10 additions & 0 deletions zeta/nn/modules/simple_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
import torch.nn.functional as F

def simple_attention(K, V, Q):
_, n_channels, _ = K.shape
A = torch.einsum("bct,bc1->bt1", [K, Q])
A = F.softmax(A * n_channels ** (-0.5), 1)
R = torch.einsum("bct, bt1->bc1", [V, A])
return torch.cat((R, Q), dim=1)

51 changes: 51 additions & 0 deletions zeta/nn/modules/spacial_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
from torch import nn
from einops.layers.torch import Rearrange
import torch.nn.functional as F

class SpacialTransformer(nn.Module):
"""
Spacial Transformer Network
https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
Usage:
>>> stn = SpacialTransformer()
>>> stn.stn(x)
"""
def __init__(
self
):
super(SpacialTransformer, self).__init__()

#spatial transformer localization-network
linear = nn.Linear(32, 3 * 2)

#initialize the weights/bias with identity transformation
linear.weight.data.zero_()

linear.bias.data.copy_(
torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)
)

self.compute_theta = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
Rearrange('b c h w -> b (c h w)', h=3, w=3),
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
linear,
Rearrange('b (row col) -> b row col', row=2, col=3),
)

def stn(self, x):
"""
stn module
"""
grid = F.affine_grid(self.compute_theta(x), x.size())
return F.grid_sample(x, grid)
10 changes: 10 additions & 0 deletions zeta/ops/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,13 @@ def channel_shuffle_new(x, groups):
"b (c1 c2) h w -> b (c2 c1) h w",
c1=groups,
)



#GLOW depth to space
def unsqueeze_2d_new(input, factor=2):
return rearrange(input, "b (c h2 w2) h w -> b c (h h2) (w w2)", h2=factor, w2=factor)


def squeeze_2d_new(input, factor=2):
return rearrange(input, "b c (h h2) (w w2) -> b (c h2 w2) h w", h2=factor, w2=factor)

0 comments on commit 8ffaccb

Please sign in to comment.