Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds additional validators and "write_script" routine #15

Merged
merged 9 commits into from
Oct 24, 2023
2 changes: 2 additions & 0 deletions RAT/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from RAT.classlist import ClassList
from RAT.project import Project
75 changes: 56 additions & 19 deletions RAT/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
from typing import Any

try:
from enum import StrEnum
Expand Down Expand Up @@ -43,7 +44,6 @@ class Languages(StrEnum):
class Priors(StrEnum):
Uniform = 'uniform'
Gaussian = 'gaussian'
Jeffreys = 'jeffreys'


class Types(StrEnum):
Expand All @@ -52,7 +52,13 @@ class Types(StrEnum):
Function = 'function'


class Background(BaseModel, validate_assignment=True, extra='forbid'):
class RATModel(BaseModel):
"""A BaseModel where enums are represented by their value."""
def __repr__(self):
return f'{self.__repr_name__()}({", ".join(repr(v) if a is None else f"{a}={v.value!r}" if isinstance(v, StrEnum) else f"{a}={v!r}" for a, v in self.__repr_args__())})'
DrPaulSharp marked this conversation as resolved.
Show resolved Hide resolved


class Background(RATModel, validate_assignment=True, extra='forbid'):
"""Defines the Backgrounds in RAT."""
name: str = Field(default_factory=lambda: 'New Background ' + next(background_number), min_length=1)
type: Types = Types.Constant
Expand All @@ -63,7 +69,7 @@ class Background(BaseModel, validate_assignment=True, extra='forbid'):
value_5: str = ''


class Contrast(BaseModel, validate_assignment=True, extra='forbid'):
class Contrast(RATModel, validate_assignment=True, extra='forbid'):
"""Groups together all of the components of the model."""
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1)
data: str = ''
Expand All @@ -76,7 +82,7 @@ class Contrast(BaseModel, validate_assignment=True, extra='forbid'):
model: list[str] = []


class ContrastWithRatio(BaseModel, validate_assignment=True, extra='forbid'):
class ContrastWithRatio(RATModel, validate_assignment=True, extra='forbid'):
"""Groups together all of the components of the model including domain terms."""
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1)
data: str = ''
Expand All @@ -90,20 +96,20 @@ class ContrastWithRatio(BaseModel, validate_assignment=True, extra='forbid'):
model: list[str] = []


class CustomFile(BaseModel, validate_assignment=True, extra='forbid'):
class CustomFile(RATModel, validate_assignment=True, extra='forbid'):
"""Defines the files containing functions to run when using custom models."""
name: str = Field(default_factory=lambda: 'New Custom File ' + next(custom_file_number), min_length=1)
filename: str = ''
language: Languages = Languages.Python
path: str = 'pwd' # Should later expand to find current file path


