Skip to content

Commit

Permalink
updated enums in controls and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RabiyaF committed Sep 18, 2023
1 parent 4269813 commit b6602e9
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 118 deletions.
140 changes: 70 additions & 70 deletions RAT/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self,
parallel: str = ParallelOptions.Single.value,
calcSldDuringFit: bool = False,
resamPars: list[Union[int, float]] = [0.9, 50],
display: str = DisplayOptions.Iter.value) -> None:
display: str = DisplayOptions.Iter) -> None:

self._parallel = parallel
self._calcSldDuringFit = calcSldDuringFit
Expand Down Expand Up @@ -293,10 +293,10 @@ class Calculate(BaseProcedure):
"""Defines the class for the calculate procedure"""

def __init__(self,
parallel: str = ParallelOptions.Single.value,
parallel: str = ParallelOptions.Single,
calcSldDuringFit: bool = False,
resamPars: list[Union[int, float]] = [0.9, 50],
display: str = DisplayOptions.Iter.value) -> None:
display: str = DisplayOptions.Iter) -> None:

# call the constructor of the parent class
super().__init__(parallel = parallel,
Expand All @@ -314,24 +314,24 @@ def procedure(self) -> str:
procedure : str
The value of the procedure property.
"""
return Procedures.Calculate.value
return Procedures.Calculate

def __repr__(self):
"""
Defines the display method for Calculate class
"""
table = super().__repr__(Procedures.Calculate.value)
table = super().__repr__(Procedures.Calculate)
return table


class Simplex(BaseProcedure):
"""Defines the class for the simplex procedure"""

def __init__(self,
parallel: str = ParallelOptions.Single.value,
parallel: str = ParallelOptions.Single,
calcSldDuringFit: bool = False,
resamPars: list[Union[int, float]] = [0.9, 50],
display: str = DisplayOptions.Iter.value,
display: str = DisplayOptions.Iter,
tolX: float = 1e-6,
tolFun: float = 1e-6,
maxFunEvals: int = 10000,
Expand Down Expand Up @@ -362,7 +362,7 @@ def procedure(self) -> str:
procedure : str
The value of the procedure property.
"""
return Procedures.Simplex.value
return Procedures.Simplex

@property
def tolX(self) -> float:
Expand Down Expand Up @@ -554,22 +554,22 @@ def __repr__(self):
"""
Defines the display method for Simplex class
"""
table = super().__repr__(Procedures.Simplex.value)
table = super().__repr__(Procedures.Simplex)
return table


class DE(BaseProcedure):
"""Defines the class for the Differential Evolution procedure"""

def __init__(self,
parallel: str = ParallelOptions.Single.value,
parallel: str = ParallelOptions.Single,
calcSldDuringFit: bool = False,
resamPars: list[Union[int, float]] = [0.9, 50],
display: str = DisplayOptions.Iter.value,
display: str = DisplayOptions.Iter,
populationSize: int = 20,
fWeight: float = 0.5,
crossoverProbability: float = 0.8,
strategy: int = StrategyOptions.RandomWithPerVectorDither.value,
strategy: int = StrategyOptions.RandomWithPerVectorDither,
targetValue: Union[int, float] = 1,
numGenerations: int = 500) -> None:

Expand All @@ -596,7 +596,7 @@ def procedure(self) -> str:
procedure : str
The value of the procedure property.
"""
return Procedures.DE.value
return Procedures.DE

@property
def populationSize(self) -> int:
Expand Down Expand Up @@ -808,18 +808,18 @@ def __repr__(self):
"""
Defines the display method for DE class
"""
table = super().__repr__(Procedures.DE.value)
table = super().__repr__(Procedures.DE)
return table


class NS(BaseProcedure):
"""Defines the class for the Nested Sampler procedure"""

def __init__(self,
parallel: str = ParallelOptions.Single.value,
parallel: str = ParallelOptions.Single,
calcSldDuringFit: bool = False,
resamPars: list[Union[int, float]] = [0.9, 50],
display: str = DisplayOptions.Iter.value,
display: str = DisplayOptions.Iter,
Nlive: int = 150,
Nmcmc: Union[float, int] = 0,
propScale: float = 0.1,
Expand All @@ -846,7 +846,7 @@ def procedure(self) -> str:
procedure : str
The value of the procedure property.
"""
return Procedures.NS.value
return Procedures.NS

@property
def Nlive(self) -> int:
Expand Down Expand Up @@ -992,23 +992,23 @@ def __repr__(self):
"""
Defines the display method for NS class
"""
table = super().__repr__(Procedures.NS.value)
table = super().__repr__(Procedures.NS)
return table


class Dream(BaseProcedure):
"""Defines the class for the Dream procedure"""

def __init__(self,
parallel: str = ParallelOptions.Single.value,
parallel: str = ParallelOptions.Single,
calcSldDuringFit: bool = False,
resamPars: list[Union[int, float]] = [0.9, 50],
display: str = DisplayOptions.Iter.value,
display: str = DisplayOptions.Iter,
nSamples: int = 50000,
nChains: int = 10,
jumpProb: float = 0.5,
pUnitGamma:float = 0.2,
boundHandling: str = BoundHandlingOptions.Fold.value) -> None:
boundHandling: str = BoundHandlingOptions.Fold) -> None:

