diff --git a/renormalizer/tn/tests/test_tn.py b/renormalizer/tn/tests/test_tn.py index bffc6b63..78853117 100644 --- a/renormalizer/tn/tests/test_tn.py +++ b/renormalizer/tn/tests/test_tn.py @@ -233,7 +233,7 @@ def test_rdm_entropy_holstein(): mps_idx1, mps_idx2 = 1, 3 dof1 = model.basis[mps_idx1].dof dof2 = model.basis[mps_idx2].dof - ttns_mutual_infos = ttns.calc_2dof_mutual_info((dof1, dof2)) + ttns_mutual_infos,_ = ttns.calc_2dof_mutual_info((dof1, dof2)) ttns_mutual_info = ttns_mutual_infos[(dof1, dof2)] np.testing.assert_allclose(ttns_mutual_info, mps_mutual_info[mps_idx1, mps_idx2], atol=1e-4) diff --git a/renormalizer/tn/tree.py b/renormalizer/tn/tree.py index 61ed310f..a35804a7 100644 --- a/renormalizer/tn/tree.py +++ b/renormalizer/tn/tree.py @@ -1185,12 +1185,14 @@ def calc_2dof_rdm(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any]]]) -> rdm_[dof_pair] = res return rdm_ - def calc_2dof_entropy(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any]]]) -> Dict[Tuple[Any, Any], float]: - rdm = self.calc_2dof_rdm(dofs) + def calc_2dof_entropy(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any]]], rdm: Dict[Any, np.ndarray]=None) -> Dict[Tuple[Any, Any], float]: + if rdm is None: + rdm = self.calc_2dof_rdm(dofs) + entropy = {key: calc_vn_entropy_dm(dm) for key, dm in rdm.items()} return entropy - def calc_2dof_mutual_info(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any]]]) -> Dict[Tuple[Any, Any], float]: + def calc_2dof_mutual_info(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any]]], rdm_2dof: Dict[Any, np.ndarray]=None) -> Dict[Tuple[Any, Any], float]: r""" Calculate mutual information between two DOFs. @@ -1214,14 +1216,16 @@ def calc_2dof_mutual_info(self, dofs: Union[Tuple[Any, Any], List[Tuple[Any, Any for dof_pair in dofs: dofs_lst.append(dof_pair[0]) dofs_lst.append(dof_pair[1]) - entropy_1site = self.calc_1dof_entropy(dofs_lst) - entropy_2site = self.calc_2dof_entropy(dofs) + entropy_1dof = self.calc_1dof_entropy(dofs_lst) + entropy_2dof = self.calc_2dof_entropy(dofs, rdm_2dof) for dof_pair in dofs: dof1 = dof_pair[0] dof2 = dof_pair[1] - mutual_info = (entropy_1site[dof1] + entropy_1site[dof2] - entropy_2site[dof_pair]) / 2 + mutual_info = (entropy_1dof[dof1] + entropy_1dof[dof2] - entropy_2dof[dof_pair]) / 2 mutual_infos[dof_pair] = mutual_info - return mutual_infos + + entropy_tuple = (entropy_1dof, entropy_2dof) + return mutual_infos, entropy_tuple def calc_bond_entropy(self) -> np.ndarray: r"""