-
Notifications
You must be signed in to change notification settings - Fork 4
/
linear.py
76 lines (62 loc) · 3.12 KB
/
linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
from torch.autograd import Function
class LinearFAFunction(Function):
@staticmethod
# same as reference linear function, but with additional fa tensor for backward
def forward(context, input, weight, weight_fa, bias=None):
context.save_for_backward(input, weight, weight_fa, bias)
output = input.matmul(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(context, grad_output):
input, weight, weight_fa, bias = context.saved_variables
grad_input = grad_weight = grad_weight_fa = grad_bias = None
if context.needs_input_grad[0]:
# all of the logic of FA resides in this one line
# calculate the gradient of input with fixed fa tensor,
# rather than the "correct" model weight
grad_input = grad_output.matmul(weight_fa)
if context.needs_input_grad[1]:
# grad for weight with FA'ed grad_output from downstream layer
# it is same with original linear function
grad_weight = grad_output.t().matmul(input)
if bias is not None and context.needs_input_grad[3]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_weight_fa, grad_bias
class LinearKPFunction(LinearFAFunction):
@staticmethod
def backward(context, grad_output):
grad_input, grad_weight, grad_weight_fa, grad_bias = LinearFAFunction.backward(context, grad_output)
# Update the backward matrices of the Kolen-Pollack algorithm
grad_weight_fa = grad_weight
return grad_input, grad_weight, grad_weight_fa, grad_bias
class FALinear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super(FALinear, self).__init__()
self.input_features = input_features
self.output_features = output_features
# weight and bias for forward pass
# weight has transposed form; more efficient (so i heard) (transposed at forward pass)
self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(output_features))
else:
self.register_parameter('bias', None)
# fixed random weight and bias for FA backward pass
# does not need gradient
self.weight_fa = nn.Parameter(torch.rand(output_features, input_features,
requires_grad=False).to(self.weight.device))
# weight initialization
torch.nn.init.kaiming_uniform_(self.weight)
torch.nn.init.kaiming_uniform_(self.weight_fa)
torch.nn.init.constant_(self.bias, 1)
def forward(self, input):
return LinearFAFunction.apply(input, self.weight, self.weight_fa, self.bias)
class KPLinear(FALinear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
def forward(self, input):
return LinearKPFunction.apply(input, self.weight, self.weight_fa, self.bias)