Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Sep 16, 2024
1 parent 2ea56df commit ada219a
Show file tree
Hide file tree
Showing 41 changed files with 91 additions and 48 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# adapt path to include the source code
sys.path.insert(0, op.join(op.dirname(__file__), "../../", "src"))

import dxtb # noqa
import dxtb # pylint: disable=unused-import

Check notice

Code scanning / CodeQL

Unused import Note documentation

Import of 'dxtb' is not used.

project = "Fully Differentiable Extended Tight-Binding"
author = "Grimme Group"
Expand Down
1 change: 0 additions & 1 deletion src/dxtb/_src/calculators/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,6 @@ def __init__(
if self.opts.batch_mode == 0 and numbers.ndim > 1:
self.opts.batch_mode = 1

# TODO: Should the IndexHelper be a singleton?
self.ihelp = IndexHelper.from_numbers(
numbers, par, self.opts.batch_mode
)
Expand Down
2 changes: 2 additions & 0 deletions src/dxtb/_src/components/interactions/coulomb/thirdorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def get_cache(
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms in the system (shape: ``(..., nat)``).
ihelp : IndexHelper
Index mapping for the basis set.
Expand Down
4 changes: 3 additions & 1 deletion src/dxtb/_src/constants/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

import torch

from dxtb import labels

# General

STRICT = False
Expand Down Expand Up @@ -76,7 +78,7 @@
INTCUTOFF = 50.0
"""Real-space cutoff (in Bohr) for integral evaluation. (50.0)"""

INTDRIVER = "libcint"
INTDRIVER = labels.INTDRIVER_LIBCINT
"""Integral driver."""

INTDRIVER_CHOICES = [
Expand Down
2 changes: 1 addition & 1 deletion src/dxtb/_src/integral/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def build_overlap(self, positions: Tensor, **kwargs: Any) -> Tensor:
# `None`, the overlap will not be rebuilt if the positions change,
# i.e., when the driver was invalidated. Hence, we would require a
# full reset of the integrals via `reset_all`. However, the integral
# reset cannot be trigger by the driver manager, so we cannot add this
# reset cannot be triggered by the driver manager, so we cannot add this
# check here. If we do, the hessian tests will fail as the overlap is
# not recalculated for positions + delta.
self.overlap.build(self.mgr.driver)
Expand Down
2 changes: 1 addition & 1 deletion src/dxtb/_src/integral/driver/pytorch/impls/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ def __call__(
Tensor
Overlap matrix or overlap gradient.
"""
... # noqa
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
2 changes: 1 addition & 1 deletion test/test_a_memory_leak/test_repulsion.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def fcn():
del energy
del arep
del zeff
del kexp # noqa
del kexp

# run garbage collector to avoid leaks across other tests
garbage_collect()
Expand Down
4 changes: 2 additions & 2 deletions test/test_a_memory_leak/test_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def fcn():
del charges
del calc
del result
del energy # noqa
del energy

# run garbage collector to avoid leaks across other tests
garbage_collect()
Expand Down Expand Up @@ -156,7 +156,7 @@ def fcn():
del charges
del calc
del result
del energy # noqa
del energy

# run garbage collector to avoid leaks across other tests
garbage_collect()
Expand Down
22 changes: 12 additions & 10 deletions test/test_a_memory_leak/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,18 @@ def _get_tensor_memory(
return total_mem


def _show_memsize(fcn, ntries: int = 10, gccollect: bool = False):
# show the memory growth
size0, num0 = _get_tensor_memory(return_number_tensors=True)

def _show_memsize(
fcn, size0: float, num0: int, ntries: int = 10, gccollect: bool = False
):
# show the memory growth over n iterations
for i in range(ntries):
fcn()
if gccollect:
if gccollect is True:
gc.collect()
size, num = _get_tensor_memory(return_number_tensors=True)

print(
f"{i + 1:2d} iteration: {size - size0:.16f} MiB of {num-num0:d} addtional tensors"
f"{i + 1:2d} iteration: {size - size0:.12f} MiB of {num - num0:d} addtional tensors"
)


Expand All @@ -133,8 +133,7 @@ def has_memleak_tensor(
fcn: Callable
A function with no input and output to be checked.
gccollect: bool
If True, then manually apply ``gc.collect()`` after the function
execution.
If True, run :func:`gc.collect` after function execution.
Returns
-------
Expand All @@ -144,12 +143,15 @@ def has_memleak_tensor(
size0, num0 = _get_tensor_memory(return_number_tensors=True)

fcn()
if gccollect:
if gccollect is True:
gc.collect()

size, num = _get_tensor_memory(return_number_tensors=True)

if size0 != size or num0 != num:
_show_memsize(fcn, repeats, gccollect=gccollect)
print(
f"{0:2d} iteration: {size - size0:.12f} MiB of {num - num0:d} addtional tensors"
)
_show_memsize(fcn, size0, num0, repeats, gccollect=gccollect)

return size0 != size
5 changes: 5 additions & 0 deletions test/test_calculator/test_cache/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def test_energy(dtype: torch.dtype) -> None:
assert calc._ncalcs == 1
assert isinstance(energy, Tensor)

# different name for energy getter
energy = calc.get_potential_energy(positions)
assert calc._ncalcs == 1
assert isinstance(energy, Tensor)

# check reset
calc.cache.reset_all()
assert len(calc.cache.list_cached_properties()) == 0
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from dxtb._src.typing import DD, Tensor

from ..conftest import DEVICE
from ...conftest import DEVICE
from .samples import samples


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from dxtb._src.param.gfn2 import GFN2_XTB as par
from dxtb._src.typing import DD, Tensor

from ..conftest import DEVICE
from ...conftest import DEVICE
from .samples import samples


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from dxtb._src.components.classicals.dispersion import new_dispersion
from dxtb._src.typing import DD

from ..conftest import DEVICE
from ...conftest import DEVICE
from .samples import samples


Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@
from dxtb._src.components.classicals.dispersion import new_dispersion
from dxtb._src.typing import DD

from ...conftest import DEVICE
from .samples import samples

sample_list = ["LiH", "SiH4", "MB16_43_01", "PbH4-BiH3"]

from ..conftest import DEVICE


@pytest.mark.grad
@pytest.mark.parametrize("name", sample_list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from dxtb._src.components.classicals.dispersion import DispersionD3
from dxtb._src.typing import DD, Callable, Tensor

from ..conftest import DEVICE
from ...conftest import DEVICE
from .samples import samples

slist = ["LiH", "SiH4"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from dxtb._src.components.classicals.dispersion import new_dispersion
from dxtb._src.typing import DD, Callable, Tensor

from ..conftest import DEVICE
from ...conftest import DEVICE
from .samples import samples

slist = ["LiH", "SiH4"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
import torch
from tad_mctc.autograd import jacrev
from tad_mctc.batch import pack
from tad_mctc.convert import reshape_fortran

from dxtb import GFN1_XTB as par
from dxtb._src.components.classicals.dispersion import new_dispersion
from dxtb._src.typing import DD, Tensor

from ..conftest import DEVICE
from ..utils import reshape_fortran
from ...conftest import DEVICE
from .samples import samples

sample_list = ["LiH", "SiH4", "MB16_43_01", "PbH4-BiH3"]
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dxtb._src.components.classicals import new_halogen
from dxtb._src.typing import DD

from ..conftest import DEVICE
from ...conftest import DEVICE
from .samples import samples


Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from dxtb._src.components.classicals import new_halogen
from dxtb._src.typing import DD

from ..conftest import DEVICE
from ...conftest import DEVICE
from .samples import samples


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dxtb._src.param import get_elem_param
from dxtb._src.typing import DD, Callable, Tensor

from ..conftest import DEVICE
from ...conftest import DEVICE
from .samples import samples

sample_list = ["br2nh3", "br2och2", "tmpda"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from dxtb._src.components.classicals import new_halogen
from dxtb._src.typing import DD, Callable, Tensor

from ..conftest import DEVICE
from ...conftest import DEVICE
from .samples import samples

# "LYS_xao" must be the last one as we have to manually exclude it for the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
import torch
from tad_mctc.autograd import jacrev
from tad_mctc.batch import pack
from tad_mctc.convert import reshape_fortran

from dxtb import GFN1_XTB as par
from dxtb import IndexHelper
from dxtb._src.components.classicals import new_halogen
from dxtb._src.typing import DD, Tensor

from ..conftest import DEVICE
from ..utils import reshape_fortran
from ...conftest import DEVICE
from .samples import samples

sample_list = ["br2nh3", "br2och2", "finch", "LiH", "SiH4", "MB16_43_01"]
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from dxtb._src.param.gfn2 import GFN2_XTB
from dxtb._src.typing import DD, Literal

from ..conftest import DEVICE
from ...conftest import DEVICE
from .samples import samples

sample_list = [
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from dxtb._src.components.classicals import new_repulsion
from dxtb._src.typing import DD

from ..conftest import DEVICE
from ...conftest import DEVICE
from .samples import samples

sample_list = ["H2O", "SiH4", "MB16_43_01", "MB16_43_02", "LYS_xao"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dxtb._src.param import get_elem_param
from dxtb._src.typing import DD, Callable, Tensor

from ..conftest import DEVICE
from ...conftest import DEVICE
from .samples import samples

sample_list = ["H2O", "SiH4", "MB16_43_01", "MB16_43_02", "LYS_xao"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)
from dxtb._src.typing import DD, Callable, Tensor

from ..conftest import DEVICE
from ...conftest import DEVICE
from .samples import samples

sample_list = ["H2O", "SiH4", "MB16_43_01", "MB16_43_02", "LYS_xao"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
import torch
from tad_mctc.autograd import jacrev
from tad_mctc.batch import pack
from tad_mctc.convert import reshape_fortran

from dxtb import GFN1_XTB as par
from dxtb import IndexHelper
from dxtb._src.components.classicals import new_repulsion
from dxtb._src.typing import DD, Tensor

from ..conftest import DEVICE
from ..utils import reshape_fortran
from ...conftest import DEVICE
from .samples import samples

sample_list = ["LiH", "SiH4", "MB16_43_01"]
Expand Down
20 changes: 19 additions & 1 deletion test/test_coulomb/test_es2_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,25 @@ def test_none() -> None:
assert es2.new_es2(dummy, par) is None


def test_store_fail() -> None:
def test_cache_input_fail() -> None:
numbers = torch.tensor([3, 1])
positions = torch.randn((2, 3))
ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB)

es = es2.new_es2(numbers, GFN1_XTB)
assert es is not None

with pytest.raises(ValueError):
es.get_cache(numbers=None, positions=positions, ihelp=ihelp)

with pytest.raises(ValueError):
es.get_cache(numbers=numbers, positions=None, ihelp=ihelp)

with pytest.raises(ValueError):
es.get_cache(numbers=numbers, positions=positions, ihelp=None)


def test_fail_store() -> None:
numbers = torch.tensor([3, 1])
positions = torch.randn((2, 3))
ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB)
Expand Down
16 changes: 15 additions & 1 deletion test/test_coulomb/test_es3_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,21 @@ def test_fail() -> None:
es3.new_es3(dummy, par)


def test_store_fail() -> None:
def test_fail_cache_input() -> None:
numbers = torch.tensor([3, 1])
ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB)

es = es3.new_es3(numbers, GFN1_XTB)
assert es is not None

with pytest.raises(ValueError):
es.get_cache(numbers=None, ihelp=ihelp)

with pytest.raises(ValueError):
es.get_cache(numbers=numbers, ihelp=None)


def test_fail_store() -> None:
numbers = torch.tensor([3, 1])
ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB)

Expand Down
Loading

0 comments on commit ada219a

Please sign in to comment.