diff --git a/mlspm/graph/_molecule_graph.py b/mlspm/graph/_molecule_graph.py index 60a90fb..7183843 100755 --- a/mlspm/graph/_molecule_graph.py +++ b/mlspm/graph/_molecule_graph.py @@ -276,8 +276,10 @@ def transform_xy( """ Transform atom positions in the xy plane. + Transformations are perfomed in the order: shift, rotate, flip x, flip y + Arguments: - shift: Shift atom positions in xy plane. Performed before rotation and flip. + shift: Shift atom positions in xy plane. rot_xy: Rotate atoms in xy plane by rot_xy degrees around center point. flip_x: Mirror atom positions in x direction with respect to the center point. flip_y: Mirror atom positions in y direction with respect to the center point. diff --git a/tests/test_graph.py b/tests/test_graph.py index a92ac29..7ce12ac 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,13 +1,14 @@ -from pathlib import Path import shutil -import torch -import pytest +from pathlib import Path + import numpy as np +import pytest +import torch def test_collate_graph(): - from mlspm.graph import MoleculeGraph from mlspm.data_loading import collate_graph + from mlspm.graph import MoleculeGraph # fmt: off @@ -217,7 +218,7 @@ def test_molecule_graph_array(): def test_molecule_graph_remove_atoms(): - from mlspm.graph import MoleculeGraph, Atom + from mlspm.graph import Atom, MoleculeGraph # fmt: off @@ -245,7 +246,7 @@ def test_molecule_graph_remove_atoms(): assert removed == [] new_molecule, removed = molecule.remove_atoms([1]) - removed_expected = [(Atom(np.array([1.0, 1.0, 0.0]), 2), [0, 1, 0])] + removed_expected = [(Atom(np.array([1.0, 1.0, 0.0]), 'H'), [0, 1, 0])] atoms_expected = np.array([ [0.0, 0.0, 0.0, 1], [1.0, 0.0, 0.0, 3], @@ -333,10 +334,80 @@ def test_molecule_graph_add_atom(): assert a == b -def test_GraphSeqStats(): - from mlspm.graph import GraphStats +def test_molecule_graph_transform_xy(): + from mlspm.graph import MoleculeGraph + + # fmt:off + atoms = np.array([ + [0.0, 0.0, 0.0, 1], + [1.0, 1.0, 1.0, 2], + [1.0, 0.0, 0.0, 3], + [2.0, 0.0, -1.0, 4] + ]) + # fmt:on + bonds = [(0, 2), (1, 2), (2, 3)] + molecule = MoleculeGraph(atoms, bonds) + + molecule_transformed = molecule.transform_xy(shift=(1, 1), rot_xy=90, flip_x=True, flip_y=True, center=(1, 1)) + + xyz_transformed = molecule_transformed.array(xyz=True) + # fmt:off + xyz_expected = np.array( + [ + [1.0, 1.0, 0.0], + [2.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [1.0, -1.0, -1.0] + ] + ) + # fmt:on + + assert np.allclose(xyz_transformed, xyz_expected) + + +def test_molecule_graph_crop_atoms(): + from mlspm.graph import MoleculeGraph + + # fmt:off + atoms = np.array([ + [0.0, 0.0, 0.0, 1], + [1.0, 1.0, 1.0, 2], + [1.0, 0.0, 0.0, 3], + [2.0, 0.0, -1.0, 4] + ]) + # fmt:on + bonds = [(0, 2), (1, 2), (2, 3)] + molecule = MoleculeGraph(atoms, bonds) + box_borders = np.array([[-0.5, -0.5, -0.5], [1.5, 1.5, 0.5]]) + + molecule_cropped = molecule.crop_atoms(box_borders) + + xyz_cropped = molecule_cropped.array(xyz=True) + # fmt:off + xyz_expected = np.array( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + ] + ) + # fmt:on + + assert np.allclose(xyz_cropped, xyz_expected) + assert molecule_cropped.bonds == [(0, 1)] + + +def test_molecule_graph_randomize_positions(): from mlspm.graph import MoleculeGraph + molecule = MoleculeGraph(np.zeros((3, 4)), []) + molecule_randomized = molecule.randomize_positions() + + assert not np.allclose(molecule_randomized.array(xyz=True), 0.0) + + +def test_GraphSeqStats(): + from mlspm.graph import GraphStats, MoleculeGraph + classes = [[0], [1], [2]] # fmt:off