diff --git a/pyscf/pbc/df/ft_ao.py b/pyscf/pbc/df/ft_ao.py index 2fb27f2180..a1a1e14c56 100644 --- a/pyscf/pbc/df/ft_ao.py +++ b/pyscf/pbc/df/ft_ao.py @@ -632,21 +632,23 @@ def strip_basis(self, rcut): dim = rs_cell.dimension if dim == 0: return self - supmol_bas_coords = self.atom_coords()[self._bas[:,gto.ATOM_OF]] - rb = np.linalg.norm(supmol_bas_coords[:,:dim], axis=1) - a = rs_cell.lattice_vectors() + + # Search the shortest distance to the reference cell for each atom in the supercell. + atom_coords = self.atom_coords() + d = np.linalg.norm(atom_coords[:,None] - rs_cell.atom_coords(), axis=2) + shortest_dist = np.min(d, axis=1) + bas_dist = shortest_dist[self._bas[:,gto.ATOM_OF]] # filter _bas nbas0 = self._bas.shape[0] - if rb.size == self.bas_mask.size: - dr = rb - np.linalg.norm(a[:dim]) - dr = dr.reshape(self.bas_mask.shape) - self.bas_mask = bas_mask = dr < rcut[:,None] + if bas_dist.size == self.bas_mask.size: + bas_dist = bas_dist.reshape(self.bas_mask.shape) + self.bas_mask = bas_mask = bas_dist < rcut[:,None] self._bas = self._bas[bas_mask.ravel()] else: dr = np.empty(self.bas_mask.shape) dr[:] = 1e9 - dr[self.bas_mask] = rb - np.linalg.norm(a[:dim]) + dr[self.bas_mask] = bas_dist bas_mask = dr < rcut[:,None] self._bas = self._bas[bas_mask[self.bas_mask]] self.bas_mask = bas_mask