Skip to content

Commit

Permalink
Formatting and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Jul 26, 2023
1 parent 01ec442 commit cad5b57
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 33 deletions.
9 changes: 8 additions & 1 deletion src/dxtb/constants/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,14 @@
experimental single-shot procedure.
"""

SCF_MODE_CHOICES = ["default", "implicit", "full", "full_tracking", "experimental", "implicit_old"]
SCF_MODE_CHOICES = [
"default",
"implicit",
"full",
"full_tracking",
"experimental",
"implicit_old",
]
"""List of possible choices for `SCF_MODE`."""

SCP_MODE = "potential"
Expand Down
75 changes: 43 additions & 32 deletions test/test_scf/test_energy.py → test/test_scf/test_fenergy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
Test energy calculations from SCF iterations.
"""
Expand Down Expand Up @@ -27,87 +26,100 @@
"verbosity": 0,
}


@pytest.mark.large
@pytest.mark.filterwarnings("ignore")
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
@pytest.mark.parametrize("partition", ["equal", "atomic"])
def test_element_energy_scf_mode(dtype: torch.dtype, partition: str) -> None:
"""Comparison of object SCF (old) vs. functional SCF."""
tol = 1e-8
dd: DD = {"device": device, "dtype": dtype}
tol = 1e-8

def calc(number, scf_mode):
def fcn(number, scf_mode):
numbers = torch.tensor([number])
positions = torch.zeros((1, 3), **dd)
charges = torch.tensor(0.0, **dd)

options = dict(opts, **{"xitorch_fatol": 1e-6, "xitorch_xatol": 1e-6, "fermi_fenergy_partition": partition, "scf_mode": scf_mode})
options = dict(
opts,
**{
"xitorch_fatol": 1e-6,
"xitorch_xatol": 1e-6,
"fermi_fenergy_partition": partition,
"scf_mode": scf_mode,
},
)
calc = Calculator(numbers, par, opts=options, **dd)
result = calc.singlepoint(numbers, positions, charges)
return result.scf.sum(-1)

return result.scf.sum(-1).item()

energies = [calc(n, "implicit") for n in range(1, 87)]
energies_old = [calc(n, "implicit_old") for n in range(1, 87)]
energies = [fcn(n, "implicit") for n in range(1, 87)]
energies_old = [fcn(n, "implicit_old") for n in range(1, 87)]
assert pytest.approx(energies, abs=tol) == energies_old


@pytest.mark.large
@pytest.mark.filterwarnings("ignore")
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
@pytest.mark.parametrize("partition", ["equal", "atomic"])
def test_element_electronic_free_energy_scf_mode(dtype: torch.dtype, partition: str) -> None:
def test_element_scf_mode(dtype: torch.dtype, partition: str) -> None:
"""Comparison of object SCF (old) vs. functional SCF."""
tol = 1e-8
dd: DD = {"device": device, "dtype": dtype}
tol = 1e-8

def calc(number, scf_mode):
def fcn(number, scf_mode):
numbers = torch.tensor([number])
positions = torch.zeros((1, 3), **dd)
charges = torch.tensor(0.0, **dd)

options = dict(opts, **{"xitorch_fatol": 1e-6, "xitorch_xatol": 1e-6, "fermi_fenergy_partition": partition, "scf_mode": scf_mode})
options = dict(
opts,
**{
"xitorch_fatol": 1e-6,
"xitorch_xatol": 1e-6,
"fermi_fenergy_partition": partition,
"scf_mode": scf_mode,
},
)
calc = Calculator(numbers, par, opts=options, **dd)
result = calc.singlepoint(numbers, positions, charges)

return result.fenergy
fenergies = [calc(n, "implicit").item() for n in range(1, 87)]
fenergies_old = [calc(n, "implicit_old").item() for n in range(1, 87)]

fenergies = [fcn(n, "implicit").item() for n in range(1, 87)]
fenergies_old = [fcn(n, "implicit_old").item() for n in range(1, 87)]
assert pytest.approx(fenergies, abs=tol) == fenergies_old


@pytest.mark.large
@pytest.mark.filterwarnings("ignore")
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_element_electronic_free_energy(dtype: torch.dtype) -> None:
def test_element(dtype: torch.dtype) -> None:
"""Different free energies for different atoms."""

dd: DD = {"device": device, "dtype": dtype}

def calc(number):
def fcn(number):
numbers = torch.tensor([number])
positions = torch.zeros((1, 3), **dd)
charges = torch.tensor(0.0, **dd)

options = dict(opts, **{"xitorch_fatol": 1e-6, "xitorch_xatol": 1e-6})
calc = Calculator(numbers, par, opts=options, **dd)
result = calc.singlepoint(numbers, positions, charges)

return result.fenergy
fenergies = [calc(n).item() for n in range(1, 87)]

fenergies = [fcn(n).item() for n in range(1, 87)]
unique = set(fenergies)
assert len(unique) > 5



@pytest.mark.large
@pytest.mark.filterwarnings("ignore")
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_element_cation(dtype: torch.dtype) -> None:

dd: DD = {"device": device, "dtype": dtype}

def calc(number):
def fcn(number):
numbers = torch.tensor([number])
positions = torch.zeros((1, 3), **dd)
charges = torch.tensor(1.0, **dd)
Expand All @@ -123,12 +135,12 @@ def calc(number):
calc = Calculator(numbers, par, opts=options, **dd)
result = calc.singlepoint(numbers, positions, charges)
return result.fenergy

# no (valence) electrons OR gold
_exclude = [1, 3, 11, 19, 37, 55, 79]
_exclude = [1, 3, 11, 19, 37, 55, 79]
numbers = [i for i in range(1, 87) if i not in _exclude]

fenergies = [calc(n).item() for n in numbers]
fenergies = [fcn(n).item() for n in numbers]
unique = set(fenergies)
assert len(unique) > 5

Expand All @@ -137,10 +149,9 @@ def calc(number):
@pytest.mark.filterwarnings("ignore")
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_element_anion(dtype: torch.dtype) -> None:

dd: DD = {"device": device, "dtype": dtype}

def calc(number):
def fcn(number):
numbers = torch.tensor([number])
positions = torch.zeros((1, 3), **dd)
charges = torch.tensor(-1.0, **dd)
Expand All @@ -159,9 +170,9 @@ def calc(number):

# Helium doesn't have enough orbitals for negative charge,
# SCF does not converge (in tblite too)
_exclude = [2, 21, 22, 23, 25, 43, 57, 58, 59]
_exclude = [2, 21, 22, 23, 25, 43, 57, 58, 59]
numbers = [i for i in range(1, 87) if i not in _exclude]

fenergies = [calc(n).item() for n in numbers]
fenergies = [fcn(n).item() for n in numbers]
unique = set(fenergies)
assert len(unique) > 5

0 comments on commit cad5b57

Please sign in to comment.