From e8fb5ef559c83170c1e0facedc521104f0892fc5 Mon Sep 17 00:00:00 2001 From: EBB2675 Date: Wed, 20 Nov 2024 13:48:53 +0100 Subject: [PATCH] add tests for AtomCenteredBasisSet and AtomCenteredFunction --- .../schema_packages/basis_set.py | 14 +-- tests/test_basis_set.py | 109 ++++++++++++++++++ 2 files changed, 115 insertions(+), 8 deletions(-) diff --git a/src/nomad_simulations/schema_packages/basis_set.py b/src/nomad_simulations/schema_packages/basis_set.py index 5fd5ac03..4aaf55f6 100644 --- a/src/nomad_simulations/schema_packages/basis_set.py +++ b/src/nomad_simulations/schema_packages/basis_set.py @@ -245,15 +245,13 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: # Validation: Check that n_primitive matches the lengths of exponents and contraction coefficients if self.n_primitive is not None: - if len(self.exponents or []) != self.n_primitive: - logger.error( - f"Mismatch in number of exponents: expected {self.n_primitive}, " - f"found {len(self.exponents or [])}." + if self.exponents is not None and len(self.exponents) != self.n_primitive: + raise ValueError( + f"Mismatch in number of exponents: expected {self.n_primitive}, found {len(self.exponents)}." ) - if len(self.contraction_coefficients or []) != self.n_primitive: - logger.error( - f"Mismatch in number of contraction coefficients: expected {self.n_primitive}, " - f"found {len(self.contraction_coefficients or [])}." + if self.contraction_coefficients is not None and len(self.contraction_coefficients) != self.n_primitive: + raise ValueError( + f"Mismatch in number of contraction coefficients: expected {self.n_primitive}, found {len(self.contraction_coefficients)}." ) diff --git a/tests/test_basis_set.py b/tests/test_basis_set.py index b1dcb03c..e96c224b 100644 --- a/tests/test_basis_set.py +++ b/tests/test_basis_set.py @@ -15,6 +15,7 @@ APWOrbital, APWPlaneWaveBasisSet, AtomCenteredBasisSet, + AtomCenteredFunction, BasisSetContainer, MuffinTinRegion, PlaneWaveBasisSet, @@ -418,3 +419,111 @@ def test_quick_step() -> None: ], } # TODO: generate a QuickStep generator in the CP2K plugin + + +@pytest.mark.parametrize( + 'basis_set_name, basis_type, role', + [ + ('cc-pVTZ', 'GTO', 'orbital'), + ('def2-TZVP', 'GTO', 'auxiliary_scf'), + ('aug-cc-pVDZ', 'STO', 'auxiliary_post_hf'), + ('custom_basis', None, None), # Undefined type and role + ], +) +def test_atom_centered_basis_set_init(basis_set_name, basis_type, role) -> None: + """Test initialization of AtomCenteredBasisSet.""" + bs = AtomCenteredBasisSet(basis_set=basis_set_name, type=basis_type, role=role) + assert bs.basis_set == basis_set_name + assert bs.type == basis_type + assert bs.role == role + + +@pytest.mark.parametrize( + 'functions', + [ + [ + AtomCenteredFunction( + basis_type='spherical', + function_type='s', + n_primitive=3, + exponents=[1.0, 2.0, 3.0], + contraction_coefficients=[0.5, 0.3, 0.2], + ), + ], + [ + AtomCenteredFunction( + basis_type='cartesian', + function_type='p', + n_primitive=1, + exponents=[0.5], + contraction_coefficients=[1.0], + ), + AtomCenteredFunction( + basis_type='spherical', + function_type='d', + n_primitive=2, + exponents=[1.0, 2.0], + contraction_coefficients=[0.4, 0.6], + ), + ], + ], +) +def test_atom_centered_basis_set_functional_composition(functions) -> None: + """Test functional composition within AtomCenteredBasisSet.""" + bs = AtomCenteredBasisSet(functional_composition=functions) + assert len(bs.functional_composition) == len(functions) + for f, ref_f in zip(bs.functional_composition, functions): + assert f.basis_type == ref_f.basis_type + assert f.function_type == ref_f.function_type + assert f.n_primitive == ref_f.n_primitive + assert np.allclose(f.exponents, ref_f.exponents) + assert np.allclose(f.contraction_coefficients, ref_f.contraction_coefficients) + + +def test_atom_centered_basis_set_normalize() -> None: + """Test normalization of AtomCenteredBasisSet.""" + bs = AtomCenteredBasisSet( + basis_set='cc-pVTZ', + type='GTO', + role='orbital', + functional_composition=[ + AtomCenteredFunction( + basis_type='spherical', + function_type='s', + n_primitive=2, + exponents=[1.0, 2.0], + contraction_coefficients=[0.5, 0.5], + ) + ], + ) + bs.normalize(None, logger) + # Add checks for normalized behavior, if any + assert bs.basis_set == 'cc-pVTZ' + +def test_atom_centered_basis_set_invalid_data() -> None: + """Test behavior with missing or invalid data.""" + bs = AtomCenteredBasisSet( + basis_set='invalid_basis', + type=None, # Missing type + role=None, # Missing role + ) + assert bs.basis_set == 'invalid_basis' + assert bs.type is None + assert bs.role is None + + # Test functional composition with invalid data + invalid_function = AtomCenteredFunction( + basis_type='spherical', + function_type='s', + n_primitive=2, + exponents=[1.0], # Mismatched length + contraction_coefficients=[0.5, 0.5], + ) + bs.functional_composition = [invalid_function] + + # Call normalize to trigger validation + with pytest.raises(ValueError, match="Mismatch in number of exponents"): + invalid_function.normalize(None, logger) + + +