Skip to content

Commit

Permalink
add 1site rdm
Browse files Browse the repository at this point in the history
  • Loading branch information
liwt31 committed Jun 25, 2024
1 parent 0f4ba0b commit 02dc123
Show file tree
Hide file tree
Showing 5 changed files with 462 additions and 305 deletions.
18 changes: 17 additions & 1 deletion renormalizer/tn/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,16 @@ def add_child(self, node: Union["TreeNode", Sequence["TreeNode"]]) -> "TreeNode"

@property
def idx_as_child(self) -> int:
"""
Returns the index of this node as a child of its parent
"""
assert self.parent
return self.parent.children.index(self)

@property
def is_leaf(self) -> bool:
return len(self.children) == 0


class TreeNodeBasis(TreeNode):
# tree node whose data is basis sets
Expand All @@ -52,8 +59,17 @@ def copy(self):


class TreeNodeTensor(TreeNode):
# tree node whose data is numerical tensors for each TTN node/site
def __init__(self, tensor, qn=None):
"""
Tree node whose data is numerical tensors for each TTN node/site.
The indices of the tensor are ordered as follows:
[child1, child2, ..., childN, physical1, physical2, ..., physicalN, parent]
Parameters
----------
tensor: The numerical tensor
qn: The quantum number from the tensor to its parent.
"""
super().__init__()
self.tensor: np.ndarray = tensor
self.qn: np.ndarray = qn
Expand Down
2 changes: 1 addition & 1 deletion renormalizer/tn/tests/test_evolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,5 @@ def test_save_load(ttns_and_ttno):
ttns2 = ttns2.evolve(ttno, tau)
assert ttns2.coeff == ttns1.coeff
exp2 = [ttns2.expectation(o) for o in op_n_list]
np.testing.assert_allclose(exp2, exp1)
np.testing.assert_allclose(exp2, exp1, atol=1e-7)
os.remove(fname)
9 changes: 9 additions & 0 deletions renormalizer/tn/tests/test_tn.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,15 @@ def test_partial_ttno(basis_tree):
e2 = ttns.expectation(ttno2)
np.testing.assert_allclose(e, e2)

@pytest.mark.parametrize("basis_tree", [basis_binary, basis_multi_basis])
def test_site_entropy(basis_tree):
ttns = TTNS.random(basis_tree, 0, 5, 1)
bond_entropy = ttns.calc_bond_entropy()
site1_entropy = ttns.calc_1site_entropy()
for i, node in enumerate(ttns):
if node.is_leaf:
np.testing.assert_allclose(bond_entropy[i], site1_entropy[i], atol=1e-10)


@pytest.mark.parametrize("basis", [basis_binary, basis_multi_basis])
def test_print(basis):
Expand Down
Loading

0 comments on commit 02dc123

Please sign in to comment.