Skip to content

Commit

Permalink
added model verification
Browse files Browse the repository at this point in the history
  • Loading branch information
RabiyaF committed Sep 21, 2023
1 parent 01d29de commit 6ebe7c2
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 200 deletions.
284 changes: 158 additions & 126 deletions RAT/controls.py
Original file line number Diff line number Diff line change
@@ -1,177 +1,209 @@
import tabulate
from typing import Union
from pydantic import BaseModel, Field, field_validator
from typing import Union, Any
from pydantic import BaseModel, Field, field_validator, model_validator
from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions


class BaseProcedure(BaseModel, validate_assignment=True, extra='forbid'):
class BaseProcedure(BaseModel, validate_assignment = True, extra = 'forbid'):
"""
Defines the base class with properties used in all five procedures.
"""
parallel: ParallelOptions = ParallelOptions.Single
calcSldDuringFit: bool = False
resamPars: list[Union[int, float]] = Field([0.9, 50], min_length=2, max_length=2)
resamPars: list[Union[int, float]] = Field([0.9, 50], min_length = 2, max_length = 2)
display: DisplayOptions = DisplayOptions.Iter

@field_validator("resamPars")
def check_resamPars(cls, resamPars):
if not 0 < resamPars[0] < 1:
raise ValueError('resamPars[0] must be between 0 and 1')
if not 0 <= resamPars[1]:
raise ValueError('resamPars[1] must be greater than 0')
raise ValueError('resamPars[1] must be greater than or equal to 0')
return resamPars

@classmethod
def verify_inputs(cls,
data: dict[str, Any],
expected_properties: list[str],
msg: str) -> None:
if isinstance(data, dict) and data != {}:
if not all(x in expected_properties for x in data.keys()):
raise ValueError(msg)

class Calculate(BaseProcedure, validate_assignment=True, extra='forbid'):

