Skip to content

Commit

Permalink
added the controlsClass (#10)
Browse files Browse the repository at this point in the history
* added the test file for controls class

* adds enums

* adds the procedure classes with doc strings and typings

* added controls class with input validation

* added the display methods for all the classes

* added tests for the procedure classes

* added tests for controls class

* updating typing in control classes

* updated enums in controls and tests

* updated typings to literal and added tests to check property types and updated docs

* converted procedures to pydantic classes

* added model verification

* added __repr__ method for control class

* addressed the review comments

* addressed comments
  • Loading branch information
RabiyaF authored Sep 25, 2023
1 parent eb2ec43 commit 6f329ed
Show file tree
Hide file tree
Showing 3 changed files with 671 additions and 0 deletions.
110 changes: 110 additions & 0 deletions RAT/controls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import tabulate
from typing import Union
from pydantic import BaseModel, Field, field_validator
from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions


class BaseProcedure(BaseModel, validate_assignment = True, extra = 'forbid'):
"""
Defines the base class with properties used in all five procedures.
"""
parallel: ParallelOptions = ParallelOptions.Single
calcSldDuringFit: bool = False
resamPars: list[float] = Field([0.9, 50], min_length = 2, max_length = 2)
display: DisplayOptions = DisplayOptions.Iter

@field_validator("resamPars")
def check_resamPars(cls, resamPars):
if not 0 < resamPars[0] < 1:
raise ValueError('resamPars[0] must be between 0 and 1')
if resamPars[1] < 0:
raise ValueError('resamPars[1] must be greater than or equal to 0')
return resamPars


class Calculate(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the calculate procedure.
"""
procedure: Procedures = Field(Procedures.Calculate, frozen = True)


class Simplex(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the simplex procedure.
"""
procedure: Procedures = Field(Procedures.Simplex, frozen = True)
tolX: float = Field(1e-6, gt = 0)
tolFun: float = Field(1e-6, gt = 0)
maxFunEvals: int = Field(10000, gt = 0)
maxIter: int = Field(1000, gt = 0)
updateFreq: int = -1
updatePlotFreq: int = -1


class DE(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the Differential Evolution procedure.
"""
procedure: Procedures = Field(Procedures.DE, frozen = True)
populationSize: int = Field(20, ge = 1)
fWeight: float = 0.5
crossoverProbability: float = Field(0.8, gt = 0, lt = 1)
strategy: StrategyOptions = StrategyOptions.RandomWithPerVectorDither
targetValue: float = Field(1.0, ge = 1)
numGenerations: int = Field(500, ge = 1)


class NS(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the Nested Sampler procedure.
"""
procedure: Procedures = Field(Procedures.NS, frozen = True)
Nlive: int = Field(150, ge = 1)
Nmcmc: float = Field(0.0, ge = 0)
propScale: float = Field(0.1, gt = 0, lt = 1)
nsTolerance: float = Field(0.1, ge = 0)


class Dream(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the Dream procedure
"""
procedure: Procedures = Field(Procedures.Dream, frozen = True)
nSamples: int = Field(50000, ge = 0)
nChains: int = Field(10, gt = 0)
jumpProb: float = Field(0.5, gt = 0, lt = 1)
pUnitGamma:float = Field(0.2, gt = 0, lt = 1)
boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold


class ControlsClass:

def __init__(self,
procedure: Procedures = Procedures.Calculate,
**properties) -> None:

if procedure == Procedures.Calculate:
self.controls = Calculate(**properties)
elif procedure == Procedures.Simplex:
self.controls = Simplex(**properties)
elif procedure == Procedures.DE:
self.controls = DE(**properties)
elif procedure == Procedures.NS:
self.controls = NS(**properties)
elif procedure == Procedures.Dream:
self.controls = Dream(**properties)

@property
def controls(self) -> Union[Calculate, Simplex, DE, NS, Dream]:
return self._controls

@controls.setter
def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None:
self._controls = value

def __repr__(self) -> str:
properties = [["Property", "Value"]] +\
[[k, v] for k, v in self._controls.__dict__.items()]
table = tabulate.tabulate(properties, headers="firstrow")
return table
48 changes: 48 additions & 0 deletions RAT/utils/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from enum import Enum
try:
from enum import StrEnum
except ImportError:
from strenum import StrEnum


class ParallelOptions(StrEnum):
"""Defines the avaliable options for parallelization"""
Single = 'single'
Points = 'points'
Contrasts = 'contrasts'
All = 'all'


class Procedures(StrEnum):
"""Defines the avaliable options for procedures"""
Calculate = 'calculate'
Simplex = 'simplex'
DE = 'de'
NS = 'ns'
Dream = 'dream'


class DisplayOptions(StrEnum):
"""Defines the avaliable options for display"""
Off = 'off'
Iter = 'iter'
Notify = 'notify'
Final = 'final'


class BoundHandlingOptions(StrEnum):
"""Defines the avaliable options for bound handling"""
Off = 'off'
Reflect = 'reflect'
Bound = 'bound'
Fold = 'fold'


class StrategyOptions(Enum):
"""Defines the avaliable options for strategies"""
Random = 1
LocalToBest = 2
BestWithJitter = 3
RandomWithPerVectorDither = 4
RandomWithPerGenerationDither = 5
RandomEitherOrAlgorithm = 6
Loading

0 comments on commit 6f329ed

Please sign in to comment.