diff --git a/RATapi/project.py b/RATapi/project.py index 61e049d8..b4d71a54 100644 --- a/RATapi/project.py +++ b/RATapi/project.py @@ -80,20 +80,25 @@ def discriminate_contrasts(contrast_input): AllFields = collections.namedtuple("AllFields", ["attribute", "fields"]) model_names_used_in = { - "background_parameters": AllFields( - "backgrounds", ["source", "value_1", "value_2", "value_3", "value_4", "value_5"] - ), - "resolution_parameters": AllFields( - "resolutions", ["source", "value_1", "value_2", "value_3", "value_4", "value_5"] - ), - "parameters": AllFields("layers", ["thickness", "SLD", "SLD_real", "SLD_imaginary", "roughness", "hydration"]), - "data": AllFields("contrasts", ["data"]), - "backgrounds": AllFields("contrasts", ["background"]), - "bulk_in": AllFields("contrasts", ["bulk_in"]), - "bulk_out": AllFields("contrasts", ["bulk_out"]), - "scalefactors": AllFields("contrasts", ["scalefactor"]), - "domain_ratios": AllFields("contrasts", ["domain_ratio"]), - "resolutions": AllFields("contrasts", ["resolution"]), + "background_parameters": [ + AllFields("backgrounds", ["source", "value_1", "value_2", "value_3", "value_4", "value_5"]) + ], + "resolution_parameters": [ + AllFields("resolutions", ["source", "value_1", "value_2", "value_3", "value_4", "value_5"]) + ], + "parameters": [AllFields("layers", ["thickness", "SLD", "SLD_real", "SLD_imaginary", "roughness", "hydration"])], + "data": [ + AllFields("contrasts", ["data"]), + AllFields("backgrounds", ["source"]), + AllFields("resolutions", ["source"]), + ], + "custom_files": [AllFields("backgrounds", ["source"]), AllFields("resolutions", ["source"])], + "backgrounds": [AllFields("contrasts", ["background"])], + "bulk_in": [AllFields("contrasts", ["bulk_in"])], + "bulk_out": [AllFields("contrasts", ["bulk_out"])], + "scalefactors": [AllFields("contrasts", ["scalefactor"])], + "domain_ratios": [AllFields("contrasts", ["domain_ratio"])], + "resolutions": [AllFields("contrasts", ["resolution"])], } # Note that the order of these parameters is hard-coded into RAT @@ -508,18 +513,19 @@ def set_absorption(self) -> "Project": @model_validator(mode="after") def update_renamed_models(self) -> "Project": """When models defined in the ClassLists are renamed, we need to update that name elsewhere in the project.""" - for class_list in model_names_used_in: + for class_list, fields_to_update in model_names_used_in.items(): old_names = self._all_names[class_list] new_names = getattr(self, class_list).get_names() if len(old_names) == len(new_names): name_diff = [(old, new) for (old, new) in zip(old_names, new_names) if old != new] for old_name, new_name in name_diff: - model_names_list = getattr(self, model_names_used_in[class_list].attribute) - all_matches = model_names_list.get_all_matches(old_name) - fields = model_names_used_in[class_list].fields - for index, field in all_matches: - if field in fields: - setattr(model_names_list[index], field, new_name) + for field in fields_to_update: + project_field = getattr(self, field.attribute) + all_matches = project_field.get_all_matches(old_name) + params = field.fields + for index, param in all_matches: + if param in params: + setattr(project_field[index], param, new_name) self._all_names = self.get_all_names() return self diff --git a/tests/test_project.py b/tests/test_project.py index b734788c..6e757e8f 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -625,26 +625,35 @@ def test_check_protected_parameters(delete_operation) -> None: @pytest.mark.parametrize( - ["model", "field"], + ["model", "fields"], [ - ("background_parameters", "source"), - ("resolution_parameters", "source"), - ("parameters", "roughness"), - ("data", "data"), - ("backgrounds", "background"), - ("bulk_in", "bulk_in"), - ("bulk_out", "bulk_out"), - ("scalefactors", "scalefactor"), - ("resolutions", "resolution"), + ("background_parameters", ["source"]), + ("resolution_parameters", ["source"]), + ("parameters", ["roughness"]), + ("data", ["data", "source", "source"]), + ("custom_files", ["source", "source"]), + ("backgrounds", ["background"]), + ("bulk_in", ["bulk_in"]), + ("bulk_out", ["bulk_out"]), + ("scalefactors", ["scalefactor"]), + ("resolutions", ["resolution"]), ], ) -def test_rename_models(test_project, model: str, field: str) -> None: +def test_rename_models(test_project, model: str, fields: list[str]) -> None: """When renaming a model in the project, the new name should be recorded when that model is referred to elsewhere in the project. """ + if model == "data": + test_project.backgrounds[0] = RATapi.models.Background(type="data", source="Simulation") + test_project.resolutions[0] = RATapi.models.Resolution(type="data", source="Simulation") + if model == "custom_files": + test_project.backgrounds[0] = RATapi.models.Background(type="function", source="Test Custom File") + test_project.resolutions[0] = RATapi.models.Resolution(type="function", source="Test Custom File") getattr(test_project, model).set_fields(-1, name="New Name") - attribute = RATapi.project.model_names_used_in[model].attribute - assert getattr(getattr(test_project, attribute)[-1], field) == "New Name" + model_name_lists = RATapi.project.model_names_used_in[model] + for model_name_list, field in zip(model_name_lists, fields): + attribute = model_name_list.attribute + assert getattr(getattr(test_project, attribute)[-1], field) == "New Name" @pytest.mark.parametrize( @@ -1197,7 +1206,7 @@ def test_wrap_del(test_project, class_list: str, parameter: str, field: str) -> pydantic.ValidationError, match=f"1 validation error for Project\n Value error, The value " f'"{parameter}" in the "{field}" field of ' - f'"{RATapi.project.model_names_used_in[class_list].attribute}" ' + f'"{RATapi.project.model_names_used_in[class_list][0].attribute}" ' f'must be defined in "{class_list}".', ): del test_attribute[index] @@ -1405,7 +1414,7 @@ def test_wrap_pop(test_project, class_list: str, parameter: str, field: str) -> pydantic.ValidationError, match=f"1 validation error for Project\n Value error, The value " f'"{parameter}" in the "{field}" field of ' - f'"{RATapi.project.model_names_used_in[class_list].attribute}" ' + f'"{RATapi.project.model_names_used_in[class_list][0].attribute}" ' f'must be defined in "{class_list}".', ): test_attribute.pop(index) @@ -1437,7 +1446,7 @@ def test_wrap_remove(test_project, class_list: str, parameter: str, field: str) pydantic.ValidationError, match=f"1 validation error for Project\n Value error, The value " f'"{parameter}" in the "{field}" field of ' - f'"{RATapi.project.model_names_used_in[class_list].attribute}" ' + f'"{RATapi.project.model_names_used_in[class_list][0].attribute}" ' f'must be defined in "{class_list}".', ): test_attribute.remove(parameter) @@ -1469,7 +1478,7 @@ def test_wrap_clear(test_project, class_list: str, parameter: str, field: str) - pydantic.ValidationError, match=f"1 validation error for Project\n Value error, The value " f'"{parameter}" in the "{field}" field of ' - f'"{RATapi.project.model_names_used_in[class_list].attribute}" ' + f'"{RATapi.project.model_names_used_in[class_list][0].attribute}" ' f'must be defined in "{class_list}".', ): test_attribute.clear()