Skip to content

Commit

Permalink
refactorize normalize
Browse files Browse the repository at this point in the history
  • Loading branch information
jjren authored and liwt31 committed Apr 7, 2022
1 parent 73d9bfb commit 45e77ce
Show file tree
Hide file tree
Showing 19 changed files with 135 additions and 101 deletions.
2 changes: 1 addition & 1 deletion renormalizer/model/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self, dof, omega, nbas, x0=0., dvr=False, general_xp_power=False):
self.dvr = True

def __repr__(self):
return f"(x0: {self.x0}, omega: {self.omega}, nbas: {self.nbas})"
return f"(dof: {self.dof}, x0: {self.x0}, omega: {self.omega}, nbas: {self.nbas})"

def op_mat(self, op: Union[Op, str]):
if not isinstance(op, Op):
Expand Down
11 changes: 8 additions & 3 deletions renormalizer/mps/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,16 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
mps : renormalizer.mps.Mps
optimized ground state MPS.
Note it's not the same with the overwritten input MPS.
See Also
--------
renormalizer.utils.configs.OptimizeConfig : The optimization configuration.
Note
----
When On-the-fly swapping algorithm is used, the site ordering of the returned
MPS is changed and the original MPO will not correspond to it and should be
updated with returned mps.model.
"""

assert mps.optimize_config.method in ["2site", "1site"]
Expand Down Expand Up @@ -130,10 +135,10 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
assert res_mps is not None
# remove the redundant basis near the edge
if mps.optimize_config.nroots == 1:
res_mps = res_mps.normalize().ensure_left_canon().canonicalise()
res_mps = res_mps.normalize("mps_only").ensure_left_canon().canonicalise()
logger.info(f"{res_mps}")
else:
res_mps = [mp.normalize().ensure_left_canon().canonicalise() for mp in res_mps]
res_mps = [mp.normalize("mps_only").ensure_left_canon().canonicalise() for mp in res_mps]
logger.info(f"{res_mps[0]}")
return macro_iteration_result, res_mps

Expand Down
37 changes: 31 additions & 6 deletions renormalizer/mps/mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,14 +326,30 @@ def _get_big_qn(self, cidx: List[int], swap=False):
qnmat = np.add.outer(qnbigl, qnbigr)
return qnbigl, qnbigr, qnmat

@property
def mp_norm(self) -> float:
# the fast version in the comment rarely makes sense because in a lot of cases
# the mps is not canonicalised (though qnidx is set)
"""
if self.is_left_canon:
assert self.check_left_canonical()
return np.linalg.norm(np.ravel(self[-1]))
else:
assert self.check_right_canonical()
return np.linalg.norm(np.ravel(self[0]))
"""
res = self.conj().dot(self).real
if res < 0:
assert np.abs(res) < 1e-8
res = 0
res = np.sqrt(res)

return float(res)

def add(self, other: "MatrixProduct"):
assert self.qntot == other.qntot
assert self.site_num == other.site_num

# note that the coeff should be the same
if self.is_mps or self.is_mpdm:
assert np.allclose(self.coeff, other.coeff)

new_mps = self.metacopy()
if other.dtype == backend.complex_dtype:
new_mps.dtype = backend.complex_dtype
Expand Down Expand Up @@ -1080,8 +1096,17 @@ def append(self, array):
self._mp.append(new_mt)

def __str__(self):
template_str = "current size: {}, Matrix product bond dim:{}"
return template_str.format(sizeof_fmt(self.total_bytes), self.bond_dims,)
if self.is_mps:
string = "mps"
elif self.is_mpo:
string = "mpo"
elif self.is_mpdm:
string = "mpdm"
else:
assert False
template_str = "{} current size: {}, Matrix product bond dim:{}"

return template_str.format(string, sizeof_fmt(self.total_bytes), self.bond_dims,)

def __del__(self):
dir_with_id = os.path.join(self.compress_config.dump_matrix_dir, str(id(self)))
Expand Down
2 changes: 1 addition & 1 deletion renormalizer/mps/mpdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def max_entangled_ex(cls, model, normalize=True):

ex_mps = ex_mpo @ mps
if normalize:
ex_mps.normalize(1.0)
ex_mps.normalize("mps_and_coeff")
return cls.from_mps(ex_mps)

@classmethod
Expand Down
6 changes: 1 addition & 5 deletions renormalizer/mps/mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,5 @@ def from_mp(cls, model, mp):
mpo.append(mt)
mpo.build_empty_qn()
return mpo

@property
def dmrg_norm(self) -> float:
res = np.sqrt(self.conj().dot(self).real)
return float(res.real)


120 changes: 64 additions & 56 deletions renormalizer/mps/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,15 @@ def adaptive_fun(self: "Mps", mpo, evolve_target_t):
f"guess_dt: {config.guess_dt}, try time step size: {dt}"
)

# note that the wavefunction is normalized
mps_half1 = fun(cur_mps, mpo, dt / 2).normalize()
mps_half2 = fun(mps_half1, mpo, dt / 2).normalize()
mps = fun(cur_mps, mpo, dt).normalize()
mps_half1 = fun(cur_mps, mpo, dt / 2)
mps_half2 = fun(mps_half1, mpo, dt / 2)
mps = fun(cur_mps, mpo, dt)
dis = mps.distance(mps_half2)

# prevent bug. save "some" memory.
del mps_half1, mps

p = (0.75 * config.adaptive_rtol / (dis + 1e-30)) ** (1./3)
p = (0.75 * config.adaptive_rtol / (dis/mps_half2.mp_norm + 1e-30)) ** (1./3)
logger.debug(f"distance: {dis}, enlarge p parameter: {p}")
if p < p_min:
p = p_min
Expand Down Expand Up @@ -425,23 +424,9 @@ def nexciton(self):

@property
def norm(self):
# return self.dmrg_norm * self.hartree_norm
return np.linalg.norm(self.coeff)

@property
def dmrg_norm(self) -> float:
# the fast version in the comment rarely makes sense because in a lot of cases
# the mps is not canonicalised (though qnidx is set)
"""
if self.is_left_canon:
assert self.check_left_canonical()
return np.linalg.norm(np.ravel(self[-1]))
else:
assert self.check_right_canonical()
return np.linalg.norm(np.ravel(self[0]))
"""
res = np.sqrt(self.conj().dot(self).real)
return float(res.real)
'''the norm of the total wavefunction
'''
return np.linalg.norm(self.coeff) * self.mp_norm

def _expectation_path(self):
# S--a--S--e--S
Expand Down Expand Up @@ -564,23 +549,34 @@ def metacopy(self) -> "Mps":
new.evolve_config = self.evolve_config.copy()
return new

def normalize(self, norm=None):
# real time propagation: dmrg should be normalized, coefficient is not changed,
# use norm=None
# imag time propagation: dmrg should be normalized, coefficient is normalized to 1.0
# applied by a operator then normalize: dmrg should be normalized,
# coefficient is set to the length
# these two cases should set `norm` equals to corresponding value
self.scale(1.0 / self.dmrg_norm, inplace=True)
if norm is not None:
self.coeff *= norm / (np.linalg.norm(self.coeff))
return self

def canonical_normalize(self):
# applied by a operator then normalize: dmrg should be normalized,
# tdh should be normalized, coefficient is set to the length
# suppose length is only determined by dmrg part
return self.normalize(self.dmrg_norm)
def normalize(self, kind):
r''' normalize the wavefunction
Parameters
----------
kind: str
"mps_only": the mps part is normalized and coeff is not modified;
"mps_norm_to_coeff": the mps part is normalized and the norm is multiplied to coeff;
"mps_and_coeff": both mps and coeff is normalized
Returns
-------
``self`` is overwritten.
'''

if kind == "mps_only":
new_coeff = self.coeff
elif kind == "mps_and_coeff":
new_coeff = self.coeff / np.linalg.norm(self.coeff)
elif kind == "mps_norm_to_coeff":
new_coeff = self.coeff * self.mp_norm
else:
raise ValueError(f"kind={kind} is valid.")
new_mps = self.scale(1.0 / self.mp_norm, inplace=True)
new_mps.coeff = new_coeff

return new_mps


def expand_bond_dimension(self, hint_mpo=None, coef=1e-10, include_ex=True):
"""
Expand Down Expand Up @@ -639,10 +635,10 @@ def expand_bond_dimension(self, hint_mpo=None, coef=1e-10, include_ex=True):
lastone = lastone.canonicalise().compress(
m_target // hint_mpo.bond_dims_mean + 1
)
lastone = (hint_mpo @ lastone).normalize()
lastone = (hint_mpo @ lastone).normalize("mps_and_coeff")
logger.debug(f"expander bond dimension: {expander.bond_dims}")
self.compress_config.bond_dim_max_value += self.bond_dims_mean
return (self + expander.scale(coef, inplace=True)).canonicalise().canonicalise().canonical_normalize()
return (self + expander.scale(coef*self.norm, inplace=True)).canonicalise().canonicalise().normalize("mps_norm_to_coeff")

