Skip to content

Commit

Permalink
Adds code to construct inputs for the compiled code (#25)
Browse files Browse the repository at this point in the history
* Tidies up configs for pydantic models

* Moves enums into "utils/enums.py"

* First draft of "inputs.py" code to construct inputs for compiled code

* Renames enums

* Moves definition of dataclasses to "utils/dataclasses.py"

* Tidies up "inputs" module with "make_problem"2" and "make_cells" routines

* Adds offset to the "index" routine in "classList.py"

* Adds "NaNList" class to "tests/utils.py"

* Adds "test_inputs.py", and corrects code in "inputs.py"

* Adjusts tests to fit with pybind example

* Updates code to ensure tests pass

* Renames parameters to match matlab updates

* Converts "inputs.py" to use C++ objects directly.

* Renames "misc.py" as "wrappers.py" and "Calc" enum as "Calculations"

* Adds background actions to the contrast model

* Adds file wrappers to "make"cells"

* Adds additional examples to "test_inputs.py" to improve test coverage"

* Adds code to support recording custom files in "make_cells"

* Updates submodule and tidying up

* Updates requirements

* . . .and "pyproject.toml"

* Fixes pydantic to version 2.6.4

* . . . and "setup.py"

* Sort out version requirements

* Addresses review comments and import statements

* Changes parameters from optional to compulsory in "Layer" and "AbsorptionLayer" models

* Enables optional hydration in layer models
  • Loading branch information
DrPaulSharp authored Apr 25, 2024
1 parent b62c5d4 commit f635416
Show file tree
Hide file tree
Showing 22 changed files with 1,595 additions and 627 deletions.
4 changes: 4 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ And finally create a separate branch to begin work

git checkout -b new-feature

If there are updates to the C++ RAT submodule, run the following command to update the local branch

git submodule update --remote

Once complete submit a [pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) via GitHub.
Ensure to rebase your branch to include the latest changes on your branch and resolve possible merge conflicts.

Expand Down
3 changes: 1 addition & 2 deletions RAT/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
from RAT.classlist import ClassList
from RAT.project import Project
import RAT.controls
from RAT.controls import set_controls
import RAT.models


dir_path = os.path.dirname(os.path.realpath(__file__))
os.environ["RAT_PATH"] = os.path.join(dir_path, '')
16 changes: 9 additions & 7 deletions RAT/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def __repr__(self):
output = repr(self.data)
return output

def __setitem__(self, index: int, item: 'RAT.models') -> None:
def __setitem__(self, index: int, item: object) -> None:
"""Replace the object at an existing index of the ClassList."""
self._setitem(index, item)

def _setitem(self, index: int, item: 'RAT.models') -> None:
def _setitem(self, index: int, item: object) -> None:
"""Auxiliary routine of "__setitem__" used to enable wrapping."""
self._check_classes(self + [item])
self._check_unique_name_fields(self + [item])
Expand Down Expand Up @@ -171,7 +171,7 @@ def insert(self, index: int, obj: object = None, **kwargs) -> None:
inserted into the ClassList and the keyword arguments are discarded.
"""
if obj and kwargs:
warnings.warn('ClassList.insert() called with both object and keyword arguments. '
warnings.warn('ClassList.insert() called with both an object and keyword arguments. '
'The keyword arguments will be ignored.', SyntaxWarning)
if obj:
if not hasattr(self, '_class_handle'):
Expand All @@ -193,15 +193,17 @@ def remove(self, item: Union[object, str]) -> None:

def count(self, item: Union[object, str]) -> int:
"""Return the number of times an object appears in the ClassList using either the object itself or its
name_field value."""
name_field value.
"""
item = self._get_item_from_name_field(item)
return self.data.count(item)

def index(self, item: Union[object, str], *args) -> int:
def index(self, item: Union[object, str], offset: bool = False, *args) -> int:
"""Return the index of a particular object in the ClassList using either the object itself or its
name_field value."""
name_field value. If offset is specified, add one to the index. This is used to account for one-based indexing.
"""
item = self._get_item_from_name_field(item)
return self.data.index(item, *args)
return self.data.index(item, *args) + int(offset)

def extend(self, other: Sequence[object]) -> None:
"""Extend the ClassList by adding another sequence."""
Expand Down
89 changes: 64 additions & 25 deletions RAT/controls.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,65 @@
from dataclasses import dataclass, field
import prettytable
from pydantic import BaseModel, Field, field_validator, ValidationError
from typing import Literal, Union

from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions
from RAT.utils.enums import Parallel, Procedures, Display, BoundHandling, Strategies
from RAT.utils.custom_errors import custom_pydantic_validation_error


@dataclass(frozen=True)
class Controls:
"""The full set of controls parameters required for the compiled RAT code."""
# All Procedures
procedure: Procedures = Procedures.Calculate
parallel: Parallel = Parallel.Single
calcSldDuringFit: bool = False
resampleParams: list[float] = field(default_factory=list[0.9, 50.0])
display: Display = Display.Iter
# Simplex
xTolerance: float = 1.0e-6
funcTolerance: float = 1.0e-6
maxFuncEvals: int = 10000
maxIterations: int = 1000
updateFreq: int = -1
updatePlotFreq: int = 1
# DE
populationSize: int = 20
fWeight: float = 0.5
crossoverProbability: float = 0.8
strategy: Strategies = Strategies.RandomWithPerVectorDither.value
targetValue: float = 1.0
numGenerations: int = 500
# NS
nLive: int = 150
nMCMC: float = 0.0
propScale: float = 0.1
nsTolerance: float = 0.1
# Dream
nSamples: int = 50000
nChains: int = 10
jumpProbability: float = 0.5
pUnitGamma: float = 0.2
boundHandling: BoundHandling = BoundHandling.Fold
adaptPCR: bool = False


class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the class for the calculate procedure, which includes the properties used in all five procedures."""
procedure: Literal[Procedures.Calculate] = Procedures.Calculate
parallel: ParallelOptions = ParallelOptions.Single
parallel: Parallel = Parallel.Single
calcSldDuringFit: bool = False
resamPars: list[float] = Field([0.9, 50], min_length=2, max_length=2)
display: DisplayOptions = DisplayOptions.Iter
resampleParams: list[float] = Field([0.9, 50], min_length=2, max_length=2)
display: Display = Display.Iter

@field_validator("resamPars")
@field_validator("resampleParams")
@classmethod
def check_resamPars(cls, resamPars):
if not 0 < resamPars[0] < 1:
raise ValueError('resamPars[0] must be between 0 and 1')
if resamPars[1] < 0:
raise ValueError('resamPars[1] must be greater than or equal to 0')
return resamPars
def check_resample_params(cls, resampleParams):
if not 0 < resampleParams[0] < 1:
raise ValueError('resampleParams[0] must be between 0 and 1')
if resampleParams[1] < 0:
raise ValueError('resampleParams[1] must be greater than or equal to 0')
return resampleParams

def __repr__(self) -> str:
table = prettytable.PrettyTable()
Expand All @@ -30,45 +68,46 @@ def __repr__(self) -> str:
return table.get_string()


class Simplex(Calculate, validate_assignment=True, extra='forbid'):
class Simplex(Calculate):
"""Defines the additional fields for the simplex procedure."""
procedure: Literal[Procedures.Simplex] = Procedures.Simplex
tolX: float = Field(1.0e-6, gt=0.0)
tolFun: float = Field(1.0e-6, gt=0.0)
maxFunEvals: int = Field(10000, gt=0)
maxIter: int = Field(1000, gt=0)
xTolerance: float = Field(1.0e-6, gt=0.0)
funcTolerance: float = Field(1.0e-6, gt=0.0)
maxFuncEvals: int = Field(10000, gt=0)
maxIterations: int = Field(1000, gt=0)
updateFreq: int = -1
updatePlotFreq: int = -1
updatePlotFreq: int = 1


class DE(Calculate, validate_assignment=True, extra='forbid'):
class DE(Calculate):
"""Defines the additional fields for the Differential Evolution procedure."""
procedure: Literal[Procedures.DE] = Procedures.DE
populationSize: int = Field(20, ge=1)
fWeight: float = 0.5
crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0)
strategy: StrategyOptions = StrategyOptions.RandomWithPerVectorDither
strategy: Strategies = Strategies.RandomWithPerVectorDither
targetValue: float = Field(1.0, ge=1.0)
numGenerations: int = Field(500, ge=1)


