Skip to content

Commit

Permalink
Adds formatted error for extra fields to "set_controls"
Browse files Browse the repository at this point in the history
  • Loading branch information
DrPaulSharp committed Oct 27, 2023
1 parent bcd549c commit 1dbe8dc
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 12 deletions.
11 changes: 9 additions & 2 deletions RAT/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pydantic import BaseModel, Field, field_validator, ValidationError
from typing import Literal, Union

from RAT.utils.custom_errors import formatted_pydantic_error
from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions


Expand Down Expand Up @@ -73,6 +74,7 @@ class Dream(Calculate, validate_assignment=True, extra='forbid'):
def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\
-> Union[Calculate, Simplex, DE, NS, Dream]:
"""Returns the appropriate controls model given the specified procedure."""
model = None
controls = {
Procedures.Calculate: Calculate,
Procedures.Simplex: Simplex,
Expand All @@ -87,7 +89,12 @@ def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\
members = list(Procedures.__members__.values())
allowed_values = ', '.join([repr(member.value) for member in members[:-1]]) + f' or {members[-1].value!r}'
raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None
except ValidationError:
raise
except ValidationError as exc:
custom_msgs = {'extra_forbidden': f'Extra inputs are not permitted. The fields for the {procedure} controls '
f'procedure are:\n {", ".join(controls[procedure].model_fields.keys())}'
}
error_string = formatted_pydantic_error(exc, custom_msgs)
# Use ANSI escape sequences to print error text in red
print('\033[31m' + error_string + '\033[0m')

return model
15 changes: 12 additions & 3 deletions RAT/utils/custom_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,33 @@
from pydantic import ValidationError


def formatted_pydantic_error(error: ValidationError) -> str:
def formatted_pydantic_error(error: ValidationError, custom_error_messages: dict[str, str] = None) -> str:
"""Write a custom string format for pydantic validation errors.
Parameters
----------
error : pydantic.ValidationError
A ValidationError produced by a pydantic model
A ValidationError produced by a pydantic model.
custom_error_messages: dict[str, str], optional
A dict of custom error messages for given error types.
Returns
-------
error_str : str
A string giving details of the ValidationError in a custom format.
"""
if custom_error_messages is None:
custom_error_messages = {}
num_errors = error.error_count()
error_str = f'{num_errors} validation error{"s"[:num_errors!=1]} for {error.title}'

for this_error in error.errors():
error_type = this_error['type']
error_msg = custom_error_messages[error_type] if error_type in custom_error_messages else this_error["msg"]

error_str += '\n'
if this_error['loc']:
error_str += ' '.join(this_error['loc']) + '\n'
error_str += ' ' + this_error['msg']
error_str += f' {error_msg}'

return error_str
28 changes: 25 additions & 3 deletions tests/test_controls.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test the controls module."""

import contextlib
import io
import pytest
import pydantic
from typing import Union, Any
Expand Down Expand Up @@ -530,19 +532,39 @@ def test_control_class_dream_repr(self) -> None:
('dream', Dream)
])
def test_set_controls(procedure: Procedures, expected_model: Union[Calculate, Simplex, DE, NS, Dream]) -> None:
"""Make sure we return the correct model given the value of procedure."""
"""We should return the correct model given the value of procedure."""
controls_model = set_controls(procedure)
assert type(controls_model) == expected_model


def test_set_controls_default_procedure() -> None:
"""Make sure we return the default model when we call "set_controls" without specifying a procedure."""
"""We should return the default model when we call "set_controls" without specifying a procedure."""
controls_model = set_controls()
assert type(controls_model) == Calculate


def test_set_controls_invalid_procedure() -> None:
"""Make sure we return the default model when we call "set_controls" without specifying a procedure."""
"""We should return the default model when we call "set_controls" without specifying a procedure."""
with pytest.raises(ValueError, match="The controls procedure must be one of: 'calculate', 'simplex', 'de', 'ns' "
"or 'dream'"):
set_controls('invalid')


@pytest.mark.parametrize(["procedure", "expected_model"], [
('calculate', Calculate),
('simplex', Simplex),
('de', DE),
('ns', NS),
('dream', Dream)
])
def test_set_controls_extra_fields(procedure: Procedures, expected_model: Union[Calculate, Simplex, DE, NS, Dream])\
-> None:
"""If we provide extra fields to a controls model through "set_controls", we should print a formatted
ValidationError with a custom error message.
"""
with contextlib.redirect_stdout(io.StringIO()) as print_str:
set_controls(procedure, extra_field='invalid')

assert print_str.getvalue() == (f'\033[31m1 validation error for {expected_model.__name__}\nextra_field\n Extra '
f'inputs are not permitted. The fields for the {procedure} controls procedure '
f'are:\n {", ".join(expected_model.model_fields.keys())}\033[0m\n')
24 changes: 20 additions & 4 deletions tests/test_custom_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,31 @@
import RAT.utils.custom_errors


def test_formatted_pydantic_error() -> None:
"""When a pytest ValidationError is raised we should be able to take it and construct a formatted string."""

# Create a custom pydantic model for the test
@pytest.fixture
def TestModel():
"""Create a custom pydantic model for the tests."""
TestModel = create_model('TestModel', int_field=(int, 1), str_field=(str, 'a'))
return TestModel


def test_formatted_pydantic_error(TestModel) -> None:
"""When a pytest ValidationError is raised we should be able to take it and construct a formatted string."""
with pytest.raises(ValidationError) as exc_info:
TestModel(int_field='string', str_field=5)

error_str = RAT.utils.custom_errors.formatted_pydantic_error(exc_info.value)
assert error_str == ('2 validation errors for TestModel\nint_field\n Input should be a valid integer, unable to '
'parse string as an integer\nstr_field\n Input should be a valid string')


def test_formatted_pydantic_error_custom_messages(TestModel) -> None:
"""When a pytest ValidationError is raised we should be able to take it and construct a formatted string,
including the custom error messages provided."""
with pytest.raises(ValidationError) as exc_info:
TestModel(int_field='string', str_field=5)

custom_messages = {'int_parsing': 'This is a custom error message',
'string_type': 'This is another custom error message'}
error_str = RAT.utils.custom_errors.formatted_pydantic_error(exc_info.value, custom_messages)
assert error_str == ('2 validation errors for TestModel\nint_field\n This is a custom error message\n'
'str_field\n This is another custom error message')

0 comments on commit 1dbe8dc

Please sign in to comment.