def evolve(self, mpo, evolve_dt, normalize=True) -> "Mps":

Expand All @@ -659,9 +655,9 @@ def evolve(self, mpo, evolve_dt, normalize=True) -> "Mps":
new_mps = method(mpo, evolve_dt)
if normalize:
if np.iscomplex(evolve_dt):
new_mps.normalize(1.0)
new_mps.normalize("mps_and_coeff")
else:
new_mps.normalize(None)
new_mps.normalize("mps_only")
return new_mps

def _evolve_prop_and_compress_tdrk4(self, mpo, evolve_dt) -> "Mps":
Expand Down Expand Up @@ -742,10 +738,7 @@ def sub_time_step_evolve(y,tau,t0):
for istage in range(rk_config.stage) if not \
np.allclose(b[0,istage],b[1,istage])])

error_norm2 = error.conj().dot(error).real
if error_norm2 < 0:
error_norm2 = 0
error = np.sqrt(error_norm2) / new_mps.dmrg_norm
error = error.norm / new_mps.norm
else:
assert len(rk_config.order) == 1
error = 0
Expand All @@ -766,12 +759,12 @@ def sub_time_step_evolve(y,tau,t0):
dt = min_abs(new_mps.evolve_config.guess_dt, evolve_dt-evolved_dt)
logger.debug(f"guess_dt: {new_mps.evolve_config.guess_dt}, try time step size: {dt}")
new_mps, error = sub_time_step_evolve(new_mps, dt, evolved_dt)
p = (new_mps.evolve_config.adaptive_rtol / (error + 1e-30)) ** 0.2
logger.debug(f"RK45 relative error: {error}, enlarge p parameter: {p}")
p = (new_mps.evolve_config.adaptive_rtol / (error + 1e-30)) ** (1/rk_config.order[0])
logger.debug(f"RKsolver:{rk_config.method} relative error: {error}, enlarge p parameter: {p}")

