Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors the controls module, adding custom errors #17

Merged
merged 7 commits into from
Nov 1, 2023
3 changes: 2 additions & 1 deletion RAT/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from RAT.classlist import ClassList
from RAT.controls import Controls
from RAT.project import Project
import RAT.controls
import RAT.models
93 changes: 43 additions & 50 deletions RAT/controls.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import prettytable
from pydantic import BaseModel, Field, field_validator
from typing import Union
from pydantic import BaseModel, Field, field_validator, ValidationError
from typing import Literal, Union

from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions


class BaseProcedure(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the base class with properties used in all five procedures."""
class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the class for the calculate procedure, which includes the properties used in all five procedures."""
procedure: Literal[Procedures.Calculate] = Procedures.Calculate
parallel: ParallelOptions = ParallelOptions.Single
calcSldDuringFit: bool = False
resamPars: list[float] = Field([0.9, 50], min_length=2, max_length=2)
Expand All @@ -21,15 +22,16 @@ def check_resamPars(cls, resamPars):
raise ValueError('resamPars[1] must be greater than or equal to 0')
return resamPars


class Calculate(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the calculate procedure."""
procedure: Procedures = Field(Procedures.Calculate, frozen=True)
def __repr__(self) -> str:
table = prettytable.PrettyTable()
table.field_names = ['Property', 'Value']
table.add_rows([[k, v] for k, v in self.__dict__.items()])
return table.get_string()


class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the simplex procedure."""
procedure: Procedures = Field(Procedures.Simplex, frozen=True)
class Simplex(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the simplex procedure."""
procedure: Literal[Procedures.Simplex] = Procedures.Simplex
tolX: float = Field(1.0e-6, gt=0.0)
tolFun: float = Field(1.0e-6, gt=0.0)
maxFunEvals: int = Field(10000, gt=0)
Expand All @@ -38,9 +40,9 @@ class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'):
updatePlotFreq: int = -1


class DE(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Differential Evolution procedure."""
procedure: Procedures = Field(Procedures.DE, frozen=True)
class DE(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the Differential Evolution procedure."""
procedure: Literal[Procedures.DE] = Procedures.DE
populationSize: int = Field(20, ge=1)
fWeight: float = 0.5
crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0)
Expand All @@ -49,52 +51,43 @@ class DE(BaseProcedure, validate_assignment=True, extra='forbid'):
numGenerations: int = Field(500, ge=1)


class NS(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Nested Sampler procedure."""
procedure: Procedures = Field(Procedures.NS, frozen=True)
class NS(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the Nested Sampler procedure."""
procedure: Literal[Procedures.NS] = Procedures.NS
Nlive: int = Field(150, ge=1)
Nmcmc: float = Field(0.0, ge=0.0)
propScale: float = Field(0.1, gt=0.0, lt=1.0)
nsTolerance: float = Field(0.1, ge=0.0)


class Dream(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Dream procedure."""
procedure: Procedures = Field(Procedures.Dream, frozen=True)
class Dream(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the Dream procedure."""
procedure: Literal[Procedures.Dream] = Procedures.Dream
nSamples: int = Field(50000, ge=0)
nChains: int = Field(10, gt=0)
jumpProb: float = Field(0.5, gt=0.0, lt=1.0)
pUnitGamma: float = Field(0.2, gt=0.0, lt=1.0)
boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold


class Controls:

def __init__(self,
procedure: Procedures = Procedures.Calculate,
**properties) -> None:

if procedure == Procedures.Calculate:
self.controls = Calculate(**properties)
elif procedure == Procedures.Simplex:
self.controls = Simplex(**properties)
elif procedure == Procedures.DE:
self.controls = DE(**properties)
elif procedure == Procedures.NS:
self.controls = NS(**properties)
elif procedure == Procedures.Dream:
self.controls = Dream(**properties)

@property
def controls(self) -> Union[Calculate, Simplex, DE, NS, Dream]:
return self._controls

@controls.setter
def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None:
self._controls = value

def __repr__(self) -> str:
table = prettytable.PrettyTable()
table.field_names = ['Property', 'Value']
table.add_rows([[k, v] for k, v in self._controls.__dict__.items()])
return table.get_string()
def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\
-> Union[Calculate, Simplex, DE, NS, Dream]:
"""Returns the appropriate controls model given the specified procedure."""
controls = {
Procedures.Calculate: Calculate,
Procedures.Simplex: Simplex,
Procedures.DE: DE,
Procedures.NS: NS,
Procedures.Dream: Dream
}

try:
model = controls[procedure](**properties)
except KeyError:
members = list(Procedures.__members__.values())
allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}'
raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None
except ValidationError:
raise

return model
10 changes: 3 additions & 7 deletions RAT/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import collections
import copy
import functools
import logging
import numpy as np
import os
from pydantic import BaseModel, ValidationInfo, field_validator, model_validator, ValidationError
from typing import Any, Callable

from RAT.classlist import ClassList
import RAT.models
from RAT.utils.custom_errors import formatted_pydantic_error
from RAT.utils.custom_errors import formatted_pydantic_error, formatted_traceback

try:
from enum import StrEnum
Expand Down Expand Up @@ -524,12 +525,7 @@ def wrapped_func(*args, **kwargs):
try:
return_value = func(*args, **kwargs)
Project.model_validate(self)
except ValidationError as e:
setattr(class_list, 'data', previous_state)
error_string = formatted_pydantic_error(e)
# Use ANSI escape sequences to print error text in red
print('\033[31m' + error_string + '\033[0m')
except (TypeError, ValueError):
except (TypeError, ValueError, ValidationError):
setattr(class_list, 'data', previous_state)
raise
finally:
Expand Down
24 changes: 21 additions & 3 deletions RAT/utils/custom_errors.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,44 @@
"""Defines routines for custom error handling in RAT."""

from pydantic import ValidationError
import traceback


def formatted_pydantic_error(error: ValidationError) -> str:
def formatted_pydantic_error(error: ValidationError, custom_error_messages: dict[str, str] = None) -> str:
"""Write a custom string format for pydantic validation errors.

Parameters
----------
error : pydantic.ValidationError
A ValidationError produced by a pydantic model
A ValidationError produced by a pydantic model.
custom_error_messages: dict[str, str], optional
A dict of custom error messages for given error types.

Returns
-------
error_str : str
A string giving details of the ValidationError in a custom format.
"""
if custom_error_messages is None:
custom_error_messages = {}
num_errors = error.error_count()
error_str = f'{num_errors} validation error{"s"[:num_errors!=1]} for {error.title}'

for this_error in error.errors():
error_type = this_error['type']
error_msg = custom_error_messages[error_type] if error_type in custom_error_messages else this_error["msg"]

error_str += '\n'
if this_error['loc']:
error_str += ' '.join(this_error['loc']) + '\n'
error_str += ' ' + this_error['msg']
error_str += f' {error_msg}'

return error_str


def formatted_traceback() -> str:
"""Takes the traceback obtained from "traceback.format_exc()" and removes the exception message for pydantic
ValidationErrors.
"""
traceback_string = traceback.format_exc()
return traceback_string.split('pydantic_core._pydantic_core.ValidationError:')[0]
Loading