From d8e56e7b23c408111df2751df8af02bcfa739437 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Sun, 24 Nov 2024 16:46:20 -0800 Subject: [PATCH] fix: correcting expected matrix type and None comparison --- matsciml/datasets/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index 1749f2c0..0e7b7286 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -699,7 +699,7 @@ def make_pymatgen_periodic_structure( is_frac = True else: is_frac = not is_cartesian # TODO this is logically confusing - if not lattice: + if lattice is None: if lat_angles is None or lat_abc is None: raise ValueError( "Unable to construct Lattice object without parameters:" @@ -894,6 +894,10 @@ def calculate_ase_periodic_shifts( frac_coords = torch.from_numpy(atoms.get_scaled_positions()).float() coords = torch.from_numpy(atoms.positions).float() + # convert numpy cells to torch in advance for einsum + if isinstance(cell, np.ndarray): + cell = torch.from_numpy(cell).float() + return_dict = { "src_nodes": torch.LongTensor(all_src), "dst_nodes": torch.LongTensor(all_dst),