diff --git a/RAT/__init__.py b/RAT/__init__.py index 2d319eed..a417d8cd 100644 --- a/RAT/__init__.py +++ b/RAT/__init__.py @@ -1,3 +1,4 @@ from RAT.classlist import ClassList -from RAT.controls import Controls from RAT.project import Project +import RAT.controls +import RAT.models diff --git a/RAT/controls.py b/RAT/controls.py index 3b4a1a56..2c16cf34 100644 --- a/RAT/controls.py +++ b/RAT/controls.py @@ -1,12 +1,14 @@ import prettytable -from pydantic import BaseModel, Field, field_validator -from typing import Union +from pydantic import BaseModel, Field, field_validator, ValidationError +from typing import Literal, Union from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions +from RAT.utils.custom_errors import custom_pydantic_validation_error -class BaseProcedure(BaseModel, validate_assignment=True, extra='forbid'): - """Defines the base class with properties used in all five procedures.""" +class Calculate(BaseModel, validate_assignment=True, extra='forbid'): + """Defines the class for the calculate procedure, which includes the properties used in all five procedures.""" + procedure: Literal[Procedures.Calculate] = Procedures.Calculate parallel: ParallelOptions = ParallelOptions.Single calcSldDuringFit: bool = False resamPars: list[float] = Field([0.9, 50], min_length=2, max_length=2) @@ -21,15 +23,16 @@ def check_resamPars(cls, resamPars): raise ValueError('resamPars[1] must be greater than or equal to 0') return resamPars - -class Calculate(BaseProcedure, validate_assignment=True, extra='forbid'): - """Defines the class for the calculate procedure.""" - procedure: Procedures = Field(Procedures.Calculate, frozen=True) + def __repr__(self) -> str: + table = prettytable.PrettyTable() + table.field_names = ['Property', 'Value'] + table.add_rows([[k, v] for k, v in self.__dict__.items()]) + return table.get_string() -class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'): - """Defines the class for the simplex procedure.""" - procedure: Procedures = Field(Procedures.Simplex, frozen=True) +class Simplex(Calculate, validate_assignment=True, extra='forbid'): + """Defines the additional fields for the simplex procedure.""" + procedure: Literal[Procedures.Simplex] = Procedures.Simplex tolX: float = Field(1.0e-6, gt=0.0) tolFun: float = Field(1.0e-6, gt=0.0) maxFunEvals: int = Field(10000, gt=0) @@ -38,9 +41,9 @@ class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'): updatePlotFreq: int = -1 -class DE(BaseProcedure, validate_assignment=True, extra='forbid'): - """Defines the class for the Differential Evolution procedure.""" - procedure: Procedures = Field(Procedures.DE, frozen=True) +class DE(Calculate, validate_assignment=True, extra='forbid'): + """Defines the additional fields for the Differential Evolution procedure.""" + procedure: Literal[Procedures.DE] = Procedures.DE populationSize: int = Field(20, ge=1) fWeight: float = 0.5 crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0) @@ -49,18 +52,18 @@ class DE(BaseProcedure, validate_assignment=True, extra='forbid'): numGenerations: int = Field(500, ge=1) -class NS(BaseProcedure, validate_assignment=True, extra='forbid'): - """Defines the class for the Nested Sampler procedure.""" - procedure: Procedures = Field(Procedures.NS, frozen=True) +class NS(Calculate, validate_assignment=True, extra='forbid'): + """Defines the additional fields for the Nested Sampler procedure.""" + procedure: Literal[Procedures.NS] = Procedures.NS Nlive: int = Field(150, ge=1) Nmcmc: float = Field(0.0, ge=0.0) propScale: float = Field(0.1, gt=0.0, lt=1.0) nsTolerance: float = Field(0.1, ge=0.0) -class Dream(BaseProcedure, validate_assignment=True, extra='forbid'): - """Defines the class for the Dream procedure.""" - procedure: Procedures = Field(Procedures.Dream, frozen=True) +class Dream(Calculate, validate_assignment=True, extra='forbid'): + """Defines the additional fields for the Dream procedure.""" + procedure: Literal[Procedures.Dream] = Procedures.Dream nSamples: int = Field(50000, ge=0) nChains: int = Field(10, gt=0) jumpProb: float = Field(0.5, gt=0.0, lt=1.0) @@ -68,33 +71,29 @@ class Dream(BaseProcedure, validate_assignment=True, extra='forbid'): boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold -class Controls: - - def __init__(self, - procedure: Procedures = Procedures.Calculate, - **properties) -> None: - - if procedure == Procedures.Calculate: - self.controls = Calculate(**properties) - elif procedure == Procedures.Simplex: - self.controls = Simplex(**properties) - elif procedure == Procedures.DE: - self.controls = DE(**properties) - elif procedure == Procedures.NS: - self.controls = NS(**properties) - elif procedure == Procedures.Dream: - self.controls = Dream(**properties) - - @property - def controls(self) -> Union[Calculate, Simplex, DE, NS, Dream]: - return self._controls - - @controls.setter - def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None: - self._controls = value - - def __repr__(self) -> str: - table = prettytable.PrettyTable() - table.field_names = ['Property', 'Value'] - table.add_rows([[k, v] for k, v in self._controls.__dict__.items()]) - return table.get_string() +def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\ + -> Union[Calculate, Simplex, DE, NS, Dream]: + """Returns the appropriate controls model given the specified procedure.""" + controls = { + Procedures.Calculate: Calculate, + Procedures.Simplex: Simplex, + Procedures.DE: DE, + Procedures.NS: NS, + Procedures.Dream: Dream + } + + try: + model = controls[procedure](**properties) + except KeyError: + members = list(Procedures.__members__.values()) + allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}' + raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None + except ValidationError as exc: + custom_error_msgs = {'extra_forbidden': f'Extra inputs are not permitted. The fields for the {procedure}' + f' controls procedure are:\n ' + f'{", ".join(controls[procedure].model_fields.keys())}\n' + } + custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs) + raise ValidationError.from_exception_data(exc.title, custom_error_list) from None + + return model diff --git a/RAT/project.py b/RAT/project.py index 77827e08..661bff0a 100644 --- a/RAT/project.py +++ b/RAT/project.py @@ -10,7 +10,7 @@ from RAT.classlist import ClassList import RAT.models -from RAT.utils.custom_errors import formatted_pydantic_error +from RAT.utils.custom_errors import custom_pydantic_validation_error try: from enum import StrEnum @@ -524,11 +524,10 @@ def wrapped_func(*args, **kwargs): try: return_value = func(*args, **kwargs) Project.model_validate(self) - except ValidationError as e: + except ValidationError as exc: setattr(class_list, 'data', previous_state) - error_string = formatted_pydantic_error(e) - # Use ANSI escape sequences to print error text in red - print('\033[31m' + error_string + '\033[0m') + custom_error_list = custom_pydantic_validation_error(exc.errors()) + raise ValidationError.from_exception_data(exc.title, custom_error_list) from None except (TypeError, ValueError): setattr(class_list, 'data', previous_state) raise diff --git a/RAT/utils/custom_errors.py b/RAT/utils/custom_errors.py index 269024b6..2fd7d211 100644 --- a/RAT/utils/custom_errors.py +++ b/RAT/utils/custom_errors.py @@ -1,26 +1,36 @@ """Defines routines for custom error handling in RAT.""" +import pydantic_core -from pydantic import ValidationError +def custom_pydantic_validation_error(error_list: list[pydantic_core.ErrorDetails], custom_errors: dict[str, str] = None + ) -> list[pydantic_core.ErrorDetails]: + """Run through the list of errors generated from a pydantic ValidationError, substituting the standard error for a + PydanticCustomError for a given set of error types. -def formatted_pydantic_error(error: ValidationError) -> str: - """Write a custom string format for pydantic validation errors. + For errors that do not have a custom error message defined, we redefine them using a PydanticCustomError to remove + the url from the error message. Parameters ---------- - error : pydantic.ValidationError - A ValidationError produced by a pydantic model + error_list : list[pydantic_core.ErrorDetails] + A list of errors produced by pydantic.ValidationError.errors(). + custom_errors: dict[str, str], optional + A dict of custom error messages for given error types. Returns ------- - error_str : str - A string giving details of the ValidationError in a custom format. + new_error : list[pydantic_core.ErrorDetails] + A list of errors including PydanticCustomErrors in place of the error types in custom_errors. """ - num_errors = error.error_count() - error_str = f'{num_errors} validation error{"s"[:num_errors!=1]} for {error.title}' - for this_error in error.errors(): - error_str += '\n' - if this_error['loc']: - error_str += ' '.join(this_error['loc']) + '\n' - error_str += ' ' + this_error['msg'] - return error_str + if custom_errors is None: + custom_errors = {} + custom_error_list = [] + for error in error_list: + if error['type'] in custom_errors: + RAT_custom_error = pydantic_core.PydanticCustomError(error['type'], custom_errors[error['type']]) + else: + RAT_custom_error = pydantic_core.PydanticCustomError(error['type'], error['msg']) + error['type'] = RAT_custom_error + custom_error_list.append(error) + + return custom_error_list diff --git a/tests/test_controls.py b/tests/test_controls.py index e1c62efd..fd128bbc 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -1,28 +1,30 @@ -"""Tests for control and procedure classes""" +"""Test the controls module.""" import pytest import pydantic from typing import Union, Any -from RAT.controls import BaseProcedure, Calculate, Simplex, DE, NS, Dream, Controls + +from RAT.controls import Calculate, Simplex, DE, NS, Dream, set_controls from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions -class TestBaseProcedure: - """Tests the BaseProcedure class.""" +class TestCalculate: + """Tests the Calculate class.""" @pytest.fixture(autouse=True) def setup_class(self): - self.base_procedure = BaseProcedure() + self.calculate = Calculate() @pytest.mark.parametrize("control_property, value", [ ('parallel', ParallelOptions.Single), ('calcSldDuringFit', False), ('resamPars', [0.9, 50]), - ('display', DisplayOptions.Iter) + ('display', DisplayOptions.Iter), + ('procedure', Procedures.Calculate) ]) - def test_base_property_values(self, control_property: str, value: Any) -> None: - """Tests the default values of BaseProcedure class.""" - assert getattr(self.base_procedure, control_property) == value + def test_calculate_property_values(self, control_property: str, value: Any) -> None: + """Tests the default values of Calculate class.""" + assert getattr(self.calculate, control_property) == value @pytest.mark.parametrize("control_property, value", [ ('parallel', ParallelOptions.All), @@ -30,105 +32,91 @@ def test_base_property_values(self, control_property: str, value: Any) -> None: ('resamPars', [0.2, 1]), ('display', DisplayOptions.Notify) ]) - def test_base_property_setters(self, control_property: str, value: Any) -> None: - """Tests the setters in BaseProcedure class.""" - setattr(self.base_procedure, control_property, value) - assert getattr(self.base_procedure, control_property) == value + def test_calculate_property_setters(self, control_property: str, value: Any) -> None: + """Tests the setters of Calculate class.""" + setattr(self.calculate, control_property, value) + assert getattr(self.calculate, control_property) == value @pytest.mark.parametrize("var1, var2", [('test', True), ('ALL', 1), ("Contrast", 3.0)]) - def test_base_parallel_validation(self, var1: str, var2: Any) -> None: - """Tests the parallel setter validation in BaseProcedure class.""" + def test_calculate_parallel_validation(self, var1: str, var2: Any) -> None: + """Tests the parallel setter validation in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'parallel', var1) + setattr(self.calculate, 'parallel', var1) assert exp.value.errors()[0]['msg'] == "Input should be 'single', 'points', 'contrasts' or 'all'" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'parallel', var2) + setattr(self.calculate, 'parallel', var2) assert exp.value.errors()[0]['msg'] == "Input should be a valid string" @pytest.mark.parametrize("value", [5.0, 12]) - def test_base_calcSldDuringFit_validation(self, value: Union[int, float]) -> None: - """Tests the calcSldDuringFit setter validation in BaseProcedure class.""" + def test_calculate_calcSldDuringFit_validation(self, value: Union[int, float]) -> None: + """Tests the calcSldDuringFit setter validation in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'calcSldDuringFit', value) + setattr(self.calculate, 'calcSldDuringFit', value) assert exp.value.errors()[0]['msg'] == "Input should be a valid boolean, unable to interpret input" @pytest.mark.parametrize("var1, var2", [('test', True), ('iterate', 1), ("FINAL", 3.0)]) - def test_base_display_validation(self, var1: str, var2: Any) -> None: - """Tests the display setter validation in BaseProcedure class.""" + def test_calculate_display_validation(self, var1: str, var2: Any) -> None: + """Tests the display setter validation in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'display', var1) + setattr(self.calculate, 'display', var1) assert exp.value.errors()[0]['msg'] == "Input should be 'off', 'iter', 'notify' or 'final'" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'display', var2) + setattr(self.calculate, 'display', var2) assert exp.value.errors()[0]['msg'] == "Input should be a valid string" @pytest.mark.parametrize("value, msg", [ ([5.0], "List should have at least 2 items after validation, not 1"), ([12, 13, 14], "List should have at most 2 items after validation, not 3") ]) - def test_base_resamPars_lenght_validation(self, value: list, msg: str) -> None: - """Tests the resamPars setter length validation in BaseProcedure class.""" + def test_calculate_resamPars_length_validation(self, value: list, msg: str) -> None: + """Tests the resamPars setter length validation in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'resamPars', value) + setattr(self.calculate, 'resamPars', value) assert exp.value.errors()[0]['msg'] == msg @pytest.mark.parametrize("value, msg", [ ([1.0, 2], "Value error, resamPars[0] must be between 0 and 1"), ([0.5, -0.1], "Value error, resamPars[1] must be greater than or equal to 0") ]) - def test_base_resamPars_value_validation(self, value: list, msg: str) -> None: - """Tests the resamPars setter value validation in BaseProcedure class.""" + def test_calculate_resamPars_value_validation(self, value: list, msg: str) -> None: + """Tests the resamPars setter value validation in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'resamPars', value) + setattr(self.calculate, 'resamPars', value) assert exp.value.errors()[0]['msg'] == msg - def test_base_extra_property_error(self) -> None: - """Tests the extra property setter in BaseProcedure class.""" - with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'test', 1) - assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" - - -class TestCalculate: - """Tests the Calculate class.""" - - @pytest.fixture(autouse=True) - def setup_class(self): - self.calculate = Calculate() - - @pytest.mark.parametrize("control_property, value", [ - ('parallel', ParallelOptions.Single), - ('calcSldDuringFit', False), - ('resamPars', [0.9, 50]), - ('display', DisplayOptions.Iter), - ('procedure', Procedures.Calculate) - ]) - def test_calculate_property_values(self, control_property: str, value: Any) -> None: - """Tests the default values of Calculate class.""" - assert getattr(self.calculate, control_property) == value - - @pytest.mark.parametrize("control_property, value", [ - ('parallel', ParallelOptions.All), - ('calcSldDuringFit', True), - ('resamPars', [0.2, 1]), - ('display', DisplayOptions.Notify) - ]) - def test_calculate_property_setters(self, control_property: str, value: Any) -> None: - """Tests the setters of Calculate class.""" - setattr(self.calculate, control_property, value) - assert getattr(self.calculate, control_property) == value - def test_calculate_extra_property_error(self) -> None: """Tests the extra property setter in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.calculate, 'test', 1) assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" - def test_calculate_procedure_error(self) -> None: - """Tests the procedure property frozen in Calculate class.""" + def test_calculate_initialise_procedure_error(self) -> None: + """Tests the procedure property can only be initialised as "calculate" in Calculate class.""" + with pytest.raises(pydantic.ValidationError) as exp: + Calculate(procedure='test') + assert exp.value.errors()[0]['msg'] == "Input should be " + + def test_calculate_set_procedure_error(self) -> None: + """Tests the procedure property is frozen in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.calculate, 'procedure', 'test') - assert exp.value.errors()[0]['msg'] == "Field is frozen" + assert exp.value.errors()[0]['msg'] == "Input should be " + + def test_repr(self) -> None: + """Tests the Calculate model __repr__.""" + table = self.calculate.__repr__() + table_str = ("+------------------+-----------+\n" + "| Property | Value |\n" + "+------------------+-----------+\n" + "| procedure | calculate |\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resamPars | [0.9, 50] |\n" + "| display | iter |\n" + "+------------------+-----------+" + ) + + assert table == table_str class TestSimplex: @@ -190,11 +178,39 @@ def test_simplex_extra_property_error(self) -> None: setattr(self.simplex, 'test', 1) assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" - def test_simplex_procedure_error(self) -> None: - """Tests the procedure property frozen in Simplex class.""" + def test_simplex_initialise_procedure_error(self) -> None: + """Tests the procedure property can only be initialised as "simplex" in Simplex class.""" + with pytest.raises(pydantic.ValidationError) as exp: + Simplex(procedure='test') + assert exp.value.errors()[0]['msg'] == "Input should be " + + def test_simplex_set_procedure_error(self) -> None: + """Tests the procedure property is frozen in Simplex class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.simplex, 'procedure', 'test') - assert exp.value.errors()[0]['msg'] == "Field is frozen" + assert exp.value.errors()[0]['msg'] == "Input should be " + + def test_repr(self) -> None: + """Tests the Simplex model __repr__.""" + table = self.simplex.__repr__() + table_str = ("+------------------+-----------+\n" + "| Property | Value |\n" + "+------------------+-----------+\n" + "| procedure | simplex |\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resamPars | [0.9, 50] |\n" + "| display | iter |\n" + "| tolX | 1e-06 |\n" + "| tolFun | 1e-06 |\n" + "| maxFunEvals | 10000 |\n" + "| maxIter | 1000 |\n" + "| updateFreq | -1 |\n" + "| updatePlotFreq | -1 |\n" + "+------------------+-----------+" + ) + + assert table == table_str class TestDE: @@ -238,13 +254,15 @@ def test_de_property_setters(self, control_property: str, value: Any) -> None: setattr(self.de, control_property, value) assert getattr(self.de, control_property) == value - @pytest.mark.parametrize("value", [0, 2]) - def test_de_crossoverProbability_error(self, value: int) -> None: + @pytest.mark.parametrize("value, msg", [ + (0, "Input should be greater than 0"), + (2, "Input should be less than 1") + ]) + def test_de_crossoverProbability_error(self, value: int, msg: str) -> None: """Tests the crossoverProbability setter error in DE class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.de, 'crossoverProbability', value) - assert exp.value.errors()[0]['msg'] in ["Input should be greater than 0", - "Input should be less than 1"] + assert exp.value.errors()[0]['msg'] == msg @pytest.mark.parametrize("control_property, value", [ ('targetValue', 0), @@ -268,11 +286,39 @@ def test_de_extra_property_error(self) -> None: setattr(self.de, 'test', 1) assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" - def test_de_procedure_error(self) -> None: - """Tests the procedure property frozen in DE class.""" + def test_de_initialise_procedure_error(self) -> None: + """Tests the procedure property can only be initialised as "de" in DE class.""" + with pytest.raises(pydantic.ValidationError) as exp: + DE(procedure='test') + assert exp.value.errors()[0]['msg'] == "Input should be " + + def test_de_set_procedure_error(self) -> None: + """Tests the procedure property is frozen in DE class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.de, 'procedure', 'test') - assert exp.value.errors()[0]['msg'] == "Field is frozen" + assert exp.value.errors()[0]['msg'] == "Input should be " + + def test_repr(self) -> None: + """Tests the DE model __repr__.""" + table = self.de.__repr__() + table_str = ("+----------------------+-------------------------------------------+\n" + "| Property | Value |\n" + "+----------------------+-------------------------------------------+\n" + "| procedure | de |\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resamPars | [0.9, 50] |\n" + "| display | iter |\n" + "| populationSize | 20 |\n" + "| fWeight | 0.5 |\n" + "| crossoverProbability | 0.8 |\n" + "| strategy | StrategyOptions.RandomWithPerVectorDither |\n" + "| targetValue | 1.0 |\n" + "| numGenerations | 500 |\n" + "+----------------------+-------------------------------------------+" + ) + + assert table == table_str class TestNS: @@ -323,13 +369,15 @@ def test_ns_Nmcmc_nsTolerance_Nlive_error(self, control_property: str, value: Un setattr(self.ns, control_property, value) assert exp.value.errors()[0]['msg'] == f"Input should be greater than or equal to {bound}" - @pytest.mark.parametrize("value", [0, 2]) - def test_ns_propScale_error(self, value: int) -> None: + @pytest.mark.parametrize("value, msg", [ + (0, "Input should be greater than 0"), + (2, "Input should be less than 1") + ]) + def test_ns_propScale_error(self, value: int, msg: str) -> None: """Tests the propScale error in NS class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.ns, 'propScale', value) - assert exp.value.errors()[0]['msg'] in ["Input should be greater than 0", - "Input should be less than 1"] + assert exp.value.errors()[0]['msg'] == msg def test_ns_extra_property_error(self) -> None: """Tests the extra property setter in NS class.""" @@ -337,11 +385,37 @@ def test_ns_extra_property_error(self) -> None: setattr(self.ns, 'test', 1) assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" + def test_ns_initialise_procedure_error(self) -> None: + """Tests the procedure property can only be initialised as "ns" in NS class.""" + with pytest.raises(pydantic.ValidationError) as exp: + NS(procedure='test') + assert exp.value.errors()[0]['msg'] == "Input should be " + def test_ns_procedure_error(self) -> None: - """Tests the procedure property frozen in NS class.""" + """Tests the procedure property is frozen in NS class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.ns, 'procedure', 'test') - assert exp.value.errors()[0]['msg'] == "Field is frozen" + assert exp.value.errors()[0]['msg'] == "Input should be " + + def test_control_class_ns_repr(self) -> None: + """Tests the NS model __repr__.""" + table = self.ns.__repr__() + table_str = ("+------------------+-----------+\n" + "| Property | Value |\n" + "+------------------+-----------+\n" + "| procedure | ns |\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resamPars | [0.9, 50] |\n" + "| display | iter |\n" + "| Nlive | 150 |\n" + "| Nmcmc | 0.0 |\n" + "| propScale | 0.1 |\n" + "| nsTolerance | 0.1 |\n" + "+------------------+-----------+" + ) + + assert table == table_str class TestDream: @@ -383,18 +457,17 @@ def test_dream_property_setters(self, control_property: str, value: Any) -> Non setattr(self.dream, control_property, value) assert getattr(self.dream, control_property) == value - @pytest.mark.parametrize("control_property, value", [ - ('jumpProb', 0), - ('jumpProb', 2), - ('pUnitGamma', -5), - ('pUnitGamma', 20) + @pytest.mark.parametrize("control_property, value, msg", [ + ('jumpProb', 0, "Input should be greater than 0"), + ('jumpProb', 2, "Input should be less than 1"), + ('pUnitGamma', -5, "Input should be greater than 0"), + ('pUnitGamma', 20, "Input should be less than 1") ]) - def test_dream_jumpprob_pUnitGamma_error(self, control_property:str, value: int) -> None: - """Tests the jumpprob pUnitGamma setter errors in Dream class.""" + def test_dream_jumpProb_pUnitGamma_error(self, control_property: str, value: int, msg: str) -> None: + """Tests the jumpProb and pUnitGamma setter errors in Dream class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.dream, control_property, value) - assert exp.value.errors()[0]['msg'] in ["Input should be greater than 0", - "Input should be less than 1"] + assert exp.value.errors()[0]['msg'] == msg @pytest.mark.parametrize("value", [-80, -2]) def test_dream_nSamples_error(self, value: int) -> None: @@ -416,136 +489,29 @@ def test_dream_extra_property_error(self) -> None: setattr(self.dream, 'test', 1) assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" + def test_dream_initialise_procedure_error(self) -> None: + """Tests the procedure property can only be initialised as "dream" in Dream class.""" + with pytest.raises(pydantic.ValidationError) as exp: + Dream(procedure='test') + assert exp.value.errors()[0]['msg'] == "Input should be " + def test_dream_procedure_error(self) -> None: - """Tests the procedure property frozen in Dream class.""" + """Tests the procedure property is frozen in Dream class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.dream, 'procedure', 'test') - assert exp.value.errors()[0]['msg'] == "Field is frozen" - - -class TestControls: - """Tests the Controls class.""" - - @pytest.fixture(autouse=True) - def setup_class(self): - self.controls = Controls() - - def test_controls_class_default_type(self) -> None: - """Tests the procedure is Calculate in Controls.""" - assert type(self.controls.controls).__name__ == "Calculate" - - def test_controls_class_properties(self) -> None: - """Tests the Controls class has control property.""" - assert hasattr(self.controls, 'controls') - - @pytest.mark.parametrize("procedure, name", [ - (Procedures.Calculate, "Calculate"), - (Procedures.Simplex, "Simplex"), - (Procedures.DE, "DE"), - (Procedures.NS, "NS"), - (Procedures.Dream, "Dream") - ]) - def test_controls_class_return_type(self, procedure: Procedures, name: str) -> None: - """Tests the Controls class is set to the correct procedure class.""" - controls = Controls(procedure) - assert type(controls.controls).__name__ == name - - def test_control_class_calculate_repr(self) -> None: - """Tests the __repr__ of Controls with Calculate procedure.""" - controls = Controls() - table = controls.__repr__() - table_str = ("+------------------+-----------+\n" - "| Property | Value |\n" - "+------------------+-----------+\n" - "| parallel | single |\n" - "| calcSldDuringFit | False |\n" - "| resamPars | [0.9, 50] |\n" - "| display | iter |\n" - "| procedure | calculate |\n" - "+------------------+-----------+" - ) - - assert table == table_str - - def test_control_class_simplex_repr(self) -> None: - """Tests the __repr__ of Controls with Simplex procedure.""" - controls = Controls(procedure=Procedures.Simplex) - table = controls.__repr__() - table_str = ("+------------------+-----------+\n" - "| Property | Value |\n" - "+------------------+-----------+\n" - "| parallel | single |\n" - "| calcSldDuringFit | False |\n" - "| resamPars | [0.9, 50] |\n" - "| display | iter |\n" - "| procedure | simplex |\n" - "| tolX | 1e-06 |\n" - "| tolFun | 1e-06 |\n" - "| maxFunEvals | 10000 |\n" - "| maxIter | 1000 |\n" - "| updateFreq | -1 |\n" - "| updatePlotFreq | -1 |\n" - "+------------------+-----------+" - ) - - assert table == table_str - - def test_control_class_de_repr(self) -> None: - """Tests the __repr__ of Controls with DE procedure.""" - controls = Controls(procedure=Procedures.DE) - table = controls.__repr__() - table_str = ("+----------------------+-------------------------------------------+\n" - "| Property | Value |\n" - "+----------------------+-------------------------------------------+\n" - "| parallel | single |\n" - "| calcSldDuringFit | False |\n" - "| resamPars | [0.9, 50] |\n" - "| display | iter |\n" - "| procedure | de |\n" - "| populationSize | 20 |\n" - "| fWeight | 0.5 |\n" - "| crossoverProbability | 0.8 |\n" - "| strategy | StrategyOptions.RandomWithPerVectorDither |\n" - "| targetValue | 1.0 |\n" - "| numGenerations | 500 |\n" - "+----------------------+-------------------------------------------+" - ) - - assert table == table_str - - def test_control_class_ns_repr(self) -> None: - """Tests the __repr__ of Controls with NS procedure.""" - controls = Controls(procedure=Procedures.NS) - table = controls.__repr__() - table_str = ("+------------------+-----------+\n" - "| Property | Value |\n" - "+------------------+-----------+\n" - "| parallel | single |\n" - "| calcSldDuringFit | False |\n" - "| resamPars | [0.9, 50] |\n" - "| display | iter |\n" - "| procedure | ns |\n" - "| Nlive | 150 |\n" - "| Nmcmc | 0.0 |\n" - "| propScale | 0.1 |\n" - "| nsTolerance | 0.1 |\n" - "+------------------+-----------+" - ) - - assert table == table_str + assert exp.value.errors()[0]['msg'] == "Input should be " def test_control_class_dream_repr(self) -> None: - """Tests the __repr__ of Controls with Dream procedure.""" - controls = Controls(procedure=Procedures.Dream) - table = controls.__repr__() + """Tests the Dream model __repr__.""" + table = self.dream.__repr__() table_str = ("+------------------+-----------+\n" "| Property | Value |\n" "+------------------+-----------+\n" + "| procedure | dream |\n" "| parallel | single |\n" "| calcSldDuringFit | False |\n" "| resamPars | [0.9, 50] |\n" - "| display | iter |\n" - "| procedure | dream |\n" + "| display | iter |\n" "| nSamples | 50000 |\n" "| nChains | 10 |\n" "| jumpProb | 0.5 |\n" @@ -555,3 +521,47 @@ def test_control_class_dream_repr(self) -> None: ) assert table == table_str + +@pytest.mark.parametrize(["procedure", "expected_model"], [ + ('calculate', Calculate), + ('simplex', Simplex), + ('de', DE), + ('ns', NS), + ('dream', Dream) +]) +def test_set_controls(procedure: Procedures, expected_model: Union[Calculate, Simplex, DE, NS, Dream]) -> None: + """We should return the correct model given the value of procedure.""" + controls_model = set_controls(procedure) + assert type(controls_model) == expected_model + + +def test_set_controls_default_procedure() -> None: + """We should return the default model when we call "set_controls" without specifying a procedure.""" + controls_model = set_controls() + assert type(controls_model) == Calculate + + +def test_set_controls_invalid_procedure() -> None: + """We should return the default model when we call "set_controls" without specifying a procedure.""" + with pytest.raises(ValueError, match="The controls procedure must be one of: 'calculate', 'simplex', 'de', 'ns' " + "or 'dream'"): + set_controls('invalid') + + +@pytest.mark.parametrize(["procedure", "expected_model"], [ + ('calculate', Calculate), + ('simplex', Simplex), + ('de', DE), + ('ns', NS), + ('dream', Dream) +]) +def test_set_controls_extra_fields(procedure: Procedures, expected_model: Union[Calculate, Simplex, DE, NS, Dream])\ + -> None: + """If we provide extra fields to a controls model through "set_controls", we should print a formatted + ValidationError with a custom error message. + """ + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for {expected_model.__name__}\n' + f'extra_field\n Extra inputs are not permitted. The fields for ' + f'the {procedure} controls procedure are:\n ' + f'{", ".join(expected_model.model_fields.keys())}\n'): + set_controls(procedure, extra_field='invalid') diff --git a/tests/test_custom_errors.py b/tests/test_custom_errors.py index 5ea283b1..6c8bdabb 100644 --- a/tests/test_custom_errors.py +++ b/tests/test_custom_errors.py @@ -1,20 +1,37 @@ """Test the utils.custom_errors module.""" - from pydantic import create_model, ValidationError import pytest +import re import RAT.utils.custom_errors -def test_formatted_pydantic_error() -> None: - """When a pytest ValidationError is raised we should be able to take it and construct a formatted string.""" +@pytest.fixture +def TestModel(): + """Create a custom pydantic model for the tests.""" + TestModel = create_model('TestModel', int_field=(int, 1), str_field=(str, 'a'), __config__={'extra': 'forbid'}) + return TestModel - # Create a custom pydantic model for the test - TestModel = create_model('TestModel', int_field=(int, 1), str_field=(str, 'a')) - with pytest.raises(ValidationError) as exc_info: +@pytest.mark.parametrize(["custom_errors", "expected_error_message"], [ + ({}, + "2 validation errors for TestModel\nint_field\n Input should be a valid integer, unable to parse string as an " + "integer [type=int_parsing, input_value='string', input_type=str]\nstr_field\n Input should be a valid string " + "[type=string_type, input_value=5, input_type=int]"), + ({'int_parsing': 'This is a custom error message', 'string_type': 'This is another custom error message'}, + "2 validation errors for TestModel\nint_field\n This is a custom error message [type=int_parsing, " + "input_value='string', input_type=str]\nstr_field\n This is another custom error message [type=string_type, " + "input_value=5, input_type=int]"), +]) +def test_custom_pydantic_validation_error(TestModel, custom_errors: dict[str, str], expected_error_message: str + ) -> None: + """When we call custom_pydantic_validation_error with custom errors, we should return an error list containing + PydanticCustomErrors, otherwise we return the original set of errors. + """ + try: TestModel(int_field='string', str_field=5) + except ValidationError as exc: + custom_error_list = RAT.utils.custom_errors.custom_pydantic_validation_error(exc.errors(), custom_errors) - error_str = RAT.utils.custom_errors.formatted_pydantic_error(exc_info.value) - assert error_str == ('2 validation errors for TestModel\nint_field\n Input should be a valid integer, unable to ' - 'parse string as an integer\nstr_field\n Input should be a valid string') + with pytest.raises(ValidationError, match=re.escape(expected_error_message)): + raise ValidationError.from_exception_data('TestModel', custom_error_list) diff --git a/tests/test_project.py b/tests/test_project.py index 95c8503f..ea131398 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -1,8 +1,6 @@ """Test the project module.""" -import contextlib import copy -import io import numpy as np import pydantic import os @@ -400,10 +398,11 @@ def test_set_absorption(input_layer: Callable, input_absorption: bool, new_layer def test_check_protected_parameters(delete_operation) -> None: """If we try to remove a protected parameter, we should raise an error.""" project = RAT.project.Project() - with contextlib.redirect_stdout(io.StringIO()) as print_str: + + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, Can\'t delete' + f' the protected parameters: Substrate Roughness'): eval(delete_operation) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, Can\'t delete the ' - f'protected parameters: Substrate Roughness\033[0m\n') + # Ensure model was not deleted assert project.parameters[0].name == 'Substrate Roughness' @@ -740,11 +739,12 @@ def test_wrap_set(test_project, class_list: str, field: str) -> None: test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) - with contextlib.redirect_stdout(io.StringIO()) as print_str: + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"undefined" in the "{field}" field of "{class_list}" must be ' + f'defined in ' + f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".'): test_attribute.set_fields(0, **{field: 'undefined'}) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "undefined" in ' - f'the "{field}" field of "{class_list}" must be defined in ' - f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\033[0m\n') + # Ensure invalid model was not changed assert test_attribute == orig_class_list @@ -764,14 +764,14 @@ def test_wrap_del(test_project, class_list: str, parameter: str, field: str) -> """If we delete a model in a ClassList containing values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) - index = test_attribute.index(parameter) - with contextlib.redirect_stdout(io.StringIO()) as print_str: + + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"{parameter}" in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".'): del test_attribute[index] - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' - f'in the "{field}" field of ' - f'"{RAT.project.model_names_used_in[class_list].attribute}" ' - f'must be defined in "{class_list}".\033[0m\n') + # Ensure model was not deleted assert test_attribute == orig_class_list @@ -803,11 +803,12 @@ def test_wrap_iadd(test_project, class_list: str, field: str) -> None: orig_class_list = copy.deepcopy(test_attribute) input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) - with contextlib.redirect_stdout(io.StringIO()) as print_str: + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"undefined" in the "{field}" field of "{class_list}" must be ' + f'defined in ' + f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".'): test_attribute += [input_model(**{field: 'undefined'})] - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "undefined" in ' - f'the "{field}" field of "{class_list}" must be defined in ' - f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\033[0m\n') + # Ensure invalid model was not added assert test_attribute == orig_class_list @@ -838,11 +839,12 @@ def test_wrap_append(test_project, class_list: str, field: str) -> None: orig_class_list = copy.deepcopy(test_attribute) input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) - with contextlib.redirect_stdout(io.StringIO()) as print_str: + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"undefined" in the "{field}" field of "{class_list}" must be ' + f'defined in ' + f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".'): test_attribute.append(input_model(**{field: 'undefined'})) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "undefined" in ' - f'the "{field}" field of "{class_list}" must be defined in ' - f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\033[0m\n') + # Ensure invalid model was not appended assert test_attribute == orig_class_list @@ -873,11 +875,12 @@ def test_wrap_insert(test_project, class_list: str, field: str) -> None: orig_class_list = copy.deepcopy(test_attribute) input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) - with contextlib.redirect_stdout(io.StringIO()) as print_str: + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"undefined" in the "{field}" field of "{class_list}" must be ' + f'defined in ' + f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".'): test_attribute.insert(0, input_model(**{field: 'undefined'})) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "undefined" in ' - f'the "{field}" field of "{class_list}" must be defined in ' - f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\033[0m\n') + # Ensure invalid model was not inserted assert test_attribute == orig_class_list @@ -931,14 +934,14 @@ def test_wrap_pop(test_project, class_list: str, parameter: str, field: str) -> """If we pop a model in a ClassList containing values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) - index = test_attribute.index(parameter) - with contextlib.redirect_stdout(io.StringIO()) as print_str: + + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"{parameter}" in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".'): test_attribute.pop(index) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' - f'in the "{field}" field of ' - f'"{RAT.project.model_names_used_in[class_list].attribute}" ' - f'must be defined in "{class_list}".\033[0m\n') + # Ensure model was not popped assert test_attribute == orig_class_list @@ -959,12 +962,12 @@ def test_wrap_remove(test_project, class_list: str, parameter: str, field: str) test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) - with contextlib.redirect_stdout(io.StringIO()) as print_str: + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"{parameter}" in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".'): test_attribute.remove(parameter) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' - f'in the "{field}" field of ' - f'"{RAT.project.model_names_used_in[class_list].attribute}" ' - f'must be defined in "{class_list}".\033[0m\n') + # Ensure model was not removed assert test_attribute == orig_class_list @@ -985,12 +988,12 @@ def test_wrap_clear(test_project, class_list: str, parameter: str, field: str) - test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) - with contextlib.redirect_stdout(io.StringIO()) as print_str: + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"{parameter}" in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".'): test_attribute.clear() - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' - f'in the "{field}" field of ' - f'"{RAT.project.model_names_used_in[class_list].attribute}" ' - f'must be defined in "{class_list}".\033[0m\n') + # Ensure list was not cleared assert test_attribute == orig_class_list @@ -1022,10 +1025,11 @@ def test_wrap_extend(test_project, class_list: str, field: str) -> None: orig_class_list = copy.deepcopy(test_attribute) input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) - with contextlib.redirect_stdout(io.StringIO()) as print_str: + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"undefined" in the "{field}" field of "{class_list}" must be ' + f'defined in ' + f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".'): test_attribute.extend([input_model(**{field: 'undefined'})]) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "undefined" in ' - f'the "{field}" field of "{class_list}" must be defined in ' - f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\033[0m\n') + # Ensure invalid model was not appended assert test_attribute == orig_class_list