diff --git a/setup.cfg b/setup.cfg index f60474b47..e157f611f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,12 +61,14 @@ dev = mypy pre-commit pylint + pyscf pytest pytest-random-order tox tox = covdefaults coverage + pyscf pytest pytest-random-order diff --git a/src/dxtb/basis/indexhelper.py b/src/dxtb/basis/indexhelper.py index 7b8f3a311..3a2b6c057 100644 --- a/src/dxtb/basis/indexhelper.py +++ b/src/dxtb/basis/indexhelper.py @@ -944,3 +944,25 @@ def allowed_dtypes(self) -> tuple[torch.dtype, ...]: Collection of allowed dtypes the TensorLike object can take. """ return (torch.int16, torch.int32, torch.int64, torch.long) + + def __repr__(self) -> str: + return ( + f"IndexHelper(\n" + f" unique_angular={self.unique_angular},\n" + f" angular={self.angular},\n" + f" atom_to_unique={self.atom_to_unique},\n" + f" ushells_to_unique={self.ushells_to_unique},\n" + f" ushells_per_unique={self.ushells_per_unique},\n" + f" shells_to_ushell={self.shells_to_ushell},\n" + f" shells_per_atom={self.shells_per_atom},\n" + f" shell_index={self.shell_index},\n" + f" shells_to_atom={self.shells_to_atom},\n" + f" orbitals_per_shell={self.orbitals_per_shell},\n" + f" orbital_index={self.orbital_index},\n" + f" orbitals_to_shell={self.orbitals_to_shell},\n" + f" batched={self.batched},\n" + f" store={self.store},\n" + f" device={self.device},\n" + f" dtype={self.dtype}\n" + ")" + ) diff --git a/src/dxtb/constants/defaults.py b/src/dxtb/constants/defaults.py index 2bc1a6b34..1ec82ec09 100644 --- a/src/dxtb/constants/defaults.py +++ b/src/dxtb/constants/defaults.py @@ -43,6 +43,9 @@ INTDRIVER = "libcint" """Integral driver.""" +INTLEVEL = 1 +"""Determines types of calculated integrals.""" + # SCF settings GUESS = "eeq" diff --git a/src/dxtb/interaction/list.py b/src/dxtb/interaction/list.py index 40cdcd42f..7b939a85e 100644 --- a/src/dxtb/interaction/list.py +++ b/src/dxtb/interaction/list.py @@ -49,6 +49,10 @@ def __init__(self, *interactions: Interaction | None) -> None: interaction for interaction in interactions if interaction is not None ] + @property + def labels(self) -> list[str]: + return [interaction.label for interaction in self.interactions] + def get_cache( self, numbers: Tensor, positions: Tensor, ihelp: IndexHelper ) -> InteractionList.Cache: @@ -233,3 +237,6 @@ def get_gradient( for interaction in self.interactions ] ).sum(dim=0) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.labels})" diff --git a/src/dxtb/utils/misc.py b/src/dxtb/utils/misc.py index 0e90f5433..c0a0dc041 100644 --- a/src/dxtb/utils/misc.py +++ b/src/dxtb/utils/misc.py @@ -30,6 +30,8 @@ def is_str_list(x: list[Any]) -> TypeGuard[list[str]]: TypeGuard[list[str]] `True` if all objects are strings, `False` otherwise. """ + if not isinstance(x, list): + return False return all(isinstance(i, str) for i in x) @@ -47,10 +49,25 @@ def is_int_list(x: list[Any]) -> TypeGuard[list[int]]: TypeGuard[list[int]] `True` if all objects are integers, `False` otherwise. """ + if not isinstance(x, list): + return False return all(isinstance(i, int) for i in x) -def is_list_basis(x) -> TypeGuard[list[AtomCGTOBasis]]: +def is_basis_list(x: Any) -> TypeGuard[list[AtomCGTOBasis]]: + """ + Determines whether all objects in the list are `AtomCGTOBasis`. + + Parameters + ---------- + x : list[Any] + List to check. + + Returns + ------- + TypeGuard[list[AtomCGTOBasis]] + `True` if all objects are `AtomCGTOBasis`, `False` otherwise. + """ if not isinstance(x, list): return False return all(isinstance(i, AtomCGTOBasis) for i in x) diff --git a/src/dxtb/xtb/calculator.py b/src/dxtb/xtb/calculator.py index 2286ec3d7..1db1b5ed1 100644 --- a/src/dxtb/xtb/calculator.py +++ b/src/dxtb/xtb/calculator.py @@ -28,7 +28,7 @@ from ..interaction import Charges, Interaction, InteractionList, Potential from ..param import Param, get_elem_angular from ..utils import Timers, ToleranceWarning, batch -from ..utils.misc import is_list_basis +from ..utils.misc import is_basis_list from ..wavefunction import filling from ..xtb.h0 import Hamiltonian from .h0 import Hamiltonian @@ -40,7 +40,7 @@ def use_intdriver(driver_arg: int = 1) -> Callable: def decorator(fcn: Callable) -> Callable: @wraps(fcn) def wrap(self: Calculator, *args: Any, **kwargs: Any) -> Any: - if self.opts["integral_driver"] == "libcint": + if self.opts["int_driver"] == "libcint": self.init_intdriver(args[driver_arg]) result = fcn(self, *args, **kwargs) @@ -277,8 +277,6 @@ def __init__( self.batched = numbers.ndim > 1 - drv = opts.get("intdriver", defaults.INTDRIVER) - self.opts = { "fwd_options": { "damp": opts.get("damp", defaults.DAMP), @@ -307,7 +305,7 @@ def __init__( "exclude": opts.get("exclude", defaults.EXCLUDE), "guess": opts.get("guess", defaults.GUESS), "spin": opts.get("spin", defaults.SPIN), - "integral_driver": drv, + "int_driver": opts.get("int_driver", defaults.INTDRIVER), } # set tolerances separately to catch unreasonably small values @@ -318,6 +316,7 @@ def __init__( self.hamiltonian = Hamiltonian(numbers, par, self.ihelp, **dd) self.integrals = Integrals(**dd) + # integrals do not work with a batched IndexHelper if self.batched: self._ihelp = [ IndexHelper.from_numbers( @@ -369,17 +368,22 @@ def __init__( halogen, dispersion, repulsion, *classical, timer=self.timer ) - self.timer.stop("setup calculator") + # integral-related setup - if self.opts["integral_driver"] == "libcint": + if self.opts["int_driver"] == "libcint": self.overlap = OverlapLibcint(numbers, par, self.ihelp, **dd) self.basis = Basis(numbers, par, self.ihelp, **dd) else: self.overlap = Overlap(numbers, par, self.ihelp, **dd) + # figure out integral level from interactions + self.set_intlevel(opts.get("int_level", defaults.INTLEVEL)) + self._intdriver = None self._positions = None + self.timer.stop("setup calculator") + def set_option(self, name: str, value: Any) -> None: if name not in self.opts: raise KeyError(f"Option '{name}' does not exist.") @@ -402,6 +406,12 @@ def set_tol(self, name: str, value: float) -> None: self.opts["fwd_options"][name] = value + def set_intlevel(self, value: int) -> None: + if "ElectricField" in self.interactions.labels: + value = max(2, value) + + self.opts["int_level"] = value + def driver( self, positions: Tensor ) -> intor.LibcintWrapper | list[intor.LibcintWrapper]: @@ -415,17 +425,18 @@ def driver( return [ intor.LibcintWrapper(ab, ihelp) for ab, ihelp in zip(atombases, self._ihelp) - if is_list_basis(ab) + if is_basis_list(ab) ] - assert is_list_basis(atombases) + assert is_basis_list(atombases) return intor.LibcintWrapper(atombases, self.ihelp) def init_intdriver(self, positions: Tensor): - if self.opts["integral_driver"] != "libcint": + if self.opts["int_driver"] != "libcint": return diff = 0 + # create intor.LibcintWrapper if it does not exist yet if self._intdriver is None: self._intdriver = self.driver(positions) @@ -478,54 +489,10 @@ def singlepoint( self.integrals.overlap = overlap self.timer.stop("Overlap") - # dipole intgral - if self.opts["integral_driver"] == "libcint": + # dipole integral + if self.opts["int_driver"] == "libcint" and self.opts["int_level"] > 1: self.timer.start("Dipole Integral") - - # statisfy type checking... - assert isinstance(self.overlap, OverlapLibcint) - assert self._intdriver is not None - - def dipole_integral( - driver: intor.LibcintWrapper, norm: Tensor - ) -> Tensor: - """ - Calculation of dipole integral. The integral is properly - normalized, using the diagonal of the overlap integral. - - Parameters - ---------- - driver : intor.LibcintWrapper - Integral driver (libcint interface). - norm : Tensor - Norm of the overlap integral. - - Returns - ------- - Tensor - Normalized dipole integral. - """ - return torch.einsum( - "xij,i,j->xij", intor.int1e("j", driver), norm, norm - ) - - if self.batched: - dpint_list = [] - - assert isinstance(self._intdriver, list) - for _batch, driver in enumerate(self._intdriver): - dpint = dipole_integral( - driver, batch.deflate(self.overlap.norm[_batch]) - ) - dpint_list.append(dpint) - - self.integrals.dipole = batch.pack(dpint_list) - else: - assert isinstance(self._intdriver, intor.LibcintWrapper) - self.integrals.dipole = dipole_integral( - self._intdriver, self.overlap.norm - ) - + self.dipole_integral() self.timer.stop("Dipole Integral") # Hamiltonian @@ -649,3 +616,40 @@ def dipole_integral( self.timer.print_times() return result + + def dipole_integral(self) -> None: + # statisfy type checking... + assert isinstance(self.overlap, OverlapLibcint) + assert self._intdriver is not None + + def dpint(driver: intor.LibcintWrapper, norm: Tensor) -> Tensor: + """ + Calculation of dipole integral. The integral is properly + normalized, using the diagonal of the overlap integral. + + Parameters + ---------- + driver : intor.LibcintWrapper + Integral driver (libcint interface). + norm : Tensor + Norm of the overlap integral. + + Returns + ------- + Tensor + Normalized dipole integral. + """ + return torch.einsum("xij,i,j->xij", intor.int1e("j", driver), norm, norm) + + if self.batched: + dpint_list = [] + + assert isinstance(self._intdriver, list) + for _batch, driver in enumerate(self._intdriver): + d = dpint(driver, batch.deflate(self.overlap.norm[_batch])) + dpint_list.append(d) + + self.integrals.dipole = batch.pack(dpint_list) + else: + assert isinstance(self._intdriver, intor.LibcintWrapper) + self.integrals.dipole = dpint(self._intdriver, self.overlap.norm) diff --git a/test/test_a_memory_leak/test_repulsion.py b/test/test_a_memory_leak/test_repulsion.py index 052fdb388..1fd2c1360 100644 --- a/test/test_a_memory_leak/test_repulsion.py +++ b/test/test_a_memory_leak/test_repulsion.py @@ -5,6 +5,8 @@ """ from __future__ import annotations +import gc + import pytest import torch @@ -72,4 +74,9 @@ def fcn(): # known reference cycle for create_graph=True energy.backward() - assert not has_memleak_tensor(fcn) + # run garbage collector to avoid leaks across other tests + gc.collect() + leak = has_memleak_tensor(fcn) + gc.collect() + + assert not leak, "Memory leak detected" diff --git a/test/test_basis/test_general.py b/test/test_basis/test_general.py index 8688eb56f..8760a05ad 100644 --- a/test/test_basis/test_general.py +++ b/test/test_basis/test_general.py @@ -7,7 +7,7 @@ import pytest import torch -from dxtb.basis import Basis, IndexHelper, slater +from dxtb.basis import Basis, IndexHelper, slater_to_gauss from dxtb.integral import overlap_gto from dxtb.param import GFN1_XTB as par from dxtb.param import get_elem_angular diff --git a/test/test_libcint/test_overlap.py b/test/test_libcint/test_overlap.py index c647f994e..8dd39404d 100644 --- a/test/test_libcint/test_overlap.py +++ b/test/test_libcint/test_overlap.py @@ -16,7 +16,7 @@ from dxtb.integral.libcint import LibcintWrapper, intor from dxtb.param import GFN1_XTB as par from dxtb.param import get_elem_angular -from dxtb.utils import is_list_basis, numpy_to_tensor +from dxtb.utils import batch, is_basis_list, numpy_to_tensor try: from dxtb.mol.external._pyscf import M @@ -36,6 +36,31 @@ def snorm(overlap: Tensor) -> Tensor: return torch.pow(overlap.diagonal(dim1=-1, dim2=-2), -0.5) +def extract_blocks(x: Tensor, block_sizes: list[int] | Tensor) -> list[Tensor]: + # Initialize the start index for the first block + start_index = 0 + + # Initialize an empty list to store the blocks + blocks: list[Tensor] = [] + + if isinstance(block_sizes, Tensor): + assert block_sizes.ndim == 1 + block_sizes = block_sizes.tolist() + + # Iterate over each block + for block_size in block_sizes: + # Generate the indices for the elements in the current block + indices = start_index + torch.arange(block_size) + + # Extract the block and append it to the list + blocks.append(x[indices, :][:, indices]) + + # Update the start index for the next block + start_index += block_size + + return blocks + + @pytest.mark.skipif(pyscf is False, reason="PySCF not installed") @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("name", sample_list) @@ -50,7 +75,7 @@ def test_single(dtype: torch.dtype, name: str) -> None: ihelp = IndexHelper.from_numbers(numbers, get_elem_angular(par.element)) bas = Basis(numbers, par, ihelp, **dd) atombases = bas.create_dqc(positions) - assert is_list_basis(atombases) + assert is_basis_list(atombases) # dxtb's libcint overlap wrapper = LibcintWrapper(atombases, ihelp) @@ -79,6 +104,10 @@ def test_single(dtype: torch.dtype, name: str) -> None: @pytest.mark.parametrize("name1", sample_list) @pytest.mark.parametrize("name2", sample_list) def test_batch(dtype: torch.dtype, name1: str, name2: str) -> None: + """ + Batched overlap using non-batched setup, i.e., one huge matrix is + calculated that is only populated on the diagonal. + """ tol = sqrt(torch.finfo(dtype).eps) * 1e-2 dd: DD = {"device": device, "dtype": dtype} @@ -92,7 +121,7 @@ def test_batch(dtype: torch.dtype, name1: str, name2: str) -> None: positions = torch.cat( ( sample1["positions"].to(**dd), - sample2["positions"].to(**dd) + 1000, + sample2["positions"].to(**dd) + 1000, # move! ), dim=0, ) @@ -100,7 +129,7 @@ def test_batch(dtype: torch.dtype, name1: str, name2: str) -> None: ihelp = IndexHelper.from_numbers(numbers, get_elem_angular(par.element)) bas = Basis(numbers, par, ihelp, **dd) atombases = bas.create_dqc(positions) - assert is_list_basis(atombases) + assert is_basis_list(atombases) # dxtb's libcint overlap wrapper = LibcintWrapper(atombases, ihelp) @@ -121,10 +150,19 @@ def test_batch(dtype: torch.dtype, name1: str, name2: str) -> None: norm = snorm(pyscf_overlap) pyscf_overlap = torch.einsum("ij,i,j->ij", pyscf_overlap, norm, norm) - print(dxtb_overlap) assert dxtb_overlap.shape == pyscf_overlap.shape assert pytest.approx(pyscf_overlap, abs=tol) == dxtb_overlap + # we could also extract the blocks and pack them as usual + n = batch.pack((sample1["numbers"].to(device), sample2["numbers"].to(device))) + ihelp2 = IndexHelper.from_numbers(n, get_elem_angular(par.element)) + sizes = ihelp2.orbitals_per_shell.sum(-1) + out = extract_blocks(dxtb_overlap, sizes) + s_packed = batch.pack(out) + + max_size = int(ihelp2.orbitals_per_shell.sum(-1).max()) + assert s_packed.shape == torch.Size((2, max_size, max_size)) + @pytest.mark.skipif(pyscf is False, reason="PySCF not installed") @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @@ -140,7 +178,7 @@ def test_grad(dtype: torch.dtype, name: str) -> None: ihelp = IndexHelper.from_numbers(numbers, get_elem_angular(par.element)) bas = Basis(numbers, par, ihelp, **dd) atombases = bas.create_dqc(positions) - assert is_list_basis(atombases) + assert is_basis_list(atombases) wrapper = LibcintWrapper(atombases, ihelp) int1 = intor.int1e("ipovlp", wrapper) diff --git a/test/test_overlap/test_grad.py b/test/test_overlap/test_grad.py index 03ff24f66..77c33894e 100644 --- a/test/test_overlap/test_grad.py +++ b/test/test_overlap/test_grad.py @@ -10,7 +10,7 @@ from torch.autograd.functional import jacobian from dxtb._types import DD, Tensor -from dxtb.basis import IndexHelper, slater +from dxtb.basis import IndexHelper, slater_to_gauss from dxtb.integral import Overlap, mmd from dxtb.param import GFN1_XTB as par from dxtb.param import get_elem_angular diff --git a/test/test_scf/test_full_tracking.py b/test/test_scf/test_full_tracking.py index b83c5282d..da2adf906 100644 --- a/test/test_scf/test_full_tracking.py +++ b/test/test_scf/test_full_tracking.py @@ -41,7 +41,7 @@ def single( opts, **{ "damp": 0.05 if mixer == "simple" else 0.4, - "intdriver": intdriver, + "int_driver": intdriver, "mixer": mixer, "scp_mode": scp_mode, "xitorch_fatol": tol, @@ -132,7 +132,7 @@ def batched( "damp": 0.05 if mixer == "simple" else 0.4, "mixer": mixer, "scp_mode": "charge", - "intdriver": intdriver, + "int_driver": intdriver, "xitorch_fatol": tol, "xitorch_xatol": tol, },