Skip to content

Commit

Permalink
Addresses review comments and import statements
Browse files Browse the repository at this point in the history
  • Loading branch information
DrPaulSharp committed Apr 19, 2024
1 parent e4b7960 commit dc52b7e
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 272 deletions.
1 change: 0 additions & 1 deletion RAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
from RAT.controls import set_controls
import RAT.models


dir_path = os.path.dirname(os.path.realpath(__file__))
os.environ["RAT_PATH"] = os.path.join(dir_path, '')
7 changes: 3 additions & 4 deletions RAT/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,12 @@ def count(self, item: Union[object, str]) -> int:
item = self._get_item_from_name_field(item)
return self.data.count(item)

def index(self, item: Union[object, str], offset: int = 0, *args) -> int:
def index(self, item: Union[object, str], offset: bool = False, *args) -> int:
"""Return the index of a particular object in the ClassList using either the object itself or its
name_field value. If an offset is specified, add this value to the index. This is used to account for one-based
indexing.
name_field value. If offset is specified, add one to the index. This is used to account for one-based indexing.
"""
item = self._get_item_from_name_field(item)
return self.data.index(item, *args) + offset
return self.data.index(item, *args) + int(offset)

def extend(self, other: Sequence[object]) -> None:
"""Extend the ClassList by adding another sequence."""
Expand Down
38 changes: 38 additions & 0 deletions RAT/controls.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass, field
import prettytable
from pydantic import BaseModel, Field, field_validator, ValidationError
from typing import Literal, Union
Expand All @@ -6,6 +7,43 @@
from RAT.utils.custom_errors import custom_pydantic_validation_error


@dataclass(frozen=True)
class Controls:
"""The full set of controls parameters required for the compiled RAT code."""
# All Procedures
procedure: Procedures = Procedures.Calculate
parallel: Parallel = Parallel.Single
calcSldDuringFit: bool = False
resampleParams: list[float] = field(default_factory=list[0.9, 50.0])
display: Display = Display.Iter
# Simplex
xTolerance: float = 1.0e-6
funcTolerance: float = 1.0e-6
maxFuncEvals: int = 10000
maxIterations: int = 1000
updateFreq: int = -1
updatePlotFreq: int = 1
# DE
populationSize: int = 20
fWeight: float = 0.5
crossoverProbability: float = 0.8
strategy: Strategies = Strategies.RandomWithPerVectorDither.value
targetValue: float = 1.0
numGenerations: int = 500
# NS
nLive: int = 150
nMCMC: float = 0.0
propScale: float = 0.1
nsTolerance: float = 0.1
# Dream
nSamples: int = 50000
nChains: int = 10
jumpProbability: float = 0.5
pUnitGamma: float = 0.2
boundHandling: BoundHandling = BoundHandling.Fold
adaptPCR: bool = False


