From de3b0e3c90d0dae1e6eaccc5274ae609d7765a03 Mon Sep 17 00:00:00 2001 From: PaulSharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Tue, 24 Oct 2023 12:01:13 +0100 Subject: [PATCH] Updates tables in Controls --- RAT/__init__.py | 1 + RAT/controls.py | 104 ++++--- tests/test_classlist.py | 20 +- tests/test_controls.py | 580 +++++++++++++++++++++------------------- 4 files changed, 375 insertions(+), 330 deletions(-) diff --git a/RAT/__init__.py b/RAT/__init__.py index 29563a23..2d319eed 100644 --- a/RAT/__init__.py +++ b/RAT/__init__.py @@ -1,2 +1,3 @@ from RAT.classlist import ClassList +from RAT.controls import Controls from RAT.project import Project diff --git a/RAT/controls.py b/RAT/controls.py index dd3613a2..3b4a1a56 100644 --- a/RAT/controls.py +++ b/RAT/controls.py @@ -1,19 +1,19 @@ -import tabulate -from typing import Union +import prettytable from pydantic import BaseModel, Field, field_validator +from typing import Union + from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions -class BaseProcedure(BaseModel, validate_assignment = True, extra = 'forbid'): - """ - Defines the base class with properties used in all five procedures. - """ +class BaseProcedure(BaseModel, validate_assignment=True, extra='forbid'): + """Defines the base class with properties used in all five procedures.""" parallel: ParallelOptions = ParallelOptions.Single calcSldDuringFit: bool = False - resamPars: list[float] = Field([0.9, 50], min_length = 2, max_length = 2) + resamPars: list[float] = Field([0.9, 50], min_length=2, max_length=2) display: DisplayOptions = DisplayOptions.Iter @field_validator("resamPars") + @classmethod def check_resamPars(cls, resamPars): if not 0 < resamPars[0] < 1: raise ValueError('resamPars[0] must be between 0 and 1') @@ -22,63 +22,53 @@ def check_resamPars(cls, resamPars): return resamPars -class Calculate(BaseProcedure, validate_assignment = True, extra = 'forbid'): - """ - Defines the class for the calculate procedure. - """ - procedure: Procedures = Field(Procedures.Calculate, frozen = True) +class Calculate(BaseProcedure, validate_assignment=True, extra='forbid'): + """Defines the class for the calculate procedure.""" + procedure: Procedures = Field(Procedures.Calculate, frozen=True) -class Simplex(BaseProcedure, validate_assignment = True, extra = 'forbid'): - """ - Defines the class for the simplex procedure. - """ - procedure: Procedures = Field(Procedures.Simplex, frozen = True) - tolX: float = Field(1e-6, gt = 0) - tolFun: float = Field(1e-6, gt = 0) - maxFunEvals: int = Field(10000, gt = 0) - maxIter: int = Field(1000, gt = 0) +class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'): + """Defines the class for the simplex procedure.""" + procedure: Procedures = Field(Procedures.Simplex, frozen=True) + tolX: float = Field(1.0e-6, gt=0.0) + tolFun: float = Field(1.0e-6, gt=0.0) + maxFunEvals: int = Field(10000, gt=0) + maxIter: int = Field(1000, gt=0) updateFreq: int = -1 updatePlotFreq: int = -1 -class DE(BaseProcedure, validate_assignment = True, extra = 'forbid'): - """ - Defines the class for the Differential Evolution procedure. - """ - procedure: Procedures = Field(Procedures.DE, frozen = True) - populationSize: int = Field(20, ge = 1) +class DE(BaseProcedure, validate_assignment=True, extra='forbid'): + """Defines the class for the Differential Evolution procedure.""" + procedure: Procedures = Field(Procedures.DE, frozen=True) + populationSize: int = Field(20, ge=1) fWeight: float = 0.5 - crossoverProbability: float = Field(0.8, gt = 0, lt = 1) + crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0) strategy: StrategyOptions = StrategyOptions.RandomWithPerVectorDither - targetValue: float = Field(1.0, ge = 1) - numGenerations: int = Field(500, ge = 1) - - -class NS(BaseProcedure, validate_assignment = True, extra = 'forbid'): - """ - Defines the class for the Nested Sampler procedure. - """ - procedure: Procedures = Field(Procedures.NS, frozen = True) - Nlive: int = Field(150, ge = 1) - Nmcmc: float = Field(0.0, ge = 0) - propScale: float = Field(0.1, gt = 0, lt = 1) - nsTolerance: float = Field(0.1, ge = 0) - - -class Dream(BaseProcedure, validate_assignment = True, extra = 'forbid'): - """ - Defines the class for the Dream procedure - """ - procedure: Procedures = Field(Procedures.Dream, frozen = True) - nSamples: int = Field(50000, ge = 0) - nChains: int = Field(10, gt = 0) - jumpProb: float = Field(0.5, gt = 0, lt = 1) - pUnitGamma:float = Field(0.2, gt = 0, lt = 1) + targetValue: float = Field(1.0, ge=1.0) + numGenerations: int = Field(500, ge=1) + + +class NS(BaseProcedure, validate_assignment=True, extra='forbid'): + """Defines the class for the Nested Sampler procedure.""" + procedure: Procedures = Field(Procedures.NS, frozen=True) + Nlive: int = Field(150, ge=1) + Nmcmc: float = Field(0.0, ge=0.0) + propScale: float = Field(0.1, gt=0.0, lt=1.0) + nsTolerance: float = Field(0.1, ge=0.0) + + +class Dream(BaseProcedure, validate_assignment=True, extra='forbid'): + """Defines the class for the Dream procedure.""" + procedure: Procedures = Field(Procedures.Dream, frozen=True) + nSamples: int = Field(50000, ge=0) + nChains: int = Field(10, gt=0) + jumpProb: float = Field(0.5, gt=0.0, lt=1.0) + pUnitGamma: float = Field(0.2, gt=0.0, lt=1.0) boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold -class ControlsClass: +class Controls: def __init__(self, procedure: Procedures = Procedures.Calculate, @@ -104,7 +94,7 @@ def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None: self._controls = value def __repr__(self) -> str: - properties = [["Property", "Value"]] +\ - [[k, v] for k, v in self._controls.__dict__.items()] - table = tabulate.tabulate(properties, headers="firstrow") - return table + table = prettytable.PrettyTable() + table.field_names = ['Property', 'Value'] + table.add_rows([[k, v] for k, v in self._controls.__dict__.items()]) + return table.get_string() diff --git a/tests/test_classlist.py b/tests/test_classlist.py index e5f09da0..b7d62d31 100644 --- a/tests/test_classlist.py +++ b/tests/test_classlist.py @@ -22,6 +22,19 @@ def two_name_class_list(): return ClassList([InputAttributes(name='Alice'), InputAttributes(name='Bob')]) +@pytest.fixture +def two_name_class_list_table(): + """The table representation of the ClassList defined in the "two_name_class_list" fixture.""" + return( + '+-------+-------+\n' + '| index | name |\n' + '+-------+-------+\n' + '| 0 | Alice |\n' + '| 1 | Bob |\n' + '+-------+-------+' + ) + + @pytest.fixture def three_name_class_list(): """A ClassList of InputAttributes, containing three elements with names defined.""" @@ -104,12 +117,9 @@ def test_identical_name_fields(self, input_list: Sequence[object], name_field: s ClassList(input_list, name_field=name_field) -@pytest.mark.parametrize("expected_string", [ - ' name\n-- ------\n 0 Alice\n 1 Bob', -]) -def test_repr_table(two_name_class_list: 'ClassList', expected_string: str) -> None: +def test_repr_table(two_name_class_list: 'ClassList', two_name_class_list_table: str) -> None: """For classes with the __dict__ attribute, we should be able to print the ClassList like a table.""" - assert repr(two_name_class_list) == expected_string + assert repr(two_name_class_list) == two_name_class_list_table def test_repr_empty_table() -> None: diff --git a/tests/test_controls.py b/tests/test_controls.py index f3cabd71..e1c62efd 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -3,35 +3,37 @@ import pytest import pydantic from typing import Union, Any -from RAT.controls import BaseProcedure, Calculate, Simplex, DE, NS, Dream, ControlsClass +from RAT.controls import BaseProcedure, Calculate, Simplex, DE, NS, Dream, Controls from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions class TestBaseProcedure: - """ - Tests the BaseProcedure class. - """ + """Tests the BaseProcedure class.""" @pytest.fixture(autouse=True) def setup_class(self): self.base_procedure = BaseProcedure() - @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.Single), - ('calcSldDuringFit', False), - ('resamPars', [0.9, 50]), - ('display', DisplayOptions.Iter)]) - def test_base_property_values(self, property: str, value: Any) -> None: + @pytest.mark.parametrize("control_property, value", [ + ('parallel', ParallelOptions.Single), + ('calcSldDuringFit', False), + ('resamPars', [0.9, 50]), + ('display', DisplayOptions.Iter) + ]) + def test_base_property_values(self, control_property: str, value: Any) -> None: """Tests the default values of BaseProcedure class.""" - assert getattr(self.base_procedure, property) == value - - @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.All), - ('calcSldDuringFit', True), - ('resamPars', [0.2, 1]), - ('display', DisplayOptions.Notify)]) - def test_base_property_setters(self, property: str, value: Any) -> None: + assert getattr(self.base_procedure, control_property) == value + + @pytest.mark.parametrize("control_property, value", [ + ('parallel', ParallelOptions.All), + ('calcSldDuringFit', True), + ('resamPars', [0.2, 1]), + ('display', DisplayOptions.Notify) + ]) + def test_base_property_setters(self, control_property: str, value: Any) -> None: """Tests the setters in BaseProcedure class.""" - setattr(self.base_procedure, property, value) - assert getattr(self.base_procedure, property) == value + setattr(self.base_procedure, control_property, value) + assert getattr(self.base_procedure, control_property) == value @pytest.mark.parametrize("var1, var2", [('test', True), ('ALL', 1), ("Contrast", 3.0)]) def test_base_parallel_validation(self, var1: str, var2: Any) -> None: @@ -60,16 +62,20 @@ def test_base_display_validation(self, var1: str, var2: Any) -> None: setattr(self.base_procedure, 'display', var2) assert exp.value.errors()[0]['msg'] == "Input should be a valid string" - @pytest.mark.parametrize("value, msg", [([5.0], "List should have at least 2 items after validation, not 1"), - ([12, 13, 14], "List should have at most 2 items after validation, not 3")]) + @pytest.mark.parametrize("value, msg", [ + ([5.0], "List should have at least 2 items after validation, not 1"), + ([12, 13, 14], "List should have at most 2 items after validation, not 3") + ]) def test_base_resamPars_lenght_validation(self, value: list, msg: str) -> None: - """Tests the resamPars setter lenght validation in BaseProcedure class.""" + """Tests the resamPars setter length validation in BaseProcedure class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.base_procedure, 'resamPars', value) assert exp.value.errors()[0]['msg'] == msg - @pytest.mark.parametrize("value, msg", [([1.0, 2], "Value error, resamPars[0] must be between 0 and 1"), - ([0.5, -0.1], "Value error, resamPars[1] must be greater than or equal to 0")]) + @pytest.mark.parametrize("value, msg", [ + ([1.0, 2], "Value error, resamPars[0] must be between 0 and 1"), + ([0.5, -0.1], "Value error, resamPars[1] must be greater than or equal to 0") + ]) def test_base_resamPars_value_validation(self, value: list, msg: str) -> None: """Tests the resamPars setter value validation in BaseProcedure class.""" with pytest.raises(pydantic.ValidationError) as exp: @@ -84,37 +90,39 @@ def test_base_extra_property_error(self) -> None: class TestCalculate: - """ - Tests the Calculate class. - """ + """Tests the Calculate class.""" @pytest.fixture(autouse=True) def setup_class(self): self.calculate = Calculate() - @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.Single), - ('calcSldDuringFit', False), - ('resamPars', [0.9, 50]), - ('display', DisplayOptions.Iter), - ('procedure', Procedures.Calculate)]) - def test_calculate_property_values(self, property: str, value: Any) -> None: + @pytest.mark.parametrize("control_property, value", [ + ('parallel', ParallelOptions.Single), + ('calcSldDuringFit', False), + ('resamPars', [0.9, 50]), + ('display', DisplayOptions.Iter), + ('procedure', Procedures.Calculate) + ]) + def test_calculate_property_values(self, control_property: str, value: Any) -> None: """Tests the default values of Calculate class.""" - assert getattr(self.calculate, property) == value - - @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.All), - ('calcSldDuringFit', True), - ('resamPars', [0.2, 1]), - ('display', DisplayOptions.Notify)]) - def test_calculate_property_setters(self, property: str, value: Any) -> None: + assert getattr(self.calculate, control_property) == value + + @pytest.mark.parametrize("control_property, value", [ + ('parallel', ParallelOptions.All), + ('calcSldDuringFit', True), + ('resamPars', [0.2, 1]), + ('display', DisplayOptions.Notify) + ]) + def test_calculate_property_setters(self, control_property: str, value: Any) -> None: """Tests the setters of Calculate class.""" - setattr(self.calculate, property, value) - assert getattr(self.calculate, property) == value + setattr(self.calculate, control_property, value) + assert getattr(self.calculate, control_property) == value def test_calculate_extra_property_error(self) -> None: """Tests the extra property setter in Calculate class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.calculate, 'test', 1) - assert exp.value.errors()[0]['msg'] == ("Object has no attribute 'test'") + assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" def test_calculate_procedure_error(self) -> None: """Tests the procedure property frozen in Calculate class.""" @@ -124,59 +132,63 @@ def test_calculate_procedure_error(self) -> None: class TestSimplex: - """ - Tests the Simplex class. - """ + """Tests the Simplex class.""" @pytest.fixture(autouse=True) def setup_class(self): self.simplex = Simplex() - @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.Single), - ('calcSldDuringFit', False), - ('resamPars', [0.9, 50]), - ('display', DisplayOptions.Iter), - ('procedure', Procedures.Simplex), - ('tolX', 1e-6), - ('tolFun', 1e-6), - ('maxFunEvals', 10000), - ('maxIter', 1000), - ('updateFreq', -1), - ('updatePlotFreq', -1)]) - def test_simplex_property_values(self, property: str, value: Any) -> None: + @pytest.mark.parametrize("control_property, value", [ + ('parallel', ParallelOptions.Single), + ('calcSldDuringFit', False), + ('resamPars', [0.9, 50]), + ('display', DisplayOptions.Iter), + ('procedure', Procedures.Simplex), + ('tolX', 1e-6), + ('tolFun', 1e-6), + ('maxFunEvals', 10000), + ('maxIter', 1000), + ('updateFreq', -1), + ('updatePlotFreq', -1) + ]) + def test_simplex_property_values(self, control_property: str, value: Any) -> None: """Tests the default values of Simplex class.""" - assert getattr(self.simplex, property) == value - - @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.All), - ('calcSldDuringFit', True), - ('resamPars', [0.2, 1]), - ('display', DisplayOptions.Notify), - ('tolX', 4e-6), - ('tolFun', 3e-4), - ('maxFunEvals', 100), - ('maxIter', 50), - ('updateFreq', 4), - ('updatePlotFreq', 3)]) - def test_simplex_property_setters(self, property: str, value: Any) -> None: + assert getattr(self.simplex, control_property) == value + + @pytest.mark.parametrize("control_property, value", [ + ('parallel', ParallelOptions.All), + ('calcSldDuringFit', True), + ('resamPars', [0.2, 1]), + ('display', DisplayOptions.Notify), + ('tolX', 4e-6), + ('tolFun', 3e-4), + ('maxFunEvals', 100), + ('maxIter', 50), + ('updateFreq', 4), + ('updatePlotFreq', 3) + ]) + def test_simplex_property_setters(self, control_property: str, value: Any) -> None: """Tests the setters of Simplex class.""" - setattr(self.simplex, property, value) - assert getattr(self.simplex, property) == value - - @pytest.mark.parametrize("property, value", [('tolX', -4e-6), - ('tolFun', -3e-4), - ('maxFunEvals', -100), - ('maxIter', -50)]) - def test_simplex_property_errors(self, property: str, value: Union[float, int]) -> None: + setattr(self.simplex, control_property, value) + assert getattr(self.simplex, control_property) == value + + @pytest.mark.parametrize("control_property, value", [ + ('tolX', -4e-6), + ('tolFun', -3e-4), + ('maxFunEvals', -100), + ('maxIter', -50) + ]) + def test_simplex_property_errors(self, control_property: str, value: Union[float, int]) -> None: """Tests the property errors of Simplex class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.simplex, property, value) + setattr(self.simplex, control_property, value) assert exp.value.errors()[0]['msg'] == "Input should be greater than 0" def test_simplex_extra_property_error(self) -> None: """Tests the extra property setter in Simplex class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.simplex, 'test', 1) - assert exp.value.errors()[0]['msg'] == ("Object has no attribute 'test'") + assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" def test_simplex_procedure_error(self) -> None: """Tests the procedure property frozen in Simplex class.""" @@ -186,43 +198,45 @@ def test_simplex_procedure_error(self) -> None: class TestDE: - """ - Tests the DE class. - """ + """Tests the DE class.""" @pytest.fixture(autouse=True) def setup_class(self): self.de = DE() - @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.Single), - ('calcSldDuringFit', False), - ('resamPars', [0.9, 50]), - ('display', DisplayOptions.Iter), - ('procedure', Procedures.DE), - ('populationSize', 20), - ('fWeight', 0.5), - ('crossoverProbability', 0.8), - ('strategy', StrategyOptions.RandomWithPerVectorDither), - ('targetValue', 1), - ('numGenerations', 500)]) - def test_de_property_values(self, property: str, value: Any) -> None: + @pytest.mark.parametrize("control_property, value", [ + ('parallel', ParallelOptions.Single), + ('calcSldDuringFit', False), + ('resamPars', [0.9, 50]), + ('display', DisplayOptions.Iter), + ('procedure', Procedures.DE), + ('populationSize', 20), + ('fWeight', 0.5), + ('crossoverProbability', 0.8), + ('strategy', StrategyOptions.RandomWithPerVectorDither), + ('targetValue', 1), + ('numGenerations', 500) + ]) + def test_de_property_values(self, control_property: str, value: Any) -> None: """Tests the default values of DE class.""" - assert getattr(self.de, property) == value - - @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.All), - ('calcSldDuringFit', True), - ('resamPars', [0.2, 1]), - ('display', DisplayOptions.Notify), - ('populationSize', 20), - ('fWeight', 0.3), - ('crossoverProbability', 0.4), - ('strategy', StrategyOptions.BestWithJitter), - ('targetValue', 2.0), - ('numGenerations', 50)]) - def test_de_property_setters(self, property: str, value: Any) -> None: + assert getattr(self.de, control_property) == value + + @pytest.mark.parametrize("control_property, value", [ + ('parallel', ParallelOptions.All), + ('calcSldDuringFit', True), + ('resamPars', [0.2, 1]), + ('display', DisplayOptions.Notify), + ('populationSize', 20), + ('fWeight', 0.3), + ('crossoverProbability', 0.4), + ('strategy', StrategyOptions.BestWithJitter), + ('targetValue', 2.0), + ('numGenerations', 50) + ]) + def test_de_property_setters(self, control_property: str, value: Any) -> None: """Tests the setters of DE class.""" - setattr(self.de, property, value) - assert getattr(self.de, property) == value + setattr(self.de, control_property, value) + assert getattr(self.de, control_property) == value @pytest.mark.parametrize("value", [0, 2]) def test_de_crossoverProbability_error(self, value: int) -> None: @@ -232,25 +246,27 @@ def test_de_crossoverProbability_error(self, value: int) -> None: assert exp.value.errors()[0]['msg'] in ["Input should be greater than 0", "Input should be less than 1"] - @pytest.mark.parametrize("property, value", [('targetValue', 0), - ('targetValue',0.999), - ('numGenerations', -500), - ('numGenerations', 0), - ('populationSize', 0), - ('populationSize', -1)]) + @pytest.mark.parametrize("control_property, value", [ + ('targetValue', 0), + ('targetValue', 0.999), + ('numGenerations', -500), + ('numGenerations', 0), + ('populationSize', 0), + ('populationSize', -1) + ]) def test_de_targetValue_numGenerations_populationSize_error(self, - property: str, + control_property: str, value: Union[int, float]) -> None: """Tests the targetValue, numGenerations, populationSize setter error in DE class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.de, property, value) + setattr(self.de, control_property, value) assert exp.value.errors()[0]['msg'] == "Input should be greater than or equal to 1" def test_de_extra_property_error(self) -> None: """Tests the extra property setter in DE class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.de, 'test', 1) - assert exp.value.errors()[0]['msg'] == ("Object has no attribute 'test'") + assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" def test_de_procedure_error(self) -> None: """Tests the procedure property frozen in DE class.""" @@ -260,47 +276,51 @@ def test_de_procedure_error(self) -> None: class TestNS: - """ - Tests the NS class. - """ + """Tests the NS class.""" @pytest.fixture(autouse=True) def setup_class(self): self.ns = NS() - @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.Single), - ('calcSldDuringFit', False), - ('resamPars', [0.9, 50]), - ('display', DisplayOptions.Iter), - ('procedure', Procedures.NS), - ('Nlive', 150), - ('Nmcmc', 0), - ('propScale', 0.1), - ('nsTolerance', 0.1)]) - def test_ns_property_values(self, property: str, value: Any) -> None: + @pytest.mark.parametrize("control_property, value", [ + ('parallel', ParallelOptions.Single), + ('calcSldDuringFit', False), + ('resamPars', [0.9, 50]), + ('display', DisplayOptions.Iter), + ('procedure', Procedures.NS), + ('Nlive', 150), + ('Nmcmc', 0), + ('propScale', 0.1), + ('nsTolerance', 0.1) + ]) + def test_ns_property_values(self, control_property: str, value: Any) -> None: """Tests the default values of NS class.""" - assert getattr(self.ns, property) == value - - @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.All), - ('calcSldDuringFit', True), - ('resamPars', [0.2, 1]), - ('display', DisplayOptions.Notify), - ('Nlive', 1500), - ('Nmcmc', 1), - ('propScale', 0.5), - ('nsTolerance', 0.8)]) - def test_ns_property_setters(self, property: str, value: Any) -> None: + assert getattr(self.ns, control_property) == value + + @pytest.mark.parametrize("control_property, value", [ + ('parallel', ParallelOptions.All), + ('calcSldDuringFit', True), + ('resamPars', [0.2, 1]), + ('display', DisplayOptions.Notify), + ('Nlive', 1500), + ('Nmcmc', 1), + ('propScale', 0.5), + ('nsTolerance', 0.8) + ]) + def test_ns_property_setters(self, control_property: str, value: Any) -> None: """Tests the setters of NS class.""" - setattr(self.ns, property, value) - assert getattr(self.ns, property) == value - - @pytest.mark.parametrize("property, value, bound", [('Nmcmc', -0.6, 0), - ('nsTolerance', -500, 0), - ('Nlive', -500, 1)]) - def test_ns_Nmcmc_nsTolerance_Nlive_error(self, property: str, value: Union[int, float], bound: int) -> None: + setattr(self.ns, control_property, value) + assert getattr(self.ns, control_property) == value + + @pytest.mark.parametrize("control_property, value, bound", [ + ('Nmcmc', -0.6, 0), + ('nsTolerance', -500, 0), + ('Nlive', -500, 1) + ]) + def test_ns_Nmcmc_nsTolerance_Nlive_error(self, control_property: str, value: Union[int, float], bound: int) -> None: """Tests the Nmcmc, nsTolerance, Nlive setter error in NS class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.ns, property, value) + setattr(self.ns, control_property, value) assert exp.value.errors()[0]['msg'] == f"Input should be greater than or equal to {bound}" @pytest.mark.parametrize("value", [0, 2]) @@ -315,7 +335,7 @@ def test_ns_extra_property_error(self) -> None: """Tests the extra property setter in NS class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.ns, 'test', 1) - assert exp.value.errors()[0]['msg'] == ("Object has no attribute 'test'") + assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" def test_ns_procedure_error(self) -> None: """Tests the procedure property frozen in NS class.""" @@ -325,50 +345,54 @@ def test_ns_procedure_error(self) -> None: class TestDream: - """ - Tests the Dream class. - """ + """Tests the Dream class.""" @pytest.fixture(autouse=True) def setup_class(self): self.dream = Dream() - @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.Single), - ('calcSldDuringFit', False), - ('resamPars', [0.9, 50]), - ('display', DisplayOptions.Iter), - ('procedure', Procedures.Dream), - ('nSamples', 50000), - ('nChains', 10), - ('jumpProb', 0.5), - ('pUnitGamma', 0.2), - ('boundHandling', BoundHandlingOptions.Fold)]) - def test_dream_property_values(self, property: str, value: Any) -> None: + @pytest.mark.parametrize("control_property, value", [ + ('parallel', ParallelOptions.Single), + ('calcSldDuringFit', False), + ('resamPars', [0.9, 50]), + ('display', DisplayOptions.Iter), + ('procedure', Procedures.Dream), + ('nSamples', 50000), + ('nChains', 10), + ('jumpProb', 0.5), + ('pUnitGamma', 0.2), + ('boundHandling', BoundHandlingOptions.Fold) + ]) + def test_dream_property_values(self, control_property: str, value: Any) -> None: """Tests the default values of Dream class.""" - assert getattr(self.dream, property) == value - - @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.All), - ('calcSldDuringFit', True), - ('resamPars', [0.2, 1]), - ('display', DisplayOptions.Notify), - ('nSamples', 500), - ('nChains', 1000), - ('jumpProb', 0.7), - ('pUnitGamma', 0.3), - ('boundHandling', BoundHandlingOptions.Reflect)]) - def test_dream_property_setters(self, property: str, value: Any) -> None: + assert getattr(self.dream, control_property) == value + + @pytest.mark.parametrize("control_property, value", [ + ('parallel', ParallelOptions.All), + ('calcSldDuringFit', True), + ('resamPars', [0.2, 1]), + ('display', DisplayOptions.Notify), + ('nSamples', 500), + ('nChains', 1000), + ('jumpProb', 0.7), + ('pUnitGamma', 0.3), + ('boundHandling', BoundHandlingOptions.Reflect) + ]) + def test_dream_property_setters(self, control_property: str, value: Any) -> None: """Tests the setters in Dream class.""" - setattr(self.dream, property, value) - assert getattr(self.dream, property) == value - - @pytest.mark.parametrize("property, value", [('jumpProb',0), - ('jumpProb', 2), - ('pUnitGamma',-5), - ('pUnitGamma', 20)]) - def test_dream_jumpprob_pUnitGamma_error(self, property:str, value: int) -> None: + setattr(self.dream, control_property, value) + assert getattr(self.dream, control_property) == value + + @pytest.mark.parametrize("control_property, value", [ + ('jumpProb', 0), + ('jumpProb', 2), + ('pUnitGamma', -5), + ('pUnitGamma', 20) + ]) + def test_dream_jumpprob_pUnitGamma_error(self, control_property:str, value: int) -> None: """Tests the jumpprob pUnitGamma setter errors in Dream class.""" with pytest.raises(pydantic.ValidationError) as exp: - setattr(self.dream, property, value) + setattr(self.dream, control_property, value) assert exp.value.errors()[0]['msg'] in ["Input should be greater than 0", "Input should be less than 1"] @@ -390,7 +414,7 @@ def test_dream_extra_property_error(self) -> None: """Tests the extra property setter in Dream class.""" with pytest.raises(pydantic.ValidationError) as exp: setattr(self.dream, 'test', 1) - assert exp.value.errors()[0]['msg'] == ("Object has no attribute 'test'") + assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" def test_dream_procedure_error(self) -> None: """Tests the procedure property frozen in Dream class.""" @@ -399,115 +423,135 @@ def test_dream_procedure_error(self) -> None: assert exp.value.errors()[0]['msg'] == "Field is frozen" -class TestControlsClass: - """ - Tests the Controls class. - """ +class TestControls: + """Tests the Controls class.""" @pytest.fixture(autouse=True) def setup_class(self): - self.controls = ControlsClass() + self.controls = Controls() def test_controls_class_default_type(self) -> None: - """Tests the procedure is Calculate in ControlsClass.""" + """Tests the procedure is Calculate in Controls.""" assert type(self.controls.controls).__name__ == "Calculate" def test_controls_class_properties(self) -> None: - """Tests the ControlsClass has control property.""" + """Tests the Controls class has control property.""" assert hasattr(self.controls, 'controls') - @pytest.mark.parametrize("procedure, name", [(Procedures.Calculate, "Calculate"), - (Procedures.Simplex, "Simplex"), - (Procedures.DE, "DE"), - (Procedures.NS, "NS"), - (Procedures.Dream, "Dream")]) + @pytest.mark.parametrize("procedure, name", [ + (Procedures.Calculate, "Calculate"), + (Procedures.Simplex, "Simplex"), + (Procedures.DE, "DE"), + (Procedures.NS, "NS"), + (Procedures.Dream, "Dream") + ]) def test_controls_class_return_type(self, procedure: Procedures, name: str) -> None: - """Tests the ControlsClass is set to the correct procedure class.""" - controls = ControlsClass(procedure) + """Tests the Controls class is set to the correct procedure class.""" + controls = Controls(procedure) assert type(controls.controls).__name__ == name def test_control_class_calculate_repr(self) -> None: - """Tests the __repr__ of ControlsClass with Calculate procedure.""" - controls = ControlsClass() + """Tests the __repr__ of Controls with Calculate procedure.""" + controls = Controls() table = controls.__repr__() - table_str = ("Property Value\n" - "---------------- ---------\n" - "parallel single\n" - "calcSldDuringFit False\n" - "resamPars [0.9, 50]\n" - "display iter\n" - "procedure calculate") + table_str = ("+------------------+-----------+\n" + "| Property | Value |\n" + "+------------------+-----------+\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resamPars | [0.9, 50] |\n" + "| display | iter |\n" + "| procedure | calculate |\n" + "+------------------+-----------+" + ) + assert table == table_str def test_control_class_simplex_repr(self) -> None: - """Tests the __repr__ of ControlsClass with Simplex procedure.""" - controls = ControlsClass(procedure=Procedures.Simplex) + """Tests the __repr__ of Controls with Simplex procedure.""" + controls = Controls(procedure=Procedures.Simplex) table = controls.__repr__() - table_str = ("Property Value\n" - "---------------- ---------\n" - "parallel single\n" - "calcSldDuringFit False\n" - "resamPars [0.9, 50]\n" - "display iter\n" - "procedure simplex\n" - "tolX 1e-06\n" - "tolFun 1e-06\n" - "maxFunEvals 10000\n" - "maxIter 1000\n" - "updateFreq -1\n" - "updatePlotFreq -1") + table_str = ("+------------------+-----------+\n" + "| Property | Value |\n" + "+------------------+-----------+\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resamPars | [0.9, 50] |\n" + "| display | iter |\n" + "| procedure | simplex |\n" + "| tolX | 1e-06 |\n" + "| tolFun | 1e-06 |\n" + "| maxFunEvals | 10000 |\n" + "| maxIter | 1000 |\n" + "| updateFreq | -1 |\n" + "| updatePlotFreq | -1 |\n" + "+------------------+-----------+" + ) + assert table == table_str def test_control_class_de_repr(self) -> None: - """Tests the __repr__ of ControlsClass with DE procedure.""" - controls = ControlsClass(procedure=Procedures.DE) + """Tests the __repr__ of Controls with DE procedure.""" + controls = Controls(procedure=Procedures.DE) table = controls.__repr__() - table_str = ("Property Value\n" - "-------------------- -----------------------------------------\n" - "parallel single\n" - "calcSldDuringFit False\n" - "resamPars [0.9, 50]\n" - "display iter\n" - "procedure de\n" - "populationSize 20\n" - "fWeight 0.5\n" - "crossoverProbability 0.8\n" - "strategy StrategyOptions.RandomWithPerVectorDither\n" - "targetValue 1.0\n" - "numGenerations 500") + table_str = ("+----------------------+-------------------------------------------+\n" + "| Property | Value |\n" + "+----------------------+-------------------------------------------+\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resamPars | [0.9, 50] |\n" + "| display | iter |\n" + "| procedure | de |\n" + "| populationSize | 20 |\n" + "| fWeight | 0.5 |\n" + "| crossoverProbability | 0.8 |\n" + "| strategy | StrategyOptions.RandomWithPerVectorDither |\n" + "| targetValue | 1.0 |\n" + "| numGenerations | 500 |\n" + "+----------------------+-------------------------------------------+" + ) + assert table == table_str def test_control_class_ns_repr(self) -> None: - """Tests the __repr__ of ControlsClass with NS procedure.""" - controls = ControlsClass(procedure=Procedures.NS) + """Tests the __repr__ of Controls with NS procedure.""" + controls = Controls(procedure=Procedures.NS) table = controls.__repr__() - table_str = ("Property Value\n" - "---------------- ---------\n" - "parallel single\n" - "calcSldDuringFit False\n" - "resamPars [0.9, 50]\n" - "display iter\n" - "procedure ns\n" - "Nlive 150\n" - "Nmcmc 0.0\n" - "propScale 0.1\n" - "nsTolerance 0.1") + table_str = ("+------------------+-----------+\n" + "| Property | Value |\n" + "+------------------+-----------+\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resamPars | [0.9, 50] |\n" + "| display | iter |\n" + "| procedure | ns |\n" + "| Nlive | 150 |\n" + "| Nmcmc | 0.0 |\n" + "| propScale | 0.1 |\n" + "| nsTolerance | 0.1 |\n" + "+------------------+-----------+" + ) + assert table == table_str def test_control_class_dream_repr(self) -> None: - """Tests the __repr__ of ControlsClass with Dream procedure.""" - controls = ControlsClass(procedure=Procedures.Dream) + """Tests the __repr__ of Controls with Dream procedure.""" + controls = Controls(procedure=Procedures.Dream) table = controls.__repr__() - table_str = ("Property Value\n" - "---------------- ---------\n" - "parallel single\n" - "calcSldDuringFit False\n" - "resamPars [0.9, 50]\n" - "display iter\n" - "procedure dream\n" - "nSamples 50000\n" - "nChains 10\n" - "jumpProb 0.5\n" - "pUnitGamma 0.2\n" - "boundHandling fold") + table_str = ("+------------------+-----------+\n" + "| Property | Value |\n" + "+------------------+-----------+\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resamPars | [0.9, 50] |\n" + "| display | iter |\n" + "| procedure | dream |\n" + "| nSamples | 50000 |\n" + "| nChains | 10 |\n" + "| jumpProb | 0.5 |\n" + "| pUnitGamma | 0.2 |\n" + "| boundHandling | fold |\n" + "+------------------+-----------+" + ) + assert table == table_str