if p < p_restart:
# not accurate, will restart
new_mps.config.guess_dt = dt * max(p_min, p)
new_mps.evolve_config.guess_dt = dt * max(p_min, p)
logger.debug(
f"evolution not converged, new guess_dt: {new_mps.evolve_config.guess_dt}"
)
Expand Down Expand Up @@ -839,15 +832,13 @@ def _evolve_prop_and_compress(self, mpo, evolve_dt) -> "Mps":
scale = (-1.0j * dt) ** idx * propagation_c[idx]
scaled_termlist.append(term.scale(scale))
del term
# Note that the wavefunction is normalized
new_mps1 = compressed_sum(scaled_termlist[:-1]).normalize()

new_mps1 = compressed_sum(scaled_termlist[:-1])
new_mps2 = compressed_sum(
[new_mps1, scaled_termlist[-1]]
).normalize()
)
dis = new_mps1.distance(new_mps2)
# here the norm of mps is 1,
# otherwise the relative value dis/norm should be used
p = (config.adaptive_rtol / (dis + 1e-30)) ** (1/(order-1))
p = (config.adaptive_rtol / (dis/new_mps2.mp_norm + 1e-30)) ** (1/order)
logger.debug(f"RK45 error distance: {dis}, enlarge p parameter: {p}")

if xp.allclose(dt, evolve_dt):
Expand Down Expand Up @@ -1702,6 +1693,23 @@ def dump(self, fname):
def __setitem__(self, key, value):
return super().__setitem__(key, value)


def add(self, other):
if not np.allclose(self.coeff, other.coeff):
self.scale(self.coeff, inplace=True)
other.scale(other.coeff, inplace=True)
self.coeff = 1
other.coeff = 1
return super().add(other)

def distance(self, other) -> float:
if not np.allclose(self.coeff, other.coeff):
self.scale(self.coeff, inplace=True)
other.scale(other.coeff, inplace=True)
self.coeff = 1
other.coeff = 1
return super().distance(other)


def projector(
ms: xp.ndarray, left: bool, Ovlp_inv1: xp.ndarray = None, Ovlp0: xp.ndarray = None
Expand Down
2 changes: 1 addition & 1 deletion renormalizer/mps/tda.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def kernel(self, restart=False, include_psi0=False):

if not restart:
# make sure that M is not redundant near the edge
mps = self.mps.ensure_right_canon().canonicalise().normalize().canonicalise()
mps = self.mps.ensure_right_canon().canonicalise().normalize("mps_and_coeff").canonicalise()
logger.debug(f"reference mps shape, {mps}")
mps_r_cano = mps.copy()
assert mps.to_right
Expand Down
8 changes: 5 additions & 3 deletions renormalizer/mps/tests/test_gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def test_ofs():
mps.compress_config.ofs = OFS.ofs_s
energies, mps_opt = optimize_mps(mps.copy(), mpo)
assert energies[-1] == pytest.approx(GS_E, rel=1e-5)
mpo = Mpo(mps_opt.model)
assert mps_opt.expectation(mpo) == pytest.approx(GS_E, rel=1e-5)


@pytest.mark.parametrize("with_ofs", (True, False))
def test_qc(with_ofs):
"""
Expand All @@ -119,10 +119,12 @@ def test_qc(with_ofs):
fci_e = -3.23747673055271 - nuc

nelec = 6
M = 12
M = 20
procedure = [[M, 0.4], [M, 0.2], [M, 0.1], [M, 0], [M, 0], [M, 0], [M, 0]]
mps = Mps.random(model, nelec, M, percent=1.0)

hf = Mps.hartree_product_state(model, {i:1 for i in range(nelec)})
mps = mps.scale(1e-8)+hf
#print("hf energy", mps.expectation(mpo))
mps.optimize_config.procedure = procedure
mps.optimize_config.method = "2site"
if with_ofs:
Expand Down
Loading

0 comments on commit 45e77ce

Please sign in to comment.