-
Notifications
You must be signed in to change notification settings - Fork 0
/
linear_svd.py
88 lines (74 loc) · 2.63 KB
/
linear_svd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import numpy as np
import logging
logger = logging.getLogger(f"./logs/{__name__}.log")
# Normalize V so we don't need to divide by norm.
def normalize(V):
d = V.shape[0]
norms = torch.norm(V, 2, dim=1)
V[:,:] = V / norms.view(d, 1)
return norms
# fasthpp function as provided
# New algorithm with O(d/t + log2(t)) operations.
def fasthpp(V, X, stop_recursion=3):
"""
V: matrix that represent weights of householder matrices (d, d)
X: rectangular matrix (d, bs) to compute H(V) @ X
stop_recursion: integer that controls how many merge iterations before recursion stops.
if None recursion continues until base case.
"""
d = V.shape[0]
with torch.cuda.device_of(V):
Y_ = V.clone().T
W_ = -2*Y_.clone()
# Only works for powers of two.
assert (d & (d-1)) == 0 and d != 0, "d should be power of two. You can just pad the matrix. "
# Step 1: compute (Y, W)s by merging!
k = 1
for i, c in enumerate(range(int(np.log2(d)))):
k_2 = k
k *= 2
m1_ = Y_.view(d//k_2, k_2, d)[0::2] @ torch.transpose(W_.view(d//k_2, k_2, d)[1::2], 1, 2)
m2_ = torch.transpose(W_.view(d//k_2, k_2, d)[0::2], 1, 2) @ m1_
W_ = W_.view(d//k_2, k_2, d)
W_[1::2] += torch.transpose(m2_, 1, 2)
W_ = W_.view(d, d)
if stop_recursion is not None and c == stop_recursion: break
# Step 2:
if stop_recursion is None: return X + W_.T @ (Y_ @ X)
else:
# For each (W,Y) pair multiply with
for i in range(d // k-1, -1, -1 ):
X = X + W_[i*k: (i+1)*k].T @ (Y_[i*k: (i+1)*k] @ X )
return X
# Orthogonal class using fasthpp
class Orthogonal(torch.nn.Module):
def __init__(self, d, device="cuda"):
super(Orthogonal, self).__init__()
self.V = torch.zeros((d, d)).normal_(0, 1)
normalize(self.V.T) # Assuming normalize function is defined
self.V = self.V.to(device)
def forward(self, X):
# fasthpp X: (d, bs)
return fasthpp(self.V, X, stop_recursion=4)
# LinearSVD class
class LinearSVD(torch.nn.Module):
def __init__(self, d):
super(LinearSVD, self).__init__()
self.d = d
self.U = Orthogonal(d)
self.D = torch.empty(d, 1).uniform_(0.99, 1.01)
self.V = Orthogonal(d)
def forward(self, X):
X = self.U(X)
X = self.D * X
X = self.V(X)
return X
if __name__ == "__main__":
# Example usage
d = 512
bs = 32
neuralSVD = LinearSVD(d=d, m=64)
X = torch.zeros(d, bs).normal_()
result = neuralSVD(X)
print(result.shape)