Skip to content

Commit

Permalink
code aesthetics
Browse files Browse the repository at this point in the history
  • Loading branch information
fabian-sp committed Dec 4, 2024
1 parent ea99461 commit 776f23f
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 158 deletions.
162 changes: 72 additions & 90 deletions src/gglasso/solver/ext_admm_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,35 @@

import numpy as np
import time
import copy
import warnings
from typing import Optional

from numba import njit
from numba.typed import List


from gglasso.solver.ggl_helper import phiplus, prox_od_1norm, prox_2norm, prox_rank_norm
from gglasso.solver.ggl_helper import phiplus, prox_od_1norm, prox_rank_norm
from gglasso.helper.ext_admm_helper import check_G


def ext_ADMM_MGL(S, lambda1, lambda2, reg , Omega_0, G,\
X0 = None, X1 = None, tol = 1e-5 , rtol = 1e-4, stopping_criterion = 'boyd',\
rho= 1., max_iter = 1000, verbose = False, measure = False, latent = False, mu1 = None):
def ext_ADMM_MGL(S: dict,
lambda1: float,
lambda2: float,
reg: str,
Omega_0: dict,
G: np.ndarray,
X0: Optional[dict]=None,
X1: Optional[dict]=None,
tol: float=1e-5,
rtol: float=1e-4,
stopping_criterion: str='boyd',
rho: float=1.,
max_iter: int=1000,
verbose: bool=False,
measure: bool=False,
latent: bool=False,
mu1: Optional[float]=None
):
"""
This is an ADMM algorithm for solving the Group Graphical Lasso problem
where not all instances have the same number of dimensions, i.e. some variables are present in some instances and not in others.
Expand Down Expand Up @@ -108,7 +123,7 @@ def ext_ADMM_MGL(S, lambda1, lambda2, reg , Omega_0, G,\
"""
K = len(S.keys())
p = np.zeros(K, dtype= int)
p = np.zeros(K, dtype=int)
for k in np.arange(K):
p[k] = S[k].shape[0]

