Skip to content

Commit

Permalink
More efficient PBC-CDERI loading (pyscf#2392)
Browse files Browse the repository at this point in the history
* More efficient PBC-CDERI loading

* correctness and readability

* workaround for h5py bug

* make requested changes

* Correct slicing for non-datasets

* Update docstring of _hstack_datasets

---------

Co-authored-by: Qiming Sun <[email protected]>
  • Loading branch information
chillenb and sunqm authored Sep 2, 2024
1 parent 2231cec commit 5417cc2
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 11 deletions.
10 changes: 6 additions & 4 deletions pyscf/ao2mo/outcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,15 +487,17 @@ def _load_from_h5g(h5group, row0, row1, out=None):
col1 = 0
for key in range(nkeys):
col0, col1 = col1, col1 + h5group[str(key)].shape[1]
h5group[str(key)].read_direct(out, dest_sel=numpy.s_[:,col0:col1],
source_sel=numpy.s_[row0:row1])
if col1 > col0:
h5group[str(key)].read_direct(out, dest_sel=numpy.s_[:,col0:col1],
source_sel=numpy.s_[row0:row1])
else: # multiple components
out = numpy.ndarray((dat.shape[0], row1-row0, ncol), dat.dtype, buffer=out)
col1 = 0
for key in range(nkeys):
col0, col1 = col1, col1 + h5group[str(key)].shape[2]
h5group[str(key)].read_direct(out, dest_sel=numpy.s_[:,:,col0:col1],
source_sel=numpy.s_[:,row0:row1])
if col1 > col0:
h5group[str(key)].read_direct(out, dest_sel=numpy.s_[:,:,col0:col1],
source_sel=numpy.s_[:,row0:row1])
return out

def _transpose_to_h5g(h5group, key, dat, blksize, chunks=None):
Expand Down
84 changes: 77 additions & 7 deletions pyscf/pbc/df/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,12 +637,12 @@ def _load_one(self, ki, kj, slices):
if self.aosym == 's1' or kikj == kjki:
dat = self.j3c[str(kikj)]
nsegs = len(dat)
out = numpy.hstack([dat[str(i)][slices] for i in range(nsegs)])
out = _hstack_datasets([dat[str(i)] for i in range(nsegs)], slices)
elif self.aosym == 's2':
dat_ij = self.j3c[str(kikj)]
dat_ji = self.j3c[str(kjki)]
tril = numpy.hstack([dat_ij[str(i)][slices] for i in range(len(dat_ij))])
triu = numpy.hstack([dat_ji[str(i)][slices] for i in range(len(dat_ji))])
tril = _hstack_datasets([dat_ij[str(i)] for i in range(len(dat_ij))], slices)
triu = _hstack_datasets([dat_ji[str(i)] for i in range(len(dat_ji))], slices)
assert tril.dtype == numpy.complex128
naux = self.naux
nao = self.nao
Expand Down Expand Up @@ -807,7 +807,7 @@ def __init__(self, dat, hermi):
def __getitem__(self, s):
dat = self.dat
if isinstance(dat, h5py.Group):
v = numpy.hstack([dat[str(i)][s] for i in range(len(dat))])
v = _hstack_datasets([dat[str(i)] for i in range(len(dat))], s)
else: # For mpi4pyscf, pyscf-1.5.1 or older
v = numpy.asarray(dat[s])

Expand All @@ -831,6 +831,76 @@ def shape(self):
else: # For mpi4pyscf, pyscf-1.5.1 or older
return dat.shape

def _hstack_datasets(data_to_stack, slices=numpy.s_[:]):
"""Faster version of the operation
np.hstack([x[slices] for x in data_to_stack]) for h5py datasets.
Parameters
----------
data_to_stack : list of h5py.Dataset or np.ndarray
Datasets/arrays to be stacked along first axis.
slices: tuple or list of slices, a slice, or ().
The slices (or indices) to select data from each H5 dataset.
Returns
-------
numpy.ndarray
The stacked data, equal to numpy.hstack([dset[slices] for dset in data_to_stack])
"""
# Step 1. Calculate the shape of the output array, and store it
# in res_shape.
res_shape = list(data_to_stack[0].shape)
dset_shapes = [x.shape for x in data_to_stack]

if not (isinstance(slices, tuple) or isinstance(slices, list)):
# If slices is not a tuple, we assume it is a single slice acting on axis 0 only.
slices = (slices,)

def len_of_slice(arraylen, s):
start, stop, step = s.indices(arraylen)
r = range(start, stop, step)
# Python has a very fast builtin method to get the length of a range.
return len(r)

for i, cur_slice in enumerate(slices):
if not isinstance(cur_slice, slice):
return numpy.hstack([dset[slices] for dset in data_to_stack])
if i == 1:
ax1widths_sliced = [len_of_slice(shp[1], cur_slice) for shp in dset_shapes]
else:
# Except along axis 1, we assume the dimensions of all datasets are the same.
# If they aren't, an error gets raised later.
res_shape[i] = len_of_slice(res_shape[i], cur_slice)
if len(slices) <= 1:
ax1widths_sliced = [shp[1] for shp in dset_shapes]

# Final dim along axis 1 is the sum of the post-slice axis 1 widths.
res_shape[1] = sum(ax1widths_sliced)

# Step 2. Allocate the output buffer
out = numpy.empty(res_shape, dtype=numpy.result_type(*[dset.dtype for dset in data_to_stack]))

# Step 3. Read data into the output buffer.
ax1ind = 0
for i, dset in enumerate(data_to_stack):
ax1width = ax1widths_sliced[i]
dest_sel = numpy.s_[:, ax1ind:ax1ind + ax1width]
if hasattr(dset, 'read_direct'):
# h5py has issues with zero-size selections, see
# https://github.com/h5py/h5py/issues/1455,
# so we check for that here.
if out[dest_sel].size > 0:
dset.read_direct(
out,
source_sel=slices,
dest_sel=dest_sel
)
else:
# For array-like objects
out[dest_sel] = dset[slices]
ax1ind += ax1width
return out

class _KPair3CLoader:
def __init__(self, dat, ki, kj, nkpts, aosym):
self.dat = dat
Expand All @@ -843,12 +913,12 @@ def __init__(self, dat, ki, kj, nkpts, aosym):
def __getitem__(self, s):
if self.aosym == 's1' or self.kikj == self.kjki:
dat = self.dat[str(self.kikj)]
out = numpy.hstack([dat[str(i)][s] for i in range(self.nsegs)])
out = _hstack_datasets([dat[str(i)] for i in range(self.nsegs)], s)
elif self.aosym == 's2':
dat_ij = self.dat[str(self.kikj)]
dat_ji = self.dat[str(self.kjki)]
tril = numpy.hstack([dat_ij[str(i)][s] for i in range(self.nsegs)])
triu = numpy.hstack([dat_ji[str(i)][s] for i in range(self.nsegs)])
tril = _hstack_datasets([dat_ij[str(i)] for i in range(self.nsegs)], s)
triu = _hstack_datasets([dat_ji[str(i)] for i in range(self.nsegs)], s)
assert tril.dtype == numpy.complex128
naux, nao_pair = tril.shape
nao = int((nao_pair * 2)**.5)
Expand Down

0 comments on commit 5417cc2

Please sign in to comment.