Skip to content

Commit

Permalink
Introduces ruff as linter and formatter (#38)
Browse files Browse the repository at this point in the history
* Adds ruff and resolves linting errors on standard rule set

* Resolves automatically fixable linting errors on advanced rule set

* Resolves most manually fixable linting errors on advanced rule set

* Switches code to use double quotes

* Resolves automatically fixable linting errors on full rule set

* Applies ruff formatter

* Finalises rule selection and tidies up code

* Adds "requirements-dev.txt"

* Adds "requirements-dev.txt"

* Adds new github action for linter and formatter

* Addresses review comments
  • Loading branch information
DrPaulSharp authored Jun 26, 2024
1 parent a14835e commit a7842e6
Show file tree
Hide file tree
Showing 46 changed files with 9,475 additions and 4,484 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/run_ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: Ruff

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
- uses: chartboost/ruff-action@v1
with:
args: 'format --check'

3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ docs/*.inv
build/*
dist/*
*.whl

# Local pre-commit hooks
.pre-commit-config.yaml
9 changes: 6 additions & 3 deletions RAT/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os

from RAT import models
from RAT.classlist import ClassList
from RAT.project import Project
from RAT.controls import set_controls
from RAT.project import Project
from RAT.run import run
import RAT.models

__all__ = ["ClassList", "Project", "run", "set_controls", "models"]

dir_path = os.path.dirname(os.path.realpath(__file__))
os.environ["RAT_PATH"] = os.path.join(dir_path, '')
os.environ["RAT_PATH"] = os.path.join(dir_path, "")
81 changes: 56 additions & 25 deletions RAT/classlist.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""The classlist module. Contains the ClassList class, which defines a list containing instances of a particular class.
"""The classlist module. Contains the ClassList class, which defines a list containing instances of a particular
class.
"""

import collections
from collections.abc import Iterable, Sequence
import contextlib
import prettytable
from typing import Any, Union
import warnings
from collections.abc import Iterable, Sequence
from typing import Any, Union

import prettytable


class ClassList(collections.UserList):
Expand All @@ -31,7 +33,9 @@ class ClassList(collections.UserList):
An instance, or list of instance(s), of the class to be used in this ClassList.
name_field : str, optional
The field used to define unique objects in the ClassList (default is "name").
"""

def __init__(self, init_list: Union[Sequence[object], object] = None, name_field: str = "name") -> None:
self.name_field = name_field

Expand All @@ -56,7 +60,7 @@ def __repr__(self):
else:
if any(model.__dict__ for model in self.data):
table = prettytable.PrettyTable()
table.field_names = ['index'] + [key.replace('_', ' ') for key in self.data[0].__dict__.keys()]
table.field_names = ["index"] + [key.replace("_", " ") for key in self.data[0].__dict__]
table.add_rows([[index] + list(model.__dict__.values()) for index, model in enumerate(self.data)])
output = table.get_string()
else:
Expand All @@ -81,15 +85,15 @@ def _delitem(self, index: int) -> None:
"""Auxiliary routine of "__delitem__" used to enable wrapping."""
del self.data[index]

def __iadd__(self, other: Sequence[object]) -> 'ClassList':
def __iadd__(self, other: Sequence[object]) -> "ClassList":
"""Define in-place addition using the "+=" operator."""
return self._iadd(other)

def _iadd(self, other: Sequence[object]) -> 'ClassList':
def _iadd(self, other: Sequence[object]) -> "ClassList":
"""Auxiliary routine of "__iadd__" used to enable wrapping."""
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
other = [other]
if not hasattr(self, '_class_handle'):
if not hasattr(self, "_class_handle"):
self._class_handle = self._determine_class_handle(self + other)
self._check_classes(self + other)
self._check_unique_name_fields(self + other)
Expand Down Expand Up @@ -129,20 +133,27 @@ def append(self, obj: object = None, **kwargs) -> None:
SyntaxWarning
Raised if the input arguments contain BOTH an object and keyword arguments. In this situation the object is
appended to the ClassList and the keyword arguments are discarded.
"""
if obj and kwargs:
warnings.warn('ClassList.append() called with both an object and keyword arguments. '
'The keyword arguments will be ignored.', SyntaxWarning)
warnings.warn(
"ClassList.append() called with both an object and keyword arguments. "
"The keyword arguments will be ignored.",
SyntaxWarning,
stacklevel=2,
)
if obj:
if not hasattr(self, '_class_handle'):
if not hasattr(self, "_class_handle"):
self._class_handle = type(obj)
self._check_classes(self + [obj])
self._check_unique_name_fields(self + [obj])
self.data.append(obj)
else:
if not hasattr(self, '_class_handle'):
raise TypeError('ClassList.append() called with keyword arguments for a ClassList without a class '
'defined. Call ClassList.append() with an object to define the class.')
if not hasattr(self, "_class_handle"):
raise TypeError(
"ClassList.append() called with keyword arguments for a ClassList without a class "
"defined. Call ClassList.append() with an object to define the class.",
)
self._validate_name_field(kwargs)
self.data.append(self._class_handle(**kwargs))

Expand All @@ -169,20 +180,27 @@ def insert(self, index: int, obj: object = None, **kwargs) -> None:
SyntaxWarning
Raised if the input arguments contain both an object and keyword arguments. In this situation the object is
inserted into the ClassList and the keyword arguments are discarded.
"""
if obj and kwargs:
warnings.warn('ClassList.insert() called with both an object and keyword arguments. '
'The keyword arguments will be ignored.', SyntaxWarning)
warnings.warn(
"ClassList.insert() called with both an object and keyword arguments. "
"The keyword arguments will be ignored.",
SyntaxWarning,
stacklevel=2,
)
if obj:
if not hasattr(self, '_class_handle'):
if not hasattr(self, "_class_handle"):
self._class_handle = type(obj)
self._check_classes(self + [obj])
self._check_unique_name_fields(self + [obj])
self.data.insert(index, obj)
else:
if not hasattr(self, '_class_handle'):
raise TypeError('ClassList.insert() called with keyword arguments for a ClassList without a class '
'defined. Call ClassList.insert() with an object to define the class.')
if not hasattr(self, "_class_handle"):
raise TypeError(
"ClassList.insert() called with keyword arguments for a ClassList without a class "
"defined. Call ClassList.insert() with an object to define the class.",
)
self._validate_name_field(kwargs)
self.data.insert(index, self._class_handle(**kwargs))

Expand All @@ -209,7 +227,7 @@ def extend(self, other: Sequence[object]) -> None:
"""Extend the ClassList by adding another sequence."""
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
other = [other]
if not hasattr(self, '_class_handle'):
if not hasattr(self, "_class_handle"):
self._class_handle = self._determine_class_handle(self + other)
self._check_classes(self + other)
self._check_unique_name_fields(self + other)
Expand All @@ -229,6 +247,7 @@ def get_names(self) -> list[str]:
-------
names : list [str]
The value of the name_field attribute of each object in the ClassList.
"""
return [getattr(model, self.name_field) for model in self.data if hasattr(model, self.name_field)]

Expand All @@ -244,9 +263,14 @@ def get_all_matches(self, value: Any) -> list[tuple]:
-------
: list [tuple]
A list of (index, field) tuples matching the given value.
"""
return [(index, field) for index, element in enumerate(self.data) for field in vars(element)
if getattr(element, field) == value]
return [
(index, field)
for index, element in enumerate(self.data)
for field in vars(element)
if getattr(element, field) == value
]

def _validate_name_field(self, input_args: dict[str, Any]) -> None:
"""Raise a ValueError if the name_field attribute is passed as an object parameter, and its value is already
Expand All @@ -261,12 +285,15 @@ def _validate_name_field(self, input_args: dict[str, Any]) -> None:
------
ValueError
Raised if the input arguments contain a name_field value already defined in the ClassList.
"""
names = self.get_names()
with contextlib.suppress(KeyError):
if input_args[self.name_field] in names:
raise ValueError(f"Input arguments contain the {self.name_field} '{input_args[self.name_field]}', "
f"which is already specified in the ClassList")
raise ValueError(
f"Input arguments contain the {self.name_field} '{input_args[self.name_field]}', "
f"which is already specified in the ClassList",
)

def _check_unique_name_fields(self, input_list: Iterable[object]) -> None:
"""Raise a ValueError if any value of the name_field attribute is used more than once in a list of class
Expand All @@ -281,6 +308,7 @@ def _check_unique_name_fields(self, input_list: Iterable[object]) -> None:
------
ValueError
Raised if the input list defines more than one object with the same value of name_field.
"""
names = [getattr(model, self.name_field) for model in input_list if hasattr(model, self.name_field)]
if len(set(names)) != len(names):
Expand All @@ -298,6 +326,7 @@ def _check_classes(self, input_list: Iterable[object]) -> None:
------
ValueError
Raised if the input list defines objects of different types.
"""
if not (all(isinstance(element, self._class_handle) for element in input_list)):
raise ValueError(f"Input list contains elements of type other than '{self._class_handle.__name__}'")
Expand All @@ -315,6 +344,7 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object,
instance : object or str
Either the object with the value of the name_field attribute given by value, or the input value if an
object with that value of the name_field attribute cannot be found.
"""
return next((model for model in self.data if getattr(model, self.name_field) == value), value)

Expand All @@ -333,6 +363,7 @@ def _determine_class_handle(input_list: Sequence[object]):
class_handle : type
The type object of the element fulfilling the condition of satisfying "issubclass" for all of the other
elements.
"""
for this_element in input_list:
if all([issubclass(type(instance), type(this_element)) for instance in input_list]):
Expand Down
40 changes: 25 additions & 15 deletions RAT/controls.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from dataclasses import dataclass, field
import prettytable
from pydantic import BaseModel, Field, field_validator, ValidationError
from typing import Literal, Union

from RAT.utils.enums import Parallel, Procedures, Display, BoundHandling, Strategies
import prettytable
from pydantic import BaseModel, Field, ValidationError, field_validator

from RAT.utils.custom_errors import custom_pydantic_validation_error
from RAT.utils.enums import BoundHandling, Display, Parallel, Procedures, Strategies


@dataclass(frozen=True)
class Controls:
"""The full set of controls parameters required for the compiled RAT code."""

# All Procedures
procedure: Procedures = Procedures.Calculate
parallel: Parallel = Parallel.Single
Expand Down Expand Up @@ -44,8 +46,9 @@ class Controls:
adaptPCR: bool = False


class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
class Calculate(BaseModel, validate_assignment=True, extra="forbid"):
"""Defines the class for the calculate procedure, which includes the properties used in all five procedures."""

procedure: Literal[Procedures.Calculate] = Procedures.Calculate
parallel: Parallel = Parallel.Single
calcSldDuringFit: bool = False
Expand All @@ -56,20 +59,21 @@ class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
@classmethod
def check_resample_params(cls, resampleParams):
if not 0 < resampleParams[0] < 1:
raise ValueError('resampleParams[0] must be between 0 and 1')
raise ValueError("resampleParams[0] must be between 0 and 1")
if resampleParams[1] < 0:
raise ValueError('resampleParams[1] must be greater than or equal to 0')
raise ValueError("resampleParams[1] must be greater than or equal to 0")
return resampleParams

def __repr__(self) -> str:
table = prettytable.PrettyTable()
table.field_names = ['Property', 'Value']
table.field_names = ["Property", "Value"]
table.add_rows([[k, v] for k, v in self.__dict__.items()])
return table.get_string()


class Simplex(Calculate):
"""Defines the additional fields for the simplex procedure."""

procedure: Literal[Procedures.Simplex] = Procedures.Simplex
xTolerance: float = Field(1.0e-6, gt=0.0)
funcTolerance: float = Field(1.0e-6, gt=0.0)
Expand All @@ -81,6 +85,7 @@ class Simplex(Calculate):

class DE(Calculate):
"""Defines the additional fields for the Differential Evolution procedure."""

procedure: Literal[Procedures.DE] = Procedures.DE
populationSize: int = Field(20, ge=1)
fWeight: float = 0.5
Expand All @@ -92,6 +97,7 @@ class DE(Calculate):

class NS(Calculate):
"""Defines the additional fields for the Nested Sampler procedure."""

procedure: Literal[Procedures.NS] = Procedures.NS
nLive: int = Field(150, ge=1)
nMCMC: float = Field(0.0, ge=0.0)
Expand All @@ -101,6 +107,7 @@ class NS(Calculate):

class Dream(Calculate):
"""Defines the additional fields for the Dream procedure."""

procedure: Literal[Procedures.Dream] = Procedures.Dream
nSamples: int = Field(50000, ge=0)
nChains: int = Field(10, gt=0)
Expand All @@ -110,28 +117,31 @@ class Dream(Calculate):
adaptPCR: bool = False


def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\
-> Union[Calculate, Simplex, DE, NS, Dream]:
def set_controls(
procedure: Procedures = Procedures.Calculate,
**properties,
) -> Union[Calculate, Simplex, DE, NS, Dream]:
"""Returns the appropriate controls model given the specified procedure."""
controls = {
Procedures.Calculate: Calculate,
Procedures.Simplex: Simplex,
Procedures.DE: DE,
Procedures.NS: NS,
Procedures.Dream: Dream
Procedures.Dream: Dream,
}

try:
model = controls[procedure](**properties)
except KeyError:
members = list(Procedures.__members__.values())
allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}'
raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None
raise ValueError(f"The controls procedure must be one of: {allowed_values}") from None
except ValidationError as exc:
custom_error_msgs = {'extra_forbidden': f'Extra inputs are not permitted. The fields for the {procedure}'
f' controls procedure are:\n '
f'{", ".join(controls[procedure].model_fields.keys())}\n'
}
custom_error_msgs = {
"extra_forbidden": f'Extra inputs are not permitted. The fields for the {procedure}'
f' controls procedure are:\n '
f'{", ".join(controls[procedure].model_fields.keys())}\n',
}
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None

Expand Down
Loading

0 comments on commit a7842e6

Please sign in to comment.