diff --git a/simulationdataschema/atoms_state.py b/simulationdataschema/atoms_state.py index d62de417..5584f913 100644 --- a/simulationdataschema/atoms_state.py +++ b/simulationdataschema/atoms_state.py @@ -178,7 +178,7 @@ def __init__(self, m_def: Section = None, m_context: Context = None, **kwargs): 'ms_numbers': dict((zip(('down', 'up'), (-0.5, 0.5)))), } - def _quantum_numbers_check(self, logger: BoundLogger) -> bool: + def _check_quantum_numbers(self, logger: BoundLogger) -> bool: """ Checks the physicality of the quantum numbers. @@ -191,10 +191,10 @@ def _quantum_numbers_check(self, logger: BoundLogger) -> bool: if self.n_quantum_number < 1: logger.error('The `n_quantum_number` must be greater than 0.') return False - if self.l_quantum_number and self.l_quantum_number < 1: + if self.l_quantum_number is not None and self.l_quantum_number < 1: logger.error('The `l_quantum_number` must be greater than 0.') return False - if self.ml_quantum_number and ( + if self.ml_quantum_number is not None and ( self.ml_quantum_number < -self.l_quantum_number or self.ml_quantum_number > self.l_quantum_number ): @@ -202,7 +202,10 @@ def _quantum_numbers_check(self, logger: BoundLogger) -> bool: 'The `ml_quantum_number` must be between `-l_quantum_number` and `l_quantum_number`.' ) return False - if self.ms_quantum_number and self.ms_quantum_number not in [-0.5, 0.5]: + if self.ms_quantum_number is not None and self.ms_quantum_number not in [ + -0.5, + 0.5, + ]: logger.error('The `ms_quantum_number` must be -0.5 or 0.5.') return False return True @@ -311,7 +314,7 @@ def normalize(self, archive, logger) -> None: super().normalize(archive, logger) # General checks for physical quantum numbers and symbols - if not self._quantum_numbers_check(logger): + if not self._check_quantum_numbers(logger): logger.error('The quantum numbers are not physical.') return diff --git a/tests/test_atoms_state.py b/tests/test_atoms_state.py index 383dcfe3..f037e3e7 100644 --- a/tests/test_atoms_state.py +++ b/tests/test_atoms_state.py @@ -59,6 +59,27 @@ def add_state( def orbital_state(self) -> OrbitalsState: return OrbitalsState(n_quantum_number=2) + @pytest.mark.parametrize( + 'number, values, results', + [ + ('n_quantum_number', [-1, 0, 1, 2], [False, False, True, True]), + ('l_quantum_number', [-2, 0, 1, 2], [False, False, True, True]), + # l_quantum_number == 2 when testing 'ml_quantum_number' + ('ml_quantum_number', [-3, 5, -2, 1], [False, False, True, True]), + ('ms_quantum_number', [0, 10, -0.5, 0.5], [False, False, True, True]), + ], + ) + def test_check_quantum_numbers(self, orbital_state, number, values, results): + """ + Test the quantum number check for the `OrbitalsState` section. + """ + for val, res in zip(values, results): + if number == 'ml_quantum_number': + orbital_state.l_quantum_number = 2 + setattr(orbital_state, number, val) + check = orbital_state._check_quantum_numbers(self.logger) + assert check == res + @pytest.mark.parametrize( 'quantum_name, quantum_type, value, countertype, expected_result', [