From b99c5776140a567997549ad0b71cab7a214c6a7a Mon Sep 17 00:00:00 2001 From: ndaelman Date: Tue, 20 Aug 2024 21:27:45 +0200 Subject: [PATCH] Add more APW tests --- .../schema_packages/basis_set.py | 26 ++++++---- tests/test_basis_set.py | 51 +++++++++++++++---- 2 files changed, 58 insertions(+), 19 deletions(-) diff --git a/src/nomad_simulations/schema_packages/basis_set.py b/src/nomad_simulations/schema_packages/basis_set.py index 54b19cc6..306c8477 100644 --- a/src/nomad_simulations/schema_packages/basis_set.py +++ b/src/nomad_simulations/schema_packages/basis_set.py @@ -234,12 +234,7 @@ def _get_open_quantities(self) -> set[str]: def _get_lengths(self, quantities: set[str]) -> list[int]: """Extract the lengths of the `quantities` contained in the set.""" present_quantities = set(quantities) & self._get_open_quantities() - lengths: list[int] = [] - for quant in present_quantities: - length = len(getattr(self, quant)) - if length > 0: # empty lists are exempt - lengths.append(length) - return lengths + return [len(getattr(self, quant)) for quant in present_quantities] def _of_equal_length(self, lengths: list[int]) -> bool: """Check if all elements in the list are of equal length.""" @@ -320,11 +315,11 @@ def n_terms_to_type(self, n_terms: Optional[int]) -> Optional[str]: """ Set the type of the APW orbital based on the differential order. """ - if n_terms is None: + if n_terms is None or n_terms == 0: return None - if n_terms == 0: + if n_terms == 1: return 'apw' - elif n_terms == 1: + elif n_terms == 2: return 'lapw' else: return 'slapw' @@ -386,6 +381,19 @@ def get_n_terms(self) -> Optional[int]: } ) + def bo_terms_to_type(self, bo_terms: Optional[int]) -> Optional[str]: + """ + Set the type of the local orbital based on the boundary order. + """ # ? include differential_order + if bo_terms is None or len(bo_terms) == 0: + return None + if sorted(bo_terms) == [0, 1]: + return 'lo' + elif sorted(bo_terms) == [0, 0, 1]: # ! double-check + return 'LO' + else: + return 'custom' + @check_normalized def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) diff --git a/tests/test_basis_set.py b/tests/test_basis_set.py index dc7ee6ec..3995dfe2 100644 --- a/tests/test_basis_set.py +++ b/tests/test_basis_set.py @@ -100,22 +100,20 @@ def test_full_apw( @pytest.mark.parametrize( - 'ref_n_terms, e, e_n, d_o', + 'ref_n_terms, e, d_o', [ - (None, [0.0], [0, 0], []), # logically inconsistent - (1, [0.0], [0], [0]), # apw - (2, [0.0, 0.0], [0, 0], [0, 1]), # lapw + (None, None, None), # unset + (0, [], []), # empty + (None, [0.0], []), # logically inconsistent + (1, [0.0], [0]), # apw + (2, 2 * [0.0], [0, 1]), # lapw ], ) -def test_apw_base_orbital( - ref_n_terms: Optional[int], e: list[float], e_n: list[int], d_o: list[int] -): +def test_apw_base_orbital(ref_n_terms: Optional[int], e: list[float], d_o: list[int]): orb = APWBaseOrbital( energy_parameter=e, - energy_parameter_n=e_n, differential_order=d_o, ) - assert orb.get_n_terms() == ref_n_terms @@ -124,8 +122,41 @@ def test_apw_base_orbital_normalize(n_terms: Optional[int], ref_n_terms: Optiona orb = APWBaseOrbital( n_terms=n_terms, energy_parameter=[0], - energy_parameter_n=[0], differential_order=[1], ) orb.normalize(None, logger) assert orb.n_terms == ref_n_terms + + +@pytest.mark.parametrize( + 'ref_type, n_terms', + [(None, None), (None, 0), ('apw', 1), ('lapw', 2), ('slapw', 3)], +) +def test_apw_orbital(ref_type: Optional[str], n_terms: Optional[int]): + orb = APWOrbital(n_terms=n_terms) + assert orb.n_terms_to_type(orb.n_terms) == ref_type + + +@pytest.mark.parametrize( + 'ref_n_terms, ref_type, e, d_o, b_o', + [ + (None, None, [0.0], [], []), # logically inconsistent + (1, 'custom', [0.0], [0], [0]), # custom + (2, 'lo', 2 * [0.0], [0, 1], [0, 1]), # lo + (3, 'LO', 3 * [0.0], [0, 1, 0], [0, 1, 0]), # LO + ], +) +def test_apw_local_orbital( + ref_n_terms: Optional[int], + ref_type: str, + e: list[float], + d_o: list[int], + b_o: list[int], +): + orb = APWLocalOrbital( + energy_parameter=e, + differential_order=d_o, + boundary_order=d_o, + ) + assert orb.get_n_terms() == ref_n_terms + assert orb.bo_terms_to_type(orb.boundary_order) == ref_type