Skip to content

Commit

Permalink
More clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Jul 26, 2023
1 parent 3966b61 commit c51d16b
Show file tree
Hide file tree
Showing 11 changed files with 169 additions and 69 deletions.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ dev =
mypy
pre-commit
pylint
pyscf
pytest
pytest-random-order
tox
tox =
covdefaults
coverage
pyscf
pytest
pytest-random-order

Expand Down
22 changes: 22 additions & 0 deletions src/dxtb/basis/indexhelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,3 +944,25 @@ def allowed_dtypes(self) -> tuple[torch.dtype, ...]:
Collection of allowed dtypes the TensorLike object can take.
"""
return (torch.int16, torch.int32, torch.int64, torch.long)

def __repr__(self) -> str:
return (
f"IndexHelper(\n"
f" unique_angular={self.unique_angular},\n"
f" angular={self.angular},\n"
f" atom_to_unique={self.atom_to_unique},\n"
f" ushells_to_unique={self.ushells_to_unique},\n"
f" ushells_per_unique={self.ushells_per_unique},\n"
f" shells_to_ushell={self.shells_to_ushell},\n"
f" shells_per_atom={self.shells_per_atom},\n"
f" shell_index={self.shell_index},\n"
f" shells_to_atom={self.shells_to_atom},\n"
f" orbitals_per_shell={self.orbitals_per_shell},\n"
f" orbital_index={self.orbital_index},\n"
f" orbitals_to_shell={self.orbitals_to_shell},\n"
f" batched={self.batched},\n"
f" store={self.store},\n"
f" device={self.device},\n"
f" dtype={self.dtype}\n"
")"
)
3 changes: 3 additions & 0 deletions src/dxtb/constants/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
INTDRIVER = "libcint"
"""Integral driver."""

INTLEVEL = 1
"""Determines types of calculated integrals."""

# SCF settings

GUESS = "eeq"
Expand Down
7 changes: 7 additions & 0 deletions src/dxtb/interaction/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def __init__(self, *interactions: Interaction | None) -> None:
interaction for interaction in interactions if interaction is not None
]

@property
def labels(self) -> list[str]:
return [interaction.label for interaction in self.interactions]

def get_cache(
self, numbers: Tensor, positions: Tensor, ihelp: IndexHelper
) -> InteractionList.Cache:
Expand Down Expand Up @@ -233,3 +237,6 @@ def get_gradient(
for interaction in self.interactions
]
).sum(dim=0)

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.labels})"
19 changes: 18 additions & 1 deletion src/dxtb/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def is_str_list(x: list[Any]) -> TypeGuard[list[str]]:
TypeGuard[list[str]]
`True` if all objects are strings, `False` otherwise.
"""
if not isinstance(x, list):
return False
return all(isinstance(i, str) for i in x)


