diff --git a/tests/test_physical_properties.py b/tests/test_physical_properties.py index d66c924e..d17650fe 100644 --- a/tests/test_physical_properties.py +++ b/tests/test_physical_properties.py @@ -18,6 +18,7 @@ import numpy as np import pytest +from typing import Union, Optional from nomad.units import ureg from nomad.datamodel import EntryArchive @@ -26,7 +27,10 @@ from . import logger from nomad_simulations.variables import Variables -from nomad_simulations.physical_property import PhysicalProperty +from nomad_simulations.physical_property import ( + PhysicalProperty, + validate_quantity_wrt_value, +) class DummyPhysicalProperty(PhysicalProperty): @@ -158,3 +162,49 @@ def test_is_derived(self): assert derived_physical_property._is_derived() is True derived_physical_property.normalize(EntryArchive(), logger) assert derived_physical_property.is_derived is True + + +# testing `validate_quantity_wrt_value` decorator +class ValidatingClass: + def __init__(self, value=None, occupation=None): + self.value = value + self.occupation = occupation + + @validate_quantity_wrt_value('occupation') + def validate_occupation(self) -> Union[bool, np.ndarray]: + return self.occupation + + +@pytest.mark.parametrize( + 'value, occupation, result', + [ + (None, None, False), # Both value and occupation are None + (np.array([[1, 2], [3, 4]]), None, False), # occupation is None + (None, np.array([[0.5, 1], [0, 0.5]]), False), # value is None + (np.array([[1, 2], [3, 4]]), np.array([]), False), # occupation is empty + ( + np.array([[1, 2], [3, 4]]), + np.array([[0.5, 1]]), + False, + ), # Shapes do not match + ( + np.array([[1, 2], [3, 4]]), + np.array([[0.5, 1], [0, 0.5]]), + np.array([[0.5, 1], [0, 0.5]]), + ), # Valid case (return `occupation`) + ], +) +def test_validate_quantity_wrt_value( + value: Optional[np.ndarray], + occupation: Optional[np.ndarray], + result: Union[bool, np.ndarray], +): + """ + Test the `validate_quantity_wrt_value` decorator. + """ + obj = ValidatingClass(value=value, occupation=occupation) + validation = obj.validate_occupation() + if isinstance(validation, bool): + assert validation == result + else: + assert np.allclose(validation, result)