class Data(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
class Data(RATModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
"""Defines the dataset required for each contrast."""
name: str = Field(default_factory=lambda: 'New Data ' + next(data_number), min_length=1)
data: np.ndarray[float] = np.empty([0, 3])
data_range: list[float] = []
simulation_range: list[float] = [0.005, 0.7]
data: np.ndarray[np.float64] = np.empty([0, 3])
data_range: list[float] = Field(default=[], min_length=2, max_length=2)
simulation_range: list[float] = Field(default=[], min_length=2, max_length=2)

@field_validator('data')
@classmethod
Expand All @@ -120,22 +126,53 @@ def check_data_dimension(cls, data: np.ndarray[float]) -> np.ndarray[float]:

@field_validator('data_range', 'simulation_range')
@classmethod
def check_list_elements(cls, limits: list[float], info: ValidationInfo) -> list[float]:
"""The data range and simulation range must contain exactly two parameters."""
if len(limits) != 2:
raise ValueError(f'{info.field_name} must contain exactly two values')
def check_min_max(cls, limits: list[float], info: ValidationInfo) -> list[float]:
"""The data range and simulation range maximum must be greater than the minimum."""
if limits[0] > limits[1]:
raise ValueError(f'{info.field_name} "min" value is greater than the "max" value')
return limits

# Also need model validators for data range compared to data etc -- need more details.
def model_post_init(self, __context: Any) -> None:
"""If the "data_range" and "simulation_range" fields are not set, but "data" is supplied, the ranges should be
set to the min and max values of the first column (assumed to be q) of the supplied data.
"""
if len(self.data[:, 0]) > 0:
data_min = np.min(self.data[:, 0])
data_max = np.max(self.data[:, 0])
for field in ["data_range", "simulation_range"]:
if field not in self.model_fields_set:
getattr(self, field).extend([data_min, data_max])

@model_validator(mode='after')
def check_ranges(self) -> 'Data':
"""The limits of the "data_range" field must lie within the range of the supplied data, whilst the limits
of the "simulation_range" field must lie outside of the range of the supplied data.
"""
if len(self.data[:, 0]) > 0:
data_min = np.min(self.data[:, 0])
data_max = np.max(self.data[:, 0])
if "data_range" in self.model_fields_set and (self.data_range[0] < data_min or
self.data_range[1] > data_max):
raise ValueError(f'The data_range value of: {self.data_range} must lie within the min/max values of '
f'the data: [{data_min}, {data_max}]')
if "simulation_range" in self.model_fields_set and (self.simulation_range[0] > data_min or
self.simulation_range[1] < data_max):
raise ValueError(f'The simulation_range value of: {self.simulation_range} must lie outside of the '
f'min/max values of the data: [{data_min}, {data_max}]')
return self

def __repr__(self):
"""Only include the name if the data is empty."""
return f'{self.__repr_name__()}({f"name={self.name!r}" if self.data.size == 0 else ", ".join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__())})'
DrPaulSharp marked this conversation as resolved.
Show resolved Hide resolved


class DomainContrast(BaseModel, validate_assignment=True, extra='forbid'):
class DomainContrast(RATModel, validate_assignment=True, extra='forbid'):
"""Groups together the layers required for each domain."""
name: str = Field(default_factory=lambda: 'New Domain Contrast ' + next(domain_contrast_number), min_length=1)
model: list[str] = []


class Layer(BaseModel, validate_assignment=True, extra='forbid', populate_by_name=True):
class Layer(RATModel, validate_assignment=True, extra='forbid', populate_by_name=True):
"""Combines parameters into defined layers."""
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number), min_length=1)
thickness: str = ''
Expand All @@ -145,7 +182,7 @@ class Layer(BaseModel, validate_assignment=True, extra='forbid', populate_by_nam
hydrate_with: Hydration = Hydration.BulkOut


class AbsorptionLayer(BaseModel, validate_assignment=True, extra='forbid', populate_by_name=True):
class AbsorptionLayer(RATModel, validate_assignment=True, extra='forbid', populate_by_name=True):
"""Combines parameters into defined layers including absorption terms."""
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number), min_length=1)
thickness: str = ''
Expand All @@ -156,7 +193,7 @@ class AbsorptionLayer(BaseModel, validate_assignment=True, extra='forbid', popul
hydrate_with: Hydration = Hydration.BulkOut


class Parameter(BaseModel, validate_assignment=True, extra='forbid'):
class Parameter(RATModel, validate_assignment=True, extra='forbid'):
"""Defines parameters needed to specify the model."""
name: str = Field(default_factory=lambda: 'New Parameter ' + next(parameter_number), min_length=1)
min: float = 0.0
Expand All @@ -180,7 +217,7 @@ class ProtectedParameter(Parameter, validate_assignment=True, extra='forbid'):
name: str = Field(frozen=True, min_length=1)


class Resolution(BaseModel, validate_assignment=True, extra='forbid'):
class Resolution(RATModel, validate_assignment=True, extra='forbid'):
"""Defines Resolutions in RAT."""
name: str = Field(default_factory=lambda: 'New Resolution ' + next(resolution_number), min_length=1)
type: Types = Types.Constant
Expand Down
121 changes: 100 additions & 21 deletions RAT/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ class Geometries(StrEnum):
'resolutions': AllFields('contrasts', ['resolution']),
}

