-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added the test file for controls class * adds enums * adds the procedure classes with doc strings and typings * added controls class with input validation * added the display methods for all the classes * added tests for the procedure classes * added tests for controls class * updating typing in control classes * updated enums in controls and tests * updated typings to literal and added tests to check property types and updated docs * converted procedures to pydantic classes * added model verification * added __repr__ method for control class * addressed the review comments * addressed comments
- Loading branch information
Showing
3 changed files
with
671 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import tabulate | ||
from typing import Union | ||
from pydantic import BaseModel, Field, field_validator | ||
from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions | ||
|
||
|
||
class BaseProcedure(BaseModel, validate_assignment = True, extra = 'forbid'): | ||
""" | ||
Defines the base class with properties used in all five procedures. | ||
""" | ||
parallel: ParallelOptions = ParallelOptions.Single | ||
calcSldDuringFit: bool = False | ||
resamPars: list[float] = Field([0.9, 50], min_length = 2, max_length = 2) | ||
display: DisplayOptions = DisplayOptions.Iter | ||
|
||
@field_validator("resamPars") | ||
def check_resamPars(cls, resamPars): | ||
if not 0 < resamPars[0] < 1: | ||
raise ValueError('resamPars[0] must be between 0 and 1') | ||
if resamPars[1] < 0: | ||
raise ValueError('resamPars[1] must be greater than or equal to 0') | ||
return resamPars | ||
|
||
|
||
class Calculate(BaseProcedure, validate_assignment = True, extra = 'forbid'): | ||
""" | ||
Defines the class for the calculate procedure. | ||
""" | ||
procedure: Procedures = Field(Procedures.Calculate, frozen = True) | ||
|
||
|
||
class Simplex(BaseProcedure, validate_assignment = True, extra = 'forbid'): | ||
""" | ||
Defines the class for the simplex procedure. | ||
""" | ||
procedure: Procedures = Field(Procedures.Simplex, frozen = True) | ||
tolX: float = Field(1e-6, gt = 0) | ||
tolFun: float = Field(1e-6, gt = 0) | ||
maxFunEvals: int = Field(10000, gt = 0) | ||
maxIter: int = Field(1000, gt = 0) | ||
updateFreq: int = -1 | ||
updatePlotFreq: int = -1 | ||
|
||
|
||
class DE(BaseProcedure, validate_assignment = True, extra = 'forbid'): | ||
""" | ||
Defines the class for the Differential Evolution procedure. | ||
""" | ||
procedure: Procedures = Field(Procedures.DE, frozen = True) | ||
populationSize: int = Field(20, ge = 1) | ||
fWeight: float = 0.5 | ||
crossoverProbability: float = Field(0.8, gt = 0, lt = 1) | ||
strategy: StrategyOptions = StrategyOptions.RandomWithPerVectorDither | ||
targetValue: float = Field(1.0, ge = 1) | ||
numGenerations: int = Field(500, ge = 1) | ||
|
||
|
||
class NS(BaseProcedure, validate_assignment = True, extra = 'forbid'): | ||
""" | ||
Defines the class for the Nested Sampler procedure. | ||
""" | ||
procedure: Procedures = Field(Procedures.NS, frozen = True) | ||
Nlive: int = Field(150, ge = 1) | ||
Nmcmc: float = Field(0.0, ge = 0) | ||
propScale: float = Field(0.1, gt = 0, lt = 1) | ||
nsTolerance: float = Field(0.1, ge = 0) | ||
|
||
|
||
class Dream(BaseProcedure, validate_assignment = True, extra = 'forbid'): | ||
""" | ||
Defines the class for the Dream procedure | ||
""" | ||
procedure: Procedures = Field(Procedures.Dream, frozen = True) | ||
nSamples: int = Field(50000, ge = 0) | ||
nChains: int = Field(10, gt = 0) | ||
jumpProb: float = Field(0.5, gt = 0, lt = 1) | ||
pUnitGamma:float = Field(0.2, gt = 0, lt = 1) | ||
boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold | ||
|
||
|
||
class ControlsClass: | ||
|
||
def __init__(self, | ||
procedure: Procedures = Procedures.Calculate, | ||
**properties) -> None: | ||
|
||
if procedure == Procedures.Calculate: | ||
self.controls = Calculate(**properties) | ||
elif procedure == Procedures.Simplex: | ||
self.controls = Simplex(**properties) | ||
elif procedure == Procedures.DE: | ||
self.controls = DE(**properties) | ||
elif procedure == Procedures.NS: | ||
self.controls = NS(**properties) | ||
elif procedure == Procedures.Dream: | ||
self.controls = Dream(**properties) | ||
|
||
@property | ||
def controls(self) -> Union[Calculate, Simplex, DE, NS, Dream]: | ||
return self._controls | ||
|
||
@controls.setter | ||
def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None: | ||
self._controls = value | ||
|
||
def __repr__(self) -> str: | ||
properties = [["Property", "Value"]] +\ | ||
[[k, v] for k, v in self._controls.__dict__.items()] | ||
table = tabulate.tabulate(properties, headers="firstrow") | ||
return table |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from enum import Enum | ||
try: | ||
from enum import StrEnum | ||
except ImportError: | ||
from strenum import StrEnum | ||
|
||
|
||
class ParallelOptions(StrEnum): | ||
"""Defines the avaliable options for parallelization""" | ||
Single = 'single' | ||
Points = 'points' | ||
Contrasts = 'contrasts' | ||
All = 'all' | ||
|
||
|
||
class Procedures(StrEnum): | ||
"""Defines the avaliable options for procedures""" | ||
Calculate = 'calculate' | ||
Simplex = 'simplex' | ||
DE = 'de' | ||
NS = 'ns' | ||
Dream = 'dream' | ||
|
||
|
||
class DisplayOptions(StrEnum): | ||
"""Defines the avaliable options for display""" | ||
Off = 'off' | ||
Iter = 'iter' | ||
Notify = 'notify' | ||
Final = 'final' | ||
|
||
|
||
class BoundHandlingOptions(StrEnum): | ||
"""Defines the avaliable options for bound handling""" | ||
Off = 'off' | ||
Reflect = 'reflect' | ||
Bound = 'bound' | ||
Fold = 'fold' | ||
|
||
|
||
class StrategyOptions(Enum): | ||
"""Defines the avaliable options for strategies""" | ||
Random = 1 | ||
LocalToBest = 2 | ||
BestWithJitter = 3 | ||
RandomWithPerVectorDither = 4 | ||
RandomWithPerGenerationDither = 5 | ||
RandomEitherOrAlgorithm = 6 |
Oops, something went wrong.