Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors the controls module, adding custom errors #17

Merged
merged 7 commits into from
Nov 1, 2023
3 changes: 2 additions & 1 deletion RAT/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from RAT.classlist import ClassList
from RAT.controls import Controls
from RAT.project import Project
import RAT.controls
import RAT.models
99 changes: 49 additions & 50 deletions RAT/controls.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import prettytable
from pydantic import BaseModel, Field, field_validator
from typing import Union
from pydantic import BaseModel, Field, field_validator, ValidationError
from typing import Literal, Union

from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions
from RAT.utils.custom_errors import custom_pydantic_validation_error


class BaseProcedure(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the base class with properties used in all five procedures."""
class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the class for the calculate procedure, which includes the properties used in all five procedures."""
procedure: Literal[Procedures.Calculate] = Procedures.Calculate
parallel: ParallelOptions = ParallelOptions.Single
calcSldDuringFit: bool = False
resamPars: list[float] = Field([0.9, 50], min_length=2, max_length=2)
Expand All @@ -21,15 +23,16 @@ def check_resamPars(cls, resamPars):
raise ValueError('resamPars[1] must be greater than or equal to 0')
return resamPars


class Calculate(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the calculate procedure."""
procedure: Procedures = Field(Procedures.Calculate, frozen=True)
def __repr__(self) -> str:
table = prettytable.PrettyTable()
table.field_names = ['Property', 'Value']
table.add_rows([[k, v] for k, v in self.__dict__.items()])
return table.get_string()


class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the simplex procedure."""
procedure: Procedures = Field(Procedures.Simplex, frozen=True)
class Simplex(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the simplex procedure."""
procedure: Literal[Procedures.Simplex] = Procedures.Simplex
tolX: float = Field(1.0e-6, gt=0.0)
tolFun: float = Field(1.0e-6, gt=0.0)
maxFunEvals: int = Field(10000, gt=0)
Expand All @@ -38,9 +41,9 @@ class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'):
updatePlotFreq: int = -1


class DE(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Differential Evolution procedure."""
procedure: Procedures = Field(Procedures.DE, frozen=True)
class DE(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the Differential Evolution procedure."""
procedure: Literal[Procedures.DE] = Procedures.DE
populationSize: int = Field(20, ge=1)
fWeight: float = 0.5
crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0)
Expand All @@ -49,52 +52,48 @@ class DE(BaseProcedure, validate_assignment=True, extra='forbid'):
numGenerations: int = Field(500, ge=1)


class NS(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Nested Sampler procedure."""
procedure: Procedures = Field(Procedures.NS, frozen=True)
class NS(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the Nested Sampler procedure."""
procedure: Literal[Procedures.NS] = Procedures.NS
Nlive: int = Field(150, ge=1)
Nmcmc: float = Field(0.0, ge=0.0)
propScale: float = Field(0.1, gt=0.0, lt=1.0)
nsTolerance: float = Field(0.1, ge=0.0)


class Dream(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Dream procedure."""
procedure: Procedures = Field(Procedures.Dream, frozen=True)
class Dream(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the Dream procedure."""
procedure: Literal[Procedures.Dream] = Procedures.Dream
nSamples: int = Field(50000, ge=0)
nChains: int = Field(10, gt=0)
jumpProb: float = Field(0.5, gt=0.0, lt=1.0)
pUnitGamma: float = Field(0.2, gt=0.0, lt=1.0)
boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold


class Controls:

def __init__(self,
procedure: Procedures = Procedures.Calculate,
**properties) -> None:

if procedure == Procedures.Calculate:
self.controls = Calculate(**properties)
elif procedure == Procedures.Simplex:
self.controls = Simplex(**properties)
elif procedure == Procedures.DE:
self.controls = DE(**properties)
elif procedure == Procedures.NS:
self.controls = NS(**properties)
elif procedure == Procedures.Dream:
self.controls = Dream(**properties)

@property
def controls(self) -> Union[Calculate, Simplex, DE, NS, Dream]:
return self._controls

@controls.setter
def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None:
self._controls = value

def __repr__(self) -> str:
table = prettytable.PrettyTable()
table.field_names = ['Property', 'Value']
table.add_rows([[k, v] for k, v in self._controls.__dict__.items()])
return table.get_string()
def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\
-> Union[Calculate, Simplex, DE, NS, Dream]:
"""Returns the appropriate controls model given the specified procedure."""
controls = {
Procedures.Calculate: Calculate,
Procedures.Simplex: Simplex,
Procedures.DE: DE,
Procedures.NS: NS,
Procedures.Dream: Dream
}

try:
model = controls[procedure](**properties)
except KeyError:
members = list(Procedures.__members__.values())
allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}'
raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None
except ValidationError as exc:
custom_error_msgs = {'extra_forbidden': f'Extra inputs are not permitted. The fields for the {procedure}'
f' controls procedure are:\n '
f'{", ".join(controls[procedure].model_fields.keys())}\n'
}
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None

return model
9 changes: 4 additions & 5 deletions RAT/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from RAT.classlist import ClassList
import RAT.models
from RAT.utils.custom_errors import formatted_pydantic_error
from RAT.utils.custom_errors import custom_pydantic_validation_error

try:
from enum import StrEnum
Expand Down Expand Up @@ -524,11 +524,10 @@ def wrapped_func(*args, **kwargs):
try:
return_value = func(*args, **kwargs)
Project.model_validate(self)
except ValidationError as e:
except ValidationError as exc:
setattr(class_list, 'data', previous_state)
error_string = formatted_pydantic_error(e)
# Use ANSI escape sequences to print error text in red
print('\033[31m' + error_string + '\033[0m')
custom_error_list = custom_pydantic_validation_error(exc.errors())
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None
except (TypeError, ValueError):
setattr(class_list, 'data', previous_state)
raise
Expand Down
40 changes: 25 additions & 15 deletions RAT/utils/custom_errors.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
"""Defines routines for custom error handling in RAT."""
import pydantic_core

from pydantic import ValidationError

def custom_pydantic_validation_error(error_list: list[pydantic_core.ErrorDetails], custom_errors: dict[str, str] = None
) -> list[pydantic_core.ErrorDetails]:
"""Run through the list of errors generated from a pydantic ValidationError, substituting the standard error for a
PydanticCustomError for a given set of error types.

def formatted_pydantic_error(error: ValidationError) -> str:
"""Write a custom string format for pydantic validation errors.
For errors that do not have a custom error message defined, we redefine them using a PydanticCustomError to remove
the url from the error message.

Parameters
----------
error : pydantic.ValidationError
A ValidationError produced by a pydantic model
error_list : list[pydantic_core.ErrorDetails]
A list of errors produced by pydantic.ValidationError.errors().
custom_errors: dict[str, str], optional
A dict of custom error messages for given error types.

Returns
-------
error_str : str
A string giving details of the ValidationError in a custom format.
new_error : list[pydantic_core.ErrorDetails]
A list of errors including PydanticCustomErrors in place of the error types in custom_errors.
"""
num_errors = error.error_count()
error_str = f'{num_errors} validation error{"s"[:num_errors!=1]} for {error.title}'
for this_error in error.errors():
error_str += '\n'
if this_error['loc']:
error_str += ' '.join(this_error['loc']) + '\n'
error_str += ' ' + this_error['msg']
return error_str
if custom_errors is None:
custom_errors = {}
custom_error_list = []
for error in error_list:
if error['type'] in custom_errors:
RAT_custom_error = pydantic_core.PydanticCustomError(error['type'], custom_errors[error['type']])
else:
RAT_custom_error = pydantic_core.PydanticCustomError(error['type'], error['msg'])
error['type'] = RAT_custom_error
custom_error_list.append(error)

return custom_error_list
Loading