# call the constructor of the parent class
super().__init__(parallel=parallel,
Expand All @@ -1032,7 +1032,7 @@ def procedure(self) -> str:
procedure : str
The value of the procedure property.
"""
return Procedures.Dream.value
return Procedures.Dream

@property
def nSamples(self) -> int:
Expand Down Expand Up @@ -1210,32 +1210,32 @@ def __repr__(self):
"""
Defines the display method for Dream class
"""
table = super().__repr__(Procedures.Dream.value)
table = super().__repr__(Procedures.Dream)
return table


class ControlsClass:

def __init__(self,
procedure: str = Procedures.Calculate.value,
procedure: str = Procedures.Calculate,
**properties) -> None:

self._procedure = procedure
self._validate_properties(**properties)

if self._procedure == Procedures.Calculate.value:
if self._procedure == Procedures.Calculate:
self._controls = Calculate(**properties)

elif self._procedure == Procedures.Simplex.value:
elif self._procedure == Procedures.Simplex:
self._controls = Simplex(**properties)

elif self._procedure == Procedures.DE.value:
elif self._procedure == Procedures.DE:
self._controls = DE(**properties)

elif self._procedure == Procedures.NS.value:
elif self._procedure == Procedures.NS:
self._controls = NS(**properties)

elif self._procedure == Procedures.Dream.value:
elif self._procedure == Procedures.Dream:
self._controls = Dream(**properties)

@property
Expand Down Expand Up @@ -1273,47 +1273,47 @@ def _validate_properties(self, **properties) -> bool:
ValueError
Raised if properties are not validated.
"""
property_names = {Procedures.Calculate.value: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display'},
Procedures.Simplex.value: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display',
'tolX',
'tolFun',
'maxFunEvals',
'maxIter',
'updateFreq',
'updatePlotFreq'},
Procedures.DE.value: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display',
'populationSize',
'fWeight',
'crossoverProbability',
'strategy',
'targetValue',
'numGenerations'},
Procedures.NS.value: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display',
'Nlive',
'Nmcmc',
'propScale',
'nsTolerance'},
Procedures.Dream.value: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display',
'nSamples',
'nChains',
'jumpProb',
'pUnitGamma',
'boundHandling'}}
property_names = {Procedures.Calculate: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display'},
Procedures.Simplex: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display',
'tolX',
'tolFun',
'maxFunEvals',
'maxIter',
'updateFreq',
'updatePlotFreq'},
Procedures.DE: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display',
'populationSize',
'fWeight',
'crossoverProbability',
'strategy',
'targetValue',
'numGenerations'},
Procedures.NS: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display',
'Nlive',
'Nmcmc',
'propScale',
'nsTolerance'},
Procedures.Dream: {'parallel',
'calcSLdDuringFit',
'resamPars',
'display',
'nSamples',
'nChains',
'jumpProb',
'pUnitGamma',
'boundHandling'}}
expected_properties = property_names[self._procedure]
input_properties = properties.keys()
if not (expected_properties | input_properties == expected_properties):
Expand Down
12 changes: 8 additions & 4 deletions RAT/utils/enums.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from enum import Enum
try:
from enum import StrEnum
except ImportError:
from strenum import StrEnum


class ParallelOptions(Enum):
class ParallelOptions(StrEnum):
"""Defines the avaliable options for parallelization"""
Single = 'single'
Points = 'points'
Contrasts = 'contrasts'
All = 'all'


class Procedures(Enum):
class Procedures(StrEnum):
"""Defines the avaliable options for procedures"""
Calculate = 'calculate'
Simplex = 'simplex'
Expand All @@ -18,15 +22,15 @@ class Procedures(Enum):
Dream = 'dream'


class DisplayOptions(Enum):
class DisplayOptions(StrEnum):
"""Defines the avaliable options for display"""
Off = 'off'
Iter = 'iter'
Notify = 'notify'
Final = 'final'


class BoundHandlingOptions(Enum):
class BoundHandlingOptions(StrEnum):
"""Defines the avaliable options for bound handling"""
No = 'no'
Reflect = 'reflect'
Expand Down
Loading

0 comments on commit b6602e9

Please sign in to comment.