class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the class for the calculate procedure, which includes the properties used in all five procedures."""
procedure: Literal[Procedures.Calculate] = Procedures.Calculate
Expand Down
5 changes: 2 additions & 3 deletions RAT/events.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Callable, Union, List
import RAT.rat_core
from RAT.rat_core import EventTypes, PlotEventData, ProgressEventData
from RAT.rat_core import EventBridge, EventTypes, PlotEventData, ProgressEventData


def notify(event_type: EventTypes, data: Union[str, PlotEventData, ProgressEventData]) -> None:
Expand Down Expand Up @@ -60,5 +59,5 @@ def clear() -> None:
__event_callbacks[key] = set()


__event_impl = RAT.rat_core.EventBridge(notify)
__event_impl = EventBridge(notify)
__event_callbacks = {EventTypes.Message: set(), EventTypes.Plot: set(), EventTypes.Progress: set()}
40 changes: 19 additions & 21 deletions RAT/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
from typing import Union

import RAT
import RAT.project
import RAT.wrappers
import RAT.utils.dataclasses
from RAT.utils.enums import Calculations, Languages, Models
import RAT.controls
from RAT.utils.enums import Calculations, LayerModels

from RAT.rat_core import Cells, Checks, Control, Limits, Priors, ProblemDefinition

Expand Down Expand Up @@ -79,7 +77,7 @@ def make_input(project: RAT.Project, controls: Union[RAT.controls.Calculate, RAT
for class_list in RAT.project.parameter_class_lists
for param in getattr(project, class_list)]

if project.model == Models.CustomXY:
if project.model == LayerModels.CustomXY:
controls.calcSldDuringFit = True

cpp_controls = make_controls(controls, checks)
Expand All @@ -103,10 +101,10 @@ def make_problem(project: RAT.Project) -> ProblemDefinition:
action_id = {'add': 1, 'subtract': 2}

# Set contrast parameters according to model type
if project.model == Models.StandardLayers:
if project.model == LayerModels.StandardLayers:
contrast_custom_files = [float('NaN')] * len(project.contrasts)
else:
contrast_custom_files = [project.custom_files.index(contrast.model[0], 1) for contrast in project.contrasts]
contrast_custom_files = [project.custom_files.index(contrast.model[0], True) for contrast in project.contrasts]

problem = ProblemDefinition()

Expand All @@ -122,15 +120,15 @@ def make_problem(project: RAT.Project) -> ProblemDefinition:
problem.domainRatio = [param.value for param in project.domain_ratios]
problem.backgroundParams = [param.value for param in project.background_parameters]
problem.resolutionParams = [param.value for param in project.resolution_parameters]
problem.contrastBulkIns = [project.bulk_in.index(contrast.bulk_in, 1) for contrast in project.contrasts]
problem.contrastBulkOuts = [project.bulk_out.index(contrast.bulk_out, 1) for contrast in project.contrasts]
problem.contrastBulkIns = [project.bulk_in.index(contrast.bulk_in, True) for contrast in project.contrasts]
problem.contrastBulkOuts = [project.bulk_out.index(contrast.bulk_out, True) for contrast in project.contrasts]
problem.contrastQzshifts = [1] * len(project.contrasts) # This is marked as "to do" in RAT
problem.contrastScalefactors = [project.scalefactors.index(contrast.scalefactor, 1) for contrast in project.contrasts]
problem.contrastDomainRatios = [project.domain_ratios.index(contrast.domain_ratio, 1)
problem.contrastScalefactors = [project.scalefactors.index(contrast.scalefactor, True) for contrast in project.contrasts]
problem.contrastDomainRatios = [project.domain_ratios.index(contrast.domain_ratio, True)
if hasattr(contrast, 'domain_ratio') else 0 for contrast in project.contrasts]
problem.contrastBackgrounds = [project.backgrounds.index(contrast.background, 1) for contrast in project.contrasts]
problem.contrastBackgrounds = [project.backgrounds.index(contrast.background, True) for contrast in project.contrasts]
problem.contrastBackgroundActions = [action_id[contrast.background_action] for contrast in project.contrasts]
problem.contrastResolutions = [project.resolutions.index(contrast.resolution, 1) for contrast in project.contrasts]
problem.contrastResolutions = [project.resolutions.index(contrast.resolution, True) for contrast in project.contrasts]
problem.contrastCustomFiles = contrast_custom_files
problem.resample = [contrast.resample for contrast in project.contrasts]
problem.dataPresent = [1 if contrast.data else 0 for contrast in project.contrasts]
Expand Down Expand Up @@ -169,12 +167,12 @@ def make_cells(project: RAT.Project) -> Cells:
hydrate_id = {'bulk in': 1, 'bulk out': 2}

# Set contrast parameters according to model type
if project.model == Models.StandardLayers:
if project.model == LayerModels.StandardLayers:
if project.calculation == Calculations.Domains:
contrast_models = [[project.domain_contrasts.index(domain_contrast, 1) for domain_contrast in contrast.model]
contrast_models = [[project.domain_contrasts.index(domain_contrast, True) for domain_contrast in contrast.model]
for contrast in project.contrasts]
else:
contrast_models = [[project.layers.index(layer, 1) for layer in contrast.model]
contrast_models = [[project.layers.index(layer, True) for layer in contrast.model]
for contrast in project.contrasts]
else:
contrast_models = [[]] * len(project.contrasts)
Expand All @@ -183,8 +181,8 @@ def make_cells(project: RAT.Project) -> Cells:
layer_details = []
for layer in project.layers:

layer_params = []#[project.parameters.index(getattr(layer, attribute), 1) for attribute in list(layer.model_fields.keys())[1:-2]]
layer_params.append(project.parameters.index(layer.hydration, 1) if layer.hydration else float('NaN'))
layer_params = []#[project.parameters.index(getattr(layer, attribute), True) for attribute in list(layer.model_fields.keys())[1:-2]]
layer_params.append(project.parameters.index(layer.hydration, True) if layer.hydration else float('NaN'))
layer_params.append(hydrate_id[layer.hydrate_with])

layer_details.append(layer_params)
Expand Down Expand Up @@ -214,7 +212,7 @@ def make_cells(project: RAT.Project) -> Cells:
cells.f3 = data_limits
cells.f4 = simulation_limits
cells.f5 = [contrast_model if contrast_model else 0 for contrast_model in contrast_models]
#cells.f6 = layer_details if project.model == Models.StandardLayers else [0]
#cells.f6 = layer_details if project.model == LayerModels.StandardLayers else [0]
cells.f7 = [param.name for param in project.parameters]
cells.f8 = [param.name for param in project.background_parameters]
cells.f9 = [param.name for param in project.scalefactors]
Expand All @@ -229,7 +227,7 @@ def make_cells(project: RAT.Project) -> Cells:
cells.f17 = [[0.0, 0.0, 0.0]] * len(project.contrasts) # Placeholder for oil chi data
cells.f18 = [[0, 1]] * len(project.domain_contrasts) # This is marked as "to do" in RAT

domain_contrast_models = [[project.layers.index(layer, 1) for layer in domain_contrast.model]
domain_contrast_models = [[project.layers.index(layer, True) for layer in domain_contrast.model]
for domain_contrast in project.domain_contrasts]
cells.f19 = [domain_contrast_model if domain_contrast_model else 0
for domain_contrast_model in domain_contrast_models]
Expand All @@ -256,7 +254,7 @@ def make_controls(controls: Union[RAT.controls.Calculate, RAT.controls.Simplex,
The controls object used in the compiled RAT code.
"""

