Skip to content

Commit

Permalink
added model renaming for data and function backgrounds/resolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhroom committed Dec 10, 2024
1 parent caa40e2 commit 427e289
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 38 deletions.
48 changes: 27 additions & 21 deletions RATapi/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,25 @@ def discriminate_contrasts(contrast_input):

AllFields = collections.namedtuple("AllFields", ["attribute", "fields"])
model_names_used_in = {
"background_parameters": AllFields(
"backgrounds", ["source", "value_1", "value_2", "value_3", "value_4", "value_5"]
),
"resolution_parameters": AllFields(
"resolutions", ["source", "value_1", "value_2", "value_3", "value_4", "value_5"]
),
"parameters": AllFields("layers", ["thickness", "SLD", "SLD_real", "SLD_imaginary", "roughness", "hydration"]),
"data": AllFields("contrasts", ["data"]),
"backgrounds": AllFields("contrasts", ["background"]),
"bulk_in": AllFields("contrasts", ["bulk_in"]),
"bulk_out": AllFields("contrasts", ["bulk_out"]),
"scalefactors": AllFields("contrasts", ["scalefactor"]),
"domain_ratios": AllFields("contrasts", ["domain_ratio"]),
"resolutions": AllFields("contrasts", ["resolution"]),
"background_parameters": [
AllFields("backgrounds", ["source", "value_1", "value_2", "value_3", "value_4", "value_5"])
],
"resolution_parameters": [
AllFields("resolutions", ["source", "value_1", "value_2", "value_3", "value_4", "value_5"])
],
"parameters": [AllFields("layers", ["thickness", "SLD", "SLD_real", "SLD_imaginary", "roughness", "hydration"])],
"data": [
AllFields("contrasts", ["data"]),
AllFields("backgrounds", ["source"]),
AllFields("resolutions", ["source"]),
],
"custom_files": [AllFields("backgrounds", ["source"]), AllFields("resolutions", ["source"])],
"backgrounds": [AllFields("contrasts", ["background"])],
"bulk_in": [AllFields("contrasts", ["bulk_in"])],
"bulk_out": [AllFields("contrasts", ["bulk_out"])],
"scalefactors": [AllFields("contrasts", ["scalefactor"])],
"domain_ratios": [AllFields("contrasts", ["domain_ratio"])],
"resolutions": [AllFields("contrasts", ["resolution"])],
}

# Note that the order of these parameters is hard-coded into RAT
Expand Down Expand Up @@ -508,18 +513,19 @@ def set_absorption(self) -> "Project":
@model_validator(mode="after")
def update_renamed_models(self) -> "Project":
"""When models defined in the ClassLists are renamed, we need to update that name elsewhere in the project."""
for class_list in model_names_used_in:
for class_list, fields_to_update in model_names_used_in.items():
old_names = self._all_names[class_list]
new_names = getattr(self, class_list).get_names()
if len(old_names) == len(new_names):
name_diff = [(old, new) for (old, new) in zip(old_names, new_names) if old != new]
for old_name, new_name in name_diff:
model_names_list = getattr(self, model_names_used_in[class_list].attribute)
all_matches = model_names_list.get_all_matches(old_name)
fields = model_names_used_in[class_list].fields
for index, field in all_matches:
if field in fields:
setattr(model_names_list[index], field, new_name)
for field in fields_to_update:
project_field = getattr(self, field.attribute)
all_matches = project_field.get_all_matches(old_name)
params = field.fields
for index, param in all_matches:
if param in params:
setattr(project_field[index], param, new_name)
self._all_names = self.get_all_names()
return self

Expand Down
43 changes: 26 additions & 17 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,26 +625,35 @@ def test_check_protected_parameters(delete_operation) -> None:


@pytest.mark.parametrize(
["model", "field"],
["model", "fields"],
[
("background_parameters", "source"),
("resolution_parameters", "source"),
("parameters", "roughness"),
("data", "data"),
("backgrounds", "background"),
("bulk_in", "bulk_in"),
("bulk_out", "bulk_out"),
("scalefactors", "scalefactor"),
("resolutions", "resolution"),
("background_parameters", ["source"]),
("resolution_parameters", ["source"]),
("parameters", ["roughness"]),
("data", ["data", "source", "source"]),
("custom_files", ["source", "source"]),
("backgrounds", ["background"]),
("bulk_in", ["bulk_in"]),
("bulk_out", ["bulk_out"]),
("scalefactors", ["scalefactor"]),
("resolutions", ["resolution"]),
],
)
def test_rename_models(test_project, model: str, field: str) -> None:
def test_rename_models(test_project, model: str, fields: list[str]) -> None:
"""When renaming a model in the project, the new name should be recorded when that model is referred to elsewhere
in the project.
"""
if model == "data":
test_project.backgrounds[0] = RATapi.models.Background(type="data", source="Simulation")
test_project.resolutions[0] = RATapi.models.Resolution(type="data", source="Simulation")
if model == "custom_files":
test_project.backgrounds[0] = RATapi.models.Background(type="function", source="Test Custom File")
test_project.resolutions[0] = RATapi.models.Resolution(type="function", source="Test Custom File")
getattr(test_project, model).set_fields(-1, name="New Name")
attribute = RATapi.project.model_names_used_in[model].attribute
assert getattr(getattr(test_project, attribute)[-1], field) == "New Name"
model_name_lists = RATapi.project.model_names_used_in[model]
for model_name_list, field in zip(model_name_lists, fields):
attribute = model_name_list.attribute
assert getattr(getattr(test_project, attribute)[-1], field) == "New Name"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1197,7 +1206,7 @@ def test_wrap_del(test_project, class_list: str, parameter: str, field: str) ->
pydantic.ValidationError,
match=f"1 validation error for Project\n Value error, The value "
f'"{parameter}" in the "{field}" field of '
f'"{RATapi.project.model_names_used_in[class_list].attribute}" '
f'"{RATapi.project.model_names_used_in[class_list][0].attribute}" '
f'must be defined in "{class_list}".',
):
del test_attribute[index]
Expand Down Expand Up @@ -1405,7 +1414,7 @@ def test_wrap_pop(test_project, class_list: str, parameter: str, field: str) ->
pydantic.ValidationError,
match=f"1 validation error for Project\n Value error, The value "
f'"{parameter}" in the "{field}" field of '
f'"{RATapi.project.model_names_used_in[class_list].attribute}" '
f'"{RATapi.project.model_names_used_in[class_list][0].attribute}" '
f'must be defined in "{class_list}".',
):
test_attribute.pop(index)
Expand Down Expand Up @@ -1437,7 +1446,7 @@ def test_wrap_remove(test_project, class_list: str, parameter: str, field: str)
pydantic.ValidationError,
match=f"1 validation error for Project\n Value error, The value "
f'"{parameter}" in the "{field}" field of '
f'"{RATapi.project.model_names_used_in[class_list].attribute}" '
f'"{RATapi.project.model_names_used_in[class_list][0].attribute}" '
f'must be defined in "{class_list}".',
):
test_attribute.remove(parameter)
Expand Down Expand Up @@ -1469,7 +1478,7 @@ def test_wrap_clear(test_project, class_list: str, parameter: str, field: str) -
pydantic.ValidationError,
match=f"1 validation error for Project\n Value error, The value "
f'"{parameter}" in the "{field}" field of '
f'"{RATapi.project.model_names_used_in[class_list].attribute}" '
f'"{RATapi.project.model_names_used_in[class_list][0].attribute}" '
f'must be defined in "{class_list}".',
):
test_attribute.clear()
Expand Down

0 comments on commit 427e289

Please sign in to comment.