Skip to content

Commit

Permalink
Introduces case insensitive enums and Classlists (RascalSoftware#53)
Browse files Browse the repository at this point in the history
* Adds case insensitive enums

* Adds "__str__" method to format contrast models as vertical list

* Adds code to ensure searching parameter names in project and classlist is case insensitive
  • Loading branch information
DrPaulSharp authored Aug 1, 2024
1 parent 48b27ba commit 0a8fa3c
Show file tree
Hide file tree
Showing 12 changed files with 305 additions and 101 deletions.
17 changes: 12 additions & 5 deletions RATapi/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ def __str__(self):
+ list(
f"{'Data array: ['+' x '.join(str(i) for i in v.shape) if v.size > 0 else '['}]"
if isinstance(v, np.ndarray)
else "\n".join(element for element in v)
if k == "model"
else str(v)
for v in model.__dict__.values()
for k, v in model.__dict__.items()
)
for index, model in enumerate(self.data)
]
Expand Down Expand Up @@ -308,9 +310,9 @@ def _validate_name_field(self, input_args: dict[str, Any]) -> None:
Raised if the input arguments contain a name_field value already defined in the ClassList.
"""
names = self.get_names()
names = [name.lower() for name in self.get_names()]
with contextlib.suppress(KeyError):
if input_args[self.name_field] in names:
if input_args[self.name_field].lower() in names:
raise ValueError(
f"Input arguments contain the {self.name_field} '{input_args[self.name_field]}', "
f"which is already specified in the ClassList",
Expand All @@ -331,7 +333,7 @@ def _check_unique_name_fields(self, input_list: Iterable[object]) -> None:
Raised if the input list defines more than one object with the same value of name_field.
"""
names = [getattr(model, self.name_field) for model in input_list if hasattr(model, self.name_field)]
names = [getattr(model, self.name_field).lower() for model in input_list if hasattr(model, self.name_field)]
if len(set(names)) != len(names):
raise ValueError(f"Input list contains objects with the same value of the {self.name_field} attribute")

