diff --git a/RAT/controls.py b/RAT/controls.py index c9a7c619..ca4d6e83 100644 --- a/RAT/controls.py +++ b/RAT/controls.py @@ -1,1222 +1,84 @@ import tabulate -from typing import Union, Any, Literal +from typing import Union +from pydantic import BaseModel, Field, field_validator from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions -class BaseProcedure: +class BaseProcedure(BaseModel, validate_assignment=True, extra='forbid'): """ - Defines the base class with properties used in - all five procedures. + Defines the base class with properties used in all five procedures. """ + parallel: ParallelOptions = ParallelOptions.Single + calcSldDuringFit: bool = False + resamPars: list[Union[int, float]] = Field([0.9, 50], min_length=2, max_length=2) + display: DisplayOptions = DisplayOptions.Iter - def __init__(self, - parallel: Literal['single', 'points', 'contrasts', 'all'] = ParallelOptions.Single, - calcSldDuringFit: bool = False, - resamPars: list[Union[int, float]] = [0.9, 50], - display: Literal['off', 'iter', 'notify', 'final'] = DisplayOptions.Iter) -> None: - - self._parallel = parallel - self._calcSldDuringFit = calcSldDuringFit - self._resamPars = resamPars - self._display = display - - def _validate_type(self, - name: str, - value: Any, - type: type) -> None: - """ - Validates the value has the correct type. - - Parameters - ---------- - name : str - The name of the property. - value : Any - The value of the property. - type : type - The expected type of the property. - - Raises - ------ - TypeError - Raised if the property has value of the wrong type. - """ - if not isinstance(value, type): - raise TypeError(f"{name} must be of type {type.__name__}") - - def _validate_value(self, - name: str, - value: Union[str, int], - enum: Union[ParallelOptions, - Procedures, - DisplayOptions, - BoundHandlingOptions, - StrategyOptions], - type: Union[str, int] = str) -> None: - """ - Validates value is present in enum. - - Parameters - ---------- - name : str - The name of the property. - value : Union[str, int] - The value of the property. - enum : Union[ParallelOptions, Procedures, - DisplayOptions, BoundHandlingOptions, - StrategyOptions] - The expected enum which contains the value. - type : Union[str, int] - The type of values in enum. - - Raises - ------ - ValueError - Raised if the value is not in enum. - """ - allowed_options = [o.value for o in enum] - if value not in allowed_options: - raise ValueError((f"{name} must be a {enum.__name__} " - f"enum or one of the following {type.__name__} " - f"{', '.join([str(o) for o in allowed_options])}")) - - def _validate_range(self, - name: str, - value: Union[float, int], - lower_limit: Union[float, int] = float('-inf'), - lower_exclusive: bool = True, - upper_limit: Union[float, int] = float('inf'), - upper_exclusive: bool = True) -> None: - """ - Validates value is in range. - - Parameters - ---------- - name : str - The name of the property. - value : Union[float, int] - The value of the property. - lower_limit : Union[float, int] - The lower limit of the value. - lower_exclusive : bool - Boolean to indicate if lower limit is exclusive. - upper_limit : Union[float, int] - The upper limit of the value. - upper_exclusive : bool - Boolean to indicate if upper limit is exclusive. - - Raises - ------ - ValueError - Raised if the value is not in range. - """ - # Case 1 - lower and upper bounds are exclusive - if lower_exclusive and upper_exclusive: - if not lower_limit < value < upper_limit: - raise ValueError((f"{name} must be greater than " - f"{lower_limit} and less than {upper_limit}")) - - # Case 2 - upper bound is exclusive - elif not lower_exclusive and upper_exclusive: - if not lower_limit <= value < upper_limit: - raise ValueError((f"{name} must be greater than or equal to " - f"{lower_limit} and less than {upper_limit}")) - - # Case 3 - lower bound is exclusive - elif lower_exclusive and not upper_exclusive: - if not lower_limit < value <= upper_limit: - raise ValueError((f"{name} must be greater than " - f"{lower_limit} and less than or equal to {upper_limit}")) - - # Case 4 - lower and upper bounds are inclusive - elif not lower_exclusive and not upper_exclusive: - if not lower_limit <= value <= upper_limit: - raise ValueError((f"{name} must be greater than or equal to " - f"{lower_limit} and less than or equal to {upper_limit}")) - - @property - def parallel(self) -> Literal['single', 'points', 'contrasts', 'all']: - """ - Gets the parallel property. - - Returns - ------- - parallel : Literal['single', 'points', 'contrasts', 'all'] - The value of the parallel property. - """ - return self._parallel - - @parallel.setter - def parallel(self, value: Literal['single', 'points', 'contrasts', 'all']) -> None: - """ - Sets the parallel property after validation. - - Parameters - ---------- - value : Literal['single', 'points', 'contrasts', 'all'] - The value to be set for the parallel property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the parallel property has - input of the wrong type or value. - """ - self._validate_type('parallel', value, str) - self._validate_value('parallel', value, ParallelOptions) - self._parallel = value - - @property - def calcSldDuringFit(self) -> bool: - """ - Gets the calcSldDuringFit property. - - Returns - ------- - calcSldDuringFit : bool - The value of the calcSldDuringFit property. - """ - return self._calcSldDuringFit - - @calcSldDuringFit.setter - def calcSldDuringFit(self, value: bool) -> None: - """ - Sets the calcSldDuringFit property after validation. - - Parameters - ---------- - value : bool - The value to be set for the calcSldDuringFit property. - - Raises - ------ - TypeError - Raised if the calcSldDuringFit property has - input of the wrong type. - """ - self._validate_type('calcSldDuringFit', value, bool) - self._calcSldDuringFit = value - - @property - def resamPars(self) -> list[float]: - """ - Gets the resamPars property. - - Returns - ------- - resamPars : list[float] - The value of the resamPars property. - """ - return self._resamPars - - @resamPars.setter - def resamPars(self, value: list[Union[int, float]]) -> None: - """ - Sets the resamPars property after validation. - - Parameters - ---------- - value : list[Union[int, float]] - The value to be set for the resamPars property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the resamPars property has - input of the wrong type or value. - """ - self._validate_type('resamPars', value, list) - if len(value) != 2: - raise ValueError("resamPars must have length of 2") - if not all(isinstance(x, (float, int)) for x in value): - raise TypeError("resamPars must be defined using floats or ints") - self._validate_range(name = 'resamPars[0]', - value = value[0], - lower_limit = 0, - upper_limit = 1) - self._validate_range(name = 'resamPars[1]', - value = value[1], - lower_limit = 0, - lower_exclusive = False) - self._resamPars = [float(v) for v in value] - - @property - def display(self) -> Literal['off', 'iter', 'notify', 'final']: - """ - Gets the display property. - - Returns - ------- - display : Literal['off', 'iter', 'notify', 'final'] - The value of the display property. - """ - return self._display - - @display.setter - def display(self, value: Literal['off', 'iter', 'notify', 'final']) -> None: - """ - Sets the display property after validation. - - Parameters - ---------- - value : Literal['off', 'iter', 'notify', 'final'] - The value to be set for the display property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the display property has - input of the wrong type or value. - """ - self._validate_type('display', value, str) - self._validate_value('display', value, DisplayOptions) - self._display = value - - def __repr__(self, procedure: Literal['calculate', 'simplex', 'de', 'ns', 'dream']) -> str: - """ - Defines the display method. - - Parameters - ---------- - procedure : Literal['calculate', 'simplex', 'de', 'ns', 'dream'] - The procedure for the controls classes. - """ - properties = [["Property", "Value"]] +\ - [["procedure", procedure]] +\ - [[k.lstrip('_'), v] for k, v in self.__dict__.items()] - table = tabulate.tabulate(properties, headers="firstrow") - return table - - -class Calculate(BaseProcedure): - """Defines the class for the calculate procedure""" - - def __init__(self, - parallel: Literal['single', 'points', 'contrasts', 'all'] = ParallelOptions.Single, - calcSldDuringFit: bool = False, - resamPars: list[Union[int, float]] = [0.9, 50], - display: Literal['off', 'iter', 'notify', 'final'] = DisplayOptions.Iter) -> None: - - # call the constructor of the parent class - super().__init__(parallel = parallel, - calcSldDuringFit = calcSldDuringFit, - resamPars = resamPars, - display = display) - - @property - def procedure(self) -> Literal['calculate']: - """ - Gets the procedure property. - - Returns - ------- - procedure : Literal['calculate'] - The value of the procedure property. - """ - return Procedures.Calculate - - def __repr__(self) -> str: - """ - Defines the display method for Calculate class - """ - table = super().__repr__(Procedures.Calculate) - return table - - -class Simplex(BaseProcedure): - """Defines the class for the simplex procedure""" - - def __init__(self, - parallel: Literal['single', 'points', 'contrasts', 'all'] = ParallelOptions.Single, - calcSldDuringFit: bool = False, - resamPars: list[Union[int, float]] = [0.9, 50], - display: Literal['off', 'iter', 'notify', 'final'] = DisplayOptions.Iter, - tolX: float = 1e-6, - tolFun: float = 1e-6, - maxFunEvals: int = 10000, - maxIter: int = 1000, - updateFreq: int = -1, - updatePlotFreq: int = -1) -> None: - - # call the constructor of the parent class - super().__init__(parallel=parallel, - calcSldDuringFit=calcSldDuringFit, - resamPars=resamPars, - display=display) - - self._tolX = tolX - self._tolFun = tolFun - self._maxFunEvals = maxFunEvals - self._maxIter = maxIter - self._updateFreq = updateFreq - self._updatePlotFreq = updatePlotFreq - - @property - def procedure(self) -> Literal['simplex']: - """ - Gets the procedure property. - - Returns - ------- - procedure : Literal['simplex'] - The value of the procedure property. - """ - return Procedures.Simplex - - @property - def tolX(self) -> float: - """ - Gets the tolX property. - - Returns - ------- - tolX : float - The value of the tolX property. - """ - return self._tolX - - @tolX.setter - def tolX(self, value: float) -> None: - """ - Sets the tolX property after validation. - - Parameters - ---------- - value : float - The value to be set for the tolX property. - - Raises - ------ - TypeError - Raised if the tolX property has - input of the wrong type. - """ - self._validate_type('tolX', value, float) - self._tolX = value - - @property - def tolFun(self) -> float: - """ - Gets the tolFun property. - - Returns - ------- - tolFun : float - The value of the tolFun property. - """ - return self._tolFun - - @tolFun.setter - def tolFun(self, value: float) -> None: - """ - Sets the tolFun property after validation. - - Parameters - ---------- - value : float - The value to be set for the tolFun property. - - Raises - ------ - TypeError - Raised if the tolFun property has - input of the wrong type. - """ - self._validate_type('tolFun', value, float) - self._tolFun = value - - @property - def maxFunEvals(self) -> int: - """ - Gets the maxFunEvals property. - - Returns - ------- - maxFunEvals : int - The value of the maxFunEvals property. - """ - return self._maxFunEvals - - @maxFunEvals.setter - def maxFunEvals(self, value: int) -> None: - """ - Sets the maxFunEvals property after validation. - - Parameters - ---------- - value : int - The value to be set for the maxFunEvals property. - - Raises - ------ - TypeError - Raised if the maxFunEvals property has - input of the wrong type. - """ - self._validate_type('maxFunEvals', value, int) - self._maxFunEvals = value - - @property - def maxIter(self) -> int: - """ - Gets the maxIter property. - - Returns - ------- - maxIter : int - The value of the maxIter property. - """ - return self._maxIter - - @maxIter.setter - def maxIter(self, value: int) -> None: - """ - Sets the maxIter property after validation. - - Parameters - ---------- - value : int - The value to be set for the maxIter property. - - Raises - ------ - TypeError - Raised if the maxIter property has - input of the wrong type. - """ - self._validate_type('maxIter', value, int) - self._maxIter = value - - @property - def updateFreq(self) -> int: - """ - Gets the updateFreq property. - - Returns - ------- - updateFreq : int - The value of the updateFreq property. - """ - return self._updateFreq - - @updateFreq.setter - def updateFreq(self, value: int) -> None: - """ - Sets the updateFreq property after validation. - - Parameters - ---------- - value : int - The value to be set for the updateFreq property. - - Raises - ------ - TypeError - Raised if the updateFreq property has - input of the wrong type. - """ - self._validate_type('updateFreq', value, int) - self._updateFreq = value - - @property - def updatePlotFreq(self) -> int: - """ - Gets the updatePlotFreq property. - - Returns - ------- - updatePlotFreq : int - The value of the updatePlotFreq property. - """ - return self._updatePlotFreq - - @updatePlotFreq.setter - def updatePlotFreq(self, value: int) -> None: - """ - Sets the updatePlotFreq property after validation. - - Parameters - ---------- - value : int - The value to be set for the updatePlotFreq property. - - Raises - ------ - TypeError - Raised if the updatePlotFreq property has - input of the wrong type. - """ - self._validate_type('updatePlotFreq', value, int) - self._updatePlotFreq = value - - def __repr__(self) -> str: - """ - Defines the display method for Simplex class - """ - table = super().__repr__(Procedures.Simplex) - return table - - -class DE(BaseProcedure): - """Defines the class for the Differential Evolution procedure""" - - def __init__(self, - parallel: Literal['single', 'points', 'contrasts', 'all'] = ParallelOptions.Single, - calcSldDuringFit: bool = False, - resamPars: list[Union[int, float]] = [0.9, 50], - display: Literal['off', 'iter', 'notify', 'final'] = DisplayOptions.Iter, - populationSize: int = 20, - fWeight: float = 0.5, - crossoverProbability: float = 0.8, - strategy: Literal[1, 2, 3, 4, 5, 6] = StrategyOptions.RandomWithPerVectorDither.value, - targetValue: Union[int, float] = 1.0, - numGenerations: int = 500) -> None: - - # call the constructor of the parent class - super().__init__(parallel=parallel, - calcSldDuringFit=calcSldDuringFit, - resamPars=resamPars, - display=display) - - self._populationSize = populationSize - self._fWeight = fWeight - self._crossoverProbability = crossoverProbability - self._strategy = strategy - self._targetValue = targetValue - self._numGenerations = numGenerations - - @property - def procedure(self) -> Literal['de']: - """ - Gets the procedure property. - - Returns - ------- - procedure : Literal['de'] - The value of the procedure property. - """ - return Procedures.DE - - @property - def populationSize(self) -> int: - """ - Gets the populationSize property. - - Returns - ------- - populationSize : int - The value of the populationSize property. - """ - return self._populationSize - - @populationSize.setter - def populationSize(self, value: int) -> None: - """ - Sets the populationSize property after validation. - - Parameters - ---------- - value : int - The value to be set for the populationSize property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the populationSize property has - input of the wrong type or value. - """ - self._validate_type('populationSize', value, int) - self._validate_range(name = 'populationSize', - value = value, - lower_limit = 1, - lower_exclusive = False) - self._populationSize = value - - @property - def fWeight(self) -> float: - """ - Gets the fWeight property. - - Returns - ------- - fWeight : float - The value of the fWeight property. - """ - return self._fWeight - - @fWeight.setter - def fWeight(self, value: float) -> None: - """ - Sets the fWeight property after validation. - - Parameters - ---------- - value : float - The value to be set for the fWeight property. + @field_validator("resamPars") + def check_resamPars(cls, resamPars): + if not 0 < resamPars[0] < 1: + raise ValueError('resamPars[0] must be between 0 and 1') + if not 0 <= resamPars[1]: + raise ValueError('resamPars[1] must be greater than 0') + return resamPars - Raises - ------ - TypeError - Raised if the fWeight property has - input of the wrong type. - """ - self._validate_type('fWeight', value, float) - self._fWeight = value - - @property - def crossoverProbability(self) -> float: - """ - Gets the crossoverProbability property. - - Returns - ------- - crossoverProbability : float - The value of the crossoverProbability property. - """ - return self._crossoverProbability - - @crossoverProbability.setter - def crossoverProbability(self, value: float) -> None: - """ - Sets the crossoverProbability property after validation. - - Parameters - ---------- - value : float - The value to be set for the crossoverProbability property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the crossoverProbability property has - input of the wrong type or value. - """ - self._validate_type('crossoverProbability', value, float) - self._validate_range(name = 'crossoverProbability', - value = value, - lower_limit = 0, - upper_limit = 1) - self._crossoverProbability = value - - @property - def strategy(self) -> Literal[1, 2, 3, 4, 5, 6]: - """ - Gets the strategy property. - - Returns - ------- - strategy : Literal[1, 2, 3, 4, 5, 6] - The value of the strategy property. - """ - return self._strategy - - @strategy.setter - def strategy(self, value: int) -> None: - """ - Sets the strategy property after validation. - - Parameters - ---------- - value : int - The value to be set for the strategy property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the strategy property has - input of the wrong type or value. - """ - self._validate_type('strategy', value, int) - self._validate_value(name = 'strategy', - value = value, - enum = StrategyOptions, - type = int) - self._strategy = value - - @property - def targetValue(self) -> float: - """ - Gets the targetValue property. - - Returns - ------- - targetValue : float - The value of the targetValue property. - """ - return self._targetValue - - @targetValue.setter - def targetValue(self, value: Union[int, float]) -> None: - """ - Sets the targetValue property after validation. - - Parameters - ---------- - value : Union[int, float] - The value to be set for the targetValue property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the targetValue property has - input of the wrong type or value. - """ - self._validate_type('targetValue', value, (int, float)) - self._validate_range(name = 'targetValue', - value = value, - lower_limit = 1, - lower_exclusive = False) - self._targetValue = float(value) - - @property - def numGenerations(self) -> int: - """ - Gets the numGenerations property. - - Returns - ------- - numGenerations : int - The value of the numGenerations property. - """ - return self._numGenerations - - @numGenerations.setter - def numGenerations(self, value: int) -> None: - """ - Sets the numGenerations property after validation. - - Parameters - ---------- - value : int - The value to be set for the numGenerations property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the numGenerations property has - input of the wrong type or value. - """ - self._validate_type('numGenerations', value, int) - self._validate_range(name = 'numGenerations', - value = value, - lower_limit = 1, - lower_exclusive = False) - self._numGenerations = value - - def __repr__(self) -> str: - """ - Defines the display method for DE class - """ - table = super().__repr__(Procedures.DE) - return table - - -class NS(BaseProcedure): - """Defines the class for the Nested Sampler procedure""" - - def __init__(self, - parallel: Literal['single', 'points', 'contrasts', 'all'] = ParallelOptions.Single, - calcSldDuringFit: bool = False, - resamPars: list[Union[int, float]] = [0.9, 50], - display: Literal['off', 'iter', 'notify', 'final'] = DisplayOptions.Iter, - Nlive: int = 150, - Nmcmc: Union[float, int] = 0.0, - propScale: float = 0.1, - nsTolerance: Union[float, int] = 0.1) -> None: - - # call the constructor of the parent class - super().__init__(parallel=parallel, - calcSldDuringFit=calcSldDuringFit, - resamPars=resamPars, - display=display) - - self._Nlive = Nlive - self._Nmcmc = Nmcmc - self._propScale = propScale - self._nsTolerance = nsTolerance - - @property - def procedure(self) -> Literal['ns']: - """ - Gets the procedure property. - - Returns - ------- - procedure : Literal['ns'] - The value of the procedure property. - """ - return Procedures.NS - - @property - def Nlive(self) -> int: - """ - Gets the Nlive property. - - Returns - ------- - Nlive : int - The value of the Nlive property. - """ - return self._Nlive - - @Nlive.setter - def Nlive(self, value: int) -> None: - """ - Sets the Nlive property after validation. - - Parameters - ---------- - value : int - The value to be set for the Nlive property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the Nlive property has - input of the wrong type or value. - """ - self._validate_type('Nlive', value, int) - self._validate_range(name = 'Nlive', - value = value, - lower_limit = 1, - lower_exclusive = False) - self._Nlive = value - - @property - def Nmcmc(self) -> float: - """ - Gets the Nmcmc property. - - Returns - ------- - Nmcmc : float - The value of the Nmcmc property. - """ - return self._Nmcmc - - @Nmcmc.setter - def Nmcmc(self, value: Union[float, int]) -> None: - """ - Sets the Nmcmc property after validation. - - Parameters - ---------- - value : Union[float, int] - The value to be set for the Nmcmc property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the Nmcmc property has - input of the wrong type or value. - """ - self._validate_type('Nmcmc', value, (int, float)) - self._validate_range(name = 'Nmcmc', - value = value, - lower_limit = 1, - lower_exclusive = False) - self._Nmcmc = float(value) - - @property - def propScale(self) -> float: - """ - Gets the propScale property. - - Returns - ------- - propScale : float - The value of the propScale property. - """ - return self._propScale - - @propScale.setter - def propScale(self, value: float) -> None: - """ - Sets the propScale property after validation. - - Parameters - ---------- - value : float - The value to be set for the propScale property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the propScale property has - input of the wrong type or value. - """ - self._validate_type('propScale', value, float) - self._validate_range(name = 'propScale', - value = value, - lower_limit = 0, - upper_limit = 1) - self._propScale = value - - @property - def nsTolerance(self) -> float: - """ - Gets the nsTolerance property. - - Returns - ------- - nsTolerance : float - The value of the nsTolerance property. - """ - return self._nsTolerance - - @nsTolerance.setter - def nsTolerance(self, value: Union[float, int]) -> None: - """ - Sets the nsTolerance property after validation. - - Parameters - ---------- - value : Union[float, int] - The value to be set for the nsTolerance property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the nsTolerance property has - input of the wrong type or value. - """ - self._validate_type('nsTolerance', value, (int, float)) - self._validate_range(name = 'nsTolerance', - value = value, - lower_limit = 0, - lower_exclusive = False) - self._nsTolerance = float(value) - - def __repr__(self) -> str: - """ - Defines the display method for NS class - """ - table = super().__repr__(Procedures.NS) - return table +class Calculate(BaseProcedure, validate_assignment=True, extra='forbid'): + """ + Defines the class for the calculate procedure + """ + procedure: Procedures = Field(Procedures.Calculate, frozen=True) -class Dream(BaseProcedure): - """Defines the class for the Dream procedure""" - - def __init__(self, - parallel: Literal['single', 'points', 'contrasts', 'all'] = ParallelOptions.Single, - calcSldDuringFit: bool = False, - resamPars: list[Union[int, float]] = [0.9, 50], - display: Literal['off', 'iter', 'notify', 'final'] = DisplayOptions.Iter, - nSamples: int = 50000, - nChains: int = 10, - jumpProb: float = 0.5, - pUnitGamma:float = 0.2, - boundHandling: Literal['no', 'reflect', 'bound', 'fold'] = BoundHandlingOptions.Fold) -> None: - # call the constructor of the parent class - super().__init__(parallel=parallel, - calcSldDuringFit=calcSldDuringFit, - resamPars=resamPars, - display=display) - - self._nSamples = nSamples - self._nChains = nChains - self._jumpProb = jumpProb # lambda in MATLAB - self._pUnitGamma = pUnitGamma - self._boundHandling = boundHandling - - @property - def procedure(self) -> Literal['dream']: - """ - Gets the procedure property. - - Returns - ------- - procedure : Literal['dream'] - The value of the procedure property. - """ - return Procedures.Dream - - @property - def nSamples(self) -> int: - """ - Gets the nSamples property. - - Returns - ------- - nSamples : int - The value of the nSamples property. - """ - return self._nSamples - - @nSamples.setter - def nSamples(self, value: int) -> None: - """ - Sets the nSamples property after validation. - - Parameters - ---------- - value : int - The value to be set for the nSamples property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the nSamples property has - input of the wrong type or value. - """ - self._validate_type('nSamples', value, int) - self._validate_range(name = 'nSamples', - value = value, - lower_limit = 0, - lower_exclusive = False) - self._nSamples = value - - @property - def nChains(self) -> int: - """ - Gets the nChains property. - - Returns - ------- - nChains : int - The value of the nChains property. - """ - return self._nChains - - @nChains.setter - def nChains(self, value: int) -> None: - """ - Sets the nChains property after validation. - - Parameters - ---------- - value : int - The value to be set for the nChains property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the nChains property has - input of the wrong type or value. - """ - self._validate_type('nChains', value, int) - self._validate_range(name = 'nChains', - value = value, - lower_limit = 0, - lower_exclusive = False) - self._nChains = value - - @property - def jumpProb(self) -> float: - """ - Gets the jumpProb property. - - Returns - ------- - jumpProb : float - The value of the jumpProb property. - """ - return self._jumpProb - - @jumpProb.setter - def jumpProb(self, value: float) -> None: - """ - Sets the jumpProb property after validation. - - Parameters - ---------- - value : float - The value to be set for the jumpProb property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the jumpProb property has - input of the wrong type or value. - """ - self._validate_type('jumpProb', value, float) - self._validate_range(name = 'jumpProb', - value = value, - lower_limit = 0, - upper_limit = 1) - self._jumpProb = value - - @property - def pUnitGamma(self) -> float: - """ - Gets the pUnitGamma property. - - Returns - ------- - pUnitGamma : float - The value of the pUnitGamma property. - """ - return self._pUnitGamma - - @pUnitGamma.setter - def pUnitGamma(self, value: float) -> None: - """ - Sets the pUnitGamma property after validation. - - Parameters - ---------- - value : float - The value to be set for the pUnitGamma property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the pUnitGamma property has - input of the wrong type or value. - """ - self._validate_type('pUnitGamma', value, float) - self._validate_range(name = 'pUnitGamma', - value = value, - lower_limit = 0, - upper_limit = 1) - self._pUnitGamma = value - - @property - def boundHandling(self) -> Literal['no', 'reflect', 'bound', 'fold']: - """ - Gets the boundHandling property. - - Returns - ------- - boundHandling : Literal['no', 'reflect', 'bound', 'fold'] - The value of the boundHandling property. - """ - return self._boundHandling - - @boundHandling.setter - def boundHandling(self, value: Literal['no', 'reflect', 'bound', 'fold']) -> None: - """ - Sets the boundHandling property after validation. - - Parameters - ---------- - value : Literal['no', 'reflect', 'bound', 'fold'] - The value to be set for the boundHandling property. - - Raises - ------ - Union[TypeError, ValueError] - Raised if the boundHandling property has - input of the wrong type or value. - """ - self._validate_type('boundHandling', value, str) - self._validate_value('boundHandling', value, BoundHandlingOptions) - self._boundHandling = value +class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'): + """ + Defines the class for the simplex procedure + """ + procedure: Procedures = Field(Procedures.Simplex, frozen=True) + tolX: float = Field(1e-6, gt=0) + tolFun: float = Field(1e-6, gt=0) + maxFunEvals: int = Field(10000, gt=0) + maxIter: int = Field(1000, gt=0) + updateFreq: int = -1 + updatePlotFreq: int = -1 + +class DE(BaseProcedure, validate_assignment=True, extra='forbid'): + """ + Defines the class for the Differential Evolution procedure + """ + procedure: Procedures = Field(Procedures.DE, frozen=True) + populationSize: int = Field(20, ge=1) + fWeight: float = 0.5 + crossoverProbability: float = Field(0.8, gt=0, lt=1) + strategy: StrategyOptions = StrategyOptions.RandomWithPerVectorDither + targetValue: Union[int, float] = Field(1.0, ge=1) + numGenerations: int = Field(500, ge=1) + +class NS(BaseProcedure, validate_assignment=True, extra='forbid'): + """ + Defines the class for the Nested Sampler procedure + """ + procedure: Procedures = Field(Procedures.NS, frozen=True) + Nlive: int = Field(150, ge=1) + Nmcmc: Union[float, int] = Field(0.0, ge=0) + propScale: float = Field(0.1, gt=0, lt=1) + nsTolerance: Union[float, int] = Field(0.1, ge=0) - def __repr__(self) -> str: - """ - Defines the display method for Dream class - """ - table = super().__repr__(Procedures.Dream) - return table +class Dream(BaseProcedure, validate_assignment=True, extra='forbid'): + """ + Defines the class for the Dream procedure + """ + procedure: Procedures = Field(Procedures.Dream, frozen=True) + nSamples: int = Field(50000, ge=0) + nChains: int = Field(10, gt=0) + jumpProb: float = Field(0.5, gt=0, lt=1) + pUnitGamma:float = Field(0.2, gt=0, lt=1) + boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold class ControlsClass: def __init__(self, - procedure: Literal['calculate', 'simplex', 'de', 'ns', 'dream'] = Procedures.Calculate, + procedure: Procedures = Procedures.Calculate, **properties) -> None: self._procedure = procedure @@ -1224,7 +86,7 @@ def __init__(self, if self._procedure == Procedures.Calculate: self._controls = Calculate(**properties) - + elif self._procedure == Procedures.Simplex: self._controls = Simplex(**properties) @@ -1313,10 +175,3 @@ def _validate_properties(self, **properties) -> None: if not (expected_properties | input_properties == expected_properties): raise ValueError((f"Properties that can be set for {self._procedure} are " f"{', '.join(sorted(expected_properties))}")) - - def __repr__(self) -> str: - """ - Defines the display method for Controls class - """ - table = self._controls.__repr__() - return table diff --git a/RAT/utils/enums.py b/RAT/utils/enums.py index badb7d57..e6961326 100644 --- a/RAT/utils/enums.py +++ b/RAT/utils/enums.py @@ -32,7 +32,7 @@ class DisplayOptions(StrEnum): class BoundHandlingOptions(StrEnum): """Defines the avaliable options for bound handling""" - No = 'no' + Off = 'off' Reflect = 'reflect' Bound = 'bound' Fold = 'fold' diff --git a/tests/test_controls.py b/tests/test_controls.py index 8f17d59f..fc2f2842 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -1,6 +1,7 @@ """Tests for control and procedure classes""" import pytest +import pydantic from typing import Union, Any from RAT.controls import BaseProcedure, Calculate, Simplex, DE, NS, Dream, ControlsClass from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions @@ -16,94 +17,66 @@ def setup_class(self): ('calcSldDuringFit', False), ('resamPars', [0.9, 50]), ('display', DisplayOptions.Iter)]) - def test_base_procedure_values(self, property: str, value: Any) -> None: + def test_base_property_values(self, property: str, value: Any) -> None: assert getattr(self.base_procedure, property) == value - - @pytest.mark.parametrize("property", ['parallel', 'calcSldDuringFit', 'resamPars', 'display']) - def test_base_procedure_properties(self, property: str) -> None: - assert hasattr(self.base_procedure, property) - - @pytest.mark.parametrize("property, var_type", [('parallel', str), + + @pytest.mark.parametrize("property, var_type", [('parallel', ParallelOptions), ('calcSldDuringFit', bool), ('resamPars', list), - ('display', str)]) - def test_base_procedure_property_types(self, property: str, var_type) -> None: + ('display', DisplayOptions)]) + def test_base_property_types(self, property: str, var_type) -> None: assert isinstance(getattr(self.base_procedure, property), var_type) - + @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.All), ('calcSldDuringFit', True), ('resamPars', [0.2, 1]), ('display', DisplayOptions.Notify)]) - def test_base_procedure_setters(self, property: str, value: Any) -> None: + def test_base_property_setters(self, property: str, value: Any) -> None: setattr(self.base_procedure, property, value) assert getattr(self.base_procedure, property) == value + + @pytest.mark.parametrize("value1, value2", [('test', True), ('ALL', 1), ("Contrast", 3.0)]) + def test_base_parallel_validation(self, value1: str, value2: Any) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.base_procedure, 'parallel', value1) + assert exp.value.errors()[0]['msg'] == "Input should be 'single','points','contrasts' or 'all'" + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.base_procedure, 'parallel', value2) + assert exp.value.errors()[0]['msg'] == "Input should be a valid string" + + @pytest.mark.parametrize("value", [5.0, 12]) + def test_base_calcSldDuringFit_validation(self, value: Any) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.base_procedure, 'calcSldDuringFit', value) + assert exp.value.errors()[0]['msg'] == "Input should be a valid boolean, unable to interpret input" + + @pytest.mark.parametrize("value1, value2", [('test', True), ('iterate', 1), ("FINAL", 3.0)]) + def test_base_display_validation(self, value1: str, value2: Any) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.base_procedure, 'display', value1) + assert exp.value.errors()[0]['msg'] == "Input should be 'off','iter','notify' or 'final'" + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.base_procedure, 'display', value2) + assert exp.value.errors()[0]['msg'] == "Input should be a valid string" - @pytest.mark.parametrize("var, exp_type", [('string', float), - (True, str), - (1, bool), - (1.0, int)]) - def test_validate_type(self, var: Any, exp_type: Any) -> None: - with pytest.raises(TypeError) as exc: - self.base_procedure._validate_type('var', var, exp_type) - assert f"var must be of type {exp_type.__name__}" in str(exc.value) - - @pytest.mark.parametrize("enum, enum_type", [(ParallelOptions, str), - (Procedures, str), - (DisplayOptions, str), - (BoundHandlingOptions, str), - (StrategyOptions, int)]) - def test_validate_value(self, enum: str, enum_type: Union[int, str]) -> None: - allowed_options = [str(o.value) for o in enum] - with pytest.raises(ValueError) as exc: - self.base_procedure._validate_value('var', 'test_variable', enum, enum_type) - assert (f"var must be a {enum.__name__} " - f"enum or one of the following {enum_type.__name__} " - f"{', '.join(allowed_options)}") == str(exc.value) - - @pytest.mark.parametrize("lower, upper, msg", - [(True, True, "var must be greater than 0 and less than 1"), - (False, True, "var must be greater than or equal to 0 and less than 1"), - (True, False, "var must be greater than 0 and less than or equal to 1"), - (False, False, "var must be greater than or equal to 0 and less than or equal to 1")]) - def test_validate_range(self, lower: bool, upper: bool, msg: str) -> None: - with pytest.raises(ValueError) as exc: - self.base_procedure._validate_range(name='var', - value=10, - lower_limit=0, - upper_limit=1, - lower_exclusive=lower, - upper_exclusive=upper) - assert msg == str(exc.value) - - @pytest.mark.parametrize("property, value, msg", - [('parallel', 1, 'parallel must be of type str'), - ('calcSldDuringFit', 1, 'calcSldDuringFit must be of type bool'), - ('resamPars', True, 'resamPars must be of type list'), - ('display', True, 'display must be of type str')]) - def test_base_procedure_properties_type_exceptions(self, property: str, value: Any, msg: str) -> None: - with pytest.raises(TypeError) as exc: - setattr(self.base_procedure, property, value) - assert msg == str(exc.value) - - def test_base_procedure_resamPars_type_exceptions(self) -> None: - with pytest.raises(TypeError) as exc: - self.base_procedure.resamPars = ['f', 'g'] - assert 'resamPars must be defined using floats or ints' == str(exc.value) - - @pytest.mark.parametrize("property, value, msg", - [('parallel', - 'test_value', - 'parallel must be a ParallelOptions enum or one of the following str single, points, contrasts, all'), - ('resamPars', - [1, 2, 3], - 'resamPars must have length of 2'), - ('display', - 'test_value', - 'display must be a DisplayOptions enum or one of the following str off, iter, notify, final')]) - def test_base_procedure_properties_type_exceptions(self, property: str, value: Any, msg: str) -> None: - with pytest.raises(ValueError) as exc: - setattr(self.base_procedure, property, value) - assert msg == str(exc.value) + @pytest.mark.parametrize("value, msg", [([5.0], "List should have at least 2 items after validation, not 1"), + ([12, 13, 14], "List should have at most 2 items after validation, not 3")]) + def test_base_resamPars_lenght_validation(self, value: Any, msg: str) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.base_procedure, 'resamPars', value) + assert exp.value.errors()[0]['msg'] == msg + + @pytest.mark.parametrize("value, msg", [([1.0, 2], "Value error, resamPars[0] must be between 0 and 1"), + ([0.5, -0.1], "Value error, resamPars[1] must be greater than 0")]) + def test_base_resamPars_value_validation(self, value: Any, msg: str) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.base_procedure, 'resamPars', value) + assert exp.value.errors()[0]['msg'] == msg + + def test_base_extra_property_error(self) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.base_procedure, 'test', 1) + assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" class TestCalculate: @@ -117,40 +90,30 @@ def setup_class(self): ('resamPars', [0.9, 50]), ('display', DisplayOptions.Iter), ('procedure', Procedures.Calculate)]) - def test_calculate_procedure_values(self, property: str, value: Any) -> None: + def test_calculate_property_values(self, property: str, value: Any) -> None: assert getattr(self.calulate, property) == value - @pytest.mark.parametrize("property", ['parallel', - 'calcSldDuringFit', - 'resamPars', - 'display', - 'procedure']) - def test_calulate_procedure_properties(self, property: str) -> None: - assert hasattr(self.calulate, property) - - def test_calculate_procedure_property_types(self) -> None: + def test_calculate_property_types(self) -> None: assert isinstance(getattr(self.calulate, 'procedure'), str) @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.All), ('calcSldDuringFit', True), ('resamPars', [0.2, 1]), ('display', DisplayOptions.Notify)]) - def test_calculate_procedure_setters(self, property: str, value: Any) -> None: + def test_calculate_property_setters(self, property: str, value: Any) -> None: setattr(self.calulate, property, value) assert getattr(self.calulate, property) == value + + def test_calculate_extra_property_error(self) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.calulate, 'test', 1) + assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" - def test_calculate_procedure_repr(self) -> None: - calulate = Calculate() - table = calulate.__repr__() - table_str = ("Property Value\n" - "---------------- ---------\n" - "procedure calculate\n" - "parallel single\n" - "calcSldDuringFit False\n" - "resamPars [0.9, 50]\n" - "display iter") - assert table == table_str - + def test_calculate_procedure_error(self) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.calulate, 'procedure', 'test') + assert exp.value.errors()[0]['msg'] == "Field is frozen" + class TestSimplex: """Tests the Simplex class.""" @@ -169,31 +132,17 @@ def setup_class(self): ('maxIter', 1000), ('updateFreq', -1), ('updatePlotFreq', -1)]) - def test_simplex_procedure_values(self, property: str, value: Any) -> None: + def test_simplex_property_values(self, property: str, value: Any) -> None: assert getattr(self.simplex, property) == value - - @pytest.mark.parametrize("property", ['parallel', - 'calcSldDuringFit', - 'resamPars', - 'display', - 'procedure', - 'tolX', - 'tolFun', - 'maxFunEvals', - 'maxIter', - 'updateFreq', - 'updatePlotFreq']) - def test_simplex_procedure_properties(self, property: str) -> None: - assert hasattr(self.simplex, property) - @pytest.mark.parametrize("property, var_type", [('procedure', str), + @pytest.mark.parametrize("property, var_type", [('procedure', Procedures), ('tolX', float), ('tolFun', float), ('maxFunEvals', int), ('maxIter', int), ('updateFreq', int), ('updatePlotFreq', int),]) - def test_simplex_procedure_property_types(self, property: str, var_type) -> None: + def test_simplex_property_types(self, property: str, var_type) -> None: assert isinstance(getattr(self.simplex, property), var_type) @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.All), @@ -206,27 +155,28 @@ def test_simplex_procedure_property_types(self, property: str, var_type) -> None ('maxIter', 50), ('updateFreq', 4), ('updatePlotFreq', 3)]) - def test_simplex_procedure_setters(self, property: str, value: Any) -> None: + def test_simplex_property_setters(self, property: str, value: Any) -> None: setattr(self.simplex, property, value) assert getattr(self.simplex, property) == value - def test_simplex_procedure_repr(self) -> None: - simplex = Simplex() - table = simplex.__repr__() - table_str = ("Property Value\n" - "---------------- ---------\n" - "procedure simplex\n" - "parallel single\n" - "calcSldDuringFit False\n" - "resamPars [0.9, 50]\n" - "display iter\n" - "tolX 1e-06\n" - "tolFun 1e-06\n" - "maxFunEvals 10000\n" - "maxIter 1000\n" - "updateFreq -1\n" - "updatePlotFreq -1") - assert table == table_str + @pytest.mark.parametrize("property, value", [('tolX', -4e-6), + ('tolFun', -3e-4), + ('maxFunEvals', -100), + ('maxIter', -50)]) + def test_simplex_property_errors(self, property: str, value: Any) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.simplex, property, value) + assert exp.value.errors()[0]['msg'] == "Input should be greater than 0" + + def test_simplex_extra_property_error(self) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.simplex, 'test', 1) + assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" + + def test_simplex_procedure_error(self) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.simplex, 'procedure', 'test') + assert exp.value.errors()[0]['msg'] == "Field is frozen" class TestDE: @@ -243,34 +193,20 @@ def setup_class(self): ('populationSize', 20), ('fWeight', 0.5), ('crossoverProbability', 0.8), - ('strategy', StrategyOptions.RandomWithPerVectorDither.value), + ('strategy', StrategyOptions.RandomWithPerVectorDither), ('targetValue', 1), ('numGenerations', 500)]) - def test_de_procedure_values(self, property: str, value: Any) -> None: + def test_de_property_values(self, property: str, value: Any) -> None: assert getattr(self.de, property) == value - @pytest.mark.parametrize("property", ['parallel', - 'calcSldDuringFit', - 'resamPars', - 'display', - 'procedure', - 'populationSize', - 'fWeight', - 'crossoverProbability', - 'strategy', - 'targetValue', - 'numGenerations']) - def test_de_procedure_properties(self, property: str) -> None: - assert hasattr(self.de, property) - - @pytest.mark.parametrize("property, var_type", [('procedure', str), + @pytest.mark.parametrize("property, var_type", [('procedure', Procedures), ('populationSize', int), ('fWeight', float), ('crossoverProbability', float), - ('strategy', int), + ('strategy', StrategyOptions), ('targetValue', float), ('numGenerations', int)]) - def test_de_procedure_property_types(self, property: str, var_type) -> None: + def test_de_property_types(self, property: str, var_type) -> None: assert isinstance(getattr(self.de, property), var_type) @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.All), @@ -280,30 +216,42 @@ def test_de_procedure_property_types(self, property: str, var_type) -> None: ('populationSize', 20), ('fWeight', 0.3), ('crossoverProbability', 0.4), - ('strategy', 3), - ('targetValue', 2), + ('strategy', StrategyOptions.BestWithJitter), + ('targetValue', 2.0), ('numGenerations', 50)]) - def test_de_procedure_setters(self, property: str, value: Any) -> None: + def test_de_property_setters(self, property: str, value: Any) -> None: setattr(self.de, property, value) assert getattr(self.de, property) == value - - def test_de_procedure_repr(self) -> None: - de = DE() - table = de.__repr__() - table_str = ("Property Value\n" - "-------------------- ---------\n" - "procedure de\n" - "parallel single\n" - "calcSldDuringFit False\n" - "resamPars [0.9, 50]\n" - "display iter\n" - "populationSize 20\n" - "fWeight 0.5\n" - "crossoverProbability 0.8\n" - "strategy 4\n" - "targetValue 1.0\n" - "numGenerations 500") - assert table == table_str + + @pytest.mark.parametrize("value", [0, 2]) + def test_de_crossoverProbability_error(self, value: int) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.de, 'crossoverProbability', value) + assert exp.value.errors()[0]['msg'] in ["Input should be greater than 0", + "Input should be less than 1"] + + @pytest.mark.parametrize("property, value", [('targetValue', 0), + ('targetValue',0.999), + ('numGenerations', -500), + ('numGenerations', 0), + ('populationSize', 0), + ('populationSize', -1)]) + def test_de_targetValue_numGenerations_populationSize_error(self, + property: str, + value: Union[int, float]) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.de, property, value) + assert exp.value.errors()[0]['msg'] == "Input should be greater than or equal to 1" + + def test_de_extra_property_error(self) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.de, 'test', 1) + assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" + + def test_de_procedure_error(self) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.de, 'procedure', 'test') + assert exp.value.errors()[0]['msg'] == "Field is frozen" class TestNS: @@ -321,27 +269,15 @@ def setup_class(self): ('Nmcmc', 0), ('propScale', 0.1), ('nsTolerance', 0.1)]) - def test_ns_procedure_values(self, property: str, value: Any) -> None: + def test_ns_property_values(self, property: str, value: Any) -> None: assert getattr(self.ns, property) == value - @pytest.mark.parametrize("property", ['parallel', - 'calcSldDuringFit', - 'resamPars', - 'display', - 'procedure', - 'Nlive', - 'Nmcmc', - 'propScale', - 'nsTolerance']) - def test_ns_procedure_properties(self, property: str) -> None: - assert hasattr(self.ns, property) - - @pytest.mark.parametrize("property, var_type", [('procedure', str), + @pytest.mark.parametrize("property, var_type", [('procedure', Procedures), ('Nlive', int), ('Nmcmc', float), ('propScale', float), ('nsTolerance', float)]) - def test_ns_procedure_property_types(self, property: str, var_type) -> None: + def test_ns_property_types(self, property: str, var_type) -> None: assert isinstance(getattr(self.ns, property), var_type) @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.All), @@ -352,26 +288,35 @@ def test_ns_procedure_property_types(self, property: str, var_type) -> None: ('Nmcmc', 1), ('propScale', 0.5), ('nsTolerance', 0.8)]) - def test_ns_procedure_setters(self, property: str, value: Any) -> None: + def test_ns_property_setters(self, property: str, value: Any) -> None: setattr(self.ns, property, value) assert getattr(self.ns, property) == value - - def test_ns_procedure_repr(self) -> None: - ns = NS() - table = ns.__repr__() - table_str = ("Property Value\n" - "---------------- ---------\n" - "procedure ns\n" - "parallel single\n" - "calcSldDuringFit False\n" - "resamPars [0.9, 50]\n" - "display iter\n" - "Nlive 150\n" - "Nmcmc 0.0\n" - "propScale 0.1\n" - "nsTolerance 0.1") - assert table == table_str - + + @pytest.mark.parametrize("property, value, bound", [('Nmcmc', -0.6, 0), + ('nsTolerance', -500, 0), + ('Nlive', -500, 1)]) + def test_ns_Nmcmc_nsTolerance_Nlive_error(self, property: str, value: Union[int, float], bound: int) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.ns, property, value) + assert exp.value.errors()[0]['msg'] == f"Input should be greater than or equal to {bound}" + + @pytest.mark.parametrize("value", [0, 2]) + def test_ns_propScale_error(self, value: int) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.ns, 'propScale', value) + assert exp.value.errors()[0]['msg'] in ["Input should be greater than 0", + "Input should be less than 1"] + + def test_ns_extra_property_error(self) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.ns, 'test', 1) + assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" + + def test_ns_procedure_error(self) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.ns, 'procedure', 'test') + assert exp.value.errors()[0]['msg'] == "Field is frozen" + class TestDream: """Tests the Dream class.""" @@ -389,29 +334,16 @@ def setup_class(self): ('jumpProb', 0.5), ('pUnitGamma', 0.2), ('boundHandling', BoundHandlingOptions.Fold)]) - def test_dream_procedure_values(self, property: str, value: Any) -> None: + def test_dream_property_values(self, property: str, value: Any) -> None: assert getattr(self.dream, property) == value - @pytest.mark.parametrize("property", ['parallel', - 'calcSldDuringFit', - 'resamPars', - 'display', - 'procedure', - 'nSamples', - 'nChains', - 'jumpProb', - 'pUnitGamma', - 'boundHandling']) - def test_dream_procedure_properties(self, property: str) -> None: - assert hasattr(self.dream, property) - - @pytest.mark.parametrize("property, var_type", [('procedure', str), + @pytest.mark.parametrize("property, var_type", [('procedure', Procedures), ('nSamples', int), ('nChains', int), ('jumpProb', float), ('pUnitGamma', float), - ('boundHandling', str)]) - def test_dream_procedure_property_types(self, property: str, var_type) -> None: + ('boundHandling', BoundHandlingOptions)]) + def test_dream_property_types(self, property: str, var_type) -> None: assert isinstance(getattr(self.dream, property), var_type) @pytest.mark.parametrize("property, value", [('parallel', ParallelOptions.All), @@ -423,26 +355,42 @@ def test_dream_procedure_property_types(self, property: str, var_type) -> None: ('jumpProb', 0.7), ('pUnitGamma', 0.3), ('boundHandling', BoundHandlingOptions.Reflect)]) - def test_dream_procedure_setters(self, property: str, value: Any) -> None: + def test_dream_property_setters(self, property: str, value: Any) -> None: setattr(self.dream, property, value) assert getattr(self.dream, property) == value + + @pytest.mark.parametrize("property, value", [('jumpProb',0), + ('jumpProb', 2), + ('pUnitGamma',-5), + ('pUnitGamma', 20)]) + def test_dream_jumpprob_pUnitGamma_error(self, property:str, value: int) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.dream, property, value) + assert exp.value.errors()[0]['msg'] in ["Input should be greater than 0", + "Input should be less than 1"] + + @pytest.mark.parametrize("value", [-80, -2]) + def test_dream_nSamples_error(self, value: int) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.dream, 'nSamples', value) + assert exp.value.errors()[0]['msg'] == "Input should be greater than or equal to 0" + + @pytest.mark.parametrize("value", [-5, 0]) + def test_dream_nChains_error(self, value: int) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.dream, 'nChains', value) + assert exp.value.errors()[0]['msg'] == "Input should be greater than 0" + + def test_dream_extra_property_error(self) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.dream, 'test', 1) + assert exp.value.errors()[0]['msg'] == "Object has no attribute 'test'" + + def test_dream_procedure_error(self) -> None: + with pytest.raises(pydantic.ValidationError) as exp: + setattr(self.dream, 'procedure', 'test') + assert exp.value.errors()[0]['msg'] == "Field is frozen" - def test_dream_procedure_repr(self) -> None: - dream = Dream() - table = dream.__repr__() - table_str = ("Property Value\n" - "---------------- ---------\n" - "procedure dream\n" - "parallel single\n" - "calcSldDuringFit False\n" - "resamPars [0.9, 50]\n" - "display iter\n" - "nSamples 50000\n" - "nChains 10\n" - "jumpProb 0.5\n" - "pUnitGamma 0.2\n" - "boundHandling fold") - assert table == table_str class TestControlsClass: """Tests the Controls class.""" @@ -453,7 +401,7 @@ def setup_class(self): def test_controls_class_default_type(self) -> None: assert type(self.controls.controls).__name__ == "Calculate" - def test_dream_procedure_properties(self) -> None: + def test_controls_class_properties(self) -> None: assert hasattr(self.controls, 'controls') @pytest.mark.parametrize("procedure, name", [(Procedures.Calculate, "Calculate"), @@ -480,15 +428,3 @@ def test_controls_class_validate_properties(self, procedure: str, msg: str) -> N with pytest.raises(ValueError) as exc: controls._validate_properties(test_variable = 200) assert msg == str(exc.value) - - def test_control_class_repr(self) -> None: - controls = ControlsClass() - table = controls.__repr__() - table_str = ("Property Value\n" - "---------------- ---------\n" - "procedure calculate\n" - "parallel single\n" - "calcSldDuringFit False\n" - "resamPars [0.9, 50]\n" - "display iter") - assert table == table_str \ No newline at end of file