diff --git a/src/nomad_simulations/schema_packages/model_system.py b/src/nomad_simulations/schema_packages/model_system.py index 1a9b35a4..73dca7b7 100644 --- a/src/nomad_simulations/schema_packages/model_system.py +++ b/src/nomad_simulations/schema_packages/model_system.py @@ -312,6 +312,10 @@ def __eq__(self, other: 'Cell') -> bool: if not isinstance(other, Cell): return False + # If the `positions` are empty, return False + if self.positions is None or other.positions is None: + return False + # The `positions` should have the same length (same number of positions) if len(self.positions) != len(other.positions): return False @@ -322,6 +326,9 @@ def __eq__(self, other: 'Cell') -> bool: return False return True + def __ne__(self, other: 'Cell') -> bool: + return not self.__eq__(other) + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) @@ -385,6 +392,9 @@ def __eq__(self, other: 'AtomicCell') -> bool: return False return True + def __ne__(self, other: 'AtomicCell') -> bool: + return not self.__eq__(other) + def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]: """ Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell` diff --git a/src/nomad_simulations/schema_packages/utils/__init__.py b/src/nomad_simulations/schema_packages/utils/__init__.py index 3ce600d9..dc4e7ea4 100644 --- a/src/nomad_simulations/schema_packages/utils/__init__.py +++ b/src/nomad_simulations/schema_packages/utils/__init__.py @@ -21,6 +21,5 @@ get_composition, get_sibling_section, get_variables, - is_equal_cell, is_not_representative, ) diff --git a/src/nomad_simulations/schema_packages/utils/utils.py b/src/nomad_simulations/schema_packages/utils/utils.py index ef77573f..e925b53b 100644 --- a/src/nomad_simulations/schema_packages/utils/utils.py +++ b/src/nomad_simulations/schema_packages/utils/utils.py @@ -175,25 +175,3 @@ def get_composition(children_names: 'list[str]') -> str: children_count_tup = np.unique(children_names, return_counts=True) formula = ''.join([f'{name}({count})' for name, count in zip(*children_count_tup)]) return formula if formula else None - - -def is_equal_cell(cell_1: 'Cell', cell_2: 'Cell') -> bool: - """ - Check if the two `Cell` objects are the same by checking if the defined `positions` are all matching. If - the objects are `AtomicCell`, it checks if the `AtomsState[*].chemical_symbol` are the same. - - Args: - cell_1 (Cell): The first `Cell` to compare. - cell_2 (Cell): The second `Cell` to compare. - - Returns: - bool: True if the cells are the same, False otherwise. - """ - # TODO extend this function to compare more information of the cells (`lattice_vectors`) and check ase.Atoms functions - # If any of the cells is None or the positions are empty, return False - if cell_1 is None or cell_2 is None: - return False - if cell_1.positions is None or cell_2.positions is None: - return False - - return cell_1 == cell_2 diff --git a/tests/test_model_system.py b/tests/test_model_system.py index 6ead47e8..e1610971 100644 --- a/tests/test_model_system.py +++ b/tests/test_model_system.py @@ -22,7 +22,10 @@ import pytest from nomad.datamodel import EntryArchive +from nomad_simulations.schema_packages.atoms_state import AtomsState from nomad_simulations.schema_packages.model_system import ( + AtomicCell, + Cell, ChemicalFormula, ModelSystem, Symmetry, @@ -32,11 +35,160 @@ from .conftest import generate_atomic_cell +class TestCell: + """ + Test the `Cell` section defined in model_system.py + """ + + @pytest.mark.parametrize( + 'cell_1, cell_2, result', + [ + (Cell(), None, False), # one cell is None + (Cell(), Cell(), False), # both cells are empty + ( + Cell(positions=[[1, 0, 0]]), + Cell(), + False, + ), # one cell has positions, the other is empty + ( + Cell(positions=[[1, 0, 0], [0, 1, 0]]), + Cell(positions=[[1, 0, 0]]), + False, + ), # length mismatch + ( + Cell(positions=[[1, 0, 0], [0, 1, 0]]), + Cell(positions=[[1, 0, 0], [0, -1, 0]]), + False, + ), # different positions + ( + Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + True, + ), # same ordered positions + ( + Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + Cell(positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]]), + True, + ), # different ordered positions but same cell + ], + ) + def test_eq(self, cell_1: Cell, cell_2: Cell, result: bool): + """ + Test the `__eq__` operator function of `Cell`. + """ + assert (cell_1 == cell_2) == result + + class TestAtomicCell: """ Test the `AtomicCell`, `Cell` and `GeometricSpace` classes defined in model_system.py """ + @pytest.mark.parametrize( + 'cell_1, cell_2, result', + [ + (Cell(), None, False), # one cell is None + (Cell(), Cell(), False), # both cells are empty + ( + Cell(positions=[[1, 0, 0]]), + Cell(), + False, + ), # one cell has positions, the other is empty + ( + Cell(positions=[[1, 0, 0], [0, 1, 0]]), + Cell(positions=[[1, 0, 0]]), + False, + ), # length mismatch + ( + Cell(positions=[[1, 0, 0], [0, 1, 0]]), + Cell(positions=[[1, 0, 0], [0, -1, 0]]), + False, + ), # different positions + ( + Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + True, + ), # same ordered positions + ( + Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + Cell(positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]]), + True, + ), # different ordered positions but same cell + ( + AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + False, + ), # one atomic cell and another cell (missing chemical symbols) + ( + AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + False, + ), # missing chemical symbols + ( + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), + True, + ), # same ordered positions and chemical symbols + ( + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='Cu'), + AtomsState(chemical_symbol='O'), + ], + ), + False, + ), # same ordered positions but different chemical symbols + ( + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), + AtomicCell( + positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + AtomsState(chemical_symbol='H'), + ], + ), + True, + ), # different ordered positions but same chemical symbols + ], + ) + def test_eq(self, cell_1: Cell, cell_2: Cell, result: bool): + """ + Test the `__eq__` operator function of `AtomicCell`. + """ + assert (cell_1 == cell_2) == result + @pytest.mark.parametrize( 'chemical_symbols, atomic_numbers, formula, lattice_vectors, positions, periodic_boundary_conditions', [ diff --git a/tests/test_utils.py b/tests/test_utils.py index 59a56098..2f6e4de9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -96,110 +96,3 @@ def test_get_variables(variables: list, result: list, result_length: int): assert len(energies) == result_length for i, energy in enumerate(energies): # asserting energies == result does not work assert energy.n_points == result[i].n_points - - -@pytest.mark.parametrize( - 'cell_1, cell_2, result', - [ - (None, None, False), # both are None - (Cell(), None, False), # one cell is None - (Cell(), Cell(), False), # both cells are empty - ( - Cell(positions=[[1, 0, 0]]), - Cell(), - False, - ), # one cell has positions, the other is empty - ( - Cell(positions=[[1, 0, 0], [0, 1, 0]]), - Cell(positions=[[1, 0, 0]]), - False, - ), # length mismatch - ( - Cell(positions=[[1, 0, 0], [0, 1, 0]]), - Cell(positions=[[1, 0, 0], [0, -1, 0]]), - False, - ), # different positions - ( - Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - True, - ), # same ordered positions - ( - Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - Cell(positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]]), - True, - ), # different ordered positions but same cell - ( - AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - False, - ), # one atomic cell and another cell (missing chemical symbols) - ( - AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - False, - ), # missing chemical symbols - ( - AtomicCell( - positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], - atoms_state=[ - AtomsState(chemical_symbol='H'), - AtomsState(chemical_symbol='H'), - AtomsState(chemical_symbol='O'), - ], - ), - AtomicCell( - positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], - atoms_state=[ - AtomsState(chemical_symbol='H'), - AtomsState(chemical_symbol='H'), - AtomsState(chemical_symbol='O'), - ], - ), - True, - ), # same ordered positions and chemical symbols - ( - AtomicCell( - positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], - atoms_state=[ - AtomsState(chemical_symbol='H'), - AtomsState(chemical_symbol='H'), - AtomsState(chemical_symbol='O'), - ], - ), - AtomicCell( - positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], - atoms_state=[ - AtomsState(chemical_symbol='H'), - AtomsState(chemical_symbol='Cu'), - AtomsState(chemical_symbol='O'), - ], - ), - False, - ), # same ordered positions but different chemical symbols - ( - AtomicCell( - positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], - atoms_state=[ - AtomsState(chemical_symbol='H'), - AtomsState(chemical_symbol='H'), - AtomsState(chemical_symbol='O'), - ], - ), - AtomicCell( - positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]], - atoms_state=[ - AtomsState(chemical_symbol='H'), - AtomsState(chemical_symbol='O'), - AtomsState(chemical_symbol='H'), - ], - ), - True, - ), # different ordered positions but same chemical symbols - ], -) -def test_is_equal_cell(cell_1: Cell, cell_2: Cell, result: bool): - """ - Test the `is_equal_cell` utility function. - """ - assert is_equal_cell(cell_1=cell_1, cell_2=cell_2) == result