Skip to content

Commit

Permalink
Fix tdscf symmetric_orth function for complex vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Dec 16, 2024
1 parent 3b1254f commit 59a8cb5
Showing 1 changed file with 42 additions and 25 deletions.
67 changes: 42 additions & 25 deletions pyscf/tdscf/_lr_eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,13 +782,19 @@ def _qr(xs, lindep=1e-14):
return xs[:nv], idx

def _symmetric_orth(xt, lindep=1e-6):
xt = np.asarray(xt)
if xt.dtype == np.float64:
return _symmetric_orth_real(xt, lindep)
else:
return _symmetric_orth_cmplx(xt, lindep)

def _symmetric_orth_real(xt, lindep=1e-6):
'''
Symmetric orthogonalization for xt = {[X, Y]},
and its dual basis vectors {[Y, X]}
'''
xt = np.asarray(xt)
x0_size = xt.shape[1]
s11 = xt.conj().dot(xt.T)
s11 = xt.dot(xt.T)
s21 = _conj_dot(xt, xt)
# Symmetric orthogonalize s, where
# s = [[s11, s21.conj().T],
Expand All @@ -806,15 +812,9 @@ def _symmetric_orth(xt, lindep=1e-6):
n = csc.shape[0]
for i in range(n):
_s21 = csc[i:,i:]
if _s21.dtype == np.float64:
# s21 is symmetric for real vectors
w, u = np.linalg.eigh(_s21)
mask = 1 - abs(w) > lindep
else:
# svd(s[:n,n:]) => svd(_s21.conj().T) => u, w
w2, u = np.linalg.eigh(_s21.conj().T.dot(_s21))
mask = 1 - w2**.5 > lindep
w = np.einsum('pi,pi->i', u.conj(), _s21.dot(u))
# s21 is symmetric for real vectors
w, u = np.linalg.eigh(_s21)
mask = 1 - abs(w) > lindep
if np.any(mask):
c = c[:,i:]
break
Expand All @@ -829,22 +829,16 @@ def _symmetric_orth(xt, lindep=1e-6):
e, c = np.linalg.eigh(c_orth.T.dot(s11).dot(c_orth))
c *= e**-.5
c_orth = c_orth.dot(c)
if s21.dtype == np.float64:
csc = c_orth.T.dot(s21).dot(c_orth)
w, u = np.linalg.eigh(csc)
c_orth = c_orth.dot(u)
else:
sc = s21.dot(c_orth)
w2, u = np.linalg.eigh(sc.conj().T.dot(sc))
c_orth = c_orth.dot(u)
w = np.einsum('pi,pi->i', c_orth.conj(), sc.dot(u))
csc = c_orth.T.dot(s21).dot(c_orth)
w, u = np.linalg.eigh(csc)
c_orth = c_orth.dot(u)

# Symmetric diagonalize
# [1 w] => c = [a b]
# [w 1] [b a]
# [1 w.conj()] => c = [a b]
# [w 1 ] [b a]
# where
# a = ((1+w)**-.5 + (1-w)**-.5)/2
# b = ((1+w)**-.5 - (1-w)**-.5)/2
# b = (phase*(1+w)**-.5 - phase*(1-w)**-.5)/2
a1 = (1 + w)**-.5
a2 = (1 - w)**-.5
a = (a1 + a2) / 2
Expand All @@ -853,8 +847,31 @@ def _symmetric_orth(xt, lindep=1e-6):
m = xt.shape[1] // 2
x_orth = (c_orth * a).T.dot(xt)
# Contribution from the conjugated basis
x_orth[:,:m] += (c_orth * b).T.dot(xt[:,m:].conj())
x_orth[:,m:] += (c_orth * b).T.dot(xt[:,:m].conj())
x_orth[:,:m] += (c_orth * b).T.dot(xt[:,m:])
x_orth[:,m:] += (c_orth * b).T.dot(xt[:,:m])
return x_orth

def _symmetric_orth_cmplx(xt, lindep=1e-6):
n, m = xt.shape
if n == 0:
raise RuntimeError('Linear dependency in trial bases')
m = m // 2
# The conjugated basis np.hstack([xt[:,m:], xt[:,:m]]).conj()
s11 = xt.conj().dot(xt.T)
s21 = _conj_dot(xt, xt)
s = np.block([[s11, s21.conj().T],
[s21, s11.conj() ]])
e, c = scipy.linalg.eigh(s)
if e[0] < lindep:
if n == 1:
return xt
return _symmetric_orth_cmplx(xt[:-1], lindep)

c_orth = (c * e**-.5).dot(c[:n].conj().T)
x_orth = c_orth[:n].T.dot(xt)
# Contribution from the conjugated basis
x_orth[:,:m] += c_orth[n:].T.dot(xt[:,m:].conj())
x_orth[:,m:] += c_orth[n:].T.dot(xt[:,:m].conj())
return x_orth

def _sym_dot(V, U1, m0, m1):
Expand Down

0 comments on commit 59a8cb5

Please sign in to comment.