From 1e24ab36ea62fa1da163dd43cc37f253e15da4a9 Mon Sep 17 00:00:00 2001 From: PaulSharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:20:31 +0100 Subject: [PATCH 1/7] Changes "controlsClass" to factory function "set_controls". --- RAT/__init__.py | 3 +- RAT/controls.py | 58 +++++----- tests/test_controls.py | 234 ++++++++++++++++++----------------------- 3 files changed, 127 insertions(+), 168 deletions(-) diff --git a/RAT/__init__.py b/RAT/__init__.py index 2d319eed..a417d8cd 100644 --- a/RAT/__init__.py +++ b/RAT/__init__.py @@ -1,3 +1,4 @@ from RAT.classlist import ClassList -from RAT.controls import Controls from RAT.project import Project +import RAT.controls +import RAT.models diff --git a/RAT/controls.py b/RAT/controls.py index 3b4a1a56..980cfdf5 100644 --- a/RAT/controls.py +++ b/RAT/controls.py @@ -5,8 +5,9 @@ from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions -class BaseProcedure(BaseModel, validate_assignment=True, extra='forbid'): +class BaseControls(BaseModel, validate_assignment=True, extra='forbid'): """Defines the base class with properties used in all five procedures.""" + procedure: Procedures = Procedures.Calculate parallel: ParallelOptions = ParallelOptions.Single calcSldDuringFit: bool = False resamPars: list[float] = Field([0.9, 50], min_length=2, max_length=2) @@ -21,13 +22,19 @@ def check_resamPars(cls, resamPars): raise ValueError('resamPars[1] must be greater than or equal to 0') return resamPars + def __repr__(self) -> str: + table = prettytable.PrettyTable() + table.field_names = ['Property', 'Value'] + table.add_rows([[k, v] for k, v in self.__dict__.items()]) + return table.get_string() + -class Calculate(BaseProcedure, validate_assignment=True, extra='forbid'): +class Calculate(BaseControls, validate_assignment=True, extra='forbid'): """Defines the class for the calculate procedure.""" procedure: Procedures = Field(Procedures.Calculate, frozen=True) -class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'): +class Simplex(BaseControls, validate_assignment=True, extra='forbid'): """Defines the class for the simplex procedure.""" procedure: Procedures = Field(Procedures.Simplex, frozen=True) tolX: float = Field(1.0e-6, gt=0.0) @@ -38,7 +45,7 @@ class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'): updatePlotFreq: int = -1 -class DE(BaseProcedure, validate_assignment=True, extra='forbid'): +class DE(BaseControls, validate_assignment=True, extra='forbid'): """Defines the class for the Differential Evolution procedure.""" procedure: Procedures = Field(Procedures.DE, frozen=True) populationSize: int = Field(20, ge=1) @@ -49,7 +56,7 @@ class DE(BaseProcedure, validate_assignment=True, extra='forbid'): numGenerations: int = Field(500, ge=1) -class NS(BaseProcedure, validate_assignment=True, extra='forbid'): +class NS(BaseControls, validate_assignment=True, extra='forbid'): """Defines the class for the Nested Sampler procedure.""" procedure: Procedures = Field(Procedures.NS, frozen=True) Nlive: int = Field(150, ge=1) @@ -58,7 +65,7 @@ class NS(BaseProcedure, validate_assignment=True, extra='forbid'): nsTolerance: float = Field(0.1, ge=0.0) -class Dream(BaseProcedure, validate_assignment=True, extra='forbid'): +class Dream(BaseControls, validate_assignment=True, extra='forbid'): """Defines the class for the Dream procedure.""" procedure: Procedures = Field(Procedures.Dream, frozen=True) nSamples: int = Field(50000, ge=0) @@ -68,33 +75,16 @@ class Dream(BaseProcedure, validate_assignment=True, extra='forbid'): boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold -class Controls: - - def __init__(self, - procedure: Procedures = Procedures.Calculate, - **properties) -> None: +def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\ + -> Union[Calculate, Simplex, DE, NS, Dream]: - if procedure == Procedures.Calculate: - self.controls = Calculate(**properties) - elif procedure == Procedures.Simplex: - self.controls = Simplex(**properties) - elif procedure == Procedures.DE: - self.controls = DE(**properties) - elif procedure == Procedures.NS: - self.controls = NS(**properties) - elif procedure == Procedures.Dream: - self.controls = Dream(**properties) + properties.update(procedure=procedure) + controls = { + Procedures.Calculate: Calculate(**properties), + Procedures.Simplex: Simplex(**properties), + Procedures.DE: DE(**properties), + Procedures.NS: NS(**properties), + Procedures.Dream: Dream(**properties) + } - @property - def controls(self) -> Union[Calculate, Simplex, DE, NS, Dream]: - return self._controls - - @controls.setter - def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None: - self._controls = value - - def __repr__(self) -> str: - table = prettytable.PrettyTable() - table.field_names = ['Property', 'Value'] - table.add_rows([[k, v] for k, v in self._controls.__dict__.items()]) - return table.get_string() + return controls[procedure] diff --git a/tests/test_controls.py b/tests/test_controls.py index e1c62efd..7d84f519 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -1,18 +1,19 @@ -"""Tests for control and procedure classes""" +"""Test the controls module.""" import pytest import pydantic from typing import Union, Any -from RAT.controls import BaseProcedure, Calculate, Simplex, DE, NS, Dream, Controls + +from RAT.controls import BaseControls, Calculate, Simplex, DE, NS, Dream from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions -class TestBaseProcedure: - """Tests the BaseProcedure class.""" +class TestBaseControls: + """Tests the BaseControls class.""" @pytest.fixture(autouse=True) def setup_class(self): - self.base_procedure = BaseProcedure() + self.base_controls = BaseControls() @pytest.mark.parametrize("control_property, value", [ ('parallel', ParallelOptions.Single), @@ -22,7 +23,7 @@ def setup_class(self): ]) def test_base_property_values(self, control_property: str, value: Any) -> None: """Tests the default values of BaseProcedure class.""" - assert getattr(self.base_procedure, control_property) == value + assert getattr(self.base_controls, control_property) == value @pytest.mark.parametrize("control_property, value", [ ('parallel', ParallelOptions.All), @@ -32,34 +33,34 @@ def test_base_property_values(self, control_property: str, value: Any) -> None: ]) def test_base_property_setters(self, control_property: str, value: Any) -> None: """Tests the setters in BaseProcedure class.""" - setattr(self.base_procedure, control_property, value) - assert getattr(self.base_procedure, control_property) == value + setattr(self.base_controls, control_property, value) + assert getattr(self.base_controls, control_property) == value @pytest.mark.parametrize("var1, var2", [('test', True), ('ALL', 1), ("Contrast", 3.0)]) def test_base_parallel_validation(self, var1: str, var2: Any) -> None: """Tests the parallel setter validation in BaseProcedure class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'parallel', var1) + setattr(self.base_controls, 'parallel', var1) assert exp.value.errors()[0]['msg'] == "Input should be 'single', 'points', 'contrasts' or 'all'" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'parallel', var2) + setattr(self.base_controls, 'parallel', var2) assert exp.value.errors()[0]['msg'] == "Input should be a valid string" @pytest.mark.parametrize("value", [5.0, 12]) def test_base_calcSldDuringFit_validation(self, value: Union[int, float]) -> None: """Tests the calcSldDuringFit setter validation in BaseProcedure class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'calcSldDuringFit', value) + setattr(self.base_controls, 'calcSldDuringFit', value) assert exp.value.errors()[0]['msg'] == "Input should be a valid boolean, unable to interpret input" @pytest.mark.parametrize("var1, var2", [('test', True), ('iterate', 1), ("FINAL", 3.0)]) def test_base_display_validation(self, var1: str, var2: Any) -> None: """Tests the display setter validation in BaseProcedure class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'display', var1) + setattr(self.base_controls, 'display', var1) assert exp.value.errors()[0]['msg'] == "Input should be 'off', 'iter', 'notify' or 'final'" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'display', var2) + setattr(self.base_controls, 'display', var2) assert exp.value.errors()[0]['msg'] == "Input should be a valid string" @pytest.mark.parametrize("value, msg", [ @@ -69,7 +70,7 @@ def test_base_display_validation(self, var1: str, var2: Any) -> None: def test_base_resamPars_lenght_validation(self, value: list, msg: str) -> None: """Tests the resamPars setter length validation in BaseProcedure class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'resamPars', value) + setattr(self.base_controls, 'resamPars', value) assert exp.value.errors()[0]['msg'] == msg @pytest.mark.parametrize("value, msg", [ @@ -79,13 +80,13 @@ def test_base_resamPars_lenght_validation(self, value: list, msg: str) -> None: def test_base_resamPars_value_validation(self, value: list, msg: str) -> None: """Tests the resamPars setter value validation in BaseProcedure class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'resamPars', value) + setattr(self.base_controls, 'resamPars', value) assert exp.value.errors()[0]['msg'] == msg def test_base_extra_property_error(self) -> None: """Tests the extra property setter in BaseProcedure class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_procedure, 'test', 1) + setattr(self.base_controls, 'test', 1) assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" @@ -130,6 +131,22 @@ def test_calculate_procedure_error(self) -> None: setattr(self.calculate, 'procedure', 'test') assert exp.value.errors()[0]['msg'] == "Field is frozen" + def test_repr(self) -> None: + """Tests the Calculate model __repr__.""" + table = self.calculate.__repr__() + table_str = ("+------------------+-----------+\n" + "| Property | Value |\n" + "+------------------+-----------+\n" + "| procedure | calculate |\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resamPars | [0.9, 50] |\n" + "| display | iter |\n" + "+------------------+-----------+" + ) + + assert table == table_str + class TestSimplex: """Tests the Simplex class.""" @@ -196,6 +213,28 @@ def test_simplex_procedure_error(self) -> None: setattr(self.simplex, 'procedure', 'test') assert exp.value.errors()[0]['msg'] == "Field is frozen" + def test_repr(self) -> None: + """Tests the Simplex model __repr__.""" + table = self.simplex.__repr__() + table_str = ("+------------------+-----------+\n" + "| Property | Value |\n" + "+------------------+-----------+\n" + "| procedure | simplex |\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resamPars | [0.9, 50] |\n" + "| display | iter |\n" + "| tolX | 1e-06 |\n" + "| tolFun | 1e-06 |\n" + "| maxFunEvals | 10000 |\n" + "| maxIter | 1000 |\n" + "| updateFreq | -1 |\n" + "| updatePlotFreq | -1 |\n" + "+------------------+-----------+" + ) + + assert table == table_str + class TestDE: """Tests the DE class.""" @@ -274,6 +313,28 @@ def test_de_procedure_error(self) -> None: setattr(self.de, 'procedure', 'test') assert exp.value.errors()[0]['msg'] == "Field is frozen" + def test_repr(self) -> None: + """Tests the DE model __repr__.""" + table = self.de.__repr__() + table_str = ("+----------------------+-------------------------------------------+\n" + "| Property | Value |\n" + "+----------------------+-------------------------------------------+\n" + "| procedure | de |\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resamPars | [0.9, 50] |\n" + "| display | iter |\n" + "| populationSize | 20 |\n" + "| fWeight | 0.5 |\n" + "| crossoverProbability | 0.8 |\n" + "| strategy | StrategyOptions.RandomWithPerVectorDither |\n" + "| targetValue | 1.0 |\n" + "| numGenerations | 500 |\n" + "+----------------------+-------------------------------------------+" + ) + + assert table == table_str + class TestNS: """Tests the NS class.""" @@ -343,6 +404,26 @@ def test_ns_procedure_error(self) -> None: setattr(self.ns, 'procedure', 'test') assert exp.value.errors()[0]['msg'] == "Field is frozen" + def test_control_class_ns_repr(self) -> None: + """Tests the NS model __repr__.""" + table = self.ns.__repr__() + table_str = ("+------------------+-----------+\n" + "| Property | Value |\n" + "+------------------+-----------+\n" + "| procedure | ns |\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resamPars | [0.9, 50] |\n" + "| display | iter |\n" + "| Nlive | 150 |\n" + "| Nmcmc | 0.0 |\n" + "| propScale | 0.1 |\n" + "| nsTolerance | 0.1 |\n" + "+------------------+-----------+" + ) + + assert table == table_str + class TestDream: """Tests the Dream class.""" @@ -422,130 +503,17 @@ def test_dream_procedure_error(self) -> None: setattr(self.dream, 'procedure', 'test') assert exp.value.errors()[0]['msg'] == "Field is frozen" - -class TestControls: - """Tests the Controls class.""" - - @pytest.fixture(autouse=True) - def setup_class(self): - self.controls = Controls() - - def test_controls_class_default_type(self) -> None: - """Tests the procedure is Calculate in Controls.""" - assert type(self.controls.controls).__name__ == "Calculate" - - def test_controls_class_properties(self) -> None: - """Tests the Controls class has control property.""" - assert hasattr(self.controls, 'controls') - - @pytest.mark.parametrize("procedure, name", [ - (Procedures.Calculate, "Calculate"), - (Procedures.Simplex, "Simplex"), - (Procedures.DE, "DE"), - (Procedures.NS, "NS"), - (Procedures.Dream, "Dream") - ]) - def test_controls_class_return_type(self, procedure: Procedures, name: str) -> None: - """Tests the Controls class is set to the correct procedure class.""" - controls = Controls(procedure) - assert type(controls.controls).__name__ == name - - def test_control_class_calculate_repr(self) -> None: - """Tests the __repr__ of Controls with Calculate procedure.""" - controls = Controls() - table = controls.__repr__() - table_str = ("+------------------+-----------+\n" - "| Property | Value |\n" - "+------------------+-----------+\n" - "| parallel | single |\n" - "| calcSldDuringFit | False |\n" - "| resamPars | [0.9, 50] |\n" - "| display | iter |\n" - "| procedure | calculate |\n" - "+------------------+-----------+" - ) - - assert table == table_str - - def test_control_class_simplex_repr(self) -> None: - """Tests the __repr__ of Controls with Simplex procedure.""" - controls = Controls(procedure=Procedures.Simplex) - table = controls.__repr__() - table_str = ("+------------------+-----------+\n" - "| Property | Value |\n" - "+------------------+-----------+\n" - "| parallel | single |\n" - "| calcSldDuringFit | False |\n" - "| resamPars | [0.9, 50] |\n" - "| display | iter |\n" - "| procedure | simplex |\n" - "| tolX | 1e-06 |\n" - "| tolFun | 1e-06 |\n" - "| maxFunEvals | 10000 |\n" - "| maxIter | 1000 |\n" - "| updateFreq | -1 |\n" - "| updatePlotFreq | -1 |\n" - "+------------------+-----------+" - ) - - assert table == table_str - - def test_control_class_de_repr(self) -> None: - """Tests the __repr__ of Controls with DE procedure.""" - controls = Controls(procedure=Procedures.DE) - table = controls.__repr__() - table_str = ("+----------------------+-------------------------------------------+\n" - "| Property | Value |\n" - "+----------------------+-------------------------------------------+\n" - "| parallel | single |\n" - "| calcSldDuringFit | False |\n" - "| resamPars | [0.9, 50] |\n" - "| display | iter |\n" - "| procedure | de |\n" - "| populationSize | 20 |\n" - "| fWeight | 0.5 |\n" - "| crossoverProbability | 0.8 |\n" - "| strategy | StrategyOptions.RandomWithPerVectorDither |\n" - "| targetValue | 1.0 |\n" - "| numGenerations | 500 |\n" - "+----------------------+-------------------------------------------+" - ) - - assert table == table_str - - def test_control_class_ns_repr(self) -> None: - """Tests the __repr__ of Controls with NS procedure.""" - controls = Controls(procedure=Procedures.NS) - table = controls.__repr__() - table_str = ("+------------------+-----------+\n" - "| Property | Value |\n" - "+------------------+-----------+\n" - "| parallel | single |\n" - "| calcSldDuringFit | False |\n" - "| resamPars | [0.9, 50] |\n" - "| display | iter |\n" - "| procedure | ns |\n" - "| Nlive | 150 |\n" - "| Nmcmc | 0.0 |\n" - "| propScale | 0.1 |\n" - "| nsTolerance | 0.1 |\n" - "+------------------+-----------+" - ) - - assert table == table_str - def test_control_class_dream_repr(self) -> None: - """Tests the __repr__ of Controls with Dream procedure.""" - controls = Controls(procedure=Procedures.Dream) - table = controls.__repr__() + """Tests the Dream model __repr__.""" + table = self.dream.__repr__() table_str = ("+------------------+-----------+\n" "| Property | Value |\n" "+------------------+-----------+\n" + "| procedure | dream |\n" "| parallel | single |\n" "| calcSldDuringFit | False |\n" "| resamPars | [0.9, 50] |\n" - "| display | iter |\n" - "| procedure | dream |\n" + "| display | iter |\n" "| nSamples | 50000 |\n" "| nChains | 10 |\n" "| jumpProb | 0.5 |\n" From b78a6ae52d521a2bc063be00081cb290ecea4258 Mon Sep 17 00:00:00 2001 From: PaulSharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Wed, 25 Oct 2023 14:15:34 +0100 Subject: [PATCH 2/7] Removes "baseControls" class and fixes bug when initialising procedure field --- RAT/controls.py | 39 ++++----- tests/test_controls.py | 188 ++++++++++++++++++++--------------------- 2 files changed, 110 insertions(+), 117 deletions(-) diff --git a/RAT/controls.py b/RAT/controls.py index 980cfdf5..1d191953 100644 --- a/RAT/controls.py +++ b/RAT/controls.py @@ -1,13 +1,13 @@ import prettytable from pydantic import BaseModel, Field, field_validator -from typing import Union +from typing import Literal, Union from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions -class BaseControls(BaseModel, validate_assignment=True, extra='forbid'): - """Defines the base class with properties used in all five procedures.""" - procedure: Procedures = Procedures.Calculate +class Calculate(BaseModel, validate_assignment=True, extra='forbid'): + """Defines the class for the calculate procedure, which includes the properties used in all five procedures.""" + procedure: Literal[Procedures.Calculate] = Procedures.Calculate parallel: ParallelOptions = ParallelOptions.Single calcSldDuringFit: bool = False resamPars: list[float] = Field([0.9, 50], min_length=2, max_length=2) @@ -29,14 +29,9 @@ def __repr__(self) -> str: return table.get_string() -class Calculate(BaseControls, validate_assignment=True, extra='forbid'): - """Defines the class for the calculate procedure.""" - procedure: Procedures = Field(Procedures.Calculate, frozen=True) - - -class Simplex(BaseControls, validate_assignment=True, extra='forbid'): - """Defines the class for the simplex procedure.""" - procedure: Procedures = Field(Procedures.Simplex, frozen=True) +class Simplex(Calculate, validate_assignment=True, extra='forbid'): + """Defines the additional fields for the simplex procedure.""" + procedure: Literal[Procedures.Simplex] = Procedures.Simplex tolX: float = Field(1.0e-6, gt=0.0) tolFun: float = Field(1.0e-6, gt=0.0) maxFunEvals: int = Field(10000, gt=0) @@ -45,9 +40,9 @@ class Simplex(BaseControls, validate_assignment=True, extra='forbid'): updatePlotFreq: int = -1 -class DE(BaseControls, validate_assignment=True, extra='forbid'): - """Defines the class for the Differential Evolution procedure.""" - procedure: Procedures = Field(Procedures.DE, frozen=True) +class DE(Calculate, validate_assignment=True, extra='forbid'): + """Defines the additional fields for the Differential Evolution procedure.""" + procedure: Literal[Procedures.DE] = Procedures.DE populationSize: int = Field(20, ge=1) fWeight: float = 0.5 crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0) @@ -56,18 +51,18 @@ class DE(BaseControls, validate_assignment=True, extra='forbid'): numGenerations: int = Field(500, ge=1) -class NS(BaseControls, validate_assignment=True, extra='forbid'): - """Defines the class for the Nested Sampler procedure.""" - procedure: Procedures = Field(Procedures.NS, frozen=True) +class NS(Calculate, validate_assignment=True, extra='forbid'): + """Defines the additional fields for the Nested Sampler procedure.""" + procedure: Literal[Procedures.NS] = Procedures.NS Nlive: int = Field(150, ge=1) Nmcmc: float = Field(0.0, ge=0.0) propScale: float = Field(0.1, gt=0.0, lt=1.0) nsTolerance: float = Field(0.1, ge=0.0) -class Dream(BaseControls, validate_assignment=True, extra='forbid'): - """Defines the class for the Dream procedure.""" - procedure: Procedures = Field(Procedures.Dream, frozen=True) +class Dream(Calculate, validate_assignment=True, extra='forbid'): + """Defines the additional fields for the Dream procedure.""" + procedure: Literal[Procedures.Dream] = Procedures.Dream nSamples: int = Field(50000, ge=0) nChains: int = Field(10, gt=0) jumpProb: float = Field(0.5, gt=0.0, lt=1.0) @@ -77,7 +72,7 @@ class Dream(BaseControls, validate_assignment=True, extra='forbid'): def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\ -> Union[Calculate, Simplex, DE, NS, Dream]: - + """Returns the appropriate controls model given the specified procedure.""" properties.update(procedure=procedure) controls = { Procedures.Calculate: Calculate(**properties), diff --git a/tests/test_controls.py b/tests/test_controls.py index 7d84f519..c1b3e012 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -4,26 +4,27 @@ import pydantic from typing import Union, Any -from RAT.controls import BaseControls, Calculate, Simplex, DE, NS, Dream +from RAT.controls import Calculate, Simplex, DE, NS, Dream from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions -class TestBaseControls: - """Tests the BaseControls class.""" +class TestCalculate: + """Tests the Calculate class.""" @pytest.fixture(autouse=True) def setup_class(self): - self.base_controls = BaseControls() + self.calculate = Calculate() @pytest.mark.parametrize("control_property, value", [ ('parallel', ParallelOptions.Single), ('calcSldDuringFit', False), ('resamPars', [0.9, 50]), - ('display', DisplayOptions.Iter) + ('display', DisplayOptions.Iter), + ('procedure', Procedures.Calculate) ]) - def test_base_property_values(self, control_property: str, value: Any) -> None: - """Tests the default values of BaseProcedure class.""" - assert getattr(self.base_controls, control_property) == value + def test_calculate_property_values(self, control_property: str, value: Any) -> None: + """Tests the default values of Calculate class.""" + assert getattr(self.calculate, control_property) == value @pytest.mark.parametrize("control_property, value", [ ('parallel', ParallelOptions.All), @@ -31,105 +32,75 @@ def test_base_property_values(self, control_property: str, value: Any) -> None: ('resamPars', [0.2, 1]), ('display', DisplayOptions.Notify) ]) - def test_base_property_setters(self, control_property: str, value: Any) -> None: - """Tests the setters in BaseProcedure class.""" - setattr(self.base_controls, control_property, value) - assert getattr(self.base_controls, control_property) == value + def test_calculate_property_setters(self, control_property: str, value: Any) -> None: + """Tests the setters of Calculate class.""" + setattr(self.calculate, control_property, value) + assert getattr(self.calculate, control_property) == value @pytest.mark.parametrize("var1, var2", [('test', True), ('ALL', 1), ("Contrast", 3.0)]) - def test_base_parallel_validation(self, var1: str, var2: Any) -> None: - """Tests the parallel setter validation in BaseProcedure class.""" + def test_calculate_parallel_validation(self, var1: str, var2: Any) -> None: + """Tests the parallel setter validation in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_controls, 'parallel', var1) + setattr(self.calculate, 'parallel', var1) assert exp.value.errors()[0]['msg'] == "Input should be 'single', 'points', 'contrasts' or 'all'" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_controls, 'parallel', var2) + setattr(self.calculate, 'parallel', var2) assert exp.value.errors()[0]['msg'] == "Input should be a valid string" @pytest.mark.parametrize("value", [5.0, 12]) - def test_base_calcSldDuringFit_validation(self, value: Union[int, float]) -> None: - """Tests the calcSldDuringFit setter validation in BaseProcedure class.""" + def test_calculate_calcSldDuringFit_validation(self, value: Union[int, float]) -> None: + """Tests the calcSldDuringFit setter validation in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_controls, 'calcSldDuringFit', value) + setattr(self.calculate, 'calcSldDuringFit', value) assert exp.value.errors()[0]['msg'] == "Input should be a valid boolean, unable to interpret input" @pytest.mark.parametrize("var1, var2", [('test', True), ('iterate', 1), ("FINAL", 3.0)]) - def test_base_display_validation(self, var1: str, var2: Any) -> None: - """Tests the display setter validation in BaseProcedure class.""" + def test_calculate_display_validation(self, var1: str, var2: Any) -> None: + """Tests the display setter validation in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_controls, 'display', var1) + setattr(self.calculate, 'display', var1) assert exp.value.errors()[0]['msg'] == "Input should be 'off', 'iter', 'notify' or 'final'" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_controls, 'display', var2) + setattr(self.calculate, 'display', var2) assert exp.value.errors()[0]['msg'] == "Input should be a valid string" @pytest.mark.parametrize("value, msg", [ ([5.0], "List should have at least 2 items after validation, not 1"), ([12, 13, 14], "List should have at most 2 items after validation, not 3") ]) - def test_base_resamPars_lenght_validation(self, value: list, msg: str) -> None: - """Tests the resamPars setter length validation in BaseProcedure class.""" + def test_calculate_resamPars_length_validation(self, value: list, msg: str) -> None: + """Tests the resamPars setter length validation in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_controls, 'resamPars', value) + setattr(self.calculate, 'resamPars', value) assert exp.value.errors()[0]['msg'] == msg @pytest.mark.parametrize("value, msg", [ ([1.0, 2], "Value error, resamPars[0] must be between 0 and 1"), ([0.5, -0.1], "Value error, resamPars[1] must be greater than or equal to 0") ]) - def test_base_resamPars_value_validation(self, value: list, msg: str) -> None: - """Tests the resamPars setter value validation in BaseProcedure class.""" + def test_calculate_resamPars_value_validation(self, value: list, msg: str) -> None: + """Tests the resamPars setter value validation in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_controls, 'resamPars', value) + setattr(self.calculate, 'resamPars', value) assert exp.value.errors()[0]['msg'] == msg - def test_base_extra_property_error(self) -> None: - """Tests the extra property setter in BaseProcedure class.""" - with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.base_controls, 'test', 1) - assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" - - -class TestCalculate: - """Tests the Calculate class.""" - - @pytest.fixture(autouse=True) - def setup_class(self): - self.calculate = Calculate() - - @pytest.mark.parametrize("control_property, value", [ - ('parallel', ParallelOptions.Single), - ('calcSldDuringFit', False), - ('resamPars', [0.9, 50]), - ('display', DisplayOptions.Iter), - ('procedure', Procedures.Calculate) - ]) - def test_calculate_property_values(self, control_property: str, value: Any) -> None: - """Tests the default values of Calculate class.""" - assert getattr(self.calculate, control_property) == value - - @pytest.mark.parametrize("control_property, value", [ - ('parallel', ParallelOptions.All), - ('calcSldDuringFit', True), - ('resamPars', [0.2, 1]), - ('display', DisplayOptions.Notify) - ]) - def test_calculate_property_setters(self, control_property: str, value: Any) -> None: - """Tests the setters of Calculate class.""" - setattr(self.calculate, control_property, value) - assert getattr(self.calculate, control_property) == value - def test_calculate_extra_property_error(self) -> None: """Tests the extra property setter in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.calculate, 'test', 1) assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" - def test_calculate_procedure_error(self) -> None: - """Tests the procedure property frozen in Calculate class.""" + def test_calculate_initialise_procedure_error(self) -> None: + """Tests the procedure property can only be initialised as "calculate" in Calculate class.""" + with pytest.raises(pydantic.ValidationError) as exp: + Calculate(procedure='test') + assert exp.value.errors()[0]['msg'] == "Input should be " + + def test_calculate_set_procedure_error(self) -> None: + """Tests the procedure property is frozen in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.calculate, 'procedure', 'test') - assert exp.value.errors()[0]['msg'] == "Field is frozen" + assert exp.value.errors()[0]['msg'] == "Input should be " def test_repr(self) -> None: """Tests the Calculate model __repr__.""" @@ -207,11 +178,17 @@ def test_simplex_extra_property_error(self) -> None: setattr(self.simplex, 'test', 1) assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" - def test_simplex_procedure_error(self) -> None: - """Tests the procedure property frozen in Simplex class.""" + def test_simplex_initialise_procedure_error(self) -> None: + """Tests the procedure property can only be initialised as "simplex" in Simplex class.""" + with pytest.raises(pydantic.ValidationError) as exp: + Simplex(procedure='test') + assert exp.value.errors()[0]['msg'] == "Input should be " + + def test_simplex_set_procedure_error(self) -> None: + """Tests the procedure property is frozen in Simplex class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.simplex, 'procedure', 'test') - assert exp.value.errors()[0]['msg'] == "Field is frozen" + assert exp.value.errors()[0]['msg'] == "Input should be " def test_repr(self) -> None: """Tests the Simplex model __repr__.""" @@ -277,13 +254,15 @@ def test_de_property_setters(self, control_property: str, value: Any) -> None: setattr(self.de, control_property, value) assert getattr(self.de, control_property) == value - @pytest.mark.parametrize("value", [0, 2]) - def test_de_crossoverProbability_error(self, value: int) -> None: + @pytest.mark.parametrize("value, msg", [ + (0, "Input should be greater than 0"), + (2, "Input should be less than 1") + ]) + def test_de_crossoverProbability_error(self, value: int, msg: str) -> None: """Tests the crossoverProbability setter error in DE class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.de, 'crossoverProbability', value) - assert exp.value.errors()[0]['msg'] in ["Input should be greater than 0", - "Input should be less than 1"] + assert exp.value.errors()[0]['msg'] == msg @pytest.mark.parametrize("control_property, value", [ ('targetValue', 0), @@ -307,11 +286,17 @@ def test_de_extra_property_error(self) -> None: setattr(self.de, 'test', 1) assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" - def test_de_procedure_error(self) -> None: - """Tests the procedure property frozen in DE class.""" + def test_de_initialise_procedure_error(self) -> None: + """Tests the procedure property can only be initialised as "de" in DE class.""" + with pytest.raises(pydantic.ValidationError) as exp: + DE(procedure='test') + assert exp.value.errors()[0]['msg'] == "Input should be " + + def test_de_set_procedure_error(self) -> None: + """Tests the procedure property is frozen in DE class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.de, 'procedure', 'test') - assert exp.value.errors()[0]['msg'] == "Field is frozen" + assert exp.value.errors()[0]['msg'] == "Input should be " def test_repr(self) -> None: """Tests the DE model __repr__.""" @@ -384,13 +369,15 @@ def test_ns_Nmcmc_nsTolerance_Nlive_error(self, control_property: str, value: Un setattr(self.ns, control_property, value) assert exp.value.errors()[0]['msg'] == f"Input should be greater than or equal to {bound}" - @pytest.mark.parametrize("value", [0, 2]) - def test_ns_propScale_error(self, value: int) -> None: + @pytest.mark.parametrize("value, msg", [ + (0, "Input should be greater than 0"), + (2, "Input should be less than 1") + ]) + def test_ns_propScale_error(self, value: int, msg: str) -> None: """Tests the propScale error in NS class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.ns, 'propScale', value) - assert exp.value.errors()[0]['msg'] in ["Input should be greater than 0", - "Input should be less than 1"] + assert exp.value.errors()[0]['msg'] == msg def test_ns_extra_property_error(self) -> None: """Tests the extra property setter in NS class.""" @@ -398,11 +385,17 @@ def test_ns_extra_property_error(self) -> None: setattr(self.ns, 'test', 1) assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" + def test_ns_initialise_procedure_error(self) -> None: + """Tests the procedure property can only be initialised as "ns" in NS class.""" + with pytest.raises(pydantic.ValidationError) as exp: + NS(procedure='test') + assert exp.value.errors()[0]['msg'] == "Input should be " + def test_ns_procedure_error(self) -> None: - """Tests the procedure property frozen in NS class.""" + """Tests the procedure property is frozen in NS class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.ns, 'procedure', 'test') - assert exp.value.errors()[0]['msg'] == "Field is frozen" + assert exp.value.errors()[0]['msg'] == "Input should be " def test_control_class_ns_repr(self) -> None: """Tests the NS model __repr__.""" @@ -464,18 +457,17 @@ def test_dream_property_setters(self, control_property: str, value: Any) -> Non setattr(self.dream, control_property, value) assert getattr(self.dream, control_property) == value - @pytest.mark.parametrize("control_property, value", [ - ('jumpProb', 0), - ('jumpProb', 2), - ('pUnitGamma', -5), - ('pUnitGamma', 20) + @pytest.mark.parametrize("control_property, value, msg", [ + ('jumpProb', 0, "Input should be greater than 0"), + ('jumpProb', 2, "Input should be less than 1"), + ('pUnitGamma', -5, "Input should be greater than 0"), + ('pUnitGamma', 20, "Input should be less than 1") ]) - def test_dream_jumpprob_pUnitGamma_error(self, control_property:str, value: int) -> None: - """Tests the jumpprob pUnitGamma setter errors in Dream class.""" + def test_dream_jumpProb_pUnitGamma_error(self, control_property: str, value: int, msg: str) -> None: + """Tests the jumpProb and pUnitGamma setter errors in Dream class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.dream, control_property, value) - assert exp.value.errors()[0]['msg'] in ["Input should be greater than 0", - "Input should be less than 1"] + assert exp.value.errors()[0]['msg'] == msg @pytest.mark.parametrize("value", [-80, -2]) def test_dream_nSamples_error(self, value: int) -> None: @@ -497,11 +489,17 @@ def test_dream_extra_property_error(self) -> None: setattr(self.dream, 'test', 1) assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" + def test_dream_initialise_procedure_error(self) -> None: + """Tests the procedure property can only be initialised as "dream" in Dream class.""" + with pytest.raises(pydantic.ValidationError) as exp: + Dream(procedure='test') + assert exp.value.errors()[0]['msg'] == "Input should be " + def test_dream_procedure_error(self) -> None: - """Tests the procedure property frozen in Dream class.""" + """Tests the procedure property is frozen in Dream class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.dream, 'procedure', 'test') - assert exp.value.errors()[0]['msg'] == "Field is frozen" + assert exp.value.errors()[0]['msg'] == "Input should be " def test_control_class_dream_repr(self) -> None: """Tests the Dream model __repr__.""" From bcd549c62f286c9afd0264adba63d41da5eaefa3 Mon Sep 17 00:00:00 2001 From: PaulSharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Thu, 26 Oct 2023 12:08:28 +0100 Subject: [PATCH 3/7] Adds error check to "set_controls" --- RAT/controls.py | 26 +++++++++++++++++--------- tests/test_controls.py | 27 ++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/RAT/controls.py b/RAT/controls.py index 1d191953..21cd35d2 100644 --- a/RAT/controls.py +++ b/RAT/controls.py @@ -1,5 +1,5 @@ import prettytable -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, ValidationError from typing import Literal, Union from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions @@ -52,7 +52,7 @@ class DE(Calculate, validate_assignment=True, extra='forbid'): class NS(Calculate, validate_assignment=True, extra='forbid'): - """Defines the additional fields for the Nested Sampler procedure.""" + """Defines the additional fields for the Nested Sampler procedure.""" procedure: Literal[Procedures.NS] = Procedures.NS Nlive: int = Field(150, ge=1) Nmcmc: float = Field(0.0, ge=0.0) @@ -73,13 +73,21 @@ class Dream(Calculate, validate_assignment=True, extra='forbid'): def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\ -> Union[Calculate, Simplex, DE, NS, Dream]: """Returns the appropriate controls model given the specified procedure.""" - properties.update(procedure=procedure) controls = { - Procedures.Calculate: Calculate(**properties), - Procedures.Simplex: Simplex(**properties), - Procedures.DE: DE(**properties), - Procedures.NS: NS(**properties), - Procedures.Dream: Dream(**properties) + Procedures.Calculate: Calculate, + Procedures.Simplex: Simplex, + Procedures.DE: DE, + Procedures.NS: NS, + Procedures.Dream: Dream } - return controls[procedure] + try: + model = controls[procedure](**properties) + except KeyError: + members = list(Procedures.__members__.values()) + allowed_values = ', '.join([repr(member.value) for member in members[:-1]]) + f' or {members[-1].value!r}' + raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None + except ValidationError: + raise + + return model diff --git a/tests/test_controls.py b/tests/test_controls.py index c1b3e012..df31c981 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -4,7 +4,7 @@ import pydantic from typing import Union, Any -from RAT.controls import Calculate, Simplex, DE, NS, Dream +from RAT.controls import Calculate, Simplex, DE, NS, Dream, set_controls from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions @@ -521,3 +521,28 @@ def test_control_class_dream_repr(self) -> None: ) assert table == table_str + +@pytest.mark.parametrize(["procedure", "expected_model"], [ + ('calculate', Calculate), + ('simplex', Simplex), + ('de', DE), + ('ns', NS), + ('dream', Dream) +]) +def test_set_controls(procedure: Procedures, expected_model: Union[Calculate, Simplex, DE, NS, Dream]) -> None: + """Make sure we return the correct model given the value of procedure.""" + controls_model = set_controls(procedure) + assert type(controls_model) == expected_model + + +def test_set_controls_default_procedure() -> None: + """Make sure we return the default model when we call "set_controls" without specifying a procedure.""" + controls_model = set_controls() + assert type(controls_model) == Calculate + + +def test_set_controls_invalid_procedure() -> None: + """Make sure we return the default model when we call "set_controls" without specifying a procedure.""" + with pytest.raises(ValueError, match="The controls procedure must be one of: 'calculate', 'simplex', 'de', 'ns' " + "or 'dream'"): + set_controls('invalid') From 1dbe8dcf2c1b31c292008473d5779d2355e94c9c Mon Sep 17 00:00:00 2001 From: PaulSharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Fri, 27 Oct 2023 11:15:47 +0100 Subject: [PATCH 4/7] Adds formatted error for extra fields to "set_controls" --- RAT/controls.py | 11 +++++++++-- RAT/utils/custom_errors.py | 15 ++++++++++++--- tests/test_controls.py | 28 +++++++++++++++++++++++++--- tests/test_custom_errors.py | 24 ++++++++++++++++++++---- 4 files changed, 66 insertions(+), 12 deletions(-) diff --git a/RAT/controls.py b/RAT/controls.py index 21cd35d2..bbc4a103 100644 --- a/RAT/controls.py +++ b/RAT/controls.py @@ -2,6 +2,7 @@ from pydantic import BaseModel, Field, field_validator, ValidationError from typing import Literal, Union +from RAT.utils.custom_errors import formatted_pydantic_error from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions @@ -73,6 +74,7 @@ class Dream(Calculate, validate_assignment=True, extra='forbid'): def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\ -> Union[Calculate, Simplex, DE, NS, Dream]: """Returns the appropriate controls model given the specified procedure.""" + model = None controls = { Procedures.Calculate: Calculate, Procedures.Simplex: Simplex, @@ -87,7 +89,12 @@ def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\ members = list(Procedures.__members__.values()) allowed_values = ', '.join([repr(member.value) for member in members[:-1]]) + f' or {members[-1].value!r}' raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None - except ValidationError: - raise + except ValidationError as exc: + custom_msgs = {'extra_forbidden': f'Extra inputs are not permitted. The fields for the {procedure} controls ' + f'procedure are:\n {", ".join(controls[procedure].model_fields.keys())}' + } + error_string = formatted_pydantic_error(exc, custom_msgs) + # Use ANSI escape sequences to print error text in red + print('\033[31m' + error_string + '\033[0m') return model diff --git a/RAT/utils/custom_errors.py b/RAT/utils/custom_errors.py index 269024b6..76226466 100644 --- a/RAT/utils/custom_errors.py +++ b/RAT/utils/custom_errors.py @@ -3,24 +3,33 @@ from pydantic import ValidationError -def formatted_pydantic_error(error: ValidationError) -> str: +def formatted_pydantic_error(error: ValidationError, custom_error_messages: dict[str, str] = None) -> str: """Write a custom string format for pydantic validation errors. Parameters ---------- error : pydantic.ValidationError - A ValidationError produced by a pydantic model + A ValidationError produced by a pydantic model. + custom_error_messages: dict[str, str], optional + A dict of custom error messages for given error types. Returns ------- error_str : str A string giving details of the ValidationError in a custom format. """ + if custom_error_messages is None: + custom_error_messages = {} num_errors = error.error_count() error_str = f'{num_errors} validation error{"s"[:num_errors!=1]} for {error.title}' + for this_error in error.errors(): + error_type = this_error['type'] + error_msg = custom_error_messages[error_type] if error_type in custom_error_messages else this_error["msg"] + error_str += '\n' if this_error['loc']: error_str += ' '.join(this_error['loc']) + '\n' - error_str += ' ' + this_error['msg'] + error_str += f' {error_msg}' + return error_str diff --git a/tests/test_controls.py b/tests/test_controls.py index df31c981..ee614cc0 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -1,5 +1,7 @@ """Test the controls module.""" +import contextlib +import io import pytest import pydantic from typing import Union, Any @@ -530,19 +532,39 @@ def test_control_class_dream_repr(self) -> None: ('dream', Dream) ]) def test_set_controls(procedure: Procedures, expected_model: Union[Calculate, Simplex, DE, NS, Dream]) -> None: - """Make sure we return the correct model given the value of procedure.""" + """We should return the correct model given the value of procedure.""" controls_model = set_controls(procedure) assert type(controls_model) == expected_model def test_set_controls_default_procedure() -> None: - """Make sure we return the default model when we call "set_controls" without specifying a procedure.""" + """We should return the default model when we call "set_controls" without specifying a procedure.""" controls_model = set_controls() assert type(controls_model) == Calculate def test_set_controls_invalid_procedure() -> None: - """Make sure we return the default model when we call "set_controls" without specifying a procedure.""" + """We should return the default model when we call "set_controls" without specifying a procedure.""" with pytest.raises(ValueError, match="The controls procedure must be one of: 'calculate', 'simplex', 'de', 'ns' " "or 'dream'"): set_controls('invalid') + + +@pytest.mark.parametrize(["procedure", "expected_model"], [ + ('calculate', Calculate), + ('simplex', Simplex), + ('de', DE), + ('ns', NS), + ('dream', Dream) +]) +def test_set_controls_extra_fields(procedure: Procedures, expected_model: Union[Calculate, Simplex, DE, NS, Dream])\ + -> None: + """If we provide extra fields to a controls model through "set_controls", we should print a formatted + ValidationError with a custom error message. + """ + with contextlib.redirect_stdout(io.StringIO()) as print_str: + set_controls(procedure, extra_field='invalid') + + assert print_str.getvalue() == (f'\033[31m1 validation error for {expected_model.__name__}\nextra_field\n Extra ' + f'inputs are not permitted. The fields for the {procedure} controls procedure ' + f'are:\n {", ".join(expected_model.model_fields.keys())}\033[0m\n') diff --git a/tests/test_custom_errors.py b/tests/test_custom_errors.py index 5ea283b1..240e95c3 100644 --- a/tests/test_custom_errors.py +++ b/tests/test_custom_errors.py @@ -6,15 +6,31 @@ import RAT.utils.custom_errors -def test_formatted_pydantic_error() -> None: - """When a pytest ValidationError is raised we should be able to take it and construct a formatted string.""" - - # Create a custom pydantic model for the test +@pytest.fixture +def TestModel(): + """Create a custom pydantic model for the tests.""" TestModel = create_model('TestModel', int_field=(int, 1), str_field=(str, 'a')) + return TestModel + +def test_formatted_pydantic_error(TestModel) -> None: + """When a pytest ValidationError is raised we should be able to take it and construct a formatted string.""" with pytest.raises(ValidationError) as exc_info: TestModel(int_field='string', str_field=5) error_str = RAT.utils.custom_errors.formatted_pydantic_error(exc_info.value) assert error_str == ('2 validation errors for TestModel\nint_field\n Input should be a valid integer, unable to ' 'parse string as an integer\nstr_field\n Input should be a valid string') + + +def test_formatted_pydantic_error_custom_messages(TestModel) -> None: + """When a pytest ValidationError is raised we should be able to take it and construct a formatted string, + including the custom error messages provided.""" + with pytest.raises(ValidationError) as exc_info: + TestModel(int_field='string', str_field=5) + + custom_messages = {'int_parsing': 'This is a custom error message', + 'string_type': 'This is another custom error message'} + error_str = RAT.utils.custom_errors.formatted_pydantic_error(exc_info.value, custom_messages) + assert error_str == ('2 validation errors for TestModel\nint_field\n This is a custom error message\n' + 'str_field\n This is another custom error message') From d25150ae5eb392f7c09a3d12bd5c7a974f58245d Mon Sep 17 00:00:00 2001 From: PaulSharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Tue, 31 Oct 2023 10:49:10 +0000 Subject: [PATCH 5/7] Introduces logging for error reporting --- RAT/controls.py | 8 ++- RAT/project.py | 12 ++-- RAT/utils/custom_errors.py | 9 +++ tests/test_controls.py | 13 ++-- tests/test_custom_errors.py | 37 +++++++++- tests/test_project.py | 134 +++++++++++++++++++----------------- 6 files changed, 130 insertions(+), 83 deletions(-) diff --git a/RAT/controls.py b/RAT/controls.py index bbc4a103..fe587f91 100644 --- a/RAT/controls.py +++ b/RAT/controls.py @@ -1,8 +1,9 @@ +import logging import prettytable from pydantic import BaseModel, Field, field_validator, ValidationError from typing import Literal, Union -from RAT.utils.custom_errors import formatted_pydantic_error +from RAT.utils.custom_errors import formatted_pydantic_error, formatted_traceback from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions @@ -93,8 +94,9 @@ def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\ custom_msgs = {'extra_forbidden': f'Extra inputs are not permitted. The fields for the {procedure} controls ' f'procedure are:\n {", ".join(controls[procedure].model_fields.keys())}' } + traceback_string = formatted_traceback() error_string = formatted_pydantic_error(exc, custom_msgs) - # Use ANSI escape sequences to print error text in red - print('\033[31m' + error_string + '\033[0m') + logger = logging.getLogger(__name__) + logger.error(traceback_string + error_string) return model diff --git a/RAT/project.py b/RAT/project.py index 77827e08..a313b3d6 100644 --- a/RAT/project.py +++ b/RAT/project.py @@ -3,6 +3,7 @@ import collections import copy import functools +import logging import numpy as np import os from pydantic import BaseModel, ValidationInfo, field_validator, model_validator, ValidationError @@ -10,7 +11,7 @@ from RAT.classlist import ClassList import RAT.models -from RAT.utils.custom_errors import formatted_pydantic_error +from RAT.utils.custom_errors import formatted_pydantic_error, formatted_traceback try: from enum import StrEnum @@ -524,11 +525,12 @@ def wrapped_func(*args, **kwargs): try: return_value = func(*args, **kwargs) Project.model_validate(self) - except ValidationError as e: + except ValidationError as exc: setattr(class_list, 'data', previous_state) - error_string = formatted_pydantic_error(e) - # Use ANSI escape sequences to print error text in red - print('\033[31m' + error_string + '\033[0m') + traceback_string = formatted_traceback() + error_string = formatted_pydantic_error(exc) + logger = logging.getLogger(__name__) + logger.error(traceback_string + error_string + '\n') except (TypeError, ValueError): setattr(class_list, 'data', previous_state) raise diff --git a/RAT/utils/custom_errors.py b/RAT/utils/custom_errors.py index 76226466..403d852d 100644 --- a/RAT/utils/custom_errors.py +++ b/RAT/utils/custom_errors.py @@ -1,6 +1,7 @@ """Defines routines for custom error handling in RAT.""" from pydantic import ValidationError +import traceback def formatted_pydantic_error(error: ValidationError, custom_error_messages: dict[str, str] = None) -> str: @@ -33,3 +34,11 @@ def formatted_pydantic_error(error: ValidationError, custom_error_messages: dict error_str += f' {error_msg}' return error_str + + +def formatted_traceback() -> str: + """Takes the traceback obtained from "traceback.format_exc()" and removes the exception message for pydantic + ValidationErrors. + """ + traceback_string = traceback.format_exc() + return traceback_string.split('pydantic_core._pydantic_core.ValidationError:')[0] diff --git a/tests/test_controls.py b/tests/test_controls.py index ee614cc0..5a888a8c 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -1,7 +1,5 @@ """Test the controls module.""" -import contextlib -import io import pytest import pydantic from typing import Union, Any @@ -557,14 +555,13 @@ def test_set_controls_invalid_procedure() -> None: ('ns', NS), ('dream', Dream) ]) -def test_set_controls_extra_fields(procedure: Procedures, expected_model: Union[Calculate, Simplex, DE, NS, Dream])\ +def test_set_controls_extra_fields(caplog, procedure: Procedures, expected_model: Union[Calculate, Simplex, DE, NS, Dream])\ -> None: """If we provide extra fields to a controls model through "set_controls", we should print a formatted ValidationError with a custom error message. """ - with contextlib.redirect_stdout(io.StringIO()) as print_str: - set_controls(procedure, extra_field='invalid') + set_controls(procedure, extra_field='invalid') - assert print_str.getvalue() == (f'\033[31m1 validation error for {expected_model.__name__}\nextra_field\n Extra ' - f'inputs are not permitted. The fields for the {procedure} controls procedure ' - f'are:\n {", ".join(expected_model.model_fields.keys())}\033[0m\n') + assert (f'1 validation error for {expected_model.__name__}\nextra_field\n Extra inputs are not permitted. The ' + f'fields for the {procedure} controls procedure are:\n {", ".join(expected_model.model_fields.keys())}\n' + ) in caplog.text diff --git a/tests/test_custom_errors.py b/tests/test_custom_errors.py index 240e95c3..562d2333 100644 --- a/tests/test_custom_errors.py +++ b/tests/test_custom_errors.py @@ -1,5 +1,5 @@ """Test the utils.custom_errors module.""" - +import pydantic from pydantic import create_model, ValidationError import pytest @@ -9,7 +9,7 @@ @pytest.fixture def TestModel(): """Create a custom pydantic model for the tests.""" - TestModel = create_model('TestModel', int_field=(int, 1), str_field=(str, 'a')) + TestModel = create_model('TestModel', int_field=(int, 1), str_field=(str, 'a'), __config__={'extra': 'forbid'}) return TestModel @@ -34,3 +34,36 @@ def test_formatted_pydantic_error_custom_messages(TestModel) -> None: error_str = RAT.utils.custom_errors.formatted_pydantic_error(exc_info.value, custom_messages) assert error_str == ('2 validation errors for TestModel\nint_field\n This is a custom error message\n' 'str_field\n This is another custom error message') + + +def test_formatted_traceback_type_error(TestModel) -> None: + """The formatted_traceback routine should return the traceback string from "traceback.format_exc()", including the + error message. + """ + error_message = '__init__() takes 1 positional argument but 2 were given' + traceback_string = '' + + try: + TestModel('invalid') + except TypeError: + traceback_string = RAT.utils.custom_errors.formatted_traceback() + + assert 'TypeError' in traceback_string + assert error_message in traceback_string + + +def test_formatted_traceback_validation_error(TestModel) -> None: + """The formatted_traceback routine should return the traceback string from "traceback.format_exc()", with the error + message removed for a pydantic ValidationError. + """ + error_message = (f"pydantic_core._pydantic_core.ValidationError: 1 validation error for {TestModel.__name__}\n" + f"extra_field\n Extra inputs are not permitted [type=extra_forbidden, input_value='invalid'," + f" input_type=str]\n") + traceback_string = error_message + + try: + TestModel(extra_field='invalid') + except pydantic.ValidationError: + traceback_string = RAT.utils.custom_errors.formatted_traceback() + + assert error_message not in traceback_string diff --git a/tests/test_project.py b/tests/test_project.py index 95c8503f..66499547 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -1,8 +1,6 @@ """Test the project module.""" -import contextlib import copy -import io import numpy as np import pydantic import os @@ -397,13 +395,15 @@ def test_set_absorption(input_layer: Callable, input_absorption: bool, new_layer 'project.parameters.remove("Substrate Roughness")', 'project.parameters.clear()', ]) -def test_check_protected_parameters(delete_operation) -> None: +def test_check_protected_parameters(caplog, delete_operation) -> None: """If we try to remove a protected parameter, we should raise an error.""" project = RAT.project.Project() - with contextlib.redirect_stdout(io.StringIO()) as print_str: - eval(delete_operation) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, Can\'t delete the ' - f'protected parameters: Substrate Roughness\033[0m\n') + eval(delete_operation) + + assert (f'1 validation error for Project\n Value error, Can\'t delete the protected parameters: Substrate ' + f'Roughness\n' + ) in caplog.text + # Ensure model was not deleted assert project.parameters[0].name == 'Substrate Roughness' @@ -735,16 +735,17 @@ def test_write_script_wrong_extension(test_project, extension: str) -> None: ('contrasts', 'scalefactor'), ('contrasts', 'resolution'), ]) -def test_wrap_set(test_project, class_list: str, field: str) -> None: +def test_wrap_set(test_project, caplog, class_list: str, field: str) -> None: """If we set the field values of a model in a ClassList as undefined values, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) - with contextlib.redirect_stdout(io.StringIO()) as print_str: - test_attribute.set_fields(0, **{field: 'undefined'}) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "undefined" in ' - f'the "{field}" field of "{class_list}" must be defined in ' - f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\033[0m\n') + test_attribute.set_fields(0, **{field: 'undefined'}) + + assert (f'1 validation error for Project\n Value error, The value "undefined" in the "{field}" field of ' + f'"{class_list}" must be defined in "{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\n' + ) in caplog.text + # Ensure invalid model was not changed assert test_attribute == orig_class_list @@ -760,18 +761,18 @@ def test_wrap_set(test_project, class_list: str, field: str) -> None: ('scalefactors', 'Scalefactor 1', 'scalefactor'), ('resolutions', 'Resolution 1', 'resolution'), ]) -def test_wrap_del(test_project, class_list: str, parameter: str, field: str) -> None: +def test_wrap_del(test_project, caplog, class_list: str, parameter: str, field: str) -> None: """If we delete a model in a ClassList containing values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) index = test_attribute.index(parameter) - with contextlib.redirect_stdout(io.StringIO()) as print_str: - del test_attribute[index] - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' - f'in the "{field}" field of ' - f'"{RAT.project.model_names_used_in[class_list].attribute}" ' - f'must be defined in "{class_list}".\033[0m\n') + del test_attribute[index] + + assert (f'1 validation error for Project\n Value error, The value "{parameter}" in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" must be defined in "{class_list}".\n' + ) in caplog.text + # Ensure model was not deleted assert test_attribute == orig_class_list @@ -797,17 +798,18 @@ def test_wrap_del(test_project, class_list: str, parameter: str, field: str) -> ('contrasts', 'scalefactor'), ('contrasts', 'resolution'), ]) -def test_wrap_iadd(test_project, class_list: str, field: str) -> None: +def test_wrap_iadd(test_project, caplog, class_list: str, field: str) -> None: """If we add a model containing undefined values to a ClassList, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) - with contextlib.redirect_stdout(io.StringIO()) as print_str: - test_attribute += [input_model(**{field: 'undefined'})] - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "undefined" in ' - f'the "{field}" field of "{class_list}" must be defined in ' - f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\033[0m\n') + test_attribute += [input_model(**{field: 'undefined'})] + + assert (f'1 validation error for Project\n Value error, The value "undefined" in the "{field}" field of ' + f'"{class_list}" must be defined in "{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\n' + ) in caplog.text + # Ensure invalid model was not added assert test_attribute == orig_class_list @@ -832,17 +834,18 @@ def test_wrap_iadd(test_project, class_list: str, field: str) -> None: ('contrasts', 'scalefactor'), ('contrasts', 'resolution'), ]) -def test_wrap_append(test_project, class_list: str, field: str) -> None: +def test_wrap_append(test_project, caplog, class_list: str, field: str) -> None: """If we append a model containing undefined values to a ClassList, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) - with contextlib.redirect_stdout(io.StringIO()) as print_str: - test_attribute.append(input_model(**{field: 'undefined'})) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "undefined" in ' - f'the "{field}" field of "{class_list}" must be defined in ' - f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\033[0m\n') + test_attribute.append(input_model(**{field: 'undefined'})) + + assert (f'1 validation error for Project\n Value error, The value "undefined" in the "{field}" field of ' + f'"{class_list}" must be defined in "{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\n' + ) in caplog.text + # Ensure invalid model was not appended assert test_attribute == orig_class_list @@ -867,17 +870,17 @@ def test_wrap_append(test_project, class_list: str, field: str) -> None: ('contrasts', 'scalefactor'), ('contrasts', 'resolution'), ]) -def test_wrap_insert(test_project, class_list: str, field: str) -> None: +def test_wrap_insert(test_project, caplog, class_list: str, field: str) -> None: """If we insert a model containing undefined values into a ClassList, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) - with contextlib.redirect_stdout(io.StringIO()) as print_str: - test_attribute.insert(0, input_model(**{field: 'undefined'})) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "undefined" in ' - f'the "{field}" field of "{class_list}" must be defined in ' - f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\033[0m\n') + test_attribute.insert(0, input_model(**{field: 'undefined'})) + + assert (f'1 validation error for Project\n Value error, The value "undefined" in the "{field}" field of ' + f'"{class_list}" must be defined in "{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\n') + # Ensure invalid model was not inserted assert test_attribute == orig_class_list @@ -927,18 +930,18 @@ def test_wrap_insert_type_error(test_project, class_list: str, field: str) -> No ('scalefactors', 'Scalefactor 1', 'scalefactor'), ('resolutions', 'Resolution 1', 'resolution'), ]) -def test_wrap_pop(test_project, class_list: str, parameter: str, field: str) -> None: +def test_wrap_pop(test_project, caplog, class_list: str, parameter: str, field: str) -> None: """If we pop a model in a ClassList containing values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) index = test_attribute.index(parameter) - with contextlib.redirect_stdout(io.StringIO()) as print_str: - test_attribute.pop(index) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' - f'in the "{field}" field of ' - f'"{RAT.project.model_names_used_in[class_list].attribute}" ' - f'must be defined in "{class_list}".\033[0m\n') + test_attribute.pop(index) + + assert (f'1 validation error for Project\n Value error, The value "{parameter}" in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" must be defined in "{class_list}".\n' + ) in caplog.text + # Ensure model was not popped assert test_attribute == orig_class_list @@ -954,17 +957,17 @@ def test_wrap_pop(test_project, class_list: str, parameter: str, field: str) -> ('scalefactors', 'Scalefactor 1', 'scalefactor'), ('resolutions', 'Resolution 1', 'resolution'), ]) -def test_wrap_remove(test_project, class_list: str, parameter: str, field: str) -> None: +def test_wrap_remove(test_project, caplog, class_list: str, parameter: str, field: str) -> None: """If we remove a model in a ClassList containing values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) - with contextlib.redirect_stdout(io.StringIO()) as print_str: - test_attribute.remove(parameter) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' - f'in the "{field}" field of ' - f'"{RAT.project.model_names_used_in[class_list].attribute}" ' - f'must be defined in "{class_list}".\033[0m\n') + test_attribute.remove(parameter) + + assert (f'1 validation error for Project\n Value error, The value "{parameter}" in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" must be defined in "{class_list}".\n' + ) in caplog.text + # Ensure model was not removed assert test_attribute == orig_class_list @@ -980,17 +983,17 @@ def test_wrap_remove(test_project, class_list: str, parameter: str, field: str) ('scalefactors', 'Scalefactor 1', 'scalefactor'), ('resolutions', 'Resolution 1', 'resolution'), ]) -def test_wrap_clear(test_project, class_list: str, parameter: str, field: str) -> None: +def test_wrap_clear(test_project, caplog, class_list: str, parameter: str, field: str) -> None: """If we clear a ClassList containing models with values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) - with contextlib.redirect_stdout(io.StringIO()) as print_str: - test_attribute.clear() - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' - f'in the "{field}" field of ' - f'"{RAT.project.model_names_used_in[class_list].attribute}" ' - f'must be defined in "{class_list}".\033[0m\n') + test_attribute.clear() + + assert (f'1 validation error for Project\n Value error, The value "{parameter}" in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" must be defined in "{class_list}".\n' + ) in caplog.text + # Ensure list was not cleared assert test_attribute == orig_class_list @@ -1016,16 +1019,17 @@ def test_wrap_clear(test_project, class_list: str, parameter: str, field: str) - ('contrasts', 'scalefactor'), ('contrasts', 'resolution'), ]) -def test_wrap_extend(test_project, class_list: str, field: str) -> None: +def test_wrap_extend(test_project, caplog, class_list: str, field: str) -> None: """If we extend a ClassList with model containing undefined values, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) - with contextlib.redirect_stdout(io.StringIO()) as print_str: - test_attribute.extend([input_model(**{field: 'undefined'})]) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "undefined" in ' - f'the "{field}" field of "{class_list}" must be defined in ' - f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\033[0m\n') + test_attribute.extend([input_model(**{field: 'undefined'})]) + + assert (f'1 validation error for Project\n Value error, The value "undefined" in the "{field}" field of ' + f'"{class_list}" must be defined in "{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\n' + ) in caplog.text + # Ensure invalid model was not appended assert test_attribute == orig_class_list From bcacf296c0c5482429a230e9011da6992edddfe4 Mon Sep 17 00:00:00 2001 From: PaulSharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Wed, 1 Nov 2023 12:17:21 +0000 Subject: [PATCH 6/7] Substitutes logging for raising errors directly --- RAT/controls.py | 15 ++---- RAT/project.py | 8 +-- tests/test_controls.py | 10 ++-- tests/test_project.py | 116 ++++++++++++++++++++--------------------- 4 files changed, 66 insertions(+), 83 deletions(-) diff --git a/RAT/controls.py b/RAT/controls.py index fe587f91..e961de1a 100644 --- a/RAT/controls.py +++ b/RAT/controls.py @@ -1,9 +1,7 @@ -import logging import prettytable from pydantic import BaseModel, Field, field_validator, ValidationError from typing import Literal, Union -from RAT.utils.custom_errors import formatted_pydantic_error, formatted_traceback from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions @@ -75,7 +73,6 @@ class Dream(Calculate, validate_assignment=True, extra='forbid'): def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\ -> Union[Calculate, Simplex, DE, NS, Dream]: """Returns the appropriate controls model given the specified procedure.""" - model = None controls = { Procedures.Calculate: Calculate, Procedures.Simplex: Simplex, @@ -88,15 +85,9 @@ def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\ model = controls[procedure](**properties) except KeyError: members = list(Procedures.__members__.values()) - allowed_values = ', '.join([repr(member.value) for member in members[:-1]]) + f' or {members[-1].value!r}' + allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}' raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None - except ValidationError as exc: - custom_msgs = {'extra_forbidden': f'Extra inputs are not permitted. The fields for the {procedure} controls ' - f'procedure are:\n {", ".join(controls[procedure].model_fields.keys())}' - } - traceback_string = formatted_traceback() - error_string = formatted_pydantic_error(exc, custom_msgs) - logger = logging.getLogger(__name__) - logger.error(traceback_string + error_string) + except ValidationError: + raise return model diff --git a/RAT/project.py b/RAT/project.py index a313b3d6..2bc7212c 100644 --- a/RAT/project.py +++ b/RAT/project.py @@ -525,13 +525,7 @@ def wrapped_func(*args, **kwargs): try: return_value = func(*args, **kwargs) Project.model_validate(self) - except ValidationError as exc: - setattr(class_list, 'data', previous_state) - traceback_string = formatted_traceback() - error_string = formatted_pydantic_error(exc) - logger = logging.getLogger(__name__) - logger.error(traceback_string + error_string + '\n') - except (TypeError, ValueError): + except (TypeError, ValueError, ValidationError): setattr(class_list, 'data', previous_state) raise finally: diff --git a/tests/test_controls.py b/tests/test_controls.py index 5a888a8c..a9061447 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -555,13 +555,11 @@ def test_set_controls_invalid_procedure() -> None: ('ns', NS), ('dream', Dream) ]) -def test_set_controls_extra_fields(caplog, procedure: Procedures, expected_model: Union[Calculate, Simplex, DE, NS, Dream])\ +def test_set_controls_extra_fields(procedure: Procedures, expected_model: Union[Calculate, Simplex, DE, NS, Dream])\ -> None: """If we provide extra fields to a controls model through "set_controls", we should print a formatted ValidationError with a custom error message. """ - set_controls(procedure, extra_field='invalid') - - assert (f'1 validation error for {expected_model.__name__}\nextra_field\n Extra inputs are not permitted. The ' - f'fields for the {procedure} controls procedure are:\n {", ".join(expected_model.model_fields.keys())}\n' - ) in caplog.text + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for {expected_model.__name__}\n' + f'extra_field\n Extra inputs are not permitted'): + set_controls(procedure, extra_field='invalid') diff --git a/tests/test_project.py b/tests/test_project.py index 66499547..ea131398 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -395,14 +395,13 @@ def test_set_absorption(input_layer: Callable, input_absorption: bool, new_layer 'project.parameters.remove("Substrate Roughness")', 'project.parameters.clear()', ]) -def test_check_protected_parameters(caplog, delete_operation) -> None: +def test_check_protected_parameters(delete_operation) -> None: """If we try to remove a protected parameter, we should raise an error.""" project = RAT.project.Project() - eval(delete_operation) - assert (f'1 validation error for Project\n Value error, Can\'t delete the protected parameters: Substrate ' - f'Roughness\n' - ) in caplog.text + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, Can\'t delete' + f' the protected parameters: Substrate Roughness'): + eval(delete_operation) # Ensure model was not deleted assert project.parameters[0].name == 'Substrate Roughness' @@ -735,16 +734,16 @@ def test_write_script_wrong_extension(test_project, extension: str) -> None: ('contrasts', 'scalefactor'), ('contrasts', 'resolution'), ]) -def test_wrap_set(test_project, caplog, class_list: str, field: str) -> None: +def test_wrap_set(test_project, class_list: str, field: str) -> None: """If we set the field values of a model in a ClassList as undefined values, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) - test_attribute.set_fields(0, **{field: 'undefined'}) - - assert (f'1 validation error for Project\n Value error, The value "undefined" in the "{field}" field of ' - f'"{class_list}" must be defined in "{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\n' - ) in caplog.text + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"undefined" in the "{field}" field of "{class_list}" must be ' + f'defined in ' + f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".'): + test_attribute.set_fields(0, **{field: 'undefined'}) # Ensure invalid model was not changed assert test_attribute == orig_class_list @@ -761,17 +760,17 @@ def test_wrap_set(test_project, caplog, class_list: str, field: str) -> None: ('scalefactors', 'Scalefactor 1', 'scalefactor'), ('resolutions', 'Resolution 1', 'resolution'), ]) -def test_wrap_del(test_project, caplog, class_list: str, parameter: str, field: str) -> None: +def test_wrap_del(test_project, class_list: str, parameter: str, field: str) -> None: """If we delete a model in a ClassList containing values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) - index = test_attribute.index(parameter) - del test_attribute[index] - assert (f'1 validation error for Project\n Value error, The value "{parameter}" in the "{field}" field of ' - f'"{RAT.project.model_names_used_in[class_list].attribute}" must be defined in "{class_list}".\n' - ) in caplog.text + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"{parameter}" in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".'): + del test_attribute[index] # Ensure model was not deleted assert test_attribute == orig_class_list @@ -798,17 +797,17 @@ def test_wrap_del(test_project, caplog, class_list: str, parameter: str, field: ('contrasts', 'scalefactor'), ('contrasts', 'resolution'), ]) -def test_wrap_iadd(test_project, caplog, class_list: str, field: str) -> None: +def test_wrap_iadd(test_project, class_list: str, field: str) -> None: """If we add a model containing undefined values to a ClassList, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) - test_attribute += [input_model(**{field: 'undefined'})] - - assert (f'1 validation error for Project\n Value error, The value "undefined" in the "{field}" field of ' - f'"{class_list}" must be defined in "{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\n' - ) in caplog.text + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"undefined" in the "{field}" field of "{class_list}" must be ' + f'defined in ' + f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".'): + test_attribute += [input_model(**{field: 'undefined'})] # Ensure invalid model was not added assert test_attribute == orig_class_list @@ -834,17 +833,17 @@ def test_wrap_iadd(test_project, caplog, class_list: str, field: str) -> None: ('contrasts', 'scalefactor'), ('contrasts', 'resolution'), ]) -def test_wrap_append(test_project, caplog, class_list: str, field: str) -> None: +def test_wrap_append(test_project, class_list: str, field: str) -> None: """If we append a model containing undefined values to a ClassList, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) - test_attribute.append(input_model(**{field: 'undefined'})) - - assert (f'1 validation error for Project\n Value error, The value "undefined" in the "{field}" field of ' - f'"{class_list}" must be defined in "{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\n' - ) in caplog.text + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"undefined" in the "{field}" field of "{class_list}" must be ' + f'defined in ' + f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".'): + test_attribute.append(input_model(**{field: 'undefined'})) # Ensure invalid model was not appended assert test_attribute == orig_class_list @@ -870,16 +869,17 @@ def test_wrap_append(test_project, caplog, class_list: str, field: str) -> None: ('contrasts', 'scalefactor'), ('contrasts', 'resolution'), ]) -def test_wrap_insert(test_project, caplog, class_list: str, field: str) -> None: +def test_wrap_insert(test_project, class_list: str, field: str) -> None: """If we insert a model containing undefined values into a ClassList, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) - test_attribute.insert(0, input_model(**{field: 'undefined'})) - - assert (f'1 validation error for Project\n Value error, The value "undefined" in the "{field}" field of ' - f'"{class_list}" must be defined in "{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\n') + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"undefined" in the "{field}" field of "{class_list}" must be ' + f'defined in ' + f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".'): + test_attribute.insert(0, input_model(**{field: 'undefined'})) # Ensure invalid model was not inserted assert test_attribute == orig_class_list @@ -930,17 +930,17 @@ def test_wrap_insert_type_error(test_project, class_list: str, field: str) -> No ('scalefactors', 'Scalefactor 1', 'scalefactor'), ('resolutions', 'Resolution 1', 'resolution'), ]) -def test_wrap_pop(test_project, caplog, class_list: str, parameter: str, field: str) -> None: +def test_wrap_pop(test_project, class_list: str, parameter: str, field: str) -> None: """If we pop a model in a ClassList containing values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) - index = test_attribute.index(parameter) - test_attribute.pop(index) - assert (f'1 validation error for Project\n Value error, The value "{parameter}" in the "{field}" field of ' - f'"{RAT.project.model_names_used_in[class_list].attribute}" must be defined in "{class_list}".\n' - ) in caplog.text + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"{parameter}" in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".'): + test_attribute.pop(index) # Ensure model was not popped assert test_attribute == orig_class_list @@ -957,16 +957,16 @@ def test_wrap_pop(test_project, caplog, class_list: str, parameter: str, field: ('scalefactors', 'Scalefactor 1', 'scalefactor'), ('resolutions', 'Resolution 1', 'resolution'), ]) -def test_wrap_remove(test_project, caplog, class_list: str, parameter: str, field: str) -> None: +def test_wrap_remove(test_project, class_list: str, parameter: str, field: str) -> None: """If we remove a model in a ClassList containing values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) - test_attribute.remove(parameter) - - assert (f'1 validation error for Project\n Value error, The value "{parameter}" in the "{field}" field of ' - f'"{RAT.project.model_names_used_in[class_list].attribute}" must be defined in "{class_list}".\n' - ) in caplog.text + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"{parameter}" in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".'): + test_attribute.remove(parameter) # Ensure model was not removed assert test_attribute == orig_class_list @@ -983,16 +983,16 @@ def test_wrap_remove(test_project, caplog, class_list: str, parameter: str, fiel ('scalefactors', 'Scalefactor 1', 'scalefactor'), ('resolutions', 'Resolution 1', 'resolution'), ]) -def test_wrap_clear(test_project, caplog, class_list: str, parameter: str, field: str) -> None: +def test_wrap_clear(test_project, class_list: str, parameter: str, field: str) -> None: """If we clear a ClassList containing models with values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) - test_attribute.clear() - - assert (f'1 validation error for Project\n Value error, The value "{parameter}" in the "{field}" field of ' - f'"{RAT.project.model_names_used_in[class_list].attribute}" must be defined in "{class_list}".\n' - ) in caplog.text + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"{parameter}" in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".'): + test_attribute.clear() # Ensure list was not cleared assert test_attribute == orig_class_list @@ -1019,17 +1019,17 @@ def test_wrap_clear(test_project, caplog, class_list: str, parameter: str, field ('contrasts', 'scalefactor'), ('contrasts', 'resolution'), ]) -def test_wrap_extend(test_project, caplog, class_list: str, field: str) -> None: +def test_wrap_extend(test_project, class_list: str, field: str) -> None: """If we extend a ClassList with model containing undefined values, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) - test_attribute.extend([input_model(**{field: 'undefined'})]) - - assert (f'1 validation error for Project\n Value error, The value "undefined" in the "{field}" field of ' - f'"{class_list}" must be defined in "{RAT.project.values_defined_in[f"{class_list}.{field}"]}".\n' - ) in caplog.text + with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' + f'"undefined" in the "{field}" field of "{class_list}" must be ' + f'defined in ' + f'"{RAT.project.values_defined_in[f"{class_list}.{field}"]}".'): + test_attribute.extend([input_model(**{field: 'undefined'})]) # Ensure invalid model was not appended assert test_attribute == orig_class_list From ffcf677f625ae6065cb853284837a4a29c9307d2 Mon Sep 17 00:00:00 2001 From: PaulSharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Wed, 1 Nov 2023 16:32:46 +0000 Subject: [PATCH 7/7] Adds routine "custom_pydantic_validation_error" to introduce custom error messages when raising a ValidationError --- RAT/controls.py | 10 ++++-- RAT/project.py | 9 +++-- RAT/utils/custom_errors.py | 56 +++++++++++++---------------- tests/test_controls.py | 4 ++- tests/test_custom_errors.py | 72 +++++++++++-------------------------- 5 files changed, 61 insertions(+), 90 deletions(-) diff --git a/RAT/controls.py b/RAT/controls.py index e961de1a..2c16cf34 100644 --- a/RAT/controls.py +++ b/RAT/controls.py @@ -3,6 +3,7 @@ from typing import Literal, Union from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions +from RAT.utils.custom_errors import custom_pydantic_validation_error class Calculate(BaseModel, validate_assignment=True, extra='forbid'): @@ -87,7 +88,12 @@ def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\ members = list(Procedures.__members__.values()) allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}' raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None - except ValidationError: - raise + except ValidationError as exc: + custom_error_msgs = {'extra_forbidden': f'Extra inputs are not permitted. The fields for the {procedure}' + f' controls procedure are:\n ' + f'{", ".join(controls[procedure].model_fields.keys())}\n' + } + custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs) + raise ValidationError.from_exception_data(exc.title, custom_error_list) from None return model diff --git a/RAT/project.py b/RAT/project.py index 2bc7212c..661bff0a 100644 --- a/RAT/project.py +++ b/RAT/project.py @@ -3,7 +3,6 @@ import collections import copy import functools -import logging import numpy as np import os from pydantic import BaseModel, ValidationInfo, field_validator, model_validator, ValidationError @@ -11,7 +10,7 @@ from RAT.classlist import ClassList import RAT.models -from RAT.utils.custom_errors import formatted_pydantic_error, formatted_traceback +from RAT.utils.custom_errors import custom_pydantic_validation_error try: from enum import StrEnum @@ -525,7 +524,11 @@ def wrapped_func(*args, **kwargs): try: return_value = func(*args, **kwargs) Project.model_validate(self) - except (TypeError, ValueError, ValidationError): + except ValidationError as exc: + setattr(class_list, 'data', previous_state) + custom_error_list = custom_pydantic_validation_error(exc.errors()) + raise ValidationError.from_exception_data(exc.title, custom_error_list) from None + except (TypeError, ValueError): setattr(class_list, 'data', previous_state) raise finally: diff --git a/RAT/utils/custom_errors.py b/RAT/utils/custom_errors.py index 403d852d..2fd7d211 100644 --- a/RAT/utils/custom_errors.py +++ b/RAT/utils/custom_errors.py @@ -1,44 +1,36 @@ """Defines routines for custom error handling in RAT.""" +import pydantic_core -from pydantic import ValidationError -import traceback +def custom_pydantic_validation_error(error_list: list[pydantic_core.ErrorDetails], custom_errors: dict[str, str] = None + ) -> list[pydantic_core.ErrorDetails]: + """Run through the list of errors generated from a pydantic ValidationError, substituting the standard error for a + PydanticCustomError for a given set of error types. -def formatted_pydantic_error(error: ValidationError, custom_error_messages: dict[str, str] = None) -> str: - """Write a custom string format for pydantic validation errors. + For errors that do not have a custom error message defined, we redefine them using a PydanticCustomError to remove + the url from the error message. Parameters ---------- - error : pydantic.ValidationError - A ValidationError produced by a pydantic model. - custom_error_messages: dict[str, str], optional + error_list : list[pydantic_core.ErrorDetails] + A list of errors produced by pydantic.ValidationError.errors(). + custom_errors: dict[str, str], optional A dict of custom error messages for given error types. Returns ------- - error_str : str - A string giving details of the ValidationError in a custom format. + new_error : list[pydantic_core.ErrorDetails] + A list of errors including PydanticCustomErrors in place of the error types in custom_errors. """ - if custom_error_messages is None: - custom_error_messages = {} - num_errors = error.error_count() - error_str = f'{num_errors} validation error{"s"[:num_errors!=1]} for {error.title}' - - for this_error in error.errors(): - error_type = this_error['type'] - error_msg = custom_error_messages[error_type] if error_type in custom_error_messages else this_error["msg"] - - error_str += '\n' - if this_error['loc']: - error_str += ' '.join(this_error['loc']) + '\n' - error_str += f' {error_msg}' - - return error_str - - -def formatted_traceback() -> str: - """Takes the traceback obtained from "traceback.format_exc()" and removes the exception message for pydantic - ValidationErrors. - """ - traceback_string = traceback.format_exc() - return traceback_string.split('pydantic_core._pydantic_core.ValidationError:')[0] + if custom_errors is None: + custom_errors = {} + custom_error_list = [] + for error in error_list: + if error['type'] in custom_errors: + RAT_custom_error = pydantic_core.PydanticCustomError(error['type'], custom_errors[error['type']]) + else: + RAT_custom_error = pydantic_core.PydanticCustomError(error['type'], error['msg']) + error['type'] = RAT_custom_error + custom_error_list.append(error) + + return custom_error_list diff --git a/tests/test_controls.py b/tests/test_controls.py index a9061447..fd128bbc 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -561,5 +561,7 @@ def test_set_controls_extra_fields(procedure: Procedures, expected_model: Union[ ValidationError with a custom error message. """ with pytest.raises(pydantic.ValidationError, match=f'1 validation error for {expected_model.__name__}\n' - f'extra_field\n Extra inputs are not permitted'): + f'extra_field\n Extra inputs are not permitted. The fields for ' + f'the {procedure} controls procedure are:\n ' + f'{", ".join(expected_model.model_fields.keys())}\n'): set_controls(procedure, extra_field='invalid') diff --git a/tests/test_custom_errors.py b/tests/test_custom_errors.py index 562d2333..6c8bdabb 100644 --- a/tests/test_custom_errors.py +++ b/tests/test_custom_errors.py @@ -1,7 +1,7 @@ """Test the utils.custom_errors module.""" -import pydantic from pydantic import create_model, ValidationError import pytest +import re import RAT.utils.custom_errors @@ -13,57 +13,25 @@ def TestModel(): return TestModel -def test_formatted_pydantic_error(TestModel) -> None: - """When a pytest ValidationError is raised we should be able to take it and construct a formatted string.""" - with pytest.raises(ValidationError) as exc_info: - TestModel(int_field='string', str_field=5) - - error_str = RAT.utils.custom_errors.formatted_pydantic_error(exc_info.value) - assert error_str == ('2 validation errors for TestModel\nint_field\n Input should be a valid integer, unable to ' - 'parse string as an integer\nstr_field\n Input should be a valid string') - - -def test_formatted_pydantic_error_custom_messages(TestModel) -> None: - """When a pytest ValidationError is raised we should be able to take it and construct a formatted string, - including the custom error messages provided.""" - with pytest.raises(ValidationError) as exc_info: - TestModel(int_field='string', str_field=5) - - custom_messages = {'int_parsing': 'This is a custom error message', - 'string_type': 'This is another custom error message'} - error_str = RAT.utils.custom_errors.formatted_pydantic_error(exc_info.value, custom_messages) - assert error_str == ('2 validation errors for TestModel\nint_field\n This is a custom error message\n' - 'str_field\n This is another custom error message') - - -def test_formatted_traceback_type_error(TestModel) -> None: - """The formatted_traceback routine should return the traceback string from "traceback.format_exc()", including the - error message. +@pytest.mark.parametrize(["custom_errors", "expected_error_message"], [ + ({}, + "2 validation errors for TestModel\nint_field\n Input should be a valid integer, unable to parse string as an " + "integer [type=int_parsing, input_value='string', input_type=str]\nstr_field\n Input should be a valid string " + "[type=string_type, input_value=5, input_type=int]"), + ({'int_parsing': 'This is a custom error message', 'string_type': 'This is another custom error message'}, + "2 validation errors for TestModel\nint_field\n This is a custom error message [type=int_parsing, " + "input_value='string', input_type=str]\nstr_field\n This is another custom error message [type=string_type, " + "input_value=5, input_type=int]"), +]) +def test_custom_pydantic_validation_error(TestModel, custom_errors: dict[str, str], expected_error_message: str + ) -> None: + """When we call custom_pydantic_validation_error with custom errors, we should return an error list containing + PydanticCustomErrors, otherwise we return the original set of errors. """ - error_message = '__init__() takes 1 positional argument but 2 were given' - traceback_string = '' - - try: - TestModel('invalid') - except TypeError: - traceback_string = RAT.utils.custom_errors.formatted_traceback() - - assert 'TypeError' in traceback_string - assert error_message in traceback_string - - -def test_formatted_traceback_validation_error(TestModel) -> None: - """The formatted_traceback routine should return the traceback string from "traceback.format_exc()", with the error - message removed for a pydantic ValidationError. - """ - error_message = (f"pydantic_core._pydantic_core.ValidationError: 1 validation error for {TestModel.__name__}\n" - f"extra_field\n Extra inputs are not permitted [type=extra_forbidden, input_value='invalid'," - f" input_type=str]\n") - traceback_string = error_message - try: - TestModel(extra_field='invalid') - except pydantic.ValidationError: - traceback_string = RAT.utils.custom_errors.formatted_traceback() + TestModel(int_field='string', str_field=5) + except ValidationError as exc: + custom_error_list = RAT.utils.custom_errors.custom_pydantic_validation_error(exc.errors(), custom_errors) - assert error_message not in traceback_string + with pytest.raises(ValidationError, match=re.escape(expected_error_message)): + raise ValidationError.from_exception_data('TestModel', custom_error_list)