Skip to content

Commit

Permalink
Updates tables in Controls
Browse files Browse the repository at this point in the history
  • Loading branch information
DrPaulSharp committed Oct 24, 2023
1 parent f8c8ce0 commit de3b0e3
Show file tree
Hide file tree
Showing 4 changed files with 375 additions and 330 deletions.
1 change: 1 addition & 0 deletions RAT/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from RAT.classlist import ClassList
from RAT.controls import Controls
from RAT.project import Project
104 changes: 47 additions & 57 deletions RAT/controls.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import tabulate
from typing import Union
import prettytable
from pydantic import BaseModel, Field, field_validator
from typing import Union

from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions


class BaseProcedure(BaseModel, validate_assignment = True, extra = 'forbid'):
"""
Defines the base class with properties used in all five procedures.
"""
class BaseProcedure(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the base class with properties used in all five procedures."""
parallel: ParallelOptions = ParallelOptions.Single
calcSldDuringFit: bool = False
resamPars: list[float] = Field([0.9, 50], min_length = 2, max_length = 2)
resamPars: list[float] = Field([0.9, 50], min_length=2, max_length=2)
display: DisplayOptions = DisplayOptions.Iter

@field_validator("resamPars")
@classmethod
def check_resamPars(cls, resamPars):
if not 0 < resamPars[0] < 1:
raise ValueError('resamPars[0] must be between 0 and 1')
Expand All @@ -22,63 +22,53 @@ def check_resamPars(cls, resamPars):
return resamPars


class Calculate(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the calculate procedure.
"""
procedure: Procedures = Field(Procedures.Calculate, frozen = True)
class Calculate(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the calculate procedure."""
procedure: Procedures = Field(Procedures.Calculate, frozen=True)


class Simplex(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the simplex procedure.
"""
procedure: Procedures = Field(Procedures.Simplex, frozen = True)
tolX: float = Field(1e-6, gt = 0)
tolFun: float = Field(1e-6, gt = 0)
maxFunEvals: int = Field(10000, gt = 0)
maxIter: int = Field(1000, gt = 0)
class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the simplex procedure."""
procedure: Procedures = Field(Procedures.Simplex, frozen=True)
tolX: float = Field(1.0e-6, gt=0.0)
tolFun: float = Field(1.0e-6, gt=0.0)
maxFunEvals: int = Field(10000, gt=0)
maxIter: int = Field(1000, gt=0)
updateFreq: int = -1
updatePlotFreq: int = -1


class DE(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the Differential Evolution procedure.
"""
procedure: Procedures = Field(Procedures.DE, frozen = True)
populationSize: int = Field(20, ge = 1)
class DE(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Differential Evolution procedure."""
procedure: Procedures = Field(Procedures.DE, frozen=True)
populationSize: int = Field(20, ge=1)
fWeight: float = 0.5
crossoverProbability: float = Field(0.8, gt = 0, lt = 1)
crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0)
strategy: StrategyOptions = StrategyOptions.RandomWithPerVectorDither
targetValue: float = Field(1.0, ge = 1)
numGenerations: int = Field(500, ge = 1)


class NS(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the Nested Sampler procedure.
"""
procedure: Procedures = Field(Procedures.NS, frozen = True)
Nlive: int = Field(150, ge = 1)
Nmcmc: float = Field(0.0, ge = 0)
propScale: float = Field(0.1, gt = 0, lt = 1)
nsTolerance: float = Field(0.1, ge = 0)


class Dream(BaseProcedure, validate_assignment = True, extra = 'forbid'):
"""
Defines the class for the Dream procedure
"""
procedure: Procedures = Field(Procedures.Dream, frozen = True)
nSamples: int = Field(50000, ge = 0)
nChains: int = Field(10, gt = 0)
jumpProb: float = Field(0.5, gt = 0, lt = 1)
pUnitGamma:float = Field(0.2, gt = 0, lt = 1)
targetValue: float = Field(1.0, ge=1.0)
numGenerations: int = Field(500, ge=1)


class NS(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Nested Sampler procedure."""
procedure: Procedures = Field(Procedures.NS, frozen=True)
Nlive: int = Field(150, ge=1)
Nmcmc: float = Field(0.0, ge=0.0)
propScale: float = Field(0.1, gt=0.0, lt=1.0)
nsTolerance: float = Field(0.1, ge=0.0)


class Dream(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Dream procedure."""
procedure: Procedures = Field(Procedures.Dream, frozen=True)
nSamples: int = Field(50000, ge=0)
nChains: int = Field(10, gt=0)
jumpProb: float = Field(0.5, gt=0.0, lt=1.0)
pUnitGamma: float = Field(0.2, gt=0.0, lt=1.0)
boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold


class ControlsClass:
class Controls:

def __init__(self,
procedure: Procedures = Procedures.Calculate,
Expand All @@ -104,7 +94,7 @@ def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None:
self._controls = value

def __repr__(self) -> str:
properties = [["Property", "Value"]] +\
[[k, v] for k, v in self._controls.__dict__.items()]
table = tabulate.tabulate(properties, headers="firstrow")
return table
table = prettytable.PrettyTable()
table.field_names = ['Property', 'Value']
table.add_rows([[k, v] for k, v in self._controls.__dict__.items()])
return table.get_string()
20 changes: 15 additions & 5 deletions tests/test_classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ def two_name_class_list():
return ClassList([InputAttributes(name='Alice'), InputAttributes(name='Bob')])


@pytest.fixture
def two_name_class_list_table():
"""The table representation of the ClassList defined in the "two_name_class_list" fixture."""
return(
'+-------+-------+\n'
'| index | name |\n'
'+-------+-------+\n'
'| 0 | Alice |\n'
'| 1 | Bob |\n'
'+-------+-------+'
)


@pytest.fixture
def three_name_class_list():
"""A ClassList of InputAttributes, containing three elements with names defined."""
Expand Down Expand Up @@ -104,12 +117,9 @@ def test_identical_name_fields(self, input_list: Sequence[object], name_field: s
ClassList(input_list, name_field=name_field)


@pytest.mark.parametrize("expected_string", [
' name\n-- ------\n 0 Alice\n 1 Bob',
])
def test_repr_table(two_name_class_list: 'ClassList', expected_string: str) -> None:
def test_repr_table(two_name_class_list: 'ClassList', two_name_class_list_table: str) -> None:
"""For classes with the __dict__ attribute, we should be able to print the ClassList like a table."""
assert repr(two_name_class_list) == expected_string
assert repr(two_name_class_list) == two_name_class_list_table


def test_repr_empty_table() -> None:
Expand Down
Loading

0 comments on commit de3b0e3

Please sign in to comment.