Skip to content

Commit

Permalink
update: imporve entropy 2dofs/sites performance
Browse files Browse the repository at this point in the history
  • Loading branch information
ansatzX authored and liwt31 committed Jul 14, 2024
1 parent ec0e361 commit 1d5eb05
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 97 deletions.
5 changes: 3 additions & 2 deletions renormalizer/tn/tests/test_tn.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def test_rdm_entropy_holstein():
mps_idx1, mps_idx2 = 1, 3
dof1 = model.basis[mps_idx1].dof
dof2 = model.basis[mps_idx2].dof
ttns_mutual_info = ttns.calc_2dof_mutual_info(dof1, dof2)
ttns_mutual_infos = ttns.calc_2dof_mutual_info((dof1, dof2))
ttns_mutual_info = ttns_mutual_infos[(dof1, dof2)]
np.testing.assert_allclose(ttns_mutual_info, mps_mutual_info[mps_idx1, mps_idx2], atol=1e-4)


Expand All @@ -258,7 +259,7 @@ def test_2dof_rdm(basis_tree, dofs):
# test 2 dof rdm
dof1, dof2 = dofs

rdm1 = ttns.calc_2dof_rdm(dof1, dof2).reshape(4, 4)
rdm1 = ttns.calc_2dof_rdm((dof1, dof2))[(dof1, dof2)].reshape(4, 4)
rdm2 = mps.calc_2site_rdm()[(dof1, dof2)].reshape(4, 4)
#np.testing.assert_allclose(rdm1, rdm2, atol=1e-10)

Expand Down
256 changes: 162 additions & 94 deletions renormalizer/tn/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,13 +1020,14 @@ def calc_1dof_entropy(self, dof: Union[Any, List[Any]]=None) -> Dict[Any, float]
rdm = self.calc_1dof_rdm(dof)
return {key: calc_vn_entropy_dm(dm) for key, dm in rdm.items()}

def calc_2site_rdm(self, idx1, idx2) -> np.ndarray:
def calc_2site_rdm(self, idxs: Union[Tuple[int, int], List[Tuple[int, int]]]=None) -> Dict[Tuple[int, int], np.ndarray]:
r""" Calculate 2-site reduced density matrix
:math:`\rho_{ij} = \textrm{Tr}_{k \neq i, k \neq j} | \Psi \rangle \langle \Psi |`.
Parameters
----------
idxs: list(tuple), optional
idx1: int
site index (in terms of ``self.node_list``) of the first site.
idx2: int
Expand All @@ -1039,106 +1040,157 @@ def calc_2site_rdm(self, idx1, idx2) -> np.ndarray:
"""
ttno_dummy = TTNO.dummy(self.basis)
ttne = TTNEnviron(self, ttno_dummy)
path = self.find_path(self.node_list[idx1], self.node_list[idx2])
assert path[0] is self.node_list[idx1]
assert path[-1] is self.node_list[idx2]
args = []
# put the nodes for RDM in the arguments
for snode in [path[0], path[-1]]:
args.append(snode.tensor.conj())
args.append(self.get_node_indices(snode, conj=True))

args.append(snode.tensor)
args.append(self.get_node_indices(snode))
if isinstance(idxs, tuple):
idxs = [idxs]
else:
assert isinstance(idxs, list)

# put the nodes in the path in the arguments
for snode in path[1:-1]:
args.append(snode.tensor.conj())
args.append(self.get_node_indices(snode, conj=True))
rdm = {}
for idx_pair in idxs:
idx1 = idx_pair[0]
idx2 = idx_pair[1]

args.append(snode.tensor)
# set ttno to ttno_dummy so that the physical indices are contracted directly
args.append(self.get_node_indices(snode, ttno=ttno_dummy))
path = self.find_path(self.node_list[idx1], self.node_list[idx2])
assert path[0] is self.node_list[idx1]
assert path[-1] is self.node_list[idx2]
args = []
# put the nodes for RDM in the arguments
for snode in [path[0], path[-1]]:
args.append(snode.tensor.conj())
args.append(self.get_node_indices(snode, conj=True))

args.append(snode.tensor)
args.append(self.get_node_indices(snode))

# put the nodes in the path in the arguments
for snode in path[1:-1]:
args.append(snode.tensor.conj())
args.append(self.get_node_indices(snode, conj=True))

args.append(snode.tensor)
# set ttno to ttno_dummy so that the physical indices are contracted directly
args.append(self.get_node_indices(snode, ttno=ttno_dummy))

# put all environment tensors in the arguments
for i, node in enumerate(path):
# skip some of the environments because they are included in the path

if i == 0:
neighbour_nodes = [path[i+1]]
elif i == len(path)-1:
neighbour_nodes = [path[i-1]]
else:
neighbour_nodes = [path[i-1], path[i+1]]