full_controls = RAT.utils.dataclasses.Controls(**controls.model_dump())
full_controls = RAT.controls.Controls(**controls.model_dump())
cpp_controls = Control()

cpp_controls.procedure = full_controls.procedure
Expand Down
10 changes: 5 additions & 5 deletions RAT/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
from typing import Any

from RAT.utils.enums import Actions, Hydration, Languages, Priors, Types
from RAT.utils.enums import BackgroundActions, Hydration, Languages, Priors, TypeOptions

try:
from enum import StrEnum
Expand Down Expand Up @@ -46,7 +46,7 @@ def __repr__(self):
class Background(RATModel):
"""Defines the Backgrounds in RAT."""
name: str = Field(default_factory=lambda: 'New Background ' + next(background_number), min_length=1)
type: Types = Types.Constant
type: TypeOptions = TypeOptions.Constant
value_1: str = ''
value_2: str = ''
value_3: str = ''
Expand All @@ -59,7 +59,7 @@ class Contrast(RATModel):
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1)
data: str = ''
background: str = ''
background_action: Actions = Actions.Add
background_action: BackgroundActions = BackgroundActions.Add
bulk_in: str = ''
bulk_out: str = ''
scalefactor: str = ''
Expand All @@ -73,7 +73,7 @@ class ContrastWithRatio(RATModel):
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1)
data: str = ''
background: str = ''
background_action: Actions = Actions.Add
background_action: BackgroundActions = BackgroundActions.Add
bulk_in: str = ''
bulk_out: str = ''
scalefactor: str = ''
Expand Down Expand Up @@ -233,7 +233,7 @@ class ProtectedParameter(Parameter):
class Resolution(RATModel):
"""Defines Resolutions in RAT."""
name: str = Field(default_factory=lambda: 'New Resolution ' + next(resolution_number), min_length=1)
type: Types = Types.Constant
type: TypeOptions = TypeOptions.Constant
value_1: str = ''
value_2: str = ''
value_3: str = ''
Expand Down
34 changes: 17 additions & 17 deletions RAT/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from RAT.classlist import ClassList
import RAT.models
from RAT.utils.custom_errors import custom_pydantic_validation_error
from RAT.utils.enums import Calculations, Geometries, Models
from RAT.utils.enums import Calculations, Geometries, LayerModels, Priors, TypeOptions