Expand Down Expand Up @@ -367,7 +369,12 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object,
object with that value of the name_field attribute cannot be found.
"""
return next((model for model in self.data if getattr(model, self.name_field) == value), value)
try:
lower_value = value.lower()
except AttributeError:
lower_value = value

return next((model for model in self.data if getattr(model, self.name_field).lower() == lower_value), value)

@staticmethod
def _determine_class_handle(input_list: Sequence[object]):
Expand Down
6 changes: 3 additions & 3 deletions RATapi/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
fields = {
"calculate": common_fields,
"simplex": [*common_fields, "xTolerance", "funcTolerance", "maxFuncEvals", "maxIterations", *update_fields],
"de": [
"DE": [
*common_fields,
"populationSize",
"fWeight",
Expand All @@ -29,8 +29,8 @@
"numGenerations",
*update_fields,
],
"ns": [*common_fields, "nLive", "nMCMC", "propScale", "nsTolerance"],
"dream": [*common_fields, "nSamples", "nChains", "jumpProbability", "pUnitGamma", "boundHandling", "adaptPCR"],
"NS": [*common_fields, "nLive", "nMCMC", "propScale", "nsTolerance"],
"DREAM": [*common_fields, "nSamples", "nChains", "jumpProbability", "pUnitGamma", "boundHandling", "adaptPCR"],
}


Expand Down
47 changes: 47 additions & 0 deletions RATapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,26 @@ class Contrast(RATModel):
resample: bool = False
model: list[str] = []

def __str__(self):
table = prettytable.PrettyTable()
table.field_names = [key.replace("_", " ") for key in self.__dict__]
model_entry = "\n".join(element for element in self.model)
table.add_row(
[
self.name,
self.data,
self.background,
self.background_action,
self.bulk_in,
self.bulk_out,
self.scalefactor,
self.resolution,
self.resample,
model_entry,
]
)
return table.get_string()


class ContrastWithRatio(RATModel):
"""Groups together all of the components of the model including domain terms."""
Expand All @@ -93,6 +113,26 @@ class ContrastWithRatio(RATModel):
domain_ratio: str = ""
model: list[str] = []

def __str__(self):
table = prettytable.PrettyTable()
table.field_names = [key.replace("_", " ") for key in self.__dict__]
model_entry = "\n".join(element for element in self.model)
table.add_row(
[
self.name,
self.data,
self.background,
self.background_action,
self.bulk_in,
self.bulk_out,
self.scalefactor,
self.resolution,
self.resample,
model_entry,
]
)
return table.get_string()


class CustomFile(RATModel):
"""Defines the files containing functions to run when using custom models."""
Expand Down Expand Up @@ -219,6 +259,13 @@ class DomainContrast(RATModel):
name: str = Field(default_factory=lambda: "New Domain Contrast " + next(domain_contrast_number), min_length=1)
model: list[str] = []

def __str__(self):
table = prettytable.PrettyTable()
table.field_names = [key.replace("_", " ") for key in self.__dict__]
model_entry = "\n".join(element for element in self.model)
table.add_row([self.name, model_entry])
return table.get_string()


class Layer(RATModel, populate_by_name=True):
"""Combines parameters into defined layers."""
Expand Down
2 changes: 1 addition & 1 deletion RATapi/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def make_results(
resample=output_results.contrastParams.resample,
)

if procedure in [Procedures.NS, Procedures.Dream]:
if procedure in [Procedures.NS, Procedures.DREAM]:
prediction_intervals = PredictionIntervals(
reflectivity=bayes_results.predictionIntervals.reflectivity,
sld=bayes_results.predictionIntervals.sld,
Expand Down
6 changes: 3 additions & 3 deletions RATapi/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def model_post_init(self, __context: Any) -> None:
if not hasattr(field, "_class_handle"):
field._class_handle = getattr(RATapi.models, model)

if "Substrate Roughness" not in self.parameters.get_names():
if "Substrate Roughness" not in [name.title() for name in self.parameters.get_names()]:
self.parameters.insert(
0,
RATapi.models.ProtectedParameter(
Expand All @@ -283,13 +283,13 @@ def model_post_init(self, __context: Any) -> None:
sigma=np.inf,
),
)
elif "Substrate Roughness" not in self.get_all_protected_parameters().values():
elif "Substrate Roughness" not in [name.title() for name in self.get_all_protected_parameters()["parameters"]]:
# If substrate roughness is included as a standard parameter replace it with a protected parameter
substrate_roughness_values = self.parameters[self.parameters.index("Substrate Roughness")].model_dump()
self.parameters.remove("Substrate Roughness")
self.parameters.insert(0, RATapi.models.ProtectedParameter(**substrate_roughness_values))

if "Simulation" not in self.data.get_names():
if "Simulation" not in [name.title() for name in self.data.get_names()]:
self.data.insert(0, RATapi.models.Data(name="Simulation", simulation_range=[0.005, 0.7]))

self._all_names = self.get_all_names()
Expand Down
103 changes: 58 additions & 45 deletions RATapi/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,40 @@
from strenum import StrEnum


# Controls
class Parallel(StrEnum):
"""Defines the available options for parallelization"""
class RATEnum(StrEnum):
@classmethod
def _missing_(cls, value: str):
value = value.lower()

Single = "single"
Points = "points"
Contrasts = "contrasts"
# Replace common alternative spellings
value = value.replace("-", " ").replace("_", " ").replace("++", "pp").replace("polarized", "polarised")

for member in cls:
if member.value.lower() == value:
return member
return None


class Procedures(StrEnum):
# Controls
class Procedures(RATEnum):
"""Defines the available options for procedures"""

Calculate = "calculate"
Simplex = "simplex"
DE = "de"
NS = "ns"
Dream = "dream"
DE = "DE"
NS = "NS"
DREAM = "DREAM"


class Display(StrEnum):
class Parallel(RATEnum):
"""Defines the available options for parallelization"""

Single = "single"
Points = "points"
Contrasts = "contrasts"


class Display(RATEnum):
"""Defines the available options for display"""

Off = "off"
Expand All @@ -34,15 +48,6 @@ class Display(StrEnum):
Final = "final"


class BoundHandling(StrEnum):
"""Defines the available options for bound handling"""

Off = "off"
Reflect = "reflect"
Bound = "bound"
Fold = "fold"


class Strategies(Enum):
"""Defines the available options for strategies"""

Expand All @@ -54,48 +59,56 @@ class Strategies(Enum):
RandomEitherOrAlgorithm = 6


# Models
class Hydration(StrEnum):
None_ = "none"
BulkIn = "bulk in"
BulkOut = "bulk out"
Oil = "oil"


class Languages(StrEnum):
Cpp = "cpp"
Python = "python"
Matlab = "matlab"

class BoundHandling(RATEnum):
"""Defines the available options for bound handling"""

class Priors(StrEnum):
Uniform = "uniform"
Gaussian = "gaussian"
Off = "off"
Reflect = "reflect"
Bound = "bound"
Fold = "fold"


class TypeOptions(StrEnum):
# Models
class TypeOptions(RATEnum):
Constant = "constant"
Data = "data"
Function = "function"


class BackgroundActions(StrEnum):
class BackgroundActions(RATEnum):
Add = "add"
Subtract = "subtract"


class Languages(RATEnum):
Cpp = "Cpp"
Python = "python"
Matlab = "matlab"


class Hydration(RATEnum):
None_ = "none"
BulkIn = "bulk in"
BulkOut = "bulk out"


class Priors(RATEnum):
Uniform = "uniform"
Gaussian = "gaussian"


# Project
class Calculations(StrEnum):
class Calculations(RATEnum):
NonPolarised = "non polarised"
Domains = "domains"


class Geometries(StrEnum):
AirSubstrate = "air/substrate"
SubstrateLiquid = "substrate/liquid"


class LayerModels(StrEnum):
class LayerModels(RATEnum):
CustomLayers = "custom layers"
CustomXY = "custom xy"
StandardLayers = "standard layers"


class Geometries(RATEnum):
AirSubstrate = "air/substrate"
SubstrateLiquid = "substrate/liquid"
15 changes: 12 additions & 3 deletions tests/test_classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,10 +767,13 @@ def test__validate_name_field(two_name_class_list: ClassList, input_dict: dict[s
"input_dict",
[
({"name": "Alice"}),
({"name": "ALICE"}),
({"name": "alice"}),
],
)
def test__validate_name_field_not_unique(two_name_class_list: ClassList, input_dict: dict[str, Any]) -> None:
"""We should raise a ValueError if we input values containing a name_field defined in an object in the ClassList."""
"""We should raise a ValueError if we input values containing a name_field defined in an object in the ClassList,
accounting for case sensitivity."""
with pytest.raises(
ValueError,
match=f"Input arguments contain the {two_name_class_list.name_field} "
Expand Down Expand Up @@ -801,11 +804,13 @@ def test__check_unique_name_fields(two_name_class_list: ClassList, input_list: I
"input_list",
[
([InputAttributes(name="Alice"), InputAttributes(name="Alice")]),
([InputAttributes(name="Alice"), InputAttributes(name="ALICE")]),
([InputAttributes(name="Alice"), InputAttributes(name="alice")]),
],
)
def test__check_unique_name_fields_not_unique(two_name_class_list: ClassList, input_list: Iterable) -> None:
"""We should raise a ValueError if an input list contains multiple objects with matching name_field values
defined.
"""We should raise a ValueError if an input list contains multiple objects with (case-insensitive) matching
name_field values defined.
"""
with pytest.raises(
ValueError,
Expand Down Expand Up @@ -846,7 +851,11 @@ def test__check_classes_different_classes(input_list: Iterable) -> None:
["value", "expected_output"],
[
("Alice", InputAttributes(name="Alice")),
("ALICE", InputAttributes(name="Alice")),
("alice", InputAttributes(name="Alice")),
("Eve", "Eve"),
("EVE", "EVE"),
("eve", "eve"),
],
)
def test__get_item_from_name_field(
Expand Down
Loading

0 comments on commit 0a8fa3c

Please sign in to comment.