-
Notifications
You must be signed in to change notification settings - Fork 48
/
periodic_activations.py
48 lines (40 loc) · 1.71 KB
/
periodic_activations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from torch import nn
import numpy as np
import math
def t2v(tau, f, out_features, w, b, w0, b0, arg=None):
if arg:
v1 = f(torch.matmul(tau, w) + b, arg)
else:
#print(w.shape, t1.shape, b.shape)
v1 = f(torch.matmul(tau, w) + b)
v2 = torch.matmul(tau, w0) + b0
#print(v1.shape)
return torch.cat([v1, v2], -1)
class SineActivation(nn.Module):
def __init__(self, in_features, out_features):
super(SineActivation, self).__init__()
self.out_features = out_features
self.w0 = nn.parameter.Parameter(torch.randn(in_features, 1))
self.b0 = nn.parameter.Parameter(torch.randn(1))
self.w = nn.parameter.Parameter(torch.randn(in_features, out_features-1))
self.b = nn.parameter.Parameter(torch.randn(out_features-1))
self.f = torch.sin
def forward(self, tau):
return t2v(tau, self.f, self.out_features, self.w, self.b, self.w0, self.b0)
class CosineActivation(nn.Module):
def __init__(self, in_features, out_features):
super(CosineActivation, self).__init__()
self.out_features = out_features
self.w0 = nn.parameter.Parameter(torch.randn(in_features, 1))
self.b0 = nn.parameter.Parameter(torch.randn(1))
self.w = nn.parameter.Parameter(torch.randn(in_features, out_features-1))
self.b = nn.parameter.Parameter(torch.randn(out_features-1))
self.f = torch.cos
def forward(self, tau):
return t2v(tau, self.f, self.out_features, self.w, self.b, self.w0, self.b0)
if __name__ == "__main__":
sineact = SineActivation(1, 64)
cosact = CosineActivation(1, 64)
print(sineact(torch.Tensor([[7]])).shape)
print(cosact(torch.Tensor([[7]])).shape)