class Calculate(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the calculate procedure
Defines the class for the calculate procedure.
"""
procedure: Procedures = Field(Procedures.Calculate, frozen=True)


class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'):
procedure: Procedures = Field(Procedures.Calculate, frozen = True)

@model_validator(mode = 'before')
@classmethod
def verify_inputs(cls, data: dict[str, Any]) -> dict[str, Any]:
expected_properties = ['parallel',
'calcSldDuringFit',
'resamPars',
'display',
'procedure']
msg = ("Properties that can be set for calculate are"
" calcSLdDuringFit, display, parallel, resamPars")
super().verify_inputs(data, expected_properties, msg)
return data


class Simplex(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the simplex procedure
Defines the class for the simplex procedure.
"""
procedure: Procedures = Field(Procedures.Simplex, frozen=True)
tolX: float = Field(1e-6, gt=0)
tolFun: float = Field(1e-6, gt=0)
maxFunEvals: int = Field(10000, gt=0)
maxIter: int = Field(1000, gt=0)
procedure: Procedures = Field(Procedures.Simplex, frozen = True)
tolX: float = Field(1e-6, gt = 0)
tolFun: float = Field(1e-6, gt = 0)
maxFunEvals: int = Field(10000, gt = 0)
maxIter: int = Field(1000, gt = 0)
updateFreq: int = -1
updatePlotFreq: int = -1

class DE(BaseProcedure, validate_assignment=True, extra='forbid'):
@model_validator(mode = 'before')
@classmethod
def verify_inputs(self, data: dict[str, Any]) -> dict[str, Any]:
expected_properties = ['parallel',
'calcSldDuringFit',
'resamPars',
'display',
'procedure',
'tolX',
'tolFun',
'maxFunEvals',
'maxIter',
'updateFreq',
'updatePlotFreq']
msg = ("Properties that can be set for simplex are"
" calcSLdDuringFit, display, maxFunEvals, maxIter,"
" parallel, resamPars, tolFun, tolX, updateFreq, "
"updatePlotFreq")
super().verify_inputs(data, expected_properties, msg)
return data


class DE(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the Differential Evolution procedure
Defines the class for the Differential Evolution procedure.
"""
procedure: Procedures = Field(Procedures.DE, frozen=True)
populationSize: int = Field(20, ge=1)
procedure: Procedures = Field(Procedures.DE, frozen = True)
populationSize: int = Field(20, ge = 1)
fWeight: float = 0.5
crossoverProbability: float = Field(0.8, gt=0, lt=1)
crossoverProbability: float = Field(0.8, gt = 0, lt = 1)
strategy: StrategyOptions = StrategyOptions.RandomWithPerVectorDither
targetValue: Union[int, float] = Field(1.0, ge=1)
numGenerations: int = Field(500, ge=1)

class NS(BaseProcedure, validate_assignment=True, extra='forbid'):
targetValue: Union[int, float] = Field(1.0, ge = 1)
numGenerations: int = Field(500, ge = 1)

@model_validator(mode = 'before')
@classmethod
def verify_inputs(self, data: dict[str, Any]) -> dict[str, Any]:
expected_properties = ['parallel',
'calcSldDuringFit',
'resamPars',
'display',
'procedure',
'populationSize',
'fWeight',
'crossoverProbability',
'strategy',
'targetValue',
'numGenerations']
msg = ("Properties that can be set for de are "
"calcSLdDuringFit, crossoverProbability, display,"
" fWeight, numGenerations, parallel, populationSize,"
" resamPars, strategy, targetValue")
super().verify_inputs(data, expected_properties, msg)
return data


class NS(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the Nested Sampler procedure
Defines the class for the Nested Sampler procedure.
"""
procedure: Procedures = Field(Procedures.NS, frozen=True)
Nlive: int = Field(150, ge=1)
Nmcmc: Union[float, int] = Field(0.0, ge=0)
propScale: float = Field(0.1, gt=0, lt=1)
nsTolerance: Union[float, int] = Field(0.1, ge=0)

class Dream(BaseProcedure, validate_assignment=True, extra='forbid'):
procedure: Procedures = Field(Procedures.NS, frozen = True)
Nlive: int = Field(150, ge = 1)
Nmcmc: Union[float, int] = Field(0.0, ge = 0)
propScale: float = Field(0.1, gt = 0, lt = 1)
nsTolerance: Union[float, int] = Field(0.1, ge = 0)

@model_validator(mode = 'before')
@classmethod
def verify_inputs(self, data: dict[str, Any]) -> dict[str, Any]:
expected_properties = ['parallel',
'calcSldDuringFit',
'resamPars',
'display',
'procedure',
'Nlive',
'Nmcmc',
'propScale',
'nsTolerance']
msg = ("Properties that can be set for ns are Nlive,"
" Nmcmc, calcSLdDuringFit, display, nsTolerance,"
" parallel, propScale, resamPars")
super().verify_inputs(data, expected_properties, msg)
return data


class Dream(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the Dream procedure
"""
procedure: Procedures = Field(Procedures.Dream, frozen=True)
nSamples: int = Field(50000, ge=0)
nChains: int = Field(10, gt=0)
jumpProb: float = Field(0.5, gt=0, lt=1)
pUnitGamma:float = Field(0.2, gt=0, lt=1)
procedure: Procedures = Field(Procedures.Dream, frozen = True)
nSamples: int = Field(50000, ge = 0)
nChains: int = Field(10, gt = 0)
jumpProb: float = Field(0.5, gt = 0, lt = 1)
pUnitGamma:float = Field(0.2, gt = 0, lt = 1)
boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold

@model_validator(mode = 'before')
@classmethod
def verify_inputs(self, data: dict[str, Any]) -> dict[str, Any]:
expected_properties = ['parallel',
'calcSldDuringFit',
'resamPars',
'display',
'procedure',
'nSamples',
'nChains',
'jumpProb',
'pUnitGamma',
'boundHandling']
msg = ("Properties that can be set for dream are"
" boundHandling, calcSLdDuringFit, display, "
"jumpProb, nChains, nSamples, pUnitGamma, "
"parallel, resamPars")
super().verify_inputs(data, expected_properties, msg)
return data


class ControlsClass:

def __init__(self,
procedure: Procedures = Procedures.Calculate,
**properties) -> None:

self._procedure = procedure
self._validate_properties(**properties)

if self._procedure == Procedures.Calculate:
self._controls = Calculate(**properties)
if procedure == Procedures.Calculate:
self.controls = Calculate(**properties)

elif self._procedure == Procedures.Simplex:
self._controls = Simplex(**properties)
elif procedure == Procedures.Simplex:
self.controls = Simplex(**properties)

elif self._procedure == Procedures.DE:
self._controls = DE(**properties)
elif procedure == Procedures.DE:
self.controls = DE(**properties)

elif self._procedure == Procedures.NS:
self._controls = NS(**properties)
elif procedure == Procedures.NS:
self.controls = NS(**properties)

elif self._procedure == Procedures.Dream:
self._controls = Dream(**properties)
elif procedure == Procedures.Dream:
self.controls = Dream(**properties)

@property
def controls(self) -> Union[Calculate, Simplex, DE, NS, Dream]:
"""
Gets the controls.
Returns
-------
controls : Union[Calculate,
Simplex,
DE,
NS,
Dream]
The class with control properties.
"""
return self._controls

def _validate_properties(self, **properties) -> None:
"""
Validates the inputs for the procedure.
Parameters
----------
properties : dict[str, Any]
The properties of the procedure.
Raises
------
ValueError
Raised if properties are not validated.
"""
property_names = {Procedures.Calculate: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display'},
Procedures.Simplex: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display',
'tolX',
'tolFun',
'maxFunEvals',
'maxIter',
'updateFreq',
'updatePlotFreq'},
Procedures.DE: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display',
'populationSize',
'fWeight',
'crossoverProbability',
'strategy',
'targetValue',
'numGenerations'},
Procedures.NS: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display',
'Nlive',
'Nmcmc',
'propScale',
'nsTolerance'},
Procedures.Dream: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display',
'nSamples',
'nChains',
'jumpProb',
'pUnitGamma',
'boundHandling'}}
expected_properties = property_names[self._procedure]
input_properties = properties.keys()
if not (expected_properties | input_properties == expected_properties):
raise ValueError((f"Properties that can be set for {self._procedure} are "
f"{', '.join(sorted(expected_properties))}"))

@controls.setter
def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None:
self._controls = value
Loading

0 comments on commit 6ebe7c2

Please sign in to comment.