Skip to content

Commit

Permalink
Sped up nabla_f_w
Browse files Browse the repository at this point in the history
  • Loading branch information
shivamgupta2 committed May 4, 2021
1 parent 7758d10 commit 035feec
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion robustlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from matplotlib import rc
import ast
import mpld3
import time
#mpld3.enable_notebook()

def err_rspca(a,b): return LA.norm(np.outer(a,a)-np.outer(b, b))
Expand Down Expand Up @@ -212,10 +213,15 @@ def alg(self, S, indicator):
w = np.ones(m) / m
X = S
eps_m = round(eps * m)
nabla_f_w_total_time = 0
sigma_w_total_time = 0
start_time = time.time()
for i in range(nItrs):
if self.sparse:
Xw = X.T @ w
sigma_w_start_time = time.time()
Sigma_w = (X.T @ spdiags(w, 0, m, m) @ X) - (Xw @ Xw.T)
sigma_w_total_time += time.time() - sigma_w_start_time
Sigma_w_minus_I = Sigma_w - np.eye(d)
#find indices of largest k entries of each row of Sigma_w_minus_I
largest_k_each_row_index_array = np.argpartition(Sigma_w_minus_I, kth=-k, axis=-1)[:, -k:]
Expand All @@ -231,7 +237,10 @@ def alg(self, S, indicator):
psi_w[largest_rows_index_array, largest_k_each_row_index_array.T] = 1
Z_w = psi_w * Sigma_w_minus_I

nabla_f_w = ((X @ (Z_w @ X.T)).diagonal() - (X @ Z_w @ (X.T @ w)) - (X @ (Z_w.T @ (X.T @ w)))) / LA.norm(Z_w)
nabla_f_w_start_time = time.time()
nabla_f_w = ((X @ (Z_w @ X.T)).diagonal() - (X @ (Z_w @ (X.T @ w))) - (X @ (Z_w.T @ (X.T @ w)))) / LA.norm(Z_w)
#nabla_f_w = ((X @ (Z_w @ (X.T @ w))) - (X @ (Z_w.T @ (X.T @ w)))) / LA.norm(Z_w)
nabla_f_w_total_time += time.time() - nabla_f_w_start_time
else:
Xw = np.matmul(X.T, w)
Sigma_w_fun = lambda v: np.matmul(X.T, w * np.matmul(X, v)) - Xw *np.matmul(Xw.T, v)
Expand All @@ -246,6 +255,9 @@ def alg(self, S, indicator):
w = w - step_size * nabla_f_w/ LA.norm(nabla_f_w)
w = self.project_onto_capped_simplex_simple(w, (1/(m - eps_m)))
#print(w.shape)
print('Time to run GD ', time.time() - start_time)
print('Time to compute Sigma_w ', sigma_w_total_time)
print('Time to copute nabla_f_w', nabla_f_w_total_time)
print(w.shape)
print(X.shape)
mu_gd = np.sum(w * X.T, axis=1)
Expand Down

0 comments on commit 035feec

Please sign in to comment.