class NS(Calculate, validate_assignment=True, extra='forbid'):
class NS(Calculate):
"""Defines the additional fields for the Nested Sampler procedure."""
procedure: Literal[Procedures.NS] = Procedures.NS
Nlive: int = Field(150, ge=1)
Nmcmc: float = Field(0.0, ge=0.0)
nLive: int = Field(150, ge=1)
nMCMC: float = Field(0.0, ge=0.0)
propScale: float = Field(0.1, gt=0.0, lt=1.0)
nsTolerance: float = Field(0.1, ge=0.0)


class Dream(Calculate, validate_assignment=True, extra='forbid'):
class Dream(Calculate):
"""Defines the additional fields for the Dream procedure."""
procedure: Literal[Procedures.Dream] = Procedures.Dream
nSamples: int = Field(50000, ge=0)
nChains: int = Field(10, gt=0)
jumpProb: float = Field(0.5, gt=0.0, lt=1.0)
jumpProbability: float = Field(0.5, gt=0.0, lt=1.0)
pUnitGamma: float = Field(0.2, gt=0.0, lt=1.0)
boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold
boundHandling: BoundHandling = BoundHandling.Fold
adaptPCR: bool = False


def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\
Expand Down
6 changes: 3 additions & 3 deletions RAT/events.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Callable, Union, List
import RAT.rat_core
from RAT.rat_core import EventTypes, PlotEventData, ProgressEventData
from RAT.rat_core import EventBridge, EventTypes, PlotEventData, ProgressEventData


def notify(event_type: EventTypes, data: Union[str, PlotEventData, ProgressEventData]) -> None:
Expand All @@ -18,6 +17,7 @@ def notify(event_type: EventTypes, data: Union[str, PlotEventData, ProgressEvent
for callback in callbacks:
callback(data)


def get_event_callback(event_type: EventTypes) -> List[Callable[[Union[str, PlotEventData, ProgressEventData]], None]]:
"""Returns all callbacks registered for the given event type.
Expand Down Expand Up @@ -59,5 +59,5 @@ def clear() -> None:
__event_callbacks[key] = set()


__event_impl = RAT.rat_core.EventBridge(notify)
__event_impl = EventBridge(notify)
__event_callbacks = {EventTypes.Message: set(), EventTypes.Plot: set(), EventTypes.Progress: set()}
Loading

0 comments on commit f635416

Please sign in to comment.