-
Notifications
You must be signed in to change notification settings - Fork 0
/
SphericalOptimizer.py
93 lines (81 loc) · 4.4 KB
/
SphericalOptimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import math
import torch
from torch.optim import Optimizer
# Spherical Optimizer Class
# Uses the first two dimensions as batch information
# Optimizes over the surface of a sphere using the initial radius throughout
#
# Example Usage:
# opt = SphericalOptimizer(torch.optim.SGD, [x], lr=0.01)
class SphericalOptimizer(Optimizer):
def __init__(self, optimizer, params, **kwargs):
self.opt = optimizer(params, **kwargs)
self.params = params
with torch.no_grad():
self.radii = {param: (param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt() for param in params}
@torch.no_grad()
def step(self, closure=None):
loss = self.opt.step(closure)
for param in self.params:
param.data.div_((param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt())
param.mul_(self.radii[param])
return loss
# Optimizer class for performing projected gradient descent with spherical constraint only on the intermediate latent vectors
class SphericalOptimizerStyle(Optimizer):
def __init__(self, optimizer, params, **kwargs):
self.opt = optimizer(params, **kwargs)
self.params = params
self.radii = (self.params[0].pow(2).sum(tuple(range(2,self.params[0].ndim)),keepdim=True)+1e-9).sqrt()
@torch.no_grad()
def step(self, closure=None):
loss = self.opt.step(closure)
latent_mod = (self.params[0].pow(2).sum(tuple(range(2,self.params[0].ndim)),keepdim=True)+1e-9).sqrt()
for layer in range(latent_mod.shape[1]):
self.params[0][:,layer,:] = self.radii[:,layer,:]*(self.params[0][:,layer,:] / latent_mod[:,layer,:])
return loss
# Optimizer class for performing projected gradient descent with hollow-ball constraint only on the intermediate latent vectors
class HollowBallOptimizerDelta(Optimizer):
def __init__(self, optimizer, params, delta_max, delta_min, **kwargs):
self.opt = optimizer(params, **kwargs)
self.params = params
self.delta_min = delta_min
self.delta_max = delta_max
@torch.no_grad()
def step(self, closure=None):
loss = self.opt.step(closure)
latent_mod = (self.params[0].pow(2).sum(tuple(range(2,self.params[0].ndim)),keepdim=True)+1e-9).sqrt()
for layer in range(latent_mod.shape[1]):
if latent_mod[:,layer,:] <= self.delta_min:
self.params[0][:,layer,:] = self.params[0][:,layer,:] / latent_mod[:,layer,:]
self.params[0][:,layer,:] = self.params[0][:,layer,:] * self.delta_min
elif latent_mod[:,layer,:] >= self.delta_max:
self.params[0][:,layer,:] = self.params[0][:,layer,:] / latent_mod[:,layer,:]
self.params[0][:,layer,:] = self.params[0][:,layer,:] * self.delta_max
return loss
# Optimizer class for performing projected gradient descent with hollow-ball constraint on the intermediate latent vectors
# and spherical constraint on the latent noise vectors
class HollowBallOptimizerDelta2(Optimizer):
def __init__(self, optimizer, params, delta_max, delta_min, **kwargs):
self.opt = optimizer(params, **kwargs)
self.params = params
self.delta_min = delta_min
self.delta_max = delta_max
with torch.no_grad():
self.radii_noise = {param: (param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt() for param in params[1:]}
@torch.no_grad()
def step(self, closure=None):
loss = self.opt.step(closure)
# Hollow ball projection of style vectors
latent_mod = (self.params[0].pow(2).sum(tuple(range(2,self.params[0].ndim)),keepdim=True)+1e-9).sqrt()
for layer in range(latent_mod.shape[1]):
if latent_mod[:,layer,:] <= self.delta_min:
self.params[0][:,layer,:] = self.params[0][:,layer,:] / latent_mod[:,layer,:]
self.params[0][:,layer,:] = self.params[0][:,layer,:] * self.delta_min
elif latent_mod[:,layer,:] >= self.delta_max:
self.params[0][:,layer,:] = self.params[0][:,layer,:] / latent_mod[:,layer,:]
self.params[0][:,layer,:] = self.params[0][:,layer,:] * self.delta_max
# Spherical projection of noise latent vectors
for param in self.params[1:]:
param.data.div_((param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt())
param.mul_(self.radii_noise[param])
return loss