-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathhypernetwork_modules.py
34 lines (20 loc) · 1.03 KB
/
hypernetwork_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
class HyperNetwork(nn.Module):
def __init__(self, f_size = 3, z_dim = 64, out_size=16, in_size=16):
super(HyperNetwork, self).__init__()
self.z_dim = z_dim
self.f_size = f_size
self.out_size = out_size
self.in_size = in_size
self.w1 = Parameter(torch.fmod(torch.randn((self.z_dim, self.out_size*self.f_size*self.f_size)).cuda(),2))
self.b1 = Parameter(torch.fmod(torch.randn((self.out_size*self.f_size*self.f_size)).cuda(),2))
self.w2 = Parameter(torch.fmod(torch.randn((self.z_dim, self.in_size*self.z_dim)).cuda(),2))
self.b2 = Parameter(torch.fmod(torch.randn((self.in_size*self.z_dim)).cuda(),2))
def forward(self, z):
h_in = torch.matmul(z, self.w2) + self.b2
h_in = h_in.view(self.in_size, self.z_dim)
h_final = torch.matmul(h_in, self.w1) + self.b1
kernel = h_final.view(self.out_size, self.in_size, self.f_size, self.f_size)
return kernel