Expand All @@ -47,10 +49,25 @@ def is_int_list(x: list[Any]) -> TypeGuard[list[int]]:
TypeGuard[list[int]]
`True` if all objects are integers, `False` otherwise.
"""
if not isinstance(x, list):
return False
return all(isinstance(i, int) for i in x)


def is_list_basis(x) -> TypeGuard[list[AtomCGTOBasis]]:
def is_basis_list(x: Any) -> TypeGuard[list[AtomCGTOBasis]]:
"""
Determines whether all objects in the list are `AtomCGTOBasis`.
Parameters
----------
x : list[Any]
List to check.
Returns
-------
TypeGuard[list[AtomCGTOBasis]]
`True` if all objects are `AtomCGTOBasis`, `False` otherwise.
"""
if not isinstance(x, list):
return False
return all(isinstance(i, AtomCGTOBasis) for i in x)
Expand Down
118 changes: 61 additions & 57 deletions src/dxtb/xtb/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..interaction import Charges, Interaction, InteractionList, Potential
from ..param import Param, get_elem_angular
from ..utils import Timers, ToleranceWarning, batch
from ..utils.misc import is_list_basis
from ..utils.misc import is_basis_list
from ..wavefunction import filling
from ..xtb.h0 import Hamiltonian
from .h0 import Hamiltonian
Expand All @@ -40,7 +40,7 @@ def use_intdriver(driver_arg: int = 1) -> Callable:
def decorator(fcn: Callable) -> Callable:
@wraps(fcn)
def wrap(self: Calculator, *args: Any, **kwargs: Any) -> Any:
if self.opts["integral_driver"] == "libcint":
if self.opts["int_driver"] == "libcint":
self.init_intdriver(args[driver_arg])

result = fcn(self, *args, **kwargs)
Expand Down Expand Up @@ -277,8 +277,6 @@ def __init__(

self.batched = numbers.ndim > 1

drv = opts.get("intdriver", defaults.INTDRIVER)

self.opts = {
"fwd_options": {
"damp": opts.get("damp", defaults.DAMP),
Expand Down Expand Up @@ -307,7 +305,7 @@ def __init__(
"exclude": opts.get("exclude", defaults.EXCLUDE),
"guess": opts.get("guess", defaults.GUESS),
"spin": opts.get("spin", defaults.SPIN),
"integral_driver": drv,
"int_driver": opts.get("int_driver", defaults.INTDRIVER),
}

# set tolerances separately to catch unreasonably small values
Expand All @@ -318,6 +316,7 @@ def __init__(
self.hamiltonian = Hamiltonian(numbers, par, self.ihelp, **dd)
self.integrals = Integrals(**dd)

# integrals do not work with a batched IndexHelper
if self.batched:
self._ihelp = [
IndexHelper.from_numbers(
Expand Down Expand Up @@ -369,17 +368,22 @@ def __init__(
halogen, dispersion, repulsion, *classical, timer=self.timer
)

self.timer.stop("setup calculator")
# integral-related setup

if self.opts["integral_driver"] == "libcint":
if self.opts["int_driver"] == "libcint":
self.overlap = OverlapLibcint(numbers, par, self.ihelp, **dd)
self.basis = Basis(numbers, par, self.ihelp, **dd)
else:
self.overlap = Overlap(numbers, par, self.ihelp, **dd)

# figure out integral level from interactions
self.set_intlevel(opts.get("int_level", defaults.INTLEVEL))

self._intdriver = None
self._positions = None

self.timer.stop("setup calculator")

def set_option(self, name: str, value: Any) -> None:
if name not in self.opts:
raise KeyError(f"Option '{name}' does not exist.")
Expand All @@ -402,6 +406,12 @@ def set_tol(self, name: str, value: float) -> None:

self.opts["fwd_options"][name] = value

def set_intlevel(self, value: int) -> None:
if "ElectricField" in self.interactions.labels:
value = max(2, value)

self.opts["int_level"] = value

def driver(
self, positions: Tensor
) -> intor.LibcintWrapper | list[intor.LibcintWrapper]:
Expand All @@ -415,17 +425,18 @@ def driver(
return [
intor.LibcintWrapper(ab, ihelp)
for ab, ihelp in zip(atombases, self._ihelp)
if is_list_basis(ab)
if is_basis_list(ab)
]

assert is_list_basis(atombases)
assert is_basis_list(atombases)
return intor.LibcintWrapper(atombases, self.ihelp)

def init_intdriver(self, positions: Tensor):
if self.opts["integral_driver"] != "libcint":
if self.opts["int_driver"] != "libcint":
return

diff = 0

# create intor.LibcintWrapper if it does not exist yet
if self._intdriver is None:
self._intdriver = self.driver(positions)
Expand Down Expand Up @@ -478,54 +489,10 @@ def singlepoint(
self.integrals.overlap = overlap
self.timer.stop("Overlap")

# dipole intgral
if self.opts["integral_driver"] == "libcint":
# dipole integral
if self.opts["int_driver"] == "libcint" and self.opts["int_level"] > 1:
self.timer.start("Dipole Integral")

# statisfy type checking...
assert isinstance(self.overlap, OverlapLibcint)
assert self._intdriver is not None

def dipole_integral(
driver: intor.LibcintWrapper, norm: Tensor
) -> Tensor:
"""
Calculation of dipole integral. The integral is properly
normalized, using the diagonal of the overlap integral.
Parameters
----------
driver : intor.LibcintWrapper
Integral driver (libcint interface).
norm : Tensor
Norm of the overlap integral.
Returns
-------
Tensor
Normalized dipole integral.
"""
return torch.einsum(
"xij,i,j->xij", intor.int1e("j", driver), norm, norm
)

if self.batched:
dpint_list = []

assert isinstance(self._intdriver, list)
for _batch, driver in enumerate(self._intdriver):
dpint = dipole_integral(
driver, batch.deflate(self.overlap.norm[_batch])
)
dpint_list.append(dpint)

self.integrals.dipole = batch.pack(dpint_list)
else:
assert isinstance(self._intdriver, intor.LibcintWrapper)
self.integrals.dipole = dipole_integral(
self._intdriver, self.overlap.norm
)

self.dipole_integral()
self.timer.stop("Dipole Integral")

# Hamiltonian
Expand Down Expand Up @@ -649,3 +616,40 @@ def dipole_integral(
self.timer.print_times()

return result

def dipole_integral(self) -> None:
# statisfy type checking...
assert isinstance(self.overlap, OverlapLibcint)
assert self._intdriver is not None

def dpint(driver: intor.LibcintWrapper, norm: Tensor) -> Tensor:
"""
Calculation of dipole integral. The integral is properly
normalized, using the diagonal of the overlap integral.
Parameters
----------
driver : intor.LibcintWrapper
Integral driver (libcint interface).
norm : Tensor
Norm of the overlap integral.
Returns
-------
Tensor
Normalized dipole integral.
"""
return torch.einsum("xij,i,j->xij", intor.int1e("j", driver), norm, norm)

if self.batched:
dpint_list = []

assert isinstance(self._intdriver, list)
for _batch, driver in enumerate(self._intdriver):
d = dpint(driver, batch.deflate(self.overlap.norm[_batch]))
dpint_list.append(d)

self.integrals.dipole = batch.pack(dpint_list)
else:
assert isinstance(self._intdriver, intor.LibcintWrapper)
self.integrals.dipole = dpint(self._intdriver, self.overlap.norm)
9 changes: 8 additions & 1 deletion test/test_a_memory_leak/test_repulsion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"""
from __future__ import annotations

import gc

import pytest
import torch

Expand Down Expand Up @@ -72,4 +74,9 @@ def fcn():
# known reference cycle for create_graph=True
energy.backward()

assert not has_memleak_tensor(fcn)
# run garbage collector to avoid leaks across other tests
gc.collect()
leak = has_memleak_tensor(fcn)
gc.collect()

assert not leak, "Memory leak detected"
2 changes: 1 addition & 1 deletion test/test_basis/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import torch

from dxtb.basis import Basis, IndexHelper, slater
from dxtb.basis import Basis, IndexHelper, slater_to_gauss
from dxtb.integral import overlap_gto
from dxtb.param import GFN1_XTB as par
from dxtb.param import get_elem_angular
Expand Down
Loading

0 comments on commit c51d16b

Please sign in to comment.