# Map project fields to pydantic models
Expand Down Expand Up @@ -87,44 +87,44 @@ class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_typ
"""
name: str = ''
calculation: Calculations = Calculations.NonPolarised
model: Models = Models.StandardLayers
model: LayerModels = LayerModels.StandardLayers
geometry: Geometries = Geometries.AirSubstrate
absorption: bool = False

parameters: ClassList = ClassList()

bulk_in: ClassList = ClassList(RAT.models.Parameter(name='SLD Air', min=0.0, value=0.0, max=0.0, fit=False,
prior_type=RAT.models.Priors.Uniform, mu=0.0, sigma=np.inf))
prior_type=Priors.Uniform, mu=0.0, sigma=np.inf))

bulk_out: ClassList = ClassList(RAT.models.Parameter(name='SLD D2O', min=6.2e-6, value=6.35e-6, max=6.35e-6,
fit=False, prior_type=RAT.models.Priors.Uniform, mu=0.0,
fit=False, prior_type=Priors.Uniform, mu=0.0,
sigma=np.inf))

qz_shifts: ClassList = ClassList(RAT.models.Parameter(name='Qz shift 1', min=-1e-4, value=0.0, max=1e-4, fit=False,
prior_type=RAT.models.Priors.Uniform, mu=0.0, sigma=np.inf))
prior_type=Priors.Uniform, mu=0.0, sigma=np.inf))

scalefactors: ClassList = ClassList(RAT.models.Parameter(name='Scalefactor 1', min=0.02, value=0.23, max=0.25,
fit=False, prior_type=RAT.models.Priors.Uniform, mu=0.0,
fit=False, prior_type=Priors.Uniform, mu=0.0,
sigma=np.inf))

domain_ratios: ClassList = ClassList(RAT.models.Parameter(name='Domain Ratio 1', min=0.4, value=0.5, max=0.6,
fit=False, prior_type=RAT.models.Priors.Uniform, mu=0.0,
fit=False, prior_type=Priors.Uniform, mu=0.0,
sigma=np.inf))

background_parameters: ClassList = ClassList(RAT.models.Parameter(name='Background Param 1', min=1e-7, value=1e-6,
max=1e-5, fit=False,
prior_type=RAT.models.Priors.Uniform, mu=0.0,
prior_type=Priors.Uniform, mu=0.0,
sigma=np.inf))

backgrounds: ClassList = ClassList(RAT.models.Background(name='Background 1', type=RAT.models.Types.Constant,
backgrounds: ClassList = ClassList(RAT.models.Background(name='Background 1', type=TypeOptions.Constant,
value_1='Background Param 1'))

resolution_parameters: ClassList = ClassList(RAT.models.Parameter(name='Resolution Param 1', min=0.01, value=0.03,
max=0.05, fit=False,
prior_type=RAT.models.Priors.Uniform, mu=0.0,
prior_type=Priors.Uniform, mu=0.0,
sigma=np.inf))

resolutions: ClassList = ClassList(RAT.models.Resolution(name='Resolution 1', type=RAT.models.Types.Constant,
resolutions: ClassList = ClassList(RAT.models.Resolution(name='Resolution 1', type=TypeOptions.Constant,
value_1='Resolution Param 1'))

custom_files: ClassList = ClassList()
Expand Down Expand Up @@ -217,14 +217,14 @@ def set_domain_contrasts(self) -> 'Project':
"""If we are not running a domains calculation with standard layers, ensure the domain_contrasts component of
the model is empty.
"""
if not (self.calculation == Calculations.Domains and self.model == Models.StandardLayers):
if not (self.calculation == Calculations.Domains and self.model == LayerModels.StandardLayers):
self.domain_contrasts.data = []
return self

@model_validator(mode='after')
def set_layers(self) -> 'Project':
"""If we are not using a standard layers model, ensure the layers component of the model is empty."""
if self.model != Models.StandardLayers:
if self.model != LayerModels.StandardLayers:
self.layers.data = []
return self

Expand Down Expand Up @@ -267,12 +267,12 @@ def check_contrast_model_length(self) -> 'Project':
"""Given certain values of the "calculation" and "model" defined in the project, the "model" field of "contrasts"
may be constrained in its length.
"""
if self.model == Models.StandardLayers and self.calculation == Calculations.Domains:
if self.model == LayerModels.StandardLayers and self.calculation == Calculations.Domains:
for contrast in self.contrasts:
if contrast.model and len(contrast.model) != 2:
raise ValueError('For a standard layers domains calculation the "model" field of "contrasts" must '
'contain exactly two values.')
elif self.model != Models.StandardLayers:
elif self.model != LayerModels.StandardLayers:
for contrast in self.contrasts:
if len(contrast.model) > 1:
raise ValueError('For a custom model calculation the "model" field of "contrasts" cannot contain '
Expand Down Expand Up @@ -434,7 +434,7 @@ def get_contrast_model_field(self):
model_field : str
The name of the field used to define the contrasts' model field.
"""
if self.model == Models.StandardLayers:
if self.model == LayerModels.StandardLayers:
if self.calculation == Calculations.Domains:
model_field = 'domain_contrasts'
else:
Expand Down Expand Up @@ -480,7 +480,7 @@ def write_script(self, obj_name: str = 'problem', script: str = 'project_script.
f.write(f'{indent}{class_list}=RAT.ClassList({contents}),\n')
f.write(f'{indent})\n')

def _classlist_wrapper(self, class_list: 'ClassList', func: Callable):
def _classlist_wrapper(self, class_list: ClassList, func: Callable):
"""Defines the function used to wrap around ClassList routines to force revalidation.
Parameters
Expand Down
Loading

0 comments on commit dc52b7e

Please sign in to comment.