-
Notifications
You must be signed in to change notification settings - Fork 22
/
ADMM.py
executable file
·212 lines (162 loc) · 6.59 KB
/
ADMM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import numpy as np
import numpy.fft as fft
from PIL import Image
import matplotlib.pyplot as plt
import yaml
def loadData(show_im=True):
psf = Image.open(psfname)
psf = np.array(psf, dtype='float32')
data = Image.open(imgname)
data = np.array(data, dtype='float32')
"""In the picamera, there is a non-trivial background
(even in the dark) that must be subtracted"""
bg = np.mean(psf[5:15,5:15])
psf -= bg
data -= bg
"""Resize to a more manageable size to do reconstruction on.
Because resizing is downsampling, it is subject to aliasing
(artifacts produced by the periodic nature of sampling). Demosaicing is an attempt
to account for/reduce the aliasing caused. In this application, we do the simplest
possible demosaicing algorithm: smoothing/blurring the image with a box filter"""
def resize(img, factor):
num = int(-np.log2(factor))
for i in range(num):
img = 0.25*(img[::2,::2,...]+img[1::2,::2,...]+img[::2,1::2,...]+img[1::2,1::2,...])
return img
psf = resize(psf, f)
data = resize(data, f)
"""Now we normalize the images so they have the same total power. Technically not a
necessary step, but the optimal hyperparameters are a function of the total power in
the PSF (among other things), so it makes sense to standardize it"""
psf /= np.linalg.norm(psf.ravel())
data /= np.linalg.norm(data.ravel())
if show_im:
fig1 = plt.figure()
plt.imshow(psf, cmap='gray')
plt.title('PSF')
fig2 = plt.figure()
plt.imshow(data, cmap='gray')
plt.title('Raw data')
plt.show()
return psf, data
def U_update(eta, image_est, tau):
return SoftThresh(Psi(image_est) + eta/mu2, tau/mu2)
def SoftThresh(x, tau):
# numpy automatically applies functions to each element of the array
return np.sign(x)*np.maximum(0, np.abs(x) - tau)
def Psi(v):
return np.stack((np.roll(v,1,axis=0) - v, np.roll(v, 1, axis=1) - v), axis=2)
def X_update(xi, image_est, H_fft, sensor_reading, X_divmat):
return X_divmat * (xi + mu1*M(image_est, H_fft) + CT(sensor_reading))
def M(vk, H_fft):
return np.real(fft.fftshift(fft.ifft2(fft.fft2(fft.ifftshift(vk))*H_fft)))
def C(M):
# Image stored as matrix (row-column rather than x-y)
top = (full_size[0] - sensor_size[0])//2
bottom = (full_size[0] + sensor_size[0])//2
left = (full_size[1] - sensor_size[1])//2
right = (full_size[1] + sensor_size[1])//2
return M[top:bottom,left:right]
def CT(b):
v_pad = (full_size[0] - sensor_size[0])//2
h_pad = (full_size[1] - sensor_size[1])//2
return np.pad(b, ((v_pad, v_pad), (h_pad, h_pad)), 'constant',constant_values=(0,0))
def precompute_X_divmat():
"""Only call this function once!
Store it in a variable and only use that variable
during every update step"""
return 1./(CT(np.ones(sensor_size)) + mu1)
def W_update(rho, image_est):
return np.maximum(rho/mu3 + image_est, 0)
def r_calc(w, rho, u, eta, x, xi, H_fft):
return (mu3*w - rho)+PsiT(mu2*u - eta) + MT(mu1*x - xi, H_fft)
def V_update(w, rho, u, eta, x, xi, H_fft, R_divmat):
freq_space_result = R_divmat*fft.fft2( fft.ifftshift(r_calc(w, rho, u, eta, x, xi, H_fft)) )
return np.real(fft.fftshift(fft.ifft2(freq_space_result)))
def PsiT(U):
diff1 = np.roll(U[...,0],-1,axis=0) - U[...,0]
diff2 = np.roll(U[...,1],-1,axis=1) - U[...,1]
return diff1 + diff2
def MT(x, H_fft):
x_zeroed = fft.ifftshift(x)
return np.real(fft.fftshift(fft.ifft2(fft.fft2(x_zeroed) * np.conj(H_fft))))
def precompute_PsiTPsi():
PsiTPsi = np.zeros(full_size)
PsiTPsi[0,0] = 4
PsiTPsi[0,1] = PsiTPsi[1,0] = PsiTPsi[0,-1] = PsiTPsi[-1,0] = -1
PsiTPsi = fft.fft2(PsiTPsi)
return PsiTPsi
def precompute_R_divmat(H_fft, PsiTPsi):
"""Only call this function once!
Store it in a variable and only use that variable
during every update step"""
MTM_component = mu1*(np.abs(np.conj(H_fft)*H_fft))
PsiTPsi_component = mu2*np.abs(PsiTPsi)
id_component = mu3
"""This matrix is a mask in frequency space. So we will only use
it on images that have already been transformed via an fft"""
return 1./(MTM_component + PsiTPsi_component + id_component)
def xi_update(xi, V, H_fft, X):
return xi + mu1*(M(V,H_fft) - X)
def eta_update(eta, V, U):
return eta + mu2*(Psi(V) - U)
def rho_update(rho, V, W):
return rho + mu3*(V - W)
def init_Matrices(H_fft):
X = np.zeros(full_size)
U = np.zeros((full_size[0], full_size[1], 2))
V = np.zeros(full_size)
W = np.zeros(full_size)
xi = np.zeros_like(M(V,H_fft))
eta = np.zeros_like(Psi(V))
rho = np.zeros_like(W)
return X,U,V,W,xi,eta,rho
def precompute_H_fft(psf):
return fft.fft2(fft.ifftshift(CT(psf)))
def ADMM_Step(X,U,V,W,xi,eta,rho, precomputed):
H_fft, data, X_divmat, R_divmat = precomputed
U = U_update(eta, V, tau)
X = X_update(xi, V, H_fft, data, X_divmat)
V = V_update(W, rho, U, eta, X, xi, H_fft, R_divmat)
W = W_update(rho, V)
xi = xi_update(xi, V, H_fft, X)
eta = eta_update(eta, V, U)
rho = rho_update(rho, V, W)
return X,U,V,W,xi,eta,rho
def runADMM(psf, data):
H_fft = precompute_H_fft(psf)
X,U,V,W,xi,eta,rho = init_Matrices(H_fft)
X_divmat = precompute_X_divmat()
PsiTPsi = precompute_PsiTPsi()
R_divmat = precompute_R_divmat(H_fft, PsiTPsi)
for i in range(iters):
X,U,V,W,xi,eta,rho = ADMM_Step(X,U,V,W,xi,eta,rho, [H_fft, data, X_divmat, R_divmat])
if i % disp_pic == 0:
print(i)
image = C(V)
image[image<0] = 0
f = plt.figure(1)
plt.imshow(image, cmap='gray')
plt.title('Reconstruction after iteration {}'.format(i))
plt.show()
return image
if __name__ == "__main__":
### Reading in params from config file (don't mess with parameter names!)
params = yaml.load(open("admm_config.yml"))
for k,v in params.items():
exec(k + "=v")
### Loading images and initializing the required arrays
psf, data = loadData(True)
sensor_size = np.array(psf.shape)
full_size = 2*sensor_size
### Running the algorithm
final_im = runADMM(psf, data)
plt.imshow(final_im, cmap='gray')
plt.title('Final reconstructed image after {} iterations'.format(iters))
plt.show()
saveim = input('Save final image? (y/n) ')
if saveim == 'y':
filename = input('Name of file: ')
plt.imshow(final_im, cmap='gray')
plt.axis('off')
plt.savefig(filename+'.png', bbox_inches='tight')