Expand All @@ -128,14 +143,13 @@ def ext_ADMM_MGL(S, lambda1, lambda2, reg , Omega_0, G,\

assert rho > 0, "ADMM penalization parameter must be positive."


# initialize
Omega_t = Omega_0.copy()
Theta_t = Omega_0.copy()
L_t = dict()

for k in np.arange(K):
L_t[k] = np.zeros((p[k],p[k]))
L_t[k] = np.zeros((p[k], p[k]))

# helper and dual variables
Lambda_t = Omega_0.copy()
Expand All @@ -144,18 +158,17 @@ def ext_ADMM_MGL(S, lambda1, lambda2, reg , Omega_0, G,\
if X0 is None:
X0_t = dict()
for k in np.arange(K):
X0_t[k] = np.zeros((p[k],p[k]))
X0_t[k] = np.zeros((p[k], p[k]))
else:
X0_t = X0.copy()

if X1 is None:
X1_t = dict()
for k in np.arange(K):
X1_t[k] = np.zeros((p[k],p[k]))
X1_t[k] = np.zeros((p[k], p[k]))
else:
X1_t = X1.copy()


runtime = np.zeros(max_iter)
residual = np.zeros(max_iter)
status = ''
Expand Down Expand Up @@ -183,22 +196,22 @@ def ext_ADMM_MGL(S, lambda1, lambda2, reg , Omega_0, G,\
# Omega Update
Omega_t_1 = Omega_t.copy()
for k in np.arange(K):
W_t = Theta_t[k] - L_t[k] - X0_t[k] - (1/rho) * S[k]
W_t = Theta_t[k] - L_t[k] - X0_t[k] - (1/rho)*S[k]
eigD, eigQ = np.linalg.eigh(W_t)
Omega_t[k] = phiplus(beta = 1/rho, D = eigD, Q = eigQ)
Omega_t[k] = phiplus(beta=1/rho, D=eigD, Q=eigQ)

# Theta Update
for k in np.arange(K):
V_t = (Omega_t[k] + L_t[k] + X0_t[k] + Lambda_t[k] - X1_t[k]) * 0.5
Theta_t[k] = prox_od_1norm(V_t, lambda1[k]/(2*rho))

#L Update
# L Update
if latent:
for k in np.arange(K):
C_t = Theta_t[k] - X0_t[k] - Omega_t[k]
C_t = (C_t.T + C_t)/2
eigD, eigQ = np.linalg.eigh(C_t)
L_t[k] = prox_rank_norm(C_t, mu1[k]/rho, D = eigD, Q = eigQ)
L_t[k] = prox_rank_norm(C_t, mu1[k]/rho, D=eigD, Q=eigQ)

# Lambda Update
Lambda_t_1 = Lambda_t.copy()
Expand All @@ -218,10 +231,19 @@ def ext_ADMM_MGL(S, lambda1, lambda2, reg , Omega_0, G,\

# Stopping condition
if stopping_criterion == 'boyd':
r_t,s_t,e_pri,e_dual = ADMM_stopping_criterion(Omega_t, Omega_t_1, Theta_t, L_t, Lambda_t, Lambda_t_1, X0_t, X1_t,\
S, rho, p, tol, rtol, latent)
r_t, s_t, e_pri, e_dual = ADMM_stopping_criterion(Omega_t,
Omega_t_1,
Theta_t,
L_t,
Lambda_t,
Lambda_t_1,
X0_t,
X1_t,
S,
rho, p, tol, rtol, latent
)

residual[iter_t] = max(r_t,s_t)
residual[iter_t] = max(r_t, s_t)

if verbose:
print(out_fmt % (iter_t,r_t,s_t,e_pri,e_dual))
Expand All @@ -231,18 +253,25 @@ def ext_ADMM_MGL(S, lambda1, lambda2, reg , Omega_0, G,\
break

elif stopping_criterion == 'kkt':
eta_A = kkt_stopping_criterion(Omega_t, Theta_t, L_t, Lambda_t, dict((k, rho*v) for k,v in X0_t.items()), dict((k, rho*v) for k,v in X1_t.items()),\
S , G, lambda1, lambda2, reg, latent, mu1)
eta_A = kkt_stopping_criterion(Omega_t,
Theta_t,
L_t,
Lambda_t,
dict((k, rho*v) for k,v in X0_t.items()),
dict((k, rho*v) for k,v in X1_t.items()),
S,
G,
lambda1, lambda2, reg, latent, mu1
)
residual[iter_t] = eta_A

if verbose:
print(out_fmt % (iter_t,eta_A))
print(out_fmt % (iter_t, eta_A))

if eta_A <= tol:
status = 'optimal'
break


##################################################################
### MAIN LOOP FINISHED
##################################################################
Expand Down Expand Up @@ -301,7 +330,6 @@ def ADMM_stopping_criterion(Omega, Omega_t_1, Theta, L, Lambda, Lambda_t_1, X0,
for k in np.arange(K):
assert np.all(L[k]==0)


dim = ((p ** 2 + p) / 2).sum() # number of elements of off-diagonal matrix

D1 = np.sqrt(sum([np.linalg.norm(Omega[k])**2 + np.linalg.norm(Lambda[k])**2 for k in np.arange(K)] ))
Expand All @@ -311,13 +339,12 @@ def ADMM_stopping_criterion(Omega, Omega_t_1, Theta, L, Lambda, Lambda_t_1, X0,
e_pri = dim * eps_abs + eps_rel * np.maximum(D1, D2)
e_dual = dim * eps_abs + eps_rel * rho * D3


r = np.sqrt(sum([np.linalg.norm(Omega[k] - Theta[k] + L[k])**2 + np.linalg.norm(Lambda[k] - Theta[k])**2 for k in np.arange(K)] ))
s = rho * np.sqrt(sum([np.linalg.norm(Omega[k] - Omega_t_1[k])**2 + np.linalg.norm(Lambda[k] - Lambda_t_1[k])**2 for k in np.arange(K)] ))

return r,s,e_pri,e_dual
return r, s, e_pri, e_dual

def kkt_stopping_criterion(Omega, Theta, L, Lambda, X0, X1, S , G, lambda1, lambda2, reg, latent = False, mu1 = None):
def kkt_stopping_criterion(Omega, Theta, L, Lambda, X0, X1, S , G, lambda1, lambda2, reg, latent=False, mu1=None):
# X0, X1 are inputed as UNscaled dual variables(!)
K = len(S.keys())

Expand All @@ -337,26 +364,31 @@ def kkt_stopping_criterion(Omega, Theta, L, Lambda, X0, X1, S , G, lambda1, lamb
eigD, eigQ = np.linalg.eigh(Omega[k] - S[k] - X0[k])
proxk = phiplus(beta = 1, D = eigD, Q = eigQ)
# primal varibale optimality
term1[k] = np.linalg.norm(Omega[k] - proxk) / (1 + np.linalg.norm(Omega[k]))
term2[k] = np.linalg.norm(Theta[k] - prox_od_1norm(Theta[k] + X0[k] - X1[k] , lambda1[k])) / (1 + np.linalg.norm(Theta[k]))
term1[k] = np.linalg.norm(Omega[k] - proxk) / (1+np.linalg.norm(Omega[k]))
term2[k] = np.linalg.norm(Theta[k] - prox_od_1norm(Theta[k] + X0[k] - X1[k] , lambda1[k])) / (1+np.linalg.norm(Theta[k]))

if latent:
eigD, eigQ = np.linalg.eigh(L[k] - X0[k])
proxk = prox_rank_norm(L[k] - X0[k], beta = mu1[k], D = eigD, Q = eigQ)
term3[k] = np.linalg.norm(L[k] - proxk) / (1 + np.linalg.norm(L[k]))
proxk = prox_rank_norm(L[k] - X0[k], beta=mu1[k], D=eigD, Q=eigQ)
term3[k] = np.linalg.norm(L[k] - proxk) / (1+np.linalg.norm(L[k]))

V[k] = Lambda[k] + X1[k]

# equality constraints
term5[k] = np.linalg.norm(Omega[k] - Theta[k] + L[k]) / (1 + np.linalg.norm(Theta[k]))
term6[k] = np.linalg.norm(Lambda[k] - Theta[k]) / (1 + np.linalg.norm(Theta[k]))

term5[k] = np.linalg.norm(Omega[k] - Theta[k] + L[k]) / (1+np.linalg.norm(Theta[k]))
term6[k] = np.linalg.norm(Lambda[k] - Theta[k]) / (1+np.linalg.norm(Theta[k]))

V = prox_2norm_G(V, G, lambda2)
for k in np.arange(K):
term4[k] = np.linalg.norm(V[k] - Lambda[k]) / (1 + np.linalg.norm(Lambda[k]))

res = max(np.linalg.norm(term1), np.linalg.norm(term2), np.linalg.norm(term3), np.linalg.norm(term4), np.linalg.norm(term5), np.linalg.norm(term6) )
term4[k] = np.linalg.norm(V[k] - Lambda[k]) / (1+np.linalg.norm(Lambda[k]))

res = max(np.linalg.norm(term1),
np.linalg.norm(term2),
np.linalg.norm(term3),
np.linalg.norm(term4),
np.linalg.norm(term5),
np.linalg.norm(term6)
)
return res

def prox_2norm_G(X, G, l2):
Expand All @@ -375,14 +407,13 @@ def prox_2norm_G(X, G, l2):
assert d[0] == 2
assert d[2] == K

group_size = (G[0,:,:] != -1).sum(axis = 1)
group_size = (G[0,:,:] != -1).sum(axis=1)

tmpX = List()
for k in np.arange(K):
tmpX.append(X[k].copy())

X1 = prox_G_inner(G, tmpX, l2, group_size)

X1 = dict(zip(np.arange(K), X1))

return X1
Expand All @@ -402,13 +433,12 @@ def prox_G_inner(G, X, l2, group_size):
v0[k] = np.nan
else:
v0[k] = X[k][G[0,l,k], G[1,l,k]]



v = v0[~np.isnan(v0)]
# scale with square root of the group size
lam = l2 * np.sqrt(group_size[l])
a = max(np.sqrt((v**2).sum()), lam)
z0 = v * (a - lam) / a
z0 = v * (a-lam)/a

v0[~np.isnan(v0)] = z0

Expand All @@ -420,52 +450,4 @@ def prox_G_inner(G, X, l2, group_size):
# lower triangular
X[k][G[1,l,k], G[0,l,k]] = v0[k]

return X


#%%
# prox operato in case numba version does not work

# def prox_2norm_G(X, G, l2):
# """
# calculates the proximal operator at points X for the group penalty induced by G
# G: 2xLxK matrix where the -th row contains the (i,j)-index of the element in Theta^k which contains to group l
# if G has a entry -1 no element is contained in the group for this Theta^k
# X: dictionary with X^k at key k, each X[k] is assumed to be symmetric
# """
# assert l2 > 0
# K = len(X.keys())
# for k in np.arange(K):
# assert abs(X[k] - X[k].T).max() <= 1e-5, "X[k] has to be symmetric"

# d = G.shape
# assert d[0] == 2
# assert d[2] == K
# L = d[1]

# X1 = copy.deepcopy(X)
# group_size = (G[0,:,:] != -1).sum(axis = 1)

# for l in np.arange(L):
# # for each group construct v, calculate prox, and insert the result. Ignore -1 entries of G
# v0 = np.zeros(K)
# for k in np.arange(K):
# if G[0,l,k] == -1:
# v0[k] = np.nan
# else:
# v0[k] = X[k][G[0,l,k], G[1,l,k]]

# v = v0[~np.isnan(v0)]
# # scale with square root of the group size
# z0 = prox_2norm(v,l2 * np.sqrt(group_size[l]))
# v0[~np.isnan(v0)] = z0

# for k in np.arange(K):
# if G[0,l,k] == -1:
# continue
# else:
# X1[k][G[0,l,k], G[1,l,k]] = v0[k]
# # lower triangular
# X1[k][G[1,l,k], G[0,l,k]] = v0[k]

# return X1
return X
Loading

0 comments on commit 776f23f

Please sign in to comment.