diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index 1749f2c0..0e7b7286 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -699,7 +699,7 @@ def make_pymatgen_periodic_structure( is_frac = True else: is_frac = not is_cartesian # TODO this is logically confusing - if not lattice: + if lattice is None: if lat_angles is None or lat_abc is None: raise ValueError( "Unable to construct Lattice object without parameters:" @@ -894,6 +894,10 @@ def calculate_ase_periodic_shifts( frac_coords = torch.from_numpy(atoms.get_scaled_positions()).float() coords = torch.from_numpy(atoms.positions).float() + # convert numpy cells to torch in advance for einsum + if isinstance(cell, np.ndarray): + cell = torch.from_numpy(cell).float() + return_dict = { "src_nodes": torch.LongTensor(all_src), "dst_nodes": torch.LongTensor(all_dst),