skip_child_idx_list: List[int] = []
skip_parent: bool = False
for neighbour_node in neighbour_nodes:
if neighbour_node.parent is node:
skip_child_idx_list.append(neighbour_node.idx_as_child)
elif node.parent is neighbour_node:
skip_parent = True

enode = ttne.node_list[self.node_idx[node]]
# put all children environments in the arguments
for j, child_tensor in enumerate(enode.environ_children):
if j in skip_child_idx_list:
continue
indices = ttne.get_child_indices(enode, j, self, ttno_dummy)
args.extend([child_tensor, indices])
# put the parent environment in the arguments
if not skip_parent:
args.append(enode.environ_parent)
args.append(ttne.get_parent_indices(enode, self, ttno_dummy))

# put all environment tensors in the arguments
for i, node in enumerate(path):
# skip some of the environments because they are included in the path
# the indices for the output tensor
indices_ket = []
indices_bra = []
for snode in [path[0], path[-1]]:
for dofs in self.tn2dofs[snode]:
indices_ket.append(("down", str(dofs)))
indices_bra.append(("up", str(dofs)))
args.append(indices_ket + indices_bra)
res = oe.contract(*asxp_oe_args(args))
rdm[idx_pair] = res
# perform the contraction
return rdm

if i == 0:
neighbour_nodes = [path[i+1]]
elif i == len(path)-1:
neighbour_nodes = [path[i-1]]
else:
neighbour_nodes = [path[i-1], path[i+1]]
def calc_2site_entropy(self, idxs: Union[Tuple[int, int], List[Tuple[int, int]]]) -> Dict[tuple, float]:

skip_child_idx_list: List[int] = []
skip_parent: bool = False
for neighbour_node in neighbour_nodes:
if neighbour_node.parent is node:
skip_child_idx_list.append(neighbour_node.idx_as_child)
elif node.parent is neighbour_node:
skip_parent = True
if isinstance(idxs, tuple):
idxs = [idxs]
else:
assert isinstance(idxs, list)

enode = ttne.node_list[self.node_idx[node]]
# put all children environments in the arguments
for j, child_tensor in enumerate(enode.environ_children):
if j in skip_child_idx_list:
continue
indices = ttne.get_child_indices(enode, j, self, ttno_dummy)
args.extend([child_tensor, indices])
# put the parent environment in the arguments
if not skip_parent:
args.append(enode.environ_parent)
args.append(ttne.get_parent_indices(enode, self, ttno_dummy))

# the indices for the output tensor
indices_ket = []
indices_bra = []
for snode in [path[0], path[-1]]:
for dofs in self.tn2dofs[snode]:
indices_ket.append(("down", str(dofs)))
indices_bra.append(("up", str(dofs)))
args.append(indices_ket + indices_bra)
rdm = self.calc_2site_rdm(idxs)
entropy = {key: calc_vn_entropy_dm(dm) for key, dm in rdm.items()}
return entropy

# perform the contraction
return oe.contract(*asxp_oe_args(args))
def calc_2dof_rdm(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any]]]) -> Dict[Tuple[Any, Any], np.ndarray]:

def calc_2site_entropy(self, idx1, idx2) -> float:
rdm = self.calc_2site_rdm(idx1, idx2)
return calc_vn_entropy_dm(rdm)

def calc_2dof_rdm(self, dof1, dof2) -> np.ndarray:
site_idx1 = self.basis.dof2idx[dof1]
site_idx2 = self.basis.dof2idx[dof2]
if site_idx1 == site_idx2:
# two dofs on the same site
rdm = self.calc_1site_rdm(site_idx1)[site_idx1]
basis_node: TreeNodeBasis = self.basis.node_list[site_idx1]
n_sets = basis_node.n_sets
basis_idx1 = basis_node.basis_sets.index(self.basis.dof2basis[dof1])
basis_idx2 = basis_node.basis_sets.index(self.basis.dof2basis[dof2])
assert basis_idx1 != basis_idx2
if isinstance(dofs, tuple):
dofs = [dofs]
else:
# two dofs on different sites
rdm = self.calc_2site_rdm(site_idx1, site_idx2)
basis_node1: TreeNodeBasis = self.basis.node_list[site_idx1]
basis_node2: TreeNodeBasis = self.basis.node_list[site_idx2]
n_sets = basis_node1.n_sets + basis_node2.n_sets
basis_idx1 = basis_node1.basis_sets.index(self.basis.dof2basis[dof1])
basis_idx2 = basis_node1.n_sets + basis_node2.basis_sets.index(self.basis.dof2basis[dof2])