class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'domain_ratios',
'background_parameters', 'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data',
'layers', 'domain_contrasts', 'contrasts']
parameter_class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'domain_ratios',
'background_parameters', 'resolution_parameters']
class_lists = parameter_class_lists + ['backgrounds', 'resolutions', 'custom_files', 'data', 'layers',
DrPaulSharp marked this conversation as resolved.
Show resolved Hide resolved
'domain_contrasts', 'contrasts']


class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
Expand Down Expand Up @@ -152,6 +153,7 @@ class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_typ

_all_names: dict
_contrast_model_field: str
_protected_parameters: dict

@field_validator('parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters',
'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers',
Expand Down Expand Up @@ -196,12 +198,20 @@ def model_post_init(self, __context: Any) -> None:
if not hasattr(field, "_class_handle"):
setattr(field, "_class_handle", getattr(RAT.models, model))

self.parameters.insert(0, RAT.models.ProtectedParameter(name='Substrate Roughness', min=1, value=3, max=5,
fit=True, prior_type=RAT.models.Priors.Uniform, mu=0,
sigma=np.inf))
if 'Substrate Roughness' not in self.parameters.get_names():
self.parameters.insert(0, RAT.models.ProtectedParameter(name='Substrate Roughness', min=1.0, value=3.0,
max=5.0, fit=True,
prior_type=RAT.models.Priors.Uniform, mu=0.0,
sigma=np.inf))
elif 'Substrate Roughness' not in self.get_all_protected_parameters().values():
# If substrate roughness is included as a standard parameter replace it with a protected parameter
DrPaulSharp marked this conversation as resolved.
Show resolved Hide resolved
substrate_roughness_values = self.parameters[self.parameters.index('Substrate Roughness')].model_dump()
self.parameters.remove('Substrate Roughness')
self.parameters.insert(0, RAT.models.ProtectedParameter(**substrate_roughness_values))

self._all_names = self.get_all_names()
self._contrast_model_field = self.get_contrast_model_field()
self._protected_parameters = self.get_all_protected_parameters()

# Wrap ClassList routines - when any of these routines are called, the wrapper will force revalidation of the
# model, handle errors and reset previous values if necessary.
Expand Down Expand Up @@ -229,6 +239,13 @@ def set_domain_contrasts(self) -> 'Project':
self.domain_contrasts.data = []
return self

@model_validator(mode='after')
def set_layers(self) -> 'Project':
"""If we are not using a standard layers model, ensure the layers component of the model is empty."""
if self.model != ModelTypes.StandardLayers:
self.layers.data = []
return self

@model_validator(mode='after')
def set_calc_type(self) -> 'Project':
"""Apply the calc_type setting to the project."""
Expand All @@ -238,6 +255,9 @@ def set_calc_type(self) -> 'Project':
for contrast in self.contrasts:
contrast_list.append(RAT.models.ContrastWithRatio(**contrast.model_dump()))
self.contrasts.data = contrast_list
self.domain_ratios.data = [RAT.models.Parameter(name='Domain Ratio 1', min=0.4, value=0.5, max=0.6,
fit=False, prior_type=RAT.models.Priors.Uniform, mu=0.0,
sigma=np.inf)]
setattr(self.contrasts, '_class_handle', getattr(RAT.models, 'ContrastWithRatio'))
elif self.calc_type != CalcTypes.Domains and handle == 'ContrastWithRatio':
for contrast in self.contrasts:
Expand Down Expand Up @@ -280,21 +300,20 @@ def check_contrast_model_length(self) -> 'Project':
@model_validator(mode='after')
def set_absorption(self) -> 'Project':
"""Apply the absorption setting to the project."""
if hasattr(self, 'layers'):
layer_list = []
handle = getattr(self.layers, '_class_handle').__name__
if self.absorption and handle == 'Layer':
for layer in self.layers:
layer_list.append(RAT.models.AbsorptionLayer(**layer.model_dump()))
self.layers.data = layer_list
setattr(self.layers, '_class_handle', getattr(RAT.models, 'AbsorptionLayer'))
elif not self.absorption and handle == 'AbsorptionLayer':
for layer in self.layers:
layer_params = layer.model_dump()
del layer_params['SLD_imaginary']
layer_list.append(RAT.models.Layer(**layer_params))
self.layers.data = layer_list
setattr(self.layers, '_class_handle', getattr(RAT.models, 'Layer'))
layer_list = []
handle = getattr(self.layers, '_class_handle').__name__
if self.absorption and handle == 'Layer':
for layer in self.layers:
layer_list.append(RAT.models.AbsorptionLayer(**layer.model_dump()))
self.layers.data = layer_list
setattr(self.layers, '_class_handle', getattr(RAT.models, 'AbsorptionLayer'))
elif not self.absorption and handle == 'AbsorptionLayer':
for layer in self.layers:
layer_params = layer.model_dump()
del layer_params['SLD_imaginary']
layer_list.append(RAT.models.Layer(**layer_params))
self.layers.data = layer_list
setattr(self.layers, '_class_handle', getattr(RAT.models, 'Layer'))
return self

@model_validator(mode='after')
Expand Down Expand Up @@ -337,6 +356,20 @@ def cross_check_model_values(self) -> 'Project':
self.check_contrast_model_allowed_values('domain_contrasts', self.layers.get_names(), 'layers')
return self

@model_validator(mode='after')
def check_protected_parameters(self) -> 'Project':
"""Protected parameters should not be deleted. If this is attempted, raise an error."""
for class_list in parameter_class_lists:
protected_parameters = [param.name for param in getattr(self, class_list)
if isinstance(param, RAT.models.ProtectedParameter)]
# All previously existing protected parameters should be present in new list
if not all(element in protected_parameters for element in self._protected_parameters[class_list]):
removed_params = [param for param in self._protected_parameters[class_list]
if param not in protected_parameters]
raise ValueError(f'Can\'t delete the protected parameters: {", ".join(str(i) for i in removed_params)}')
self._protected_parameters = self.get_all_protected_parameters()
return self

def __repr__(self):
output = ''
for key, value in self.__dict__.items():
Expand All @@ -354,6 +387,12 @@ def get_all_names(self):
"""Record the names of all models defined in the project."""
return {class_list: getattr(self, class_list).get_names() for class_list in class_lists}

def get_all_protected_parameters(self):
"""Record the protected parameters defined in the project."""
return {class_list: [param.name for param in getattr(self, class_list)
if isinstance(param, RAT.models.ProtectedParameter)]
for class_list in parameter_class_lists}

def check_allowed_values(self, attribute: str, field_list: list[str], allowed_values: list[str]) -> None:
"""Check the values of the given fields in the given model are in the supplied list of allowed values.

Expand Down Expand Up @@ -422,6 +461,46 @@ def get_contrast_model_field(self):
model_field = 'custom_files'
return model_field

def write_script(self, obj_name: str = 'problem', script: str = 'project_script.py'):
"""Write a python script that can be run to reproduce this project object.

Parameters
----------
obj_name : str, optional
The name given to the project object under construction (default is "problem").
script : str, optional
The filename of the generated script (default is "project_script.py").
"""
# Need to ensure correct format for script name
file_parts = script.split('.')
DrPaulSharp marked this conversation as resolved.
Show resolved Hide resolved

try:
file_parts[1]
except IndexError:
script += '.py'
else:
if file_parts[1] != 'py':
raise ValueError('The script name provided to "write_script" must use the ".py" format')

indent = 4 * " "

with open(script, 'w') as f:

f.write('# THIS FILE IS GENERATED FROM RAT VIA THE "WRITE_SCRIPT" ROUTINE. IT IS NOT PART OF THE RAT CODE.'
'\n\n')

# Need imports
f.write('import RAT\nfrom RAT.models import *\nfrom numpy import array, inf\n\n')

f.write(f"{obj_name} = RAT.Project(\n{indent}name='{self.name}', calc_type='{self.calc_type}',"
f" model='{self.model}', geometry='{self.geometry}', absorption={self.absorption},\n")

for class_list in class_lists:
contents = getattr(self, class_list).data
if contents:
f.write(f'{indent}{class_list}=RAT.ClassList({contents}),\n')
f.write(f'{indent})\n')

def _classlist_wrapper(self, class_list: 'ClassList', func: Callable):
"""Defines the function used to wrap around ClassList routines to force revalidation.

Expand Down
Loading