-
Notifications
You must be signed in to change notification settings - Fork 2
/
tag_update.py
166 lines (150 loc) · 5.99 KB
/
tag_update.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import math
from utils import *
class TAG(object):
"""
Implementation of our proposed TAG optimizer
"""
def __init__(self, model, args, num_tasks, optim='rms', lr=None, b=5):
"""
Gets all the necessary arguments for initialization
:param model: Current model
:param args: All arguments for experiment configuration
:param num_tasks: Total number of tasks
:param optim: Base optimizers to be used: {'rms':TAG-RMSProp, 'adagrad':TAG-Adagrad, 'adam': TAG-Adam}
:param lr: Learning rate (eta)
:param b: Hyperparameter for regulating alpha - high b value implies more focus on preventing forgetting
"""
self.optim = optim
self.args = args
self.iters = 0
self.model = model
self.b = b
self.weight_decay = 0.0
if self.optim=='adam':
self.beta1, self.beta2 = 0.9, 0.999
else:
self.beta1, self.beta2 = 0.9, 0.99
self.lr = lr
self.alpha_add_ = {}
self.v, self.v_t = {}, {}
self.m, self.m_t = {}, {}
self.m_t_norms = {}
for task in range(num_tasks):
self.v_t[task] = {}
self.m_t[task] = {}
self.m_t_norms[task] = {}
self.alpha_add_[task] = {}
for (name, param) in model.named_parameters():
if task == 0:
self.v[name] = torch.zeros_like(param).to(args.device)
self.m[name] = torch.zeros_like(param).to(args.device)
self.alpha_add_[task][name] = np.array([1])
self.v_t[task][name] = torch.zeros_like(param).to(args.device)
self.m_t[task][name] = torch.zeros_like(param).to(args.device)
self.m_t_norms[task][name] = torch.zeros_like(param).to(args.device)
def zero_grad(self):
return self.model.zero_grad()
def update_all(self, task_id):
"""
Normalize the current task-based first moments (that will remain fixed)
"""
for name, v in self.model.named_parameters():
self.m_t_norms[task_id][name] = self.m_t[task_id][name].reshape(-1) / torch.norm(self.m_t[task_id][name])
def update_naive(self, param_name, param_grad):
"""
Use the naive-optimizer update
:param param_name: Parameter identity
:param param_grad: Gradient associated with the given parameter
:return: New update to the given parameter
"""
if self.optim=='rms':
self.v[param_name] = self.beta2 * self.v[param_name] + (1 - self.beta2) * param_grad ** 2
else:
self.v[param_name] += param_grad ** 2
denom = torch.sqrt(self.v[param_name]) + 1e-8
return - (self.lr * param_grad / denom)
def update_tag(self, param_name, param_grad, task_id):
"""
Update Task-based accumulated gradients, calculate alpha and return the new updates
:param param_name: Parameter identity
:param param_grad: Gradient associated with the given parameter
:param task_id: Current task identity
:return: New update to the given parameter
"""
bias_corr1, bias_corr2 = 1, 1
new_v = None
# Update task-based first moment
self.m_t[task_id][param_name] = self.beta1 * self.m_t[task_id][param_name] + (1 - self.beta1) * param_grad
# Change numerator based on the optimizer
if self.optim=='adam':
bias_corr1, bias_corr2 = 1 - self.beta1 ** (self.iters + 1), 1 - self.beta2 ** (self.iters + 1)
numer = self.m_t[task_id][param_name] / bias_corr1
else:
numer = param_grad
# Update task-based second moments based on the optimizer
if self.optim=='rms' or self.optim=='adam':
self.v_t[task_id][param_name] = self.beta2 * self.v_t[task_id][param_name] + (1 - self.beta2) * param_grad ** 2
else:
self.v_t[task_id][param_name] = self.v_t[task_id][param_name] + param_grad ** 2
# Get new alphas by computing correlation using task-based first moments
if task_id>0:
alpha_add = []
for t in range(task_id):
corr = torch.dot(self.m_t[task_id][param_name].reshape(-1) / torch.norm(self.m_t[task_id][param_name]),
self.m_t_norms[t][param_name])
alpha_add += [(-corr).cpu().numpy()]
alpha_add += [-1.]
alpha_add = torch.from_numpy(np.array(alpha_add)).to(DEVICE)
alpha_add_ = torch.exp(self.b*alpha_add).float()
else:
alpha_add_ = torch.from_numpy(np.array([1.0] * (task_id + 1))).to(DEVICE)
self.alpha_add_[task_id][param_name] = alpha_add_.cpu().numpy()
# Concatenate all task-based second moments
for t in range(task_id):
new_v = self.v_t[t][param_name].unsqueeze(0) \
if t==0 \
else torch.cat((new_v, self.v_t[t][param_name].unsqueeze(0)), dim=0)
new_v = self.v_t[task_id][param_name].unsqueeze(0) \
if new_v is None \
else torch.cat((new_v, self.v_t[task_id][param_name].unsqueeze(0)), dim=0)
# Compute inner product of alphas and task-based second moments using torch.einsum() function.
# eq takes care of varying the dimensions of parameter variable with each layer.
eq = {1:'n,nh->h', 2:'n,nhw->hw', 3:'n,nhwc->hwc', 4: 'n,nhwvd->hwvd', 5:'n,nhwzxc->hwzxc'}[len(param_grad.shape)]
denom = (torch.sqrt(torch.einsum(eq, alpha_add_.float(), new_v))/ math.sqrt(bias_corr2)) + 1e-8
return - (self.lr * numer / denom)
def step(self, model, task_id, step):
"""
Perform update over the parameters
:param model: Current model
:param task_id: Current task id (t)
:param step: Current Step (n)
:return:
"""
self.iters = step
state_dict = model.state_dict()
for i, (name, param) in enumerate(state_dict.items()):
if name.split('.')[-1] in ['running_mean', 'num_batches_tracked', 'running_var']:
continue
for n, v in model.named_parameters():
if n == name:
break
if v.grad is None:
continue
update = self.update_tag(name, v.grad, task_id)
state_dict[name].data.copy_(param + update.reshape(param.shape))
return state_dict
def store_alpha(tag_optimizer, task_id, iter, alpha_mean=None):
"""
Collects alpha values for given task (t) and current step (n)
:param tag_optimizer: Object of the class tag_opt()
:param task_id: Current task identity
:param iter: Current step in the epoch
:return: alpha_mean: Dictionary with previous task ids as keys
"""
for tau in tag_optimizer.alpha_add_[task_id]:
alphas = tag_optimizer.alpha_add_[task_id][tau]
if iter==0:
alpha_mean[tau] = alphas
else:
alpha_mean[tau] = (alpha_mean[tau]*iter + alphas)/(iter+1)
return alpha_mean