indices = [(0, i) for i in range(n_sets)] * 2
indices[basis_idx1] = (1, 0)
indices[basis_idx2] = (1, 1)
indices[n_sets + basis_idx1] = (1, 2)
indices[n_sets + basis_idx2] = (1, 3)
return oe.contract(rdm, indices, [(1, i) for i in range(4)])

def calc_2dof_entropy(self, dof1, dof2) -> float:
rdm = self.calc_2dof_rdm(dof1, dof2)
return calc_vn_entropy_dm(rdm)

def calc_2dof_mutual_info(self, dof1, dof2) -> float:
assert isinstance(dofs, list)

rdm_ = {}
rdm_1site_idx_lst = []
rdm_2site_idx_lst = []
for dof_pair in dofs:
dof1 = dof_pair[0]
dof2 = dof_pair[1]

site_idx1 = self.basis.dof2idx[dof1]
site_idx2 = self.basis.dof2idx[dof2]
if site_idx1 == site_idx2:
rdm_1site_idx_lst.append(site_idx1)
rdm_1site_idx_lst.append(site_idx2)
else:
rdm_2site_idx_lst.append((site_idx1, site_idx2))
if len(rdm_1site_idx_lst) > 0:
rdm_1sites = self.calc_1site_rdm(rdm_1site_idx_lst)
if len(rdm_2site_idx_lst) > 0:
rdm_2sites = self.calc_2site_rdm(rdm_2site_idx_lst)

for dof_pair in dofs:
dof1 = dof_pair[0]
dof2 = dof_pair[1]

site_idx1 = self.basis.dof2idx[dof1]
site_idx2 = self.basis.dof2idx[dof2]
if site_idx1 == site_idx2:
# two dofs on the same site
rdm = rdm_1sites[site_idx1]
basis_node: TreeNodeBasis = self.basis.node_list[site_idx1]
n_sets = basis_node.n_sets
basis_idx1 = basis_node.basis_sets.index(self.basis.dof2basis[dof1])
basis_idx2 = basis_node.basis_sets.index(self.basis.dof2basis[dof2])
assert basis_idx1 != basis_idx2
else:
# two dofs on different sites
rdm = rdm_2sites[(site_idx1, site_idx2)]
basis_node1: TreeNodeBasis = self.basis.node_list[site_idx1]
basis_node2: TreeNodeBasis = self.basis.node_list[site_idx2]
n_sets = basis_node1.n_sets + basis_node2.n_sets
basis_idx1 = basis_node1.basis_sets.index(self.basis.dof2basis[dof1])
basis_idx2 = basis_node1.n_sets + basis_node2.basis_sets.index(self.basis.dof2basis[dof2])

indices = [(0, i) for i in range(n_sets)] * 2
indices[basis_idx1] = (1, 0)
indices[basis_idx2] = (1, 1)
indices[n_sets + basis_idx1] = (1, 2)
indices[n_sets + basis_idx2] = (1, 3)
res = oe.contract(rdm, indices, [(1, i) for i in range(4)])
rdm_[dof_pair] = res
return rdm_

def calc_2dof_entropy(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any]]]) -> Dict[Tuple[Any, Any], float]:
rdm = self.calc_2dof_rdm(dofs)
entropy = {key: calc_vn_entropy_dm(dm) for key, dm in rdm.items()}
return entropy

def calc_2dof_mutual_info(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any]]]) -> Dict[Tuple[Any, Any], float]:
r"""
Calculate mutual information between two DOFs.
Expand All @@ -1150,10 +1202,26 @@ def calc_2dof_mutual_info(self, dof1, dof2) -> float:
-------
mutual_info : float
mutual information between the two DOFs
mutual_infos : Dict[Any, float]
"""
entropy_1site = self.calc_1dof_entropy([dof1, dof2])
entropy_2site = self.calc_2dof_entropy(dof1, dof2)
return (entropy_1site[dof1] + entropy_1site[dof2] - entropy_2site) / 2
if isinstance(dofs, tuple):
dofs = [dofs]
else:
assert isinstance(dofs, list)

mutual_infos = {}
dofs_lst = []
for dof_pair in dofs:
dofs_lst.append(dof_pair[0])
dofs_lst.append(dof_pair[1])
entropy_1site = self.calc_1dof_entropy(dofs_lst)
entropy_2site = self.calc_2dof_entropy(dofs)
for dof_pair in dofs:
dof1 = dof_pair[0]
dof2 = dof_pair[1]
mutual_info = (entropy_1site[dof1] + entropy_1site[dof2] - entropy_2site[dof_pair]) / 2
mutual_infos[dof_pair] = mutual_info
return mutual_infos

def calc_bond_entropy(self) -> np.ndarray:
r"""
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
"scipy",
"h5py",
"opt_einsum",
"sympy"
"sympy",
"print-tree2"
]

setuptools.setup(
Expand Down

0 comments on commit 1d5eb05

Please sign in to comment.