From dc52b7e7a4dfc0db30fef61d0c7a286ef3ba079f Mon Sep 17 00:00:00 2001 From: Paul Sharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Fri, 19 Apr 2024 15:13:28 +0100 Subject: [PATCH] Addresses review comments and import statements --- RAT/__init__.py | 1 - RAT/classlist.py | 7 +- RAT/controls.py | 38 +++++++++ RAT/events.py | 5 +- RAT/inputs.py | 40 +++++----- RAT/models.py | 10 +-- RAT/project.py | 34 ++++---- RAT/utils/dataclasses.py | 42 ---------- RAT/utils/enums.py | 6 +- tests/test_classlist.py | 88 ++++++++++----------- tests/test_inputs.py | 94 +++++++++++----------- tests/test_project.py | 166 +++++++++++++++++++-------------------- 12 files changed, 259 insertions(+), 272 deletions(-) delete mode 100644 RAT/utils/dataclasses.py diff --git a/RAT/__init__.py b/RAT/__init__.py index 990182bc..8bb0027d 100644 --- a/RAT/__init__.py +++ b/RAT/__init__.py @@ -4,6 +4,5 @@ from RAT.controls import set_controls import RAT.models - dir_path = os.path.dirname(os.path.realpath(__file__)) os.environ["RAT_PATH"] = os.path.join(dir_path, '') diff --git a/RAT/classlist.py b/RAT/classlist.py index bdcbb9e1..0e246152 100644 --- a/RAT/classlist.py +++ b/RAT/classlist.py @@ -198,13 +198,12 @@ def count(self, item: Union[object, str]) -> int: item = self._get_item_from_name_field(item) return self.data.count(item) - def index(self, item: Union[object, str], offset: int = 0, *args) -> int: + def index(self, item: Union[object, str], offset: bool = False, *args) -> int: """Return the index of a particular object in the ClassList using either the object itself or its - name_field value. If an offset is specified, add this value to the index. This is used to account for one-based - indexing. + name_field value. If offset is specified, add one to the index. This is used to account for one-based indexing. """ item = self._get_item_from_name_field(item) - return self.data.index(item, *args) + offset + return self.data.index(item, *args) + int(offset) def extend(self, other: Sequence[object]) -> None: """Extend the ClassList by adding another sequence.""" diff --git a/RAT/controls.py b/RAT/controls.py index 0572c6d4..bc4c0bad 100644 --- a/RAT/controls.py +++ b/RAT/controls.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import prettytable from pydantic import BaseModel, Field, field_validator, ValidationError from typing import Literal, Union @@ -6,6 +7,43 @@ from RAT.utils.custom_errors import custom_pydantic_validation_error +@dataclass(frozen=True) +class Controls: + """The full set of controls parameters required for the compiled RAT code.""" + # All Procedures + procedure: Procedures = Procedures.Calculate + parallel: Parallel = Parallel.Single + calcSldDuringFit: bool = False + resampleParams: list[float] = field(default_factory=list[0.9, 50.0]) + display: Display = Display.Iter + # Simplex + xTolerance: float = 1.0e-6 + funcTolerance: float = 1.0e-6 + maxFuncEvals: int = 10000 + maxIterations: int = 1000 + updateFreq: int = -1 + updatePlotFreq: int = 1 + # DE + populationSize: int = 20 + fWeight: float = 0.5 + crossoverProbability: float = 0.8 + strategy: Strategies = Strategies.RandomWithPerVectorDither.value + targetValue: float = 1.0 + numGenerations: int = 500 + # NS + nLive: int = 150 + nMCMC: float = 0.0 + propScale: float = 0.1 + nsTolerance: float = 0.1 + # Dream + nSamples: int = 50000 + nChains: int = 10 + jumpProbability: float = 0.5 + pUnitGamma: float = 0.2 + boundHandling: BoundHandling = BoundHandling.Fold + adaptPCR: bool = False + + class Calculate(BaseModel, validate_assignment=True, extra='forbid'): """Defines the class for the calculate procedure, which includes the properties used in all five procedures.""" procedure: Literal[Procedures.Calculate] = Procedures.Calculate diff --git a/RAT/events.py b/RAT/events.py index 4b03ef96..a308625b 100644 --- a/RAT/events.py +++ b/RAT/events.py @@ -1,6 +1,5 @@ from typing import Callable, Union, List -import RAT.rat_core -from RAT.rat_core import EventTypes, PlotEventData, ProgressEventData +from RAT.rat_core import EventBridge, EventTypes, PlotEventData, ProgressEventData def notify(event_type: EventTypes, data: Union[str, PlotEventData, ProgressEventData]) -> None: @@ -60,5 +59,5 @@ def clear() -> None: __event_callbacks[key] = set() -__event_impl = RAT.rat_core.EventBridge(notify) +__event_impl = EventBridge(notify) __event_callbacks = {EventTypes.Message: set(), EventTypes.Plot: set(), EventTypes.Progress: set()} diff --git a/RAT/inputs.py b/RAT/inputs.py index 40d98ffa..2a685aaf 100644 --- a/RAT/inputs.py +++ b/RAT/inputs.py @@ -4,10 +4,8 @@ from typing import Union import RAT -import RAT.project -import RAT.wrappers -import RAT.utils.dataclasses -from RAT.utils.enums import Calculations, Languages, Models +import RAT.controls +from RAT.utils.enums import Calculations, LayerModels from RAT.rat_core import Cells, Checks, Control, Limits, Priors, ProblemDefinition @@ -79,7 +77,7 @@ def make_input(project: RAT.Project, controls: Union[RAT.controls.Calculate, RAT for class_list in RAT.project.parameter_class_lists for param in getattr(project, class_list)] - if project.model == Models.CustomXY: + if project.model == LayerModels.CustomXY: controls.calcSldDuringFit = True cpp_controls = make_controls(controls, checks) @@ -103,10 +101,10 @@ def make_problem(project: RAT.Project) -> ProblemDefinition: action_id = {'add': 1, 'subtract': 2} # Set contrast parameters according to model type - if project.model == Models.StandardLayers: + if project.model == LayerModels.StandardLayers: contrast_custom_files = [float('NaN')] * len(project.contrasts) else: - contrast_custom_files = [project.custom_files.index(contrast.model[0], 1) for contrast in project.contrasts] + contrast_custom_files = [project.custom_files.index(contrast.model[0], True) for contrast in project.contrasts] problem = ProblemDefinition() @@ -122,15 +120,15 @@ def make_problem(project: RAT.Project) -> ProblemDefinition: problem.domainRatio = [param.value for param in project.domain_ratios] problem.backgroundParams = [param.value for param in project.background_parameters] problem.resolutionParams = [param.value for param in project.resolution_parameters] - problem.contrastBulkIns = [project.bulk_in.index(contrast.bulk_in, 1) for contrast in project.contrasts] - problem.contrastBulkOuts = [project.bulk_out.index(contrast.bulk_out, 1) for contrast in project.contrasts] + problem.contrastBulkIns = [project.bulk_in.index(contrast.bulk_in, True) for contrast in project.contrasts] + problem.contrastBulkOuts = [project.bulk_out.index(contrast.bulk_out, True) for contrast in project.contrasts] problem.contrastQzshifts = [1] * len(project.contrasts) # This is marked as "to do" in RAT - problem.contrastScalefactors = [project.scalefactors.index(contrast.scalefactor, 1) for contrast in project.contrasts] - problem.contrastDomainRatios = [project.domain_ratios.index(contrast.domain_ratio, 1) + problem.contrastScalefactors = [project.scalefactors.index(contrast.scalefactor, True) for contrast in project.contrasts] + problem.contrastDomainRatios = [project.domain_ratios.index(contrast.domain_ratio, True) if hasattr(contrast, 'domain_ratio') else 0 for contrast in project.contrasts] - problem.contrastBackgrounds = [project.backgrounds.index(contrast.background, 1) for contrast in project.contrasts] + problem.contrastBackgrounds = [project.backgrounds.index(contrast.background, True) for contrast in project.contrasts] problem.contrastBackgroundActions = [action_id[contrast.background_action] for contrast in project.contrasts] - problem.contrastResolutions = [project.resolutions.index(contrast.resolution, 1) for contrast in project.contrasts] + problem.contrastResolutions = [project.resolutions.index(contrast.resolution, True) for contrast in project.contrasts] problem.contrastCustomFiles = contrast_custom_files problem.resample = [contrast.resample for contrast in project.contrasts] problem.dataPresent = [1 if contrast.data else 0 for contrast in project.contrasts] @@ -169,12 +167,12 @@ def make_cells(project: RAT.Project) -> Cells: hydrate_id = {'bulk in': 1, 'bulk out': 2} # Set contrast parameters according to model type - if project.model == Models.StandardLayers: + if project.model == LayerModels.StandardLayers: if project.calculation == Calculations.Domains: - contrast_models = [[project.domain_contrasts.index(domain_contrast, 1) for domain_contrast in contrast.model] + contrast_models = [[project.domain_contrasts.index(domain_contrast, True) for domain_contrast in contrast.model] for contrast in project.contrasts] else: - contrast_models = [[project.layers.index(layer, 1) for layer in contrast.model] + contrast_models = [[project.layers.index(layer, True) for layer in contrast.model] for contrast in project.contrasts] else: contrast_models = [[]] * len(project.contrasts) @@ -183,8 +181,8 @@ def make_cells(project: RAT.Project) -> Cells: layer_details = [] for layer in project.layers: - layer_params = []#[project.parameters.index(getattr(layer, attribute), 1) for attribute in list(layer.model_fields.keys())[1:-2]] - layer_params.append(project.parameters.index(layer.hydration, 1) if layer.hydration else float('NaN')) + layer_params = []#[project.parameters.index(getattr(layer, attribute), True) for attribute in list(layer.model_fields.keys())[1:-2]] + layer_params.append(project.parameters.index(layer.hydration, True) if layer.hydration else float('NaN')) layer_params.append(hydrate_id[layer.hydrate_with]) layer_details.append(layer_params) @@ -214,7 +212,7 @@ def make_cells(project: RAT.Project) -> Cells: cells.f3 = data_limits cells.f4 = simulation_limits cells.f5 = [contrast_model if contrast_model else 0 for contrast_model in contrast_models] - #cells.f6 = layer_details if project.model == Models.StandardLayers else [0] + #cells.f6 = layer_details if project.model == LayerModels.StandardLayers else [0] cells.f7 = [param.name for param in project.parameters] cells.f8 = [param.name for param in project.background_parameters] cells.f9 = [param.name for param in project.scalefactors] @@ -229,7 +227,7 @@ def make_cells(project: RAT.Project) -> Cells: cells.f17 = [[0.0, 0.0, 0.0]] * len(project.contrasts) # Placeholder for oil chi data cells.f18 = [[0, 1]] * len(project.domain_contrasts) # This is marked as "to do" in RAT - domain_contrast_models = [[project.layers.index(layer, 1) for layer in domain_contrast.model] + domain_contrast_models = [[project.layers.index(layer, True) for layer in domain_contrast.model] for domain_contrast in project.domain_contrasts] cells.f19 = [domain_contrast_model if domain_contrast_model else 0 for domain_contrast_model in domain_contrast_models] @@ -256,7 +254,7 @@ def make_controls(controls: Union[RAT.controls.Calculate, RAT.controls.Simplex, The controls object used in the compiled RAT code. """ - full_controls = RAT.utils.dataclasses.Controls(**controls.model_dump()) + full_controls = RAT.controls.Controls(**controls.model_dump()) cpp_controls = Control() cpp_controls.procedure = full_controls.procedure diff --git a/RAT/models.py b/RAT/models.py index 7700bca4..4b2b3881 100644 --- a/RAT/models.py +++ b/RAT/models.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator from typing import Any -from RAT.utils.enums import Actions, Hydration, Languages, Priors, Types +from RAT.utils.enums import BackgroundActions, Hydration, Languages, Priors, TypeOptions try: from enum import StrEnum @@ -46,7 +46,7 @@ def __repr__(self): class Background(RATModel): """Defines the Backgrounds in RAT.""" name: str = Field(default_factory=lambda: 'New Background ' + next(background_number), min_length=1) - type: Types = Types.Constant + type: TypeOptions = TypeOptions.Constant value_1: str = '' value_2: str = '' value_3: str = '' @@ -59,7 +59,7 @@ class Contrast(RATModel): name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1) data: str = '' background: str = '' - background_action: Actions = Actions.Add + background_action: BackgroundActions = BackgroundActions.Add bulk_in: str = '' bulk_out: str = '' scalefactor: str = '' @@ -73,7 +73,7 @@ class ContrastWithRatio(RATModel): name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1) data: str = '' background: str = '' - background_action: Actions = Actions.Add + background_action: BackgroundActions = BackgroundActions.Add bulk_in: str = '' bulk_out: str = '' scalefactor: str = '' @@ -233,7 +233,7 @@ class ProtectedParameter(Parameter): class Resolution(RATModel): """Defines Resolutions in RAT.""" name: str = Field(default_factory=lambda: 'New Resolution ' + next(resolution_number), min_length=1) - type: Types = Types.Constant + type: TypeOptions = TypeOptions.Constant value_1: str = '' value_2: str = '' value_3: str = '' diff --git a/RAT/project.py b/RAT/project.py index 080ac209..54636309 100644 --- a/RAT/project.py +++ b/RAT/project.py @@ -11,7 +11,7 @@ from RAT.classlist import ClassList import RAT.models from RAT.utils.custom_errors import custom_pydantic_validation_error -from RAT.utils.enums import Calculations, Geometries, Models +from RAT.utils.enums import Calculations, Geometries, LayerModels, Priors, TypeOptions # Map project fields to pydantic models @@ -87,44 +87,44 @@ class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_typ """ name: str = '' calculation: Calculations = Calculations.NonPolarised - model: Models = Models.StandardLayers + model: LayerModels = LayerModels.StandardLayers geometry: Geometries = Geometries.AirSubstrate absorption: bool = False parameters: ClassList = ClassList() bulk_in: ClassList = ClassList(RAT.models.Parameter(name='SLD Air', min=0.0, value=0.0, max=0.0, fit=False, - prior_type=RAT.models.Priors.Uniform, mu=0.0, sigma=np.inf)) + prior_type=Priors.Uniform, mu=0.0, sigma=np.inf)) bulk_out: ClassList = ClassList(RAT.models.Parameter(name='SLD D2O', min=6.2e-6, value=6.35e-6, max=6.35e-6, - fit=False, prior_type=RAT.models.Priors.Uniform, mu=0.0, + fit=False, prior_type=Priors.Uniform, mu=0.0, sigma=np.inf)) qz_shifts: ClassList = ClassList(RAT.models.Parameter(name='Qz shift 1', min=-1e-4, value=0.0, max=1e-4, fit=False, - prior_type=RAT.models.Priors.Uniform, mu=0.0, sigma=np.inf)) + prior_type=Priors.Uniform, mu=0.0, sigma=np.inf)) scalefactors: ClassList = ClassList(RAT.models.Parameter(name='Scalefactor 1', min=0.02, value=0.23, max=0.25, - fit=False, prior_type=RAT.models.Priors.Uniform, mu=0.0, + fit=False, prior_type=Priors.Uniform, mu=0.0, sigma=np.inf)) domain_ratios: ClassList = ClassList(RAT.models.Parameter(name='Domain Ratio 1', min=0.4, value=0.5, max=0.6, - fit=False, prior_type=RAT.models.Priors.Uniform, mu=0.0, + fit=False, prior_type=Priors.Uniform, mu=0.0, sigma=np.inf)) background_parameters: ClassList = ClassList(RAT.models.Parameter(name='Background Param 1', min=1e-7, value=1e-6, max=1e-5, fit=False, - prior_type=RAT.models.Priors.Uniform, mu=0.0, + prior_type=Priors.Uniform, mu=0.0, sigma=np.inf)) - backgrounds: ClassList = ClassList(RAT.models.Background(name='Background 1', type=RAT.models.Types.Constant, + backgrounds: ClassList = ClassList(RAT.models.Background(name='Background 1', type=TypeOptions.Constant, value_1='Background Param 1')) resolution_parameters: ClassList = ClassList(RAT.models.Parameter(name='Resolution Param 1', min=0.01, value=0.03, max=0.05, fit=False, - prior_type=RAT.models.Priors.Uniform, mu=0.0, + prior_type=Priors.Uniform, mu=0.0, sigma=np.inf)) - resolutions: ClassList = ClassList(RAT.models.Resolution(name='Resolution 1', type=RAT.models.Types.Constant, + resolutions: ClassList = ClassList(RAT.models.Resolution(name='Resolution 1', type=TypeOptions.Constant, value_1='Resolution Param 1')) custom_files: ClassList = ClassList() @@ -217,14 +217,14 @@ def set_domain_contrasts(self) -> 'Project': """If we are not running a domains calculation with standard layers, ensure the domain_contrasts component of the model is empty. """ - if not (self.calculation == Calculations.Domains and self.model == Models.StandardLayers): + if not (self.calculation == Calculations.Domains and self.model == LayerModels.StandardLayers): self.domain_contrasts.data = [] return self @model_validator(mode='after') def set_layers(self) -> 'Project': """If we are not using a standard layers model, ensure the layers component of the model is empty.""" - if self.model != Models.StandardLayers: + if self.model != LayerModels.StandardLayers: self.layers.data = [] return self @@ -267,12 +267,12 @@ def check_contrast_model_length(self) -> 'Project': """Given certain values of the "calculation" and "model" defined in the project, the "model" field of "contrasts" may be constrained in its length. """ - if self.model == Models.StandardLayers and self.calculation == Calculations.Domains: + if self.model == LayerModels.StandardLayers and self.calculation == Calculations.Domains: for contrast in self.contrasts: if contrast.model and len(contrast.model) != 2: raise ValueError('For a standard layers domains calculation the "model" field of "contrasts" must ' 'contain exactly two values.') - elif self.model != Models.StandardLayers: + elif self.model != LayerModels.StandardLayers: for contrast in self.contrasts: if len(contrast.model) > 1: raise ValueError('For a custom model calculation the "model" field of "contrasts" cannot contain ' @@ -434,7 +434,7 @@ def get_contrast_model_field(self): model_field : str The name of the field used to define the contrasts' model field. """ - if self.model == Models.StandardLayers: + if self.model == LayerModels.StandardLayers: if self.calculation == Calculations.Domains: model_field = 'domain_contrasts' else: @@ -480,7 +480,7 @@ def write_script(self, obj_name: str = 'problem', script: str = 'project_script. f.write(f'{indent}{class_list}=RAT.ClassList({contents}),\n') f.write(f'{indent})\n') - def _classlist_wrapper(self, class_list: 'ClassList', func: Callable): + def _classlist_wrapper(self, class_list: ClassList, func: Callable): """Defines the function used to wrap around ClassList routines to force revalidation. Parameters diff --git a/RAT/utils/dataclasses.py b/RAT/utils/dataclasses.py deleted file mode 100644 index 40893f13..00000000 --- a/RAT/utils/dataclasses.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Defines dataclasses used for the inputs to the compiled RAT code.""" - -from dataclasses import dataclass, field - -from RAT.utils.enums import Parallel, Procedures, Display, BoundHandling, Strategies - - -@dataclass(frozen=True) -class Controls: - """The full set of controls parameters required for RATMain.""" - # All Procedures - procedure: Procedures = Procedures.Calculate - parallel: Parallel = Parallel.Single - calcSldDuringFit: bool = False - resampleParams: list[float] = field(default_factory=list[0.9, 50.0]) - display: Display = Display.Iter - # Simplex - xTolerance: float = 1.0e-6 - funcTolerance: float = 1.0e-6 - maxFuncEvals: int = 10000 - maxIterations: int = 1000 - updateFreq: int = -1 - updatePlotFreq: int = 1 - # DE - populationSize: int = 20 - fWeight: float = 0.5 - crossoverProbability: float = 0.8 - strategy: Strategies = Strategies.RandomWithPerVectorDither.value - targetValue: float = 1.0 - numGenerations: int = 500 - # NS - nLive: int = 150 - nMCMC: float = 0.0 - propScale: float = 0.1 - nsTolerance: float = 0.1 - # Dream - nSamples: int = 50000 - nChains: int = 10 - jumpProbability: float = 0.5 - pUnitGamma: float = 0.2 - boundHandling: BoundHandling = BoundHandling.Fold - adaptPCR: bool = False diff --git a/RAT/utils/enums.py b/RAT/utils/enums.py index e7389332..c1f8e395 100644 --- a/RAT/utils/enums.py +++ b/RAT/utils/enums.py @@ -67,13 +67,13 @@ class Priors(StrEnum): Gaussian = 'gaussian' -class Types(StrEnum): +class TypeOptions(StrEnum): Constant = 'constant' Data = 'data' Function = 'function' -class Actions(StrEnum): +class BackgroundActions(StrEnum): Add = 'add' Subtract = 'subtract' @@ -89,7 +89,7 @@ class Geometries(StrEnum): SubstrateLiquid = 'substrate/liquid' -class Models(StrEnum): +class LayerModels(StrEnum): CustomLayers = 'custom layers' CustomXY = 'custom xy' StandardLayers = 'standard layers' diff --git a/tests/test_classlist.py b/tests/test_classlist.py index 8de1b251..7c71b6f7 100644 --- a/tests/test_classlist.py +++ b/tests/test_classlist.py @@ -117,7 +117,7 @@ def test_identical_name_fields(self, input_list: Sequence[object], name_field: s ClassList(input_list, name_field=name_field) -def test_repr_table(two_name_class_list: 'ClassList', two_name_class_list_table: str) -> None: +def test_repr_table(two_name_class_list: ClassList, two_name_class_list_table: str) -> None: """For classes with the __dict__ attribute, we should be able to print the ClassList like a table.""" assert repr(two_name_class_list) == two_name_class_list_table @@ -158,7 +158,7 @@ def test_setitem(two_name_class_list: ClassList, new_item: InputAttributes, expe @pytest.mark.parametrize("new_item", [ (InputAttributes(name='Bob')), ]) -def test_setitem_same_name_field(two_name_class_list: 'ClassList', new_item: InputAttributes) -> None: +def test_setitem_same_name_field(two_name_class_list: ClassList, new_item: InputAttributes) -> None: """If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError.""" with pytest.raises(ValueError, match="Input list contains objects with the same value of the name attribute"): two_name_class_list[0] = new_item @@ -167,20 +167,20 @@ def test_setitem_same_name_field(two_name_class_list: 'ClassList', new_item: Inp @pytest.mark.parametrize("new_values", [ 'Bob', ]) -def test_setitem_different_classes(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None: +def test_setitem_different_classes(two_name_class_list: ClassList, new_values: dict[str, Any]) -> None: """If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError.""" with pytest.raises(ValueError, match=f"Input list contains elements of type other than 'InputAttributes'"): two_name_class_list[0] = new_values -def test_delitem(two_name_class_list: 'ClassList', one_name_class_list: 'ClassList') -> None: +def test_delitem(two_name_class_list: ClassList, one_name_class_list: ClassList) -> None: """We should be able to delete elements from a ClassList with the del operator.""" class_list = two_name_class_list del class_list[1] assert class_list == one_name_class_list -def test_delitem_not_present(two_name_class_list: 'ClassList') -> None: +def test_delitem_not_present(two_name_class_list: ClassList) -> None: """If we use the del operator to delete an index out of range, we should raise an IndexError.""" class_list = two_name_class_list with pytest.raises(IndexError, match=re.escape("list assignment index out of range")): @@ -193,7 +193,7 @@ def test_delitem_not_present(two_name_class_list: 'ClassList') -> None: (InputAttributes(name='Eve'),), (InputAttributes(name='Eve')), ]) -def test_iadd(two_name_class_list: 'ClassList', added_list: Iterable, three_name_class_list: 'ClassList') -> None: +def test_iadd(two_name_class_list: ClassList, added_list: Iterable, three_name_class_list: ClassList) -> None: """We should be able to use the "+=" operator to add iterables to a ClassList. Individual objects should be wrapped in a list before being added.""" class_list = two_name_class_list @@ -206,7 +206,7 @@ def test_iadd(two_name_class_list: 'ClassList', added_list: Iterable, three_name ([InputAttributes(name='Alice'), InputAttributes(name='Bob')]), (InputAttributes(name='Alice'), InputAttributes(name='Bob')), ]) -def test_iadd_empty_classlist(added_list: Sequence, two_name_class_list: 'ClassList') -> None: +def test_iadd_empty_classlist(added_list: Sequence, two_name_class_list: ClassList) -> None: """We should be able to use the "+=" operator to add iterables to an empty ClassList, whilst also setting _class_handle.""" class_list = ClassList() @@ -215,7 +215,7 @@ def test_iadd_empty_classlist(added_list: Sequence, two_name_class_list: 'ClassL assert isinstance(added_list[-1], class_list._class_handle) -def test_mul(two_name_class_list: 'ClassList') -> None: +def test_mul(two_name_class_list: ClassList) -> None: """If we use the "*" operator on a ClassList, we should raise a TypeError.""" n = 2 with pytest.raises(TypeError, match=re.escape(f"unsupported operand type(s) for *: " @@ -224,7 +224,7 @@ def test_mul(two_name_class_list: 'ClassList') -> None: two_name_class_list * n -def test_rmul(two_name_class_list: 'ClassList') -> None: +def test_rmul(two_name_class_list: ClassList) -> None: """If we use the "*" operator on a ClassList, we should raise a TypeError.""" n = 2 with pytest.raises(TypeError, match=re.escape(f"unsupported operand type(s) for *: " @@ -233,7 +233,7 @@ def test_rmul(two_name_class_list: 'ClassList') -> None: n * two_name_class_list -def test_imul(two_name_class_list: 'ClassList') -> None: +def test_imul(two_name_class_list: ClassList) -> None: """If we use the "*=" operator on a ClassList, we should raise a TypeError.""" n = 2 with pytest.raises(TypeError, match=re.escape(f"unsupported operand type(s) for *=: " @@ -245,9 +245,9 @@ def test_imul(two_name_class_list: 'ClassList') -> None: @pytest.mark.parametrize("new_object", [ (InputAttributes(name='Eve')), ]) -def test_append_object(two_name_class_list: 'ClassList', +def test_append_object(two_name_class_list: ClassList, new_object: object, - three_name_class_list: 'ClassList') -> None: + three_name_class_list: ClassList) -> None: """We should be able to append to a ClassList using a new object.""" class_list = two_name_class_list class_list.append(new_object) @@ -257,9 +257,9 @@ def test_append_object(two_name_class_list: 'ClassList', @pytest.mark.parametrize("new_values", [ ({'name': 'Eve'}), ]) -def test_append_kwargs(two_name_class_list: 'ClassList', +def test_append_kwargs(two_name_class_list: ClassList, new_values: dict[str, Any], - three_name_class_list: 'ClassList') -> None: + three_name_class_list: ClassList) -> None: """We should be able to append to a ClassList using keyword arguments.""" class_list = two_name_class_list class_list.append(**new_values) @@ -269,10 +269,10 @@ def test_append_kwargs(two_name_class_list: 'ClassList', @pytest.mark.parametrize(["new_object", "new_values"], [ (InputAttributes(name='Eve'), {'name': 'John'}), ]) -def test_append_object_and_kwargs(two_name_class_list: 'ClassList', +def test_append_object_and_kwargs(two_name_class_list: ClassList, new_object: object, new_values: dict[str, Any], - three_name_class_list: 'ClassList') -> None: + three_name_class_list: ClassList) -> None: """If we append to a ClassList using a new object and keyword arguments, we raise a warning, and append the object, discarding the keyword arguments.""" class_list = two_name_class_list @@ -286,7 +286,7 @@ def test_append_object_and_kwargs(two_name_class_list: 'ClassList', @pytest.mark.parametrize("new_object", [ (InputAttributes(name='Alice')), ]) -def test_append_object_empty_classlist(new_object: object, one_name_class_list: 'ClassList') -> None: +def test_append_object_empty_classlist(new_object: object, one_name_class_list: ClassList) -> None: """We should be able to append to an empty ClassList using a new object, whilst also setting _class_handle.""" class_list = ClassList() class_list.append(new_object) @@ -309,7 +309,7 @@ def test_append_kwargs_empty_classlist(new_values: dict[str, Any]) -> None: @pytest.mark.parametrize("new_object", [ (InputAttributes(name='Alice')), ]) -def test_append_object_same_name_field(two_name_class_list: 'ClassList', new_object: object) -> None: +def test_append_object_same_name_field(two_name_class_list: ClassList, new_object: object) -> None: """If we append an object with an already-specified name_field value to a ClassList we should raise a ValueError.""" with pytest.raises(ValueError, match=f"Input list contains objects with the same value of the " f"{two_name_class_list.name_field} attribute"): @@ -319,7 +319,7 @@ def test_append_object_same_name_field(two_name_class_list: 'ClassList', new_obj @pytest.mark.parametrize("new_values", [ ({'name': 'Alice'}), ]) -def test_append_kwargs_same_name_field(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None: +def test_append_kwargs_same_name_field(two_name_class_list: ClassList, new_values: dict[str, Any]) -> None: """If we append an object with an already-specified name_field value to a ClassList we should raise a ValueError.""" with pytest.raises(ValueError, match=f"Input arguments contain the {two_name_class_list.name_field} " f"'{new_values[two_name_class_list.name_field]}', " @@ -330,7 +330,7 @@ def test_append_kwargs_same_name_field(two_name_class_list: 'ClassList', new_val @pytest.mark.parametrize("new_object", [ (InputAttributes(name='Eve')) ]) -def test_insert_object(two_name_class_list: 'ClassList', new_object: object) -> None: +def test_insert_object(two_name_class_list: ClassList, new_object: object) -> None: """We should be able to insert an object within a ClassList using a new object.""" two_name_class_list.insert(1, new_object) assert two_name_class_list == ClassList([InputAttributes(name='Alice'), @@ -341,7 +341,7 @@ def test_insert_object(two_name_class_list: 'ClassList', new_object: object) -> @pytest.mark.parametrize("new_values", [ ({'name': 'Eve'}) ]) -def test_insert_kwargs(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None: +def test_insert_kwargs(two_name_class_list: ClassList, new_values: dict[str, Any]) -> None: """We should be able to insert an object within a ClassList using keyword arguments.""" two_name_class_list.insert(1, **new_values) assert two_name_class_list == ClassList([InputAttributes(name='Alice'), @@ -352,10 +352,10 @@ def test_insert_kwargs(two_name_class_list: 'ClassList', new_values: dict[str, A @pytest.mark.parametrize(["new_object", "new_values"], [ (InputAttributes(name='Eve'), {'name': 'John'}), ]) -def test_insert_object_and_kwargs(two_name_class_list: 'ClassList', +def test_insert_object_and_kwargs(two_name_class_list: ClassList, new_object: object, new_values: dict[str, Any], - three_name_class_list: 'ClassList') -> None: + three_name_class_list: ClassList) -> None: """If call insert() on a ClassList using a new object and keyword arguments, we raise a warning, and append the object, discarding the keyword arguments.""" class_list = two_name_class_list @@ -371,7 +371,7 @@ def test_insert_object_and_kwargs(two_name_class_list: 'ClassList', @pytest.mark.parametrize("new_object", [ (InputAttributes(name='Alice')), ]) -def test_insert_object_empty_classlist(new_object: object, one_name_class_list: 'ClassList') -> None: +def test_insert_object_empty_classlist(new_object: object, one_name_class_list: ClassList) -> None: """We should be able to insert a new object into an empty ClassList, whilst also setting _class_handle.""" class_list = ClassList() class_list.insert(0, new_object) @@ -394,7 +394,7 @@ def test_insert_kwargs_empty_classlist(new_values: dict[str, Any]) -> None: @pytest.mark.parametrize("new_object", [ (InputAttributes(name='Alice')) ]) -def test_insert_object_same_name(two_name_class_list: 'ClassList', new_object: object) -> None: +def test_insert_object_same_name(two_name_class_list: ClassList, new_object: object) -> None: """If we insert an object with an already-specified name_field value to a ClassList we should raise a ValueError.""" with pytest.raises(ValueError, match=f"Input list contains objects with the same value of the " f"{two_name_class_list.name_field} attribute"): @@ -404,7 +404,7 @@ def test_insert_object_same_name(two_name_class_list: 'ClassList', new_object: o @pytest.mark.parametrize("new_values", [ ({'name': 'Alice'}) ]) -def test_insert_kwargs_same_name(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None: +def test_insert_kwargs_same_name(two_name_class_list: ClassList, new_values: dict[str, Any]) -> None: """If we insert an object with an already-specified name_field value to a ClassList we should raise a ValueError.""" with pytest.raises(ValueError, match=f"Input arguments contain the {two_name_class_list.name_field} " f"'{new_values[two_name_class_list.name_field]}', " @@ -416,7 +416,7 @@ def test_insert_kwargs_same_name(two_name_class_list: 'ClassList', new_values: d "Bob", (InputAttributes(name='Bob')), ]) -def test_remove(two_name_class_list: 'ClassList', remove_value: Union[object, str]) -> None: +def test_remove(two_name_class_list: ClassList, remove_value: Union[object, str]) -> None: """We should be able to remove an object either by the value of the name_field or by specifying the object itself.""" two_name_class_list.remove(remove_value) @@ -427,7 +427,7 @@ def test_remove(two_name_class_list: 'ClassList', remove_value: Union[object, st 'Eve', (InputAttributes(name='Eve')), ]) -def test_remove_not_present(two_name_class_list: 'ClassList', remove_value: Union[object, str]) -> None: +def test_remove_not_present(two_name_class_list: ClassList, remove_value: Union[object, str]) -> None: """If we remove an object not included in the ClassList we should raise a ValueError.""" with pytest.raises(ValueError, match=re.escape("list.remove(x): x not in list")): two_name_class_list.remove(remove_value) @@ -439,7 +439,7 @@ def test_remove_not_present(two_name_class_list: 'ClassList', remove_value: Unio ('Eve', 0), (InputAttributes(name='Eve'), 0), ]) -def test_count(two_name_class_list: 'ClassList', count_value: Union[object, str], expected_count: int) -> None: +def test_count(two_name_class_list: ClassList, count_value: Union[object, str], expected_count: int) -> None: """We should be able to determine the number of times an object is in the ClassList using either the object itself or its name_field value. """ @@ -450,7 +450,7 @@ def test_count(two_name_class_list: 'ClassList', count_value: Union[object, str] ('Bob', 1), (InputAttributes(name='Bob'), 1), ]) -def test_index(two_name_class_list: 'ClassList', index_value: Union[object, str], expected_index: int) -> None: +def test_index(two_name_class_list: ClassList, index_value: Union[object, str], expected_index: int) -> None: """We should be able to find the index of an object in the ClassList either by its name_field value or by specifying the object itself. """ @@ -461,7 +461,7 @@ def test_index(two_name_class_list: 'ClassList', index_value: Union[object, str] ('Bob', 1, 2), (InputAttributes(name='Bob'), -3, -2), ]) -def test_index_offset(two_name_class_list: 'ClassList', index_value: Union[object, str], offset: int, +def test_index_offset(two_name_class_list: ClassList, index_value: Union[object, str], offset: int, expected_index: int) -> None: """We should be able to find the index of an object in the ClassList either by its name_field value or by specifying the object itself. When using an offset, the value of the index should be shifted accordingly. @@ -472,7 +472,7 @@ def test_index_offset(two_name_class_list: 'ClassList', index_value: Union[objec 'Eve', (InputAttributes(name='Eve')), ]) -def test_index_not_present(two_name_class_list: 'ClassList', index_value: Union[object, str]) -> None: +def test_index_not_present(two_name_class_list: ClassList, index_value: Union[object, str]) -> None: """If we try to find the index of an object not included in the ClassList we should raise a ValueError.""" # with pytest.raises(ValueError, match=f"'{index_value}' is not in list") as e: with pytest.raises(ValueError): @@ -485,7 +485,7 @@ def test_index_not_present(two_name_class_list: 'ClassList', index_value: Union[ (InputAttributes(name='Eve'),), (InputAttributes(name='Eve')), ]) -def test_extend(two_name_class_list: 'ClassList', extended_list: Sequence, three_name_class_list: 'ClassList') -> None: +def test_extend(two_name_class_list: ClassList, extended_list: Sequence, three_name_class_list: ClassList) -> None: """We should be able to extend a ClassList using another ClassList or a sequence. Individual objects should be wrapped in a list before being added.""" class_list = two_name_class_list @@ -498,7 +498,7 @@ def test_extend(two_name_class_list: 'ClassList', extended_list: Sequence, three ([InputAttributes(name='Alice')]), (InputAttributes(name='Alice'),), ]) -def test_extend_empty_classlist(extended_list: Sequence, one_name_class_list: 'ClassList') -> None: +def test_extend_empty_classlist(extended_list: Sequence, one_name_class_list: ClassList) -> None: """We should be able to extend a ClassList using another ClassList or a sequence""" class_list = ClassList() class_list.extend(extended_list) @@ -511,7 +511,7 @@ def test_extend_empty_classlist(extended_list: Sequence, one_name_class_list: 'C ({'name': 'John', 'surname': 'Luther'}, ClassList([InputAttributes(name='John', surname='Luther'), InputAttributes(name='Bob')])), ]) -def test_set_fields(two_name_class_list: 'ClassList', new_values: dict[str, Any], expected_classlist: 'ClassList')\ +def test_set_fields(two_name_class_list: ClassList, new_values: dict[str, Any], expected_classlist: ClassList)\ -> None: """We should be able to set field values in an element of a ClassList using keyword arguments.""" class_list = two_name_class_list @@ -522,7 +522,7 @@ def test_set_fields(two_name_class_list: 'ClassList', new_values: dict[str, Any] @pytest.mark.parametrize("new_values", [ ({'name': 'Bob'}), ]) -def test_set_fields_same_name_field(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None: +def test_set_fields_same_name_field(two_name_class_list: ClassList, new_values: dict[str, Any]) -> None: """If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError.""" with pytest.raises(ValueError, match=f"Input arguments contain the {two_name_class_list.name_field} " f"'{new_values[two_name_class_list.name_field]}', " @@ -540,7 +540,7 @@ def test_set_fields_same_name_field(two_name_class_list: 'ClassList', new_values name_field='surname'), ["Morgan", "Terwilliger"]), (ClassList(InputAttributes()), []), ]) -def test_get_names(class_list: 'ClassList', expected_names: list[str]) -> None: +def test_get_names(class_list: ClassList, expected_names: list[str]) -> None: """We should get a list of the values of the name_field attribute from each object with it defined in the ClassList.""" assert class_list.get_names() == expected_names @@ -552,7 +552,7 @@ def test_get_names(class_list: 'ClassList', expected_names: list[str]) -> None: (ClassList([InputAttributes(surname='Morgan'), InputAttributes(surname='Terwilliger')]), []), (ClassList(InputAttributes()), []), ]) -def test_get_all_matches(class_list: 'ClassList', expected_matches: list[tuple]) -> None: +def test_get_all_matches(class_list: ClassList, expected_matches: list[tuple]) -> None: """We should get a list of (index, field) tuples matching the given value in the ClassList.""" assert class_list.get_all_matches("Alice") == expected_matches @@ -561,7 +561,7 @@ def test_get_all_matches(class_list: 'ClassList', expected_matches: list[tuple]) ({'name': 'Eve'}), ({'surname': 'Polastri'}), ]) -def test__validate_name_field(two_name_class_list: 'ClassList', input_dict: dict[str, Any]) -> None: +def test__validate_name_field(two_name_class_list: ClassList, input_dict: dict[str, Any]) -> None: """We should not raise an error if the input values do not contain a name_field value defined in an object in the ClassList.""" assert two_name_class_list._validate_name_field(input_dict) is None @@ -570,7 +570,7 @@ def test__validate_name_field(two_name_class_list: 'ClassList', input_dict: dict @pytest.mark.parametrize("input_dict", [ ({'name': 'Alice'}), ]) -def test__validate_name_field_not_unique(two_name_class_list: 'ClassList', input_dict: dict[str, Any]) -> None: +def test__validate_name_field_not_unique(two_name_class_list: ClassList, input_dict: dict[str, Any]) -> None: """We should raise a ValueError if we input values containing a name_field defined in an object in the ClassList.""" with pytest.raises(ValueError, match=f"Input arguments contain the {two_name_class_list.name_field} " f"'{input_dict[two_name_class_list.name_field]}', " @@ -585,7 +585,7 @@ def test__validate_name_field_not_unique(two_name_class_list: 'ClassList', input ([InputAttributes()]), ([]), ]) -def test__check_unique_name_fields(two_name_class_list: 'ClassList', input_list: Iterable) -> None: +def test__check_unique_name_fields(two_name_class_list: ClassList, input_list: Iterable) -> None: """We should not raise an error if an input list contains objects with different name_field values, or if the name_field is not defined.""" assert two_name_class_list._check_unique_name_fields(input_list) is None @@ -594,7 +594,7 @@ def test__check_unique_name_fields(two_name_class_list: 'ClassList', input_list: @pytest.mark.parametrize("input_list", [ ([InputAttributes(name='Alice'), InputAttributes(name='Alice')]), ]) -def test__check_unique_name_fields_not_unique(two_name_class_list: 'ClassList', input_list: Iterable) -> None: +def test__check_unique_name_fields_not_unique(two_name_class_list: ClassList, input_list: Iterable) -> None: """We should raise a ValueError if an input list contains multiple objects with matching name_field values defined.""" with pytest.raises(ValueError, match=f"Input list contains objects with the same value of the " @@ -626,7 +626,7 @@ def test__check_classes_different_classes(input_list: Iterable) -> None: ("Alice", InputAttributes(name='Alice')), ("Eve", "Eve"), ]) -def test__get_item_from_name_field(two_name_class_list: 'ClassList', +def test__get_item_from_name_field(two_name_class_list: ClassList, value: str, expected_output: Union[object, str]) -> None: """When we input the name_field value of an object defined in the ClassList, we should return the object. @@ -644,7 +644,7 @@ def test__get_item_from_name_field(two_name_class_list: 'ClassList', ([InputAttributes(name='Alice'), dict(name='Bob')], InputAttributes), ([dict(name='Alice'), InputAttributes(name='Bob')], dict), ]) -def test_determine_class_handle(input_list: 'ClassList', expected_type: type) -> None: +def test_determine_class_handle(input_list: ClassList, expected_type: type) -> None: """The _class_handle for the ClassList should be the type that satisfies the condition "isinstance(element, type)" for all elements in the ClassList. """ diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 8382ee79..8fcf3674 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -3,10 +3,10 @@ import numpy as np import pytest -from RAT.controls import set_controls +import RAT from RAT.inputs import make_input, make_problem, make_cells, make_controls -import RAT.project -import RAT.utils.enums +from RAT.utils.enums import (BoundHandling, Calculations, Display, Geometries, LayerModels, Parallel, Procedures, + TypeOptions) from RAT.rat_core import Cells, Checks, Control, Limits, Priors, ProblemDefinition @@ -27,7 +27,7 @@ def standard_layers_project(): @pytest.fixture def domains_project(): """Add parameters to the default project for a domains calculation.""" - test_project = RAT.Project(calculation=RAT.utils.enums.Calculations.Domains, + test_project = RAT.Project(calculation=Calculations.Domains, data=RAT.ClassList([RAT.models.Data(name='Simulation', data=np.array([[1.0, 1.0, 1.0]]))])) test_project.parameters.append(name='Test SLD') test_project.custom_files.append(name='Test Custom File', filename='matlab_test.m', language='matlab') @@ -43,7 +43,7 @@ def domains_project(): @pytest.fixture def custom_xy_project(): """Add parameters to the default project for a non polarised calculation and use the custom xy model.""" - test_project = RAT.Project(model=RAT.utils.enums.Models.CustomXY) + test_project = RAT.Project(model=LayerModels.CustomXY) test_project.parameters.append(name='Test SLD') test_project.custom_files.append(name='Test Custom File', filename='matlab_test.m', language='matlab') test_project.contrasts.append(name='Test Contrast', data='Simulation', background='Background 1', bulk_in='SLD Air', @@ -56,9 +56,9 @@ def custom_xy_project(): def standard_layers_problem(): """The expected problem object from "standard_layers_project".""" problem = ProblemDefinition() - problem.TF = RAT.utils.enums.Calculations.NonPolarised - problem.modelType = RAT.utils.enums.Models.StandardLayers - problem.geometry = RAT.utils.enums.Geometries.AirSubstrate + problem.TF = Calculations.NonPolarised + problem.modelType = LayerModels.StandardLayers + problem.geometry = Geometries.AirSubstrate problem.useImaginary = False problem.params = [3.0, 0.0] problem.bulkIn = [0.0] @@ -96,9 +96,9 @@ def standard_layers_problem(): def domains_problem(): """The expected problem object from "standard_layers_project".""" problem = ProblemDefinition() - problem.TF = RAT.utils.enums.Calculations.Domains - problem.modelType = RAT.utils.enums.Models.StandardLayers - problem.geometry = RAT.utils.enums.Geometries.AirSubstrate + problem.TF = Calculations.Domains + problem.modelType = LayerModels.StandardLayers + problem.geometry = Geometries.AirSubstrate problem.useImaginary = False problem.params = [3.0, 0.0] problem.bulkIn = [0.0] @@ -136,9 +136,9 @@ def domains_problem(): def custom_xy_problem(): """The expected problem object from "custom_xy_project".""" problem = ProblemDefinition() - problem.TF = RAT.utils.enums.Calculations.NonPolarised - problem.modelType = RAT.utils.enums.Models.CustomXY - problem.geometry = RAT.utils.enums.Geometries.AirSubstrate + problem.TF = Calculations.NonPolarised + problem.modelType = LayerModels.CustomXY + problem.geometry = Geometries.AirSubstrate problem.useImaginary = False problem.params = [3.0, 0.0] problem.bulkIn = [0.0] @@ -190,8 +190,8 @@ def standard_layers_cells(): cells.f12 = ['SLD D2O'] cells.f13 = ['Resolution Param 1'] cells.f14 = ['matlab_test'] - cells.f15 = [RAT.models.Types.Constant] - cells.f16 = [RAT.models.Types.Constant] + cells.f15 = [TypeOptions.Constant] + cells.f16 = [TypeOptions.Constant] cells.f17 = [[0.0, 0.0, 0.0]] cells.f18 = [] cells.f19 = [] @@ -218,8 +218,8 @@ def domains_cells(): cells.f12 = ['SLD D2O'] cells.f13 = ['Resolution Param 1'] cells.f14 = ['matlab_test'] - cells.f15 = [RAT.models.Types.Constant] - cells.f16 = [RAT.models.Types.Constant] + cells.f15 = [TypeOptions.Constant] + cells.f16 = [TypeOptions.Constant] cells.f17 = [[0.0, 0.0, 0.0]] cells.f18 = [[0, 1], [0, 1]] cells.f19 = [[1], [1]] @@ -246,8 +246,8 @@ def custom_xy_cells(): cells.f12 = ['SLD D2O'] cells.f13 = ['Resolution Param 1'] cells.f14 = ['matlab_test'] - cells.f15 = [RAT.models.Types.Constant] - cells.f16 = [RAT.models.Types.Constant] + cells.f15 = [TypeOptions.Constant] + cells.f16 = [TypeOptions.Constant] cells.f17 = [[0.0, 0.0, 0.0]] cells.f18 = [] cells.f19 = [] @@ -292,14 +292,14 @@ def domains_limits(): def non_polarised_priors(): """The expected priors object from "standard_layers_project" and "custom_xy_project".""" priors = Priors() - priors.param = [['Substrate Roughness', RAT.models.Priors.Uniform, 0.0, np.inf], - ['Test SLD', RAT.models.Priors.Uniform, 0.0, np.inf]] - priors.backgroundParam = [['Background Param 1', RAT.models.Priors.Uniform, 0.0, np.inf]] - priors.qzshift = [['Qz shift 1', RAT.models.Priors.Uniform, 0.0, np.inf]] - priors.scalefactor = [['Scalefactor 1', RAT.models.Priors.Uniform, 0.0, np.inf]] - priors.bulkIn = [['SLD Air', RAT.models.Priors.Uniform, 0.0, np.inf]] - priors.bulkOut = [['SLD D2O', RAT.models.Priors.Uniform, 0.0, np.inf]] - priors.resolutionParam = [['Resolution Param 1', RAT.models.Priors.Uniform, 0.0, np.inf]] + priors.param = [['Substrate Roughness', RAT.utils.enums.Priors.Uniform, 0.0, np.inf], + ['Test SLD', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] + priors.backgroundParam = [['Background Param 1', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] + priors.qzshift = [['Qz shift 1', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] + priors.scalefactor = [['Scalefactor 1', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] + priors.bulkIn = [['SLD Air', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] + priors.bulkOut = [['SLD D2O', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] + priors.resolutionParam = [['Resolution Param 1', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] priors.domainRatio = [] priors.priorNames = ['Substrate Roughness', 'Test SLD', 'Background Param 1', 'Scalefactor 1', 'Qz shift 1', 'SLD Air', 'SLD D2O', 'Resolution Param 1'] @@ -313,15 +313,15 @@ def non_polarised_priors(): def domains_priors(): """The expected priors object from "domains_project".""" priors = Priors() - priors.param = [['Substrate Roughness', RAT.models.Priors.Uniform, 0.0, np.inf], - ['Test SLD', RAT.models.Priors.Uniform, 0.0, np.inf]] - priors.backgroundParam = [['Background Param 1', RAT.models.Priors.Uniform, 0.0, np.inf]] - priors.qzshift = [['Qz shift 1', RAT.models.Priors.Uniform, 0.0, np.inf]] - priors.scalefactor = [['Scalefactor 1', RAT.models.Priors.Uniform, 0.0, np.inf]] - priors.bulkIn = [['SLD Air', RAT.models.Priors.Uniform, 0.0, np.inf]] - priors.bulkOut = [['SLD D2O', RAT.models.Priors.Uniform, 0.0, np.inf]] - priors.resolutionParam = [['Resolution Param 1', RAT.models.Priors.Uniform, 0.0, np.inf]] - priors.domainRatio = [['Domain Ratio 1', RAT.models.Priors.Uniform, 0.0, np.inf]] + priors.param = [['Substrate Roughness', RAT.utils.enums.Priors.Uniform, 0.0, np.inf], + ['Test SLD', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] + priors.backgroundParam = [['Background Param 1', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] + priors.qzshift = [['Qz shift 1', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] + priors.scalefactor = [['Scalefactor 1', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] + priors.bulkIn = [['SLD Air', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] + priors.bulkOut = [['SLD D2O', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] + priors.resolutionParam = [['Resolution Param 1', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] + priors.domainRatio = [['Domain Ratio 1', RAT.utils.enums.Priors.Uniform, 0.0, np.inf]] priors.priorNames = ['Substrate Roughness', 'Test SLD', 'Background Param 1', 'Scalefactor 1', 'Qz shift 1', 'SLD Air', 'SLD D2O', 'Resolution Param 1', 'Domain Ratio 1'] priors.priorValues = [[1, 0.0, np.inf], [1, 0.0, np.inf], [1, 0.0, np.inf], [1, 0.0, np.inf], [1, 0.0, np.inf], @@ -336,11 +336,11 @@ def standard_layers_controls(): "standard_layers_project". """ controls = Control() - controls.procedure = RAT.utils.enums.Procedures.Calculate - controls.parallel = RAT.utils.enums.Parallel.Single + controls.procedure = Procedures.Calculate + controls.parallel = Parallel.Single controls.calcSldDuringFit = False controls.resampleParams = [0.9, 50.0] - controls.display = RAT.utils.enums.Display.Iter + controls.display = Display.Iter controls.xTolerance = 1.0e-6 controls.funcTolerance = 1.0e-6 controls.maxFuncEvals = 10000 @@ -361,7 +361,7 @@ def standard_layers_controls(): controls.nChains = 10 controls.jumpProbability = 0.5 controls.pUnitGamma = 0.2 - controls.boundHandling = RAT.utils.enums.BoundHandling.Fold + controls.boundHandling = BoundHandling.Fold controls.checks.fitParam = [1, 0] controls.checks.fitBackgroundParam = [0] controls.checks.fitQzshift = [0] @@ -379,11 +379,11 @@ def custom_xy_controls(): """The expected controls object for input to the compiled RAT code given the default inputs and "custom_xy_project". """ controls = Control() - controls.procedure = RAT.utils.enums.Procedures.Calculate - controls.parallel = RAT.utils.enums.Parallel.Single + controls.procedure = Procedures.Calculate + controls.parallel = Parallel.Single controls.calcSldDuringFit = True controls.resampleParams = [0.9, 50.0] - controls.display = RAT.utils.enums.Display.Iter + controls.display = Display.Iter controls.xTolerance = 1.0e-6 controls.funcTolerance = 1.0e-6 controls.maxFuncEvals = 10000 @@ -404,7 +404,7 @@ def custom_xy_controls(): controls.nChains = 10 controls.jumpProbability = 0.5 controls.pUnitGamma = 0.2 - controls.boundHandling = RAT.utils.enums.BoundHandling.Fold + controls.boundHandling = BoundHandling.Fold controls.checks.fitParam = [1, 0] controls.checks.fitBackgroundParam = [0] controls.checks.fitQzshift = [0] @@ -455,7 +455,7 @@ def test_make_input(test_project, test_problem, test_cells, test_limits, test_pr parameter_fields = ["param", "backgroundParam", "scalefactor", "qzshift", "bulkIn", "bulkOut", "resolutionParam", "domainRatio"] - problem, cells, limits, priors, controls = make_input(test_project, set_controls()) + problem, cells, limits, priors, controls = make_input(test_project, RAT.set_controls()) check_problem_equal(problem, test_problem) check_cells_equal(cells, test_cells) @@ -504,7 +504,7 @@ def test_make_controls(standard_layers_controls, test_checks) -> None: """The controls object should contain the full set of controls parameters, with the appropriate set defined by the input controls. """ - controls = make_controls(set_controls(), test_checks) + controls = make_controls(RAT.set_controls(), test_checks) check_controls_equal(controls, standard_layers_controls) diff --git a/tests/test_project.py b/tests/test_project.py index d8becfaa..612533e4 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -9,9 +9,8 @@ from typing import Callable import tempfile -from RAT.classlist import ClassList -import RAT.models -import RAT.project +import RAT +from RAT.utils.enums import Calculations, LayerModels @pytest.fixture @@ -142,7 +141,7 @@ def test_classlists_specific_cases() -> None: """The ClassLists in the "Project" model should contain instances of specific models given various non-default options. """ - project = RAT.project.Project(calculation=RAT.project.Calculations.Domains, absorption=True) + project = RAT.Project(calculation=Calculations.Domains, absorption=True) assert project.layers._class_handle.__name__ == 'AbsorptionLayer' assert project.contrasts._class_handle.__name__ == 'ContrastWithRatio' @@ -163,7 +162,7 @@ def test_initialise_wrong_classes(input_model: Callable) -> None: with pytest.raises(pydantic.ValidationError, match='1 validation error for Project\nparameters\n Value error, ' '"parameters" ClassList contains objects other than ' '"Parameter"'): - RAT.project.Project(parameters=ClassList(input_model())) + RAT.Project(parameters=RAT.ClassList(input_model())) @pytest.mark.parametrize(["input_model", "absorption", "actual_model_name"], [ @@ -177,14 +176,14 @@ def test_initialise_wrong_layers(input_model: Callable, absorption: bool, actual with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\nlayers\n Value error, ' f'"layers" ClassList contains objects other than ' f'"{actual_model_name}"'): - RAT.project.Project(absorption=absorption, layers=ClassList(input_model())) + RAT.Project(absorption=absorption, layers=RAT.ClassList(input_model())) @pytest.mark.parametrize(["input_model", "calculation", "actual_model_name"], [ - (RAT.models.Contrast, RAT.project.Calculations.Domains, "ContrastWithRatio"), - (RAT.models.ContrastWithRatio, RAT.project.Calculations.NonPolarised, "Contrast"), + (RAT.models.Contrast, Calculations.Domains, "ContrastWithRatio"), + (RAT.models.ContrastWithRatio, Calculations.NonPolarised, "Contrast"), ]) -def test_initialise_wrong_contrasts(input_model: Callable, calculation: 'RAT.project.Calculations', +def test_initialise_wrong_contrasts(input_model: Callable, calculation: Calculations, actual_model_name: str) -> None: """If the "Project" model is initialised with the incorrect contrast model given the value of calculation, we should raise a ValidationError. @@ -192,7 +191,7 @@ def test_initialise_wrong_contrasts(input_model: Callable, calculation: 'RAT.pro with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\ncontrasts\n Value error, ' f'"contrasts" ClassList contains objects other than ' f'"{actual_model_name}"'): - RAT.project.Project(calculation=calculation, contrasts=ClassList(input_model())) + RAT.Project(calculation=calculation, contrasts=RAT.ClassList(input_model())) @pytest.mark.parametrize("input_parameter", [ @@ -203,7 +202,7 @@ def test_initialise_without_substrate_roughness(input_parameter: Callable) -> No """If the "Project" model is initialised without "Substrate Roughness as a protected parameter, add it to the front of the "parameters" ClassList. """ - project = RAT.project.Project(parameters=ClassList(RAT.models.Parameter(name='Substrate Roughness'))) + project = RAT.Project(parameters=RAT.ClassList(RAT.models.Parameter(name='Substrate Roughness'))) assert project.parameters[0] == RAT.models.ProtectedParameter(name='Substrate Roughness') @@ -222,7 +221,7 @@ def test_assign_wrong_classes(test_project, field: str, wrong_input_model: Calla with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n{field}\n Value error, ' f'"{field}" ClassList contains objects other than ' f'"{RAT.project.model_in_classlist[field]}"'): - setattr(test_project, field, ClassList(wrong_input_model())) + setattr(test_project, field, RAT.ClassList(wrong_input_model())) @pytest.mark.parametrize(["wrong_input_model", "absorption", "actual_model_name"], [ @@ -231,25 +230,24 @@ def test_assign_wrong_classes(test_project, field: str, wrong_input_model: Calla ]) def test_assign_wrong_layers(wrong_input_model: Callable, absorption: bool, actual_model_name: str) -> None: """If we assign incorrect classes to the "Project" model, we should raise a ValidationError.""" - project = RAT.project.Project(absorption=absorption) + project = RAT.Project(absorption=absorption) with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\nlayers\n Value error, ' f'"layers" ClassList contains objects other than ' f'"{actual_model_name}"'): - setattr(project, 'layers', ClassList(wrong_input_model())) + setattr(project, 'layers', RAT.ClassList(wrong_input_model())) @pytest.mark.parametrize(["wrong_input_model", "calculation", "actual_model_name"], [ - (RAT.models.Contrast, RAT.project.Calculations.Domains, "ContrastWithRatio"), - (RAT.models.ContrastWithRatio, RAT.project.Calculations.NonPolarised, "Contrast"), + (RAT.models.Contrast, Calculations.Domains, "ContrastWithRatio"), + (RAT.models.ContrastWithRatio, Calculations.NonPolarised, "Contrast"), ]) -def test_assign_wrong_contrasts(wrong_input_model: Callable, calculation: 'RAT.project.Calculations', - actual_model_name: str) -> None: +def test_assign_wrong_contrasts(wrong_input_model: Callable, calculation: Calculations, actual_model_name: str) -> None: """If we assign incorrect classes to the "Project" model, we should raise a ValidationError.""" - project = RAT.project.Project(calculation=calculation) + project = RAT.Project(calculation=calculation) with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\ncontrasts\n Value error, ' f'"contrasts" ClassList contains objects other than ' f'"{actual_model_name}"'): - setattr(project, 'contrasts', ClassList(wrong_input_model())) + setattr(project, 'contrasts', RAT.ClassList(wrong_input_model())) @pytest.mark.parametrize("field", [ @@ -287,29 +285,29 @@ def test_set_domain_ratios(test_project) -> None: @pytest.mark.parametrize("project_parameters", [ - ({'calculation': RAT.project.Calculations.NonPolarised, 'model': RAT.project.Models.StandardLayers}), - ({'calculation': RAT.project.Calculations.NonPolarised, 'model': RAT.project.Models.CustomLayers}), - ({'calculation': RAT.project.Calculations.NonPolarised, 'model': RAT.project.Models.CustomXY}), - ({'calculation': RAT.project.Calculations.Domains, 'model': RAT.project.Models.CustomLayers}), - ({'calculation': RAT.project.Calculations.Domains, 'model': RAT.project.Models.CustomXY}), + ({'calculation': Calculations.NonPolarised, 'model': LayerModels.StandardLayers}), + ({'calculation': Calculations.NonPolarised, 'model': LayerModels.CustomLayers}), + ({'calculation': Calculations.NonPolarised, 'model': LayerModels.CustomXY}), + ({'calculation': Calculations.Domains, 'model': LayerModels.CustomLayers}), + ({'calculation': Calculations.Domains, 'model': LayerModels.CustomXY}), ]) def test_set_domain_contrasts(project_parameters: dict) -> None: """If we are not running a domains calculation with standard layers, the "domain_contrasts" field of the model should always be empty. """ - project = RAT.project.Project(**project_parameters) + project = RAT.Project(**project_parameters) assert project.domain_contrasts == [] project.domain_contrasts.append(name='New Domain Contrast') assert project.domain_contrasts == [] @pytest.mark.parametrize("project_parameters", [ - ({'model': RAT.project.Models.CustomLayers}), - ({'model': RAT.project.Models.CustomXY}), + ({'model': LayerModels.CustomLayers}), + ({'model': LayerModels.CustomXY}), ]) def test_set_domain_contrasts(project_parameters: dict) -> None: """If we are not using a standard layers model, the "layers" field of the model should always be empty.""" - project = RAT.project.Project(**project_parameters) + project = RAT.Project(**project_parameters) assert project.layers == [] project.layers.append(name='New Layer') assert project.layers == [] @@ -317,16 +315,15 @@ def test_set_domain_contrasts(project_parameters: dict) -> None: @pytest.mark.parametrize(["input_calculation", "input_contrast", "new_calculation", "new_contrast_model", "num_domain_ratios"], [ - (RAT.project.Calculations.NonPolarised, RAT.models.Contrast, RAT.project.Calculations.Domains, "ContrastWithRatio", 1), - (RAT.project.Calculations.Domains, RAT.models.ContrastWithRatio, RAT.project.Calculations.NonPolarised, "Contrast", 0), + (Calculations.NonPolarised, RAT.models.Contrast, Calculations.Domains, "ContrastWithRatio", 1), + (Calculations.Domains, RAT.models.ContrastWithRatio, Calculations.NonPolarised, "Contrast", 0), ]) -def test_set_calculation(input_calculation: 'RAT.project.Calculations', input_contrast: Callable, - new_calculation: 'RAT.project.Calculations', new_contrast_model: str, - num_domain_ratios: int) -> None: +def test_set_calculation(input_calculation: Calculations, input_contrast: Callable, new_calculation: Calculations, + new_contrast_model: str, num_domain_ratios: int) -> None: """When changing the value of the calculation option, the "contrasts" ClassList should switch to using the appropriate Contrast model. """ - project = RAT.project.Project(calculation=input_calculation, contrasts=ClassList(input_contrast())) + project = RAT.Project(calculation=input_calculation, contrasts=RAT.ClassList(input_contrast())) project.calculation = new_calculation assert project.calculation is new_calculation @@ -336,14 +333,14 @@ def test_set_calculation(input_calculation: 'RAT.project.Calculations', input_co @pytest.mark.parametrize(["new_calc", "new_model", "expected_contrast_model"], [ - (RAT.project.Calculations.NonPolarised, RAT.project.Models.StandardLayers, ['Test Layer']), - (RAT.project.Calculations.NonPolarised, RAT.project.Models.CustomLayers, []), - (RAT.project.Calculations.NonPolarised, RAT.project.Models.CustomXY, []), - (RAT.project.Calculations.Domains, RAT.project.Models.StandardLayers, []), - (RAT.project.Calculations.Domains, RAT.project.Models.CustomLayers, []), - (RAT.project.Calculations.Domains, RAT.project.Models.CustomXY, []), + (Calculations.NonPolarised, LayerModels.StandardLayers, ['Test Layer']), + (Calculations.NonPolarised, LayerModels.CustomLayers, []), + (Calculations.NonPolarised, LayerModels.CustomXY, []), + (Calculations.Domains, LayerModels.StandardLayers, []), + (Calculations.Domains, LayerModels.CustomLayers, []), + (Calculations.Domains, LayerModels.CustomXY, []), ]) -def test_set_contrast_model_field(test_project, new_calc: 'RAT.project.Calculations', new_model: 'RAT.project.Models', +def test_set_contrast_model_field(test_project, new_calc: Calculations, new_model: LayerModels, expected_contrast_model: list[str]) -> None: """If we change the calculation and model such that the values of the "model" field of the "contrasts" model come from a different field of the project, we should clear the contrast "model" field. @@ -354,24 +351,24 @@ def test_set_contrast_model_field(test_project, new_calc: 'RAT.project.Calculati @pytest.mark.parametrize(["input_model", "test_contrast_model", "error_message"], [ - (RAT.project.Models.StandardLayers, ['Test Domain Ratio'], + (LayerModels.StandardLayers, ['Test Domain Ratio'], 'For a standard layers domains calculation the "model" field of "contrasts" must contain exactly two values.'), - (RAT.project.Models.StandardLayers, ['Test Domain Ratio', 'Test Domain Ratio', 'Test Domain Ratio'], + (LayerModels.StandardLayers, ['Test Domain Ratio', 'Test Domain Ratio', 'Test Domain Ratio'], 'For a standard layers domains calculation the "model" field of "contrasts" must contain exactly two values.'), - (RAT.project.Models.CustomLayers, ['Test Custom File', 'Test Custom File'], + (LayerModels.CustomLayers, ['Test Custom File', 'Test Custom File'], 'For a custom model calculation the "model" field of "contrasts" cannot contain more than one value.'), ]) -def test_check_contrast_model_length(test_project, input_model: 'RAT.project.Models', - test_contrast_model: list[str], error_message: str) -> None: +def test_check_contrast_model_length(test_project, input_model: LayerModels, test_contrast_model: list[str], + error_message: str) -> None: """If we are not running a domains calculation with standard layers, the "domain_contrasts" field of the model should always be empty. """ - test_domain_ratios = ClassList(RAT.models.Parameter(name='Test Domain Ratio')) - test_contrasts = ClassList(RAT.models.ContrastWithRatio(model=test_contrast_model)) + test_domain_ratios = RAT.ClassList(RAT.models.Parameter(name='Test Domain Ratio')) + test_contrasts = RAT.ClassList(RAT.models.ContrastWithRatio(model=test_contrast_model)) with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, {error_message}'): - RAT.project.Project(calculation=RAT.project.Calculations.Domains, model=input_model, - domain_ratios=test_domain_ratios, contrasts=test_contrasts) + RAT.Project(calculation=Calculations.Domains, model=input_model, domain_ratios=test_domain_ratios, + contrasts=test_contrasts) @pytest.mark.parametrize(["input_layer", "input_absorption", "new_layer_model"], [ @@ -382,7 +379,7 @@ def test_set_absorption(input_layer: Callable, input_absorption: bool, new_layer """When changing the value of the absorption option, the "layers" ClassList should switch to using the appropriate Layer model. """ - project = RAT.project.Project(absorption=input_absorption, layers=ClassList(input_layer())) + project = RAT.Project(absorption=input_absorption, layers=RAT.ClassList(input_layer())) project.absorption = not input_absorption assert project.absorption is not input_absorption @@ -398,7 +395,7 @@ def test_set_absorption(input_layer: Callable, input_absorption: bool, new_layer ]) def test_check_protected_parameters(delete_operation) -> None: """If we try to remove a protected parameter, we should raise an error.""" - project = RAT.project.Project() + project = RAT.Project() with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, Can\'t delete' f' the protected parameters: Substrate Roughness'): @@ -443,7 +440,7 @@ def test_allowed_backgrounds(field: str) -> None: with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' f'"undefined" in the "{field}" field of "backgrounds" must be ' f'defined in "background_parameters".'): - RAT.project.Project(backgrounds=ClassList(test_background)) + RAT.Project(backgrounds=RAT.ClassList(test_background)) @pytest.mark.parametrize("field", [ @@ -459,7 +456,7 @@ def test_allowed_layers(field: str) -> None: with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' f'"undefined" in the "{field}" field of "layers" must be ' f'defined in "parameters".'): - RAT.project.Project(absorption=False, layers=ClassList(test_layer)) + RAT.Project(absorption=False, layers=RAT.ClassList(test_layer)) @pytest.mark.parametrize("field", [ @@ -476,7 +473,7 @@ def test_allowed_absorption_layers(field: str) -> None: with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' f'"undefined" in the "{field}" field of "layers" must be ' f'defined in "parameters".'): - RAT.project.Project(absorption=True, layers=ClassList(test_layer)) + RAT.Project(absorption=True, layers=RAT.ClassList(test_layer)) @pytest.mark.parametrize("field", [ @@ -494,7 +491,7 @@ def test_allowed_resolutions(field: str) -> None: with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' f'"undefined" in the "{field}" field of "resolutions" must be ' f'defined in "resolution_parameters".'): - RAT.project.Project(resolutions=ClassList(test_resolution)) + RAT.Project(resolutions=RAT.ClassList(test_resolution)) @pytest.mark.parametrize(["field", "model_name"], [ @@ -513,7 +510,7 @@ def test_allowed_contrasts(field: str, model_name: str) -> None: with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' f'"undefined" in the "{field}" field of "contrasts" must be ' f'defined in "{model_name}".'): - RAT.project.Project(calculation=RAT.project.Calculations.NonPolarised, contrasts=ClassList(test_contrast)) + RAT.Project(calculation=Calculations.NonPolarised, contrasts=RAT.ClassList(test_contrast)) @pytest.mark.parametrize(["field", "model_name"], [ @@ -533,32 +530,32 @@ def test_allowed_contrasts_with_ratio(field: str, model_name: str) -> None: with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The value ' f'"undefined" in the "{field}" field of "contrasts" must be ' f'defined in "{model_name}".'): - RAT.project.Project(calculation=RAT.project.Calculations.Domains, contrasts=ClassList(test_contrast)) + RAT.Project(calculation=Calculations.Domains, contrasts=RAT.ClassList(test_contrast)) @pytest.mark.parametrize(["input_calc", "input_model", "test_contrast", "field_name"], [ - (RAT.project.Calculations.Domains, RAT.project.Models.StandardLayers, + (Calculations.Domains, LayerModels.StandardLayers, RAT.models.ContrastWithRatio(name='Test Contrast', model=['undefined', 'undefined']), 'domain_contrasts'), - (RAT.project.Calculations.Domains, RAT.project.Models.CustomLayers, + (Calculations.Domains, LayerModels.CustomLayers, RAT.models.ContrastWithRatio(name='Test Contrast', model=['undefined']), 'custom_files'), - (RAT.project.Calculations.Domains, RAT.project.Models.CustomXY, + (Calculations.Domains, LayerModels.CustomXY, RAT.models.ContrastWithRatio(name='Test Contrast', model=['undefined']), 'custom_files'), - (RAT.project.Calculations.NonPolarised, RAT.project.Models.StandardLayers, + (Calculations.NonPolarised, LayerModels.StandardLayers, RAT.models.Contrast(name='Test Contrast', model=['undefined', 'undefined', 'undefined']), 'layers'), - (RAT.project.Calculations.NonPolarised, RAT.project.Models.CustomLayers, + (Calculations.NonPolarised, LayerModels.CustomLayers, RAT.models.Contrast(name='Test Contrast', model=['undefined']), 'custom_files'), - (RAT.project.Calculations.NonPolarised, RAT.project.Models.CustomXY, + (Calculations.NonPolarised, LayerModels.CustomXY, RAT.models.Contrast(name='Test Contrast', model=['undefined']), 'custom_files'), ]) -def test_allowed_contrast_models(input_calc: 'RAT.project.Calculations', input_model: 'RAT.project.Models', - test_contrast: 'RAT.models', field_name: str) -> None: +def test_allowed_contrast_models(input_calc: Calculations, input_model: LayerModels, test_contrast: 'RAT.models', + field_name: str) -> None: """If any value in the model field of the contrasts is set to a value not specified in the appropriate part of the project, we should raise a ValidationError. """ with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n Value error, The values: ' f'"{", ".join(test_contrast.model)}" in the "model" field of ' f'"contrasts" must be defined in "{field_name}".'): - RAT.project.Project(calculation=input_calc, model=input_model, contrasts=ClassList(test_contrast)) + RAT.Project(calculation=input_calc, model=input_model, contrasts=RAT.ClassList(test_contrast)) def test_allowed_domain_contrast_models() -> None: @@ -569,12 +566,12 @@ def test_allowed_domain_contrast_models() -> None: with pytest.raises(pydantic.ValidationError, match='1 validation error for Project\n Value error, The values: ' '"undefined" in the "model" field of "domain_contrasts" must be ' 'defined in "layers".'): - RAT.project.Project(calculation=RAT.project.Calculations.Domains, domain_contrasts=ClassList(test_contrast)) + RAT.Project(calculation=Calculations.Domains, domain_contrasts=RAT.ClassList(test_contrast)) def test_repr(default_project_repr: str) -> None: """We should be able to print the "Project" model as a formatted list of the fields.""" - assert repr(RAT.project.Project()) == default_project_repr + assert repr(RAT.Project()) == default_project_repr def test_get_all_names(test_project) -> None: @@ -617,7 +614,7 @@ def test_get_all_protected_parameters(test_project) -> None: ]) def test_check_allowed_values(test_value: str) -> None: """We should not raise an error if string values are defined and on the list of allowed values.""" - project = RAT.project.Project.model_construct(backgrounds=ClassList(RAT.models.Background(value_1=test_value))) + project = RAT.Project.model_construct(backgrounds=RAT.ClassList(RAT.models.Background(value_1=test_value))) assert project.check_allowed_values("backgrounds", ["value_1"], ["Background Param 1"]) is None @@ -626,7 +623,7 @@ def test_check_allowed_values(test_value: str) -> None: ]) def test_check_allowed_values_not_on_list(test_value: str) -> None: """If string values are defined and are not included on the list of allowed values we should raise a ValueError.""" - project = RAT.project.Project.model_construct(backgrounds=ClassList(RAT.models.Background(value_1=test_value))) + project = RAT.Project.model_construct(backgrounds=RAT.ClassList(RAT.models.Background(value_1=test_value))) with pytest.raises(ValueError, match=f'The value "{test_value}" in the "value_1" field of "backgrounds" must be ' f'defined in "background_parameters".'): project.check_allowed_values("backgrounds", ["value_1"], ["Background Param 1"]) @@ -639,8 +636,8 @@ def test_check_allowed_values_not_on_list(test_value: str) -> None: def test_check_contrast_model_allowed_values(test_values: list[str]) -> None: """We should not raise an error if values are defined in a non-empty list and all are on the list of allowed values. """ - project = RAT.project.Project.model_construct(contrasts=ClassList(RAT.models.Contrast(name='Test Contrast', - model=test_values))) + project = RAT.Project.model_construct(contrasts=RAT.ClassList(RAT.models.Contrast(name='Test Contrast', + model=test_values))) assert project.check_contrast_model_allowed_values("contrasts", ["Test Layer"], "layers") is None @@ -652,27 +649,26 @@ def test_check_allowed_values_not_on_list(test_values: list[str]) -> None: """If string values are defined in a non-empty list and any of them are not included on the list of allowed values we should raise a ValueError. """ - project = RAT.project.Project.model_construct(contrasts=ClassList(RAT.models.Contrast(name='Test Contrast', - model=test_values))) + project = RAT.Project.model_construct(contrasts=RAT.ClassList(RAT.models.Contrast(name='Test Contrast', + model=test_values))) with pytest.raises(ValueError, match=f'The values: "{", ".join(str(i) for i in test_values)}" in the "model" field ' f'of "contrasts" must be defined in "layers".'): project.check_contrast_model_allowed_values("contrasts", ["Test Layer"], "layers") @pytest.mark.parametrize(["input_calc", "input_model", "expected_field_name"], [ - (RAT.project.Calculations.Domains, RAT.project.Models.StandardLayers, 'domain_contrasts'), - (RAT.project.Calculations.NonPolarised, RAT.project.Models.StandardLayers, 'layers'), - (RAT.project.Calculations.Domains, RAT.project.Models.CustomLayers, 'custom_files'), - (RAT.project.Calculations.NonPolarised, RAT.project.Models.CustomLayers, 'custom_files'), - (RAT.project.Calculations.Domains, RAT.project.Models.CustomXY, 'custom_files'), - (RAT.project.Calculations.NonPolarised, RAT.project.Models.CustomXY, 'custom_files'), + (Calculations.Domains, LayerModels.StandardLayers, 'domain_contrasts'), + (Calculations.NonPolarised, LayerModels.StandardLayers, 'layers'), + (Calculations.Domains, LayerModels.CustomLayers, 'custom_files'), + (Calculations.NonPolarised, LayerModels.CustomLayers, 'custom_files'), + (Calculations.Domains, LayerModels.CustomXY, 'custom_files'), + (Calculations.NonPolarised, LayerModels.CustomXY, 'custom_files'), ]) -def test_get_contrast_model_field(input_calc: 'RAT.project.Calculations', input_model: 'RAT.project.Models', - expected_field_name: str) -> None: +def test_get_contrast_model_field(input_calc: Calculations, input_model: LayerModels, expected_field_name: str) -> None: """Each combination of calculation and model determines the field where the values of "model" field of "contrasts" are defined. """ - project = RAT.project.Project(calculation=input_calc, model=input_model) + project = RAT.Project(calculation=input_calc, model=input_model) assert project.get_contrast_model_field() == expected_field_name