Skip to content

Commit

Permalink
Merge pull request #13 from So-Takamoto/cell_grad_rewrite
Browse files Browse the repository at this point in the history
use shift for gradient calculation instead of cell
  • Loading branch information
corochann authored Sep 21, 2021
2 parents 4a0d47d + eca0300 commit dd8644e
Show file tree
Hide file tree
Showing 16 changed files with 241 additions and 161 deletions.
4 changes: 2 additions & 2 deletions .flexci/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ configs {
key: "torch-dftd.pytest"
value {
requirement {
cpu: 4
memory: 24
cpu: 6
memory: 36
disk: 10
gpu: 1
}
Expand Down
2 changes: 1 addition & 1 deletion .flexci/pytest_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ main() {
docker run --runtime=nvidia --rm --volume="$(pwd)":/workspace -w /workspace \
${IMAGE} \
bash -x -c "pip install flake8 pytest pytest-cov pytest-xdist pytest-benchmark && \
pip install cupy-cuda102 pytorch-pfn-extras && \
pip install cupy-cuda102 pytorch-pfn-extras!=0.5.0 && \
pip install -e .[develop] && \
pysen run lint && \
pytest --cov=torch_dftd -n $(nproc) -m 'not slow' tests &&
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[tool:pytest]
markers =
slow: mark test as slow.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup_requires: List[str] = []
install_requires: List[str] = [
"ase>=3.18, <4.0.0", # Note that we require ase==3.21.1 for pytest.
"pymatgen",
"pymatgen>=2020.1.28",
]
extras_require: Dict[str, List[str]] = {
"develop": ["pysen[lint]==0.9.1"],
Expand Down
36 changes: 22 additions & 14 deletions tests/functions_tests/test_triplets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ def test_calc_triplets():
dtype=torch.long,
device=device,
)
shift = torch.zeros((edge_index.shape[1], 3), dtype=torch.float32, device=device)
shift[:, 0] = torch.tensor(
shift_pos = torch.zeros((edge_index.shape[1], 3), dtype=torch.float32, device=device)
shift_pos[:, 0] = torch.tensor(
[1, 2, 3, 4, 5, 6, -1, -2, -3, -4, -5, -6], dtype=torch.float32, device=device
)
# print("shift", shift.shape)
triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets(
edge_index, shift
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(
edge_index, shift_pos
)
# print("triplet_node_index", triplet_node_index.shape, triplet_node_index)
# print("multiplicity", multiplicity.shape, multiplicity)
Expand All @@ -38,6 +38,20 @@ def test_calc_triplets():
)
assert multiplicity.shape == (n_triplets,)
assert torch.all(multiplicity.cpu() == torch.ones((n_triplets,), dtype=torch.float32))

assert torch.allclose(
edge_jk.cpu(),
torch.tensor([[7, 6], [8, 6], [8, 7], [9, 10], [9, 11], [11, 10]], dtype=torch.long),
)
# shift for edge `i->j`, `i->k`, `j->k`.
triplet_shift = torch.stack(
[
-shift_pos[edge_jk[:, 0]],
-shift_pos[edge_jk[:, 1]],
shift_pos[edge_jk[:, 0]] - shift_pos[edge_jk[:, 1]],
],
dim=1,
)
assert torch.allclose(
triplet_shift.cpu()[:, :, 0],
torch.tensor(
Expand All @@ -61,7 +75,7 @@ def test_calc_triplets_noshift():
edge_index = torch.tensor(
[[0, 1, 1, 3, 1, 2, 3, 0], [1, 2, 3, 0, 0, 1, 1, 3]], dtype=torch.long, device=device
)
triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets(
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(
edge_index, dtype=torch.float64
)
# print("triplet_node_index", triplet_node_index.shape, triplet_node_index)
Expand All @@ -78,13 +92,7 @@ def test_calc_triplets_noshift():
assert multiplicity.shape == (n_triplets,)
assert multiplicity.dtype == torch.float64
assert torch.all(multiplicity.cpu() == torch.ones((n_triplets,), dtype=torch.float64))
assert torch.all(
triplet_shift.cpu()
== torch.zeros(
(n_triplets, 3, 3),
dtype=torch.float32,
)
)
assert torch.all(edge_jk.cpu() == torch.tensor([[1, 0], [2, 3]], dtype=torch.long))
assert torch.all(batch_triplets.cpu() == torch.zeros((n_triplets,), dtype=torch.long))


Expand All @@ -95,7 +103,7 @@ def test_calc_triplets_noshift():
def test_calc_triplets_no_triplets(edge_index):
# edge_index = edge_index.to("cuda:0")
# No triplet exist in this graph. Case1: No edge, Case 2 No triplets in this edge.
triplet_node_index, multiplicity, triplet_shift, batch_triplets = calc_triplets(edge_index)
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(edge_index)
# print("triplet_node_index", triplet_node_index.shape, triplet_node_index)
# print("multiplicity", multiplicity.shape, multiplicity)
# print("triplet_shift", triplet_shift.shape, triplet_shift)
Expand All @@ -104,7 +112,7 @@ def test_calc_triplets_no_triplets(edge_index):
# 0 triplets exist.
assert triplet_node_index.shape == (0, 3)
assert multiplicity.shape == (0,)
assert triplet_shift.shape == (0, 3, 3)
assert edge_jk.shape == (0, 2)
assert batch_triplets.shape == (0,)


Expand Down
42 changes: 27 additions & 15 deletions tests/test_torch_dftd3_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,35 @@
import pytest
import torch
from ase import Atoms
from ase.build import fcc111, molecule
from ase.build import bulk, fcc111, molecule
from ase.calculators.dftd3 import DFTD3
from ase.calculators.emt import EMT
from torch_dftd.testing.damping import damping_method_list, damping_xc_combination_list
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator


def _create_atoms() -> List[Atoms]:
@pytest.fixture(
params=[
pytest.param("mol", id="mol"),
pytest.param("slab", id="slab"),
pytest.param("large", marks=[pytest.mark.slow], id="large"),
]
)
def atoms(request) -> Atoms:
"""Initialization"""
atoms = molecule("CH3CH2OCH3")
mol = molecule("CH3CH2OCH3")

slab = fcc111("Au", size=(2, 1, 3), vacuum=80.0)
slab.set_cell(
slab.get_cell().array @ np.array([[1.0, 0.1, 0.2], [0.0, 1.0, 0.3], [0.0, 0.0, 1.0]])
)
slab.pbc = np.array([True, True, True])
return [atoms, slab]

large_bulk = bulk("Pt", "fcc") * (4, 4, 4)

atoms_dict = {"mol": mol, "slab": slab, "large": large_bulk}

return atoms_dict[request.param]


def _assert_energy_equal(calc1, calc2, atoms: Atoms):
Expand Down Expand Up @@ -53,20 +68,21 @@ def _test_calc_energy(damping, xc, old, atoms, device="cpu", dtype=torch.float64
_assert_energy_equal(dftd3_calc, torch_dftd3_calc, atoms)


def _assert_energy_force_stress_equal(calc1, calc2, atoms: Atoms):
def _assert_energy_force_stress_equal(calc1, calc2, atoms: Atoms, force_tol: float = 1e-5):
calc1.reset()
atoms.calc = calc1
f1 = atoms.get_forces()
e1 = atoms.get_potential_energy()
if np.all(atoms.pbc == np.array([True, True, True])):
s1 = atoms.get_stress()

calc2.reset()
atoms.calc = calc2
f2 = atoms.get_forces()
e2 = atoms.get_potential_energy()
assert np.allclose(e1, e2, atol=1e-4, rtol=1e-4)
assert np.allclose(f1, f2, atol=1e-5, rtol=1e-5)
assert np.allclose(f1, f2, atol=force_tol, rtol=force_tol)
if np.all(atoms.pbc == np.array([True, True, True])):
s1 = atoms.get_stress()
s2 = atoms.get_stress()
assert np.allclose(s1, s2, atol=1e-5, rtol=1e-5)

Expand All @@ -83,6 +99,9 @@ def _test_calc_energy_force_stress(
cnthr=15.0,
):
cutoff = 22.0 # Make test faster
force_tol = 1e-5
if dtype == torch.float32:
force_tol = 1.0e-4
with tempfile.TemporaryDirectory() as tmpdirname:
dftd3_calc = DFTD3(
damping=damping,
Expand All @@ -105,25 +124,22 @@ def _test_calc_energy_force_stress(
abc=abc,
bidirectional=bidirectional,
)
_assert_energy_force_stress_equal(dftd3_calc, torch_dftd3_calc, atoms)
_assert_energy_force_stress_equal(dftd3_calc, torch_dftd3_calc, atoms, force_tol=force_tol)


@pytest.mark.parametrize("damping,xc,old", damping_xc_combination_list)
@pytest.mark.parametrize("atoms", _create_atoms())
def test_calc_energy(damping, xc, old, atoms):
"""Test1-1: check damping,xc,old combination works for energy"""
_test_calc_energy(damping, xc, old, atoms, device="cpu")


@pytest.mark.parametrize("damping,xc,old", damping_xc_combination_list)
@pytest.mark.parametrize("atoms", _create_atoms())
def test_calc_energy_force_stress(damping, xc, old, atoms):
"""Test1-2: check damping,xc,old combination works for energy, force & stress"""
_test_calc_energy_force_stress(damping, xc, old, atoms, device="cpu")


@pytest.mark.parametrize("damping,old", damping_method_list)
@pytest.mark.parametrize("atoms", _create_atoms())
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_calc_energy_device(damping, old, atoms, device, dtype):
Expand All @@ -133,7 +149,6 @@ def test_calc_energy_device(damping, old, atoms, device, dtype):


@pytest.mark.parametrize("damping,old", damping_method_list)
@pytest.mark.parametrize("atoms", _create_atoms())
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_calc_energy_force_stress_device(damping, old, atoms, device, dtype):
Expand All @@ -142,7 +157,6 @@ def test_calc_energy_force_stress_device(damping, old, atoms, device, dtype):
_test_calc_energy_force_stress(damping, xc, old, atoms, device=device, dtype=dtype)


@pytest.mark.parametrize("atoms", _create_atoms())
@pytest.mark.parametrize("damping,old", damping_method_list)
def test_calc_energy_force_stress_bidirectional(atoms, damping, old):
"""Test with bidirectional=False"""
Expand All @@ -161,7 +175,6 @@ def test_calc_energy_force_stress_bidirectional(atoms, damping, old):
_assert_energy_force_stress_equal(dftd3_calc, torch_dftd3_calc, atoms)


@pytest.mark.parametrize("atoms", _create_atoms())
@pytest.mark.parametrize("damping,old", damping_method_list)
def test_calc_energy_force_stress_cutoff_smoothing(atoms, damping, old):
"""Test wit cutoff_smoothing."""
Expand Down Expand Up @@ -207,7 +220,6 @@ def test_calc_energy_force_stress_with_dft():


@pytest.mark.parametrize("damping,old", damping_method_list)
@pytest.mark.parametrize("atoms", _create_atoms())
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("dtype", [torch.float64])
@pytest.mark.parametrize("bidirectional", [True, False])
Expand Down
29 changes: 22 additions & 7 deletions tests/test_torch_dftd3_calculator_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,22 @@
import pytest
import torch
from ase import Atoms
from ase.build import fcc111, molecule
from ase.build import bulk, fcc111, molecule
from torch_dftd.testing.damping import damping_method_list
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator


def _create_atoms() -> List[List[Atoms]]:
@pytest.fixture(
params=[
pytest.param("case1", id="mol+slab"),
pytest.param("case2", id="mol+slab(wo_pbc)"),
pytest.param("case3", id="null"),
pytest.param("case4", marks=[pytest.mark.slow], id="large"),
]
)
def atoms_list(request) -> List[Atoms]:
"""Initialization"""
atoms = molecule("CH3CH2OCH3")
mol = molecule("CH3CH2OCH3")

slab = fcc111("Au", size=(2, 1, 3), vacuum=80.0)
slab.pbc = np.array([True, True, True])
Expand All @@ -24,7 +32,17 @@ def _create_atoms() -> List[List[Atoms]]:
slab_wo_pbc.pbc = np.array([False, False, False])

null = Atoms()
return [[atoms, slab], [atoms, slab_wo_pbc], [null]]

large_bulk = bulk("Pt", "fcc") * (8, 8, 8)

atoms_dict = {
"case1": [mol, slab],
"case2": [mol, slab_wo_pbc],
"case3": [null],
"case4": [large_bulk],
}

return atoms_dict[request.param]


def _assert_energy_equal_batch(calc1, atoms_list: List[Atoms]):
Expand Down Expand Up @@ -91,7 +109,6 @@ def _test_calc_energy_force_stress(


@pytest.mark.parametrize("damping,old", damping_method_list)
@pytest.mark.parametrize("atoms_list", _create_atoms())
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_calc_energy_device_batch(damping, old, atoms_list, device, dtype):
Expand All @@ -101,7 +118,6 @@ def test_calc_energy_device_batch(damping, old, atoms_list, device, dtype):


@pytest.mark.parametrize("damping,old", damping_method_list)
@pytest.mark.parametrize("atoms_list", _create_atoms())
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_calc_energy_force_stress_device_batch(damping, old, atoms_list, device, dtype):
Expand All @@ -111,7 +127,6 @@ def test_calc_energy_force_stress_device_batch(damping, old, atoms_list, device,


@pytest.mark.parametrize("damping,old", damping_method_list)
@pytest.mark.parametrize("atoms_list", _create_atoms())
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("bidirectional", [True, False])
@pytest.mark.parametrize("dtype", [torch.float64])
Expand Down
21 changes: 15 additions & 6 deletions tests/test_torch_dftd3_calculator_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,30 @@
import numpy as np
import pytest
from ase import Atoms
from ase.build import fcc111, molecule
from ase.build import bulk, fcc111, molecule
from ase.calculators.dftd3 import DFTD3
from ase.units import Bohr
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator


def _create_atoms() -> List[Atoms]:
@pytest.fixture(
params=[
pytest.param("mol", id="mol"),
pytest.param("slab", id="slab"),
pytest.param("large", marks=[pytest.mark.slow], id="large"),
]
)
def atoms(request) -> Atoms:
"""Initialization"""
atoms = molecule("CH3CH2OCH3")
mol = molecule("CH3CH2OCH3")

slab = fcc111("Au", size=(2, 1, 3), vacuum=80.0)
slab.pbc = np.array([True, True, True])
return [atoms, slab]

large_bulk = bulk("Pt", "fcc") * (4, 4, 4)

atoms_dict = {"mol": mol, "slab": slab, "large": large_bulk}
return atoms_dict[request.param]


def calc_energy(calculator, atoms):
Expand All @@ -35,7 +46,6 @@ def calc_force_stress(calculator, atoms):
return True


@pytest.mark.parametrize("atoms", _create_atoms())
def test_dftd3_calculator_benchmark(atoms, benchmark):
damping = "bj"
xc = "pbe"
Expand All @@ -53,7 +63,6 @@ def test_dftd3_calculator_benchmark(atoms, benchmark):
)


@pytest.mark.parametrize("atoms", _create_atoms())
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
def test_torch_dftd3_calculator_benchmark(atoms, device, benchmark):
damping = "bj"
Expand Down
Loading

0 comments on commit dd8644e

Please sign in to comment.