Skip to content

Commit

Permalink
Added tests for MoleculeGraph.
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Nov 25, 2023
1 parent 487a30a commit 62c4d7b
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 9 deletions.
4 changes: 3 additions & 1 deletion mlspm/graph/_molecule_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,10 @@ def transform_xy(
"""
Transform atom positions in the xy plane.
Transformations are perfomed in the order: shift, rotate, flip x, flip y
Arguments:
shift: Shift atom positions in xy plane. Performed before rotation and flip.
shift: Shift atom positions in xy plane.
rot_xy: Rotate atoms in xy plane by rot_xy degrees around center point.
flip_x: Mirror atom positions in x direction with respect to the center point.
flip_y: Mirror atom positions in y direction with respect to the center point.
Expand Down
87 changes: 79 additions & 8 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from pathlib import Path
import shutil
import torch
import pytest
from pathlib import Path

import numpy as np
import pytest
import torch


def test_collate_graph():
from mlspm.graph import MoleculeGraph
from mlspm.data_loading import collate_graph
from mlspm.graph import MoleculeGraph

# fmt: off

Expand Down Expand Up @@ -217,7 +218,7 @@ def test_molecule_graph_array():


def test_molecule_graph_remove_atoms():
from mlspm.graph import MoleculeGraph, Atom
from mlspm.graph import Atom, MoleculeGraph

# fmt: off

Expand Down Expand Up @@ -245,7 +246,7 @@ def test_molecule_graph_remove_atoms():
assert removed == []

new_molecule, removed = molecule.remove_atoms([1])
removed_expected = [(Atom(np.array([1.0, 1.0, 0.0]), 2), [0, 1, 0])]
removed_expected = [(Atom(np.array([1.0, 1.0, 0.0]), 'H'), [0, 1, 0])]
atoms_expected = np.array([
[0.0, 0.0, 0.0, 1],
[1.0, 0.0, 0.0, 3],
Expand Down Expand Up @@ -333,10 +334,80 @@ def test_molecule_graph_add_atom():
assert a == b


def test_GraphSeqStats():
from mlspm.graph import GraphStats
def test_molecule_graph_transform_xy():
from mlspm.graph import MoleculeGraph

# fmt:off
atoms = np.array([
[0.0, 0.0, 0.0, 1],
[1.0, 1.0, 1.0, 2],
[1.0, 0.0, 0.0, 3],
[2.0, 0.0, -1.0, 4]
])
# fmt:on
bonds = [(0, 2), (1, 2), (2, 3)]
molecule = MoleculeGraph(atoms, bonds)

molecule_transformed = molecule.transform_xy(shift=(1, 1), rot_xy=90, flip_x=True, flip_y=True, center=(1, 1))

xyz_transformed = molecule_transformed.array(xyz=True)
# fmt:off
xyz_expected = np.array(
[
[1.0, 1.0, 0.0],
[2.0, 0.0, 1.0],
[1.0, 0.0, 0.0],
[1.0, -1.0, -1.0]
]
)
# fmt:on

assert np.allclose(xyz_transformed, xyz_expected)


def test_molecule_graph_crop_atoms():
from mlspm.graph import MoleculeGraph

# fmt:off
atoms = np.array([
[0.0, 0.0, 0.0, 1],
[1.0, 1.0, 1.0, 2],
[1.0, 0.0, 0.0, 3],
[2.0, 0.0, -1.0, 4]
])
# fmt:on
bonds = [(0, 2), (1, 2), (2, 3)]
molecule = MoleculeGraph(atoms, bonds)
box_borders = np.array([[-0.5, -0.5, -0.5], [1.5, 1.5, 0.5]])

molecule_cropped = molecule.crop_atoms(box_borders)

xyz_cropped = molecule_cropped.array(xyz=True)
# fmt:off
xyz_expected = np.array(
[
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
]
)
# fmt:on

assert np.allclose(xyz_cropped, xyz_expected)
assert molecule_cropped.bonds == [(0, 1)]


def test_molecule_graph_randomize_positions():
from mlspm.graph import MoleculeGraph

molecule = MoleculeGraph(np.zeros((3, 4)), [])
molecule_randomized = molecule.randomize_positions()

assert not np.allclose(molecule_randomized.array(xyz=True), 0.0)


def test_GraphSeqStats():
from mlspm.graph import GraphStats, MoleculeGraph

classes = [[0], [1], [2]]

# fmt:off
Expand Down

0 comments on commit 62c4d7b

Please sign in to comment.