Skip to content

Commit

Permalink
fix dump load coeff bug
Browse files Browse the repository at this point in the history
  • Loading branch information
liwt31 committed May 8, 2024
1 parent 75ebe3f commit 37145a7
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
4 changes: 3 additions & 1 deletion renormalizer/mps/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,13 @@ def pair_tensor_contract(
return tensordot(view_left, view_right, axes=(left_pos, right_pos))


def asnumpy(array: Union[np.ndarray, xp.ndarray, Matrix]) -> np.ndarray:
def asnumpy(array: Union[np.ndarray, xp.ndarray, Matrix, List]) -> np.ndarray:
if array is None:
return None
if isinstance(array, Matrix):
return array.array
if isinstance(array, List):
return np.array(array)
if not USE_GPU:
assert isinstance(array, np.ndarray)
return array
Expand Down
4 changes: 2 additions & 2 deletions renormalizer/tn/tests/test_evolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def add_ttno_offset(ttns: TTNS, ttno: TTNO):
return TTNO(ttno.basis, ham_terms)



def construct_ttns_and_ttno_chain():
basis, ttns, ttno = from_mps(init_mps)
op_n_list = [TTNO(basis, [Op(r"a^\dagger a", i)]) for i in range(3)]
Expand Down Expand Up @@ -167,6 +166,7 @@ def test_save_load(ttns_and_ttno):
ttns2.dump(fname)
ttns2 = TTNS.load(ttns.basis, fname)
ttns2 = ttns2.evolve(ttno, tau)
assert ttns2.coeff == ttns1.coeff
exp2 = [ttns2.expectation(o) for o in op_n_list]
np.testing.assert_allclose(exp1, exp2)
np.testing.assert_allclose(exp2, exp1)
os.remove(fname)
28 changes: 23 additions & 5 deletions renormalizer/tn/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class TTNBase(Tree):
# A tree whose tree node is TreeNodeTensor

@classmethod
def load(cls, basis: BasisTree, fname: str):
def load(cls, basis: BasisTree, fname: str, other_attrs=None):
npload = np.load(fname, allow_pickle=True)
assert npload["version"] == "0.1"

Expand All @@ -38,7 +38,10 @@ def load(cls, basis: BasisTree, fname: str):
qn = npload[f"qn_{i}"]
nodes.append(TreeNodeTensor(tensor, qn))
copy_connection(basis.node_list, nodes)
return cls(basis, root=nodes[0])
instance = cls(basis, root=nodes[0])
for attr in other_attrs:
setattr(instance, attr, npload[attr])
return instance

def __init__(self, basis: BasisTree, root: TreeNodeTensor):
self.basis = basis
Expand All @@ -47,13 +50,15 @@ def __init__(self, basis: BasisTree, root: TreeNodeTensor):
def dump(self, fname: str, other_attrs=None):
if other_attrs is None:
other_attrs = []

Check warning on line 52 in renormalizer/tn/tree.py

View check run for this annotation

Codecov / codecov/patch

renormalizer/tn/tree.py#L52

Added line #L52 was not covered by tests
elif isinstance(other_attrs, str):
other_attrs = [other_attrs]
assert isinstance(other_attrs, list)

data_dict = {
"version": "0.1",
"nsites": len(self),
}

for attr in other_attrs:
data_dict[attr] = getattr(self, attr)

for i, node in enumerate(self.node_list):
data_dict[f"tensor_{i}"] = node.tensor
data_dict[f"qn_{i}"] = node.qn
Expand Down Expand Up @@ -321,6 +326,13 @@ def __matmul__(self, other):

class TTNS(TTNBase):

@classmethod
def load(cls, basis: BasisTree, fname: str, other_attrs=None):
if other_attrs is None:
other_attrs = []
other_attrs = other_attrs + ["coeff"]
return super().load(basis, fname, other_attrs)

@classmethod
def random(cls, basis: BasisTree, qntot, m_max, percent=1.0):
"""
Expand Down Expand Up @@ -1069,6 +1081,12 @@ def scale(self, val, inplace=False):
new_mp.root.tensor *= val
return new_mp

def dump(self, fname, other_attrs=None):
if other_attrs is None:
other_attrs = []
other_attrs = other_attrs + ["coeff"]
super().dump(fname, other_attrs)

def __add__(self, other: "TTNS"):
return self.add(other)

Expand Down

0 comments on commit 37145a7

Please sign in to comment.