diff --git a/pyproject.toml b/pyproject.toml index d4927e5..e6de6e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,4 +34,16 @@ mark-parentheses = false [tool.ruff.lint.pep8-naming] # if overriding a PyQt method, please add it here! # names should be in alphabetical order for readability -extend-ignore-names = ["allKeys", "createEditor", "mergeWith", "setEditorData", "textFromValue", ] +extend-ignore-names = ['allKeys', + 'columnCount', + 'createEditor', + 'headerData', + 'mergeWith', + 'rowCount', + 'setData', + 'setEditorData', + 'setModelData', + 'setValue', + 'stepBy', + 'textFromValue', + 'valueFromText',] diff --git a/rascal2/core/commands.py b/rascal2/core/commands.py index bb75ce7..ebf8717 100644 --- a/rascal2/core/commands.py +++ b/rascal2/core/commands.py @@ -1,8 +1,10 @@ """File for Qt commands.""" from enum import IntEnum, unique +from typing import Callable from PyQt6 import QtGui +from RATapi import ClassList @unique @@ -10,24 +12,48 @@ class CommandID(IntEnum): """Unique ID for undoable commands""" EditControls = 1000 + EditProject = 2000 -class EditControls(QtGui.QUndoCommand): - """Command for editing the Controls object.""" +class AbstractModelEdit(QtGui.QUndoCommand): + """Command for editing an attribute of the model.""" - def __init__(self, attr, value, presenter): + attribute = None + + def __init__(self, new_values: dict, presenter): super().__init__() self.presenter = presenter - self.attr = attr - self.value = value - self.old_value = getattr(self.presenter.model.controls, self.attr) - self.setText(f"Set control {self.attr} to {self.value}") + self.new_values = new_values + if self.attribute is None: + raise NotImplementedError("AbstractEditModel should not be instantiated directly.") + else: + self.model_class = getattr(self.presenter.model, self.attribute) + self.old_values = {attr: getattr(self.model_class, attr) for attr in self.new_values} + self.update_text() + + def update_text(self): + """Update the undo command text.""" + if len(self.new_values) == 1: + attr, value = list(self.new_values.items())[0] + if isinstance(list(self.new_values.values())[0], ClassList): + text = f"Changed values in {attr}" + else: + text = f"Set {self.attribute} {attr} to {value}" + else: + text = f"Save update to {self.attribute}" + + self.setText(text) + + @property + def update_attribute(self) -> Callable: + """Return the method used to update the attribute.""" + raise NotImplementedError def undo(self): - self.presenter.model.update_controls({self.attr: self.old_value}) + self.update_attribute(self.old_values) def redo(self): - self.presenter.model.update_controls({self.attr: self.value}) + self.update_attribute(self.new_values) def mergeWith(self, command): """Merges consecutive Edit controls commands if the attributes are the @@ -35,16 +61,38 @@ def mergeWith(self, command): # We should think about if merging all Edit controls irrespective of # attribute is the way to go for UX - if self.attr != command.attr: + if list(self.new_values.keys()) != list(command.new_values.keys()): return False - if self.old_value == command.value: + if list(self.old_values.values()) == list(command.new_values.values()): self.setObsolete(True) - self.value = command.value - self.setText(f"Set control {self.attr} to {self.value}") + self.new_values = command.new_values + self.update_text() return True def id(self): """Returns ID used for merging commands""" + raise NotImplementedError + + +class EditControls(AbstractModelEdit): + attribute = "controls" + + @property + def update_attribute(self): + return self.presenter.model.update_controls + + def id(self): return CommandID.EditControls + + +class EditProject(AbstractModelEdit): + attribute = "project" + + @property + def update_attribute(self): + return self.presenter.model.update_project + + def id(self): + return CommandID.EditProject diff --git a/rascal2/static/images/delete.png b/rascal2/static/images/delete.png new file mode 100644 index 0000000..eaef9b0 Binary files /dev/null and b/rascal2/static/images/delete.png differ diff --git a/rascal2/ui/model.py b/rascal2/ui/model.py index e1da445..171bdfe 100644 --- a/rascal2/ui/model.py +++ b/rascal2/ui/model.py @@ -33,7 +33,7 @@ def create_project(self, name: str, save_path: str): self.controls = RAT.Controls() self.save_path = save_path - def update_project(self, problem_definition: RAT.rat_core.ProblemDefinition): + def handle_results(self, problem_definition: RAT.rat_core.ProblemDefinition): """Update the project given a set of results.""" parameter_field = { "parameters": "params", @@ -49,16 +49,17 @@ def update_project(self, problem_definition: RAT.rat_core.ProblemDefinition): for index, value in enumerate(getattr(problem_definition, parameter_field[class_list])): getattr(self.project, class_list)[index].value = value - def update_controls(self, new_values): - """ + def update_project(self, new_values: dict) -> None: + """Replaces the project with a new project. Parameters ---------- - new_values: Dict - The attribute name-value pair to updated on the controls. + new_values : dict + New values to set in the project. + """ - vars(self.controls).update(new_values) - self.controls_updated.emit() + vars(self.project).update(new_values) + self.project_updated.emit() def save_project(self): """Save the project to the save path.""" @@ -110,13 +111,13 @@ def load_r1_project(self, load_path: str): self.controls = RAT.Controls() self.save_path = str(Path(load_path).parent) - def edit_project(self, updated_project) -> None: - """Updates the project. + def update_controls(self, new_values): + """ Parameters ---------- - updated_project : RAT.Project - The updated project. + new_values: Dict + The attribute name-value pair to updated on the controls. """ - self.project = updated_project - self.project_updated.emit() + vars(self.controls).update(new_values) + self.controls_updated.emit() diff --git a/rascal2/ui/presenter.py b/rascal2/ui/presenter.py index 3109253..6ec7f06 100644 --- a/rascal2/ui/presenter.py +++ b/rascal2/ui/presenter.py @@ -106,7 +106,7 @@ def edit_controls(self, setting: str, value: Any): with warnings.catch_warnings(): warnings.simplefilter("ignore") self.model.controls.model_validate({setting: value}) - self.view.undo_stack.push(commands.EditControls(setting, value, self)) + self.view.undo_stack.push(commands.EditControls({setting: value}, self)) def save_project(self, save_as: bool = False): """Save the model. @@ -150,7 +150,7 @@ def run(self): def handle_results(self): """Handle a RAT run being finished.""" - self.model.update_project(self.runner.updated_problem) + self.model.handle_results(self.runner.updated_problem) self.view.handle_results(self.runner.results) def handle_interrupt(self): @@ -174,15 +174,24 @@ def handle_event(self): elif isinstance(event, LogData): self.view.logging.log(event.level, event.msg) - def edit_project(self, updated_project) -> None: - """Updates the project. + def edit_project(self, updated_project: dict) -> None: + """Edit the Project with a dictionary of attributes. Parameters ---------- - updated_project : RAT.Project - The updated project. + updated_project : dict + The updated project attributes. + + Raises + ------ + ValidationError + If the updated project attributes are not valid. + """ - self.model.edit_project(updated_project) + project_dict = self.model.project.model_dump() + project_dict.update(updated_project) + self.model.project.model_validate(project_dict) + self.view.undo_stack.push(commands.EditProject(updated_project, self)) # '\d+\.\d+' is the regex for diff --git a/rascal2/ui/view.py b/rascal2/ui/view.py index 50fc4a1..49f8076 100644 --- a/rascal2/ui/view.py +++ b/rascal2/ui/view.py @@ -35,7 +35,6 @@ def __init__(self): # TODO replace the widgets below # plotting: NO ISSUE YET # https://github.com/RascalSoftware/RasCAL-2/issues/5 - # project: NO ISSUE YET self.plotting_widget = QtWidgets.QWidget() self.terminal_widget = TerminalWidget(self) self.controls_widget = ControlsWidget(self) @@ -246,7 +245,7 @@ def setup_mdi(self): # if windows are already created, don't set them up again, # just refresh the widget data if len(self.mdi.subWindowList()) == 4: - self.controls_widget.setup_controls() + self.setup_mdi_widgets() return widgets = { @@ -255,8 +254,7 @@ def setup_mdi(self): "Terminal": self.terminal_widget, "Fitting Controls": self.controls_widget, } - self.controls_widget.setup_controls() - self.project_widget.update_project_view() + self.setup_mdi_widgets() self.terminal_widget.text_area.setVisible(True) for title, widget in reversed(widgets.items()): @@ -269,6 +267,11 @@ def setup_mdi(self): self.startup_dlg = self.takeCentralWidget() self.setCentralWidget(self.mdi) + def setup_mdi_widgets(self): + """Performs setup of MDI widgets that relies on the Project existing.""" + self.controls_widget.setup_controls() + self.project_widget.update_project_view() + def reset_mdi_layout(self): """Reset MDI layout to the default.""" if self.settings.mdi_defaults is None: diff --git a/rascal2/widgets/controls.py b/rascal2/widgets/controls.py index 1bb3494..dea5974 100644 --- a/rascal2/widgets/controls.py +++ b/rascal2/widgets/controls.py @@ -213,6 +213,7 @@ def __init__(self, parent, settings, presenter): for i, setting in enumerate(settings): field_info = controls_fields[setting] self.rows[setting] = get_validated_input(field_info) + self.rows[setting].layout().setContentsMargins(5, 0, 0, 0) self.datasetter[setting] = self.create_model_data_setter(setting) self.rows[setting].edited_signal.connect(self.datasetter[setting]) label = QtWidgets.QLabel(setting) diff --git a/rascal2/widgets/delegates.py b/rascal2/widgets/delegates.py index 87b1a87..f98dd1c 100644 --- a/rascal2/widgets/delegates.py +++ b/rascal2/widgets/delegates.py @@ -1,40 +1,78 @@ """Delegates for items in Qt tables.""" +from typing import Literal + from PyQt6 import QtCore, QtGui, QtWidgets +from rascal2.widgets.inputs import AdaptiveDoubleSpinBox, get_validated_input + -class EnumDelegate(QtWidgets.QStyledItemDelegate): - """Item delegate for Enums.""" +class ValidatedInputDelegate(QtWidgets.QStyledItemDelegate): + """Item delegate for validated inputs.""" - def __init__(self, parent, enum): + def __init__(self, field_info, parent): super().__init__(parent) - self.enum = enum + self.table = parent + self.field_info = field_info def createEditor(self, parent, option, index): - combobox = QtWidgets.QComboBox(parent) - combobox.addItems(str(e.value) for e in self.enum) - return combobox + widget = get_validated_input(self.field_info, parent) + widget.set_data(index.data(QtCore.Qt.ItemDataRole.DisplayRole)) + + # fill in background as otherwise you can see the original View text underneath + widget.setAutoFillBackground(True) + widget.setBackgroundRole(QtGui.QPalette.ColorRole.Base) - def setEditorData(self, editor: QtWidgets.QCheckBox, index): + return widget + + def setEditorData(self, editor: QtWidgets.QWidget, index): data = index.data(QtCore.Qt.ItemDataRole.DisplayRole) - editor.setCurrentText(data) + editor.set_data(data) + + def setModelData(self, editor, model, index): + data = editor.get_data() + model.setData(index, data, QtCore.Qt.ItemDataRole.EditRole) + +class ValueSpinBoxDelegate(QtWidgets.QStyledItemDelegate): + """Item delegate for parameter values between a dynamic min and max. -class BoolDelegate(QtWidgets.QStyledItemDelegate): - """Item delegate for bools.""" + Parameters + ---------- + field : Literal["min", "value", "max"] + The field of the parameter - def __init__(self, parent): + """ + + def __init__(self, field: Literal["min", "value", "max"], parent): super().__init__(parent) - self.parent = parent + self.table = parent + self.field = field def createEditor(self, parent, option, index): - checkbox = QtWidgets.QCheckBox(parent) + widget = AdaptiveDoubleSpinBox(parent) + + max_val = float("inf") + min_val = -float("inf") + + if self.field in ["min", "value"]: + max_val = index.siblingAtColumn(index.column() + 1).data(QtCore.Qt.ItemDataRole.DisplayRole) + if self.field in ["value", "max"]: + min_val = index.siblingAtColumn(index.column() - 1).data(QtCore.Qt.ItemDataRole.DisplayRole) + + widget.setMinimum(min_val) + widget.setMaximum(max_val) + # fill in background as otherwise you can see the original View text underneath - checkbox.setAutoFillBackground(True) - checkbox.setBackgroundRole(QtGui.QPalette.ColorRole.Base) - return checkbox + widget.setAutoFillBackground(True) + widget.setBackgroundRole(QtGui.QPalette.ColorRole.Base) + + return widget - def setEditorData(self, editor: QtWidgets.QCheckBox, index): + def setEditorData(self, editor: AdaptiveDoubleSpinBox, index): data = index.data(QtCore.Qt.ItemDataRole.DisplayRole) - data = data == "True" # data from model is given as a string - editor.setChecked(data) + editor.setValue(data) + + def setModelData(self, editor, model, index): + data = editor.value() + model.setData(index, data, QtCore.Qt.ItemDataRole.EditRole) diff --git a/rascal2/widgets/inputs.py b/rascal2/widgets/inputs.py index 78cf285..b360ed5 100644 --- a/rascal2/widgets/inputs.py +++ b/rascal2/widgets/inputs.py @@ -8,13 +8,15 @@ from PyQt6 import QtCore, QtGui, QtWidgets -def get_validated_input(field_info: FieldInfo) -> QtWidgets.QWidget: +def get_validated_input(field_info: FieldInfo, parent=None) -> QtWidgets.QWidget: """Get a validated input widget from Pydantic field info. Parameters ---------- field_info : FieldInfo The Pydantic field info for the field. + parent : QWidget or None, default None + The parent widget of this widget. Returns ------- @@ -31,9 +33,9 @@ def get_validated_input(field_info: FieldInfo) -> QtWidgets.QWidget: for input_type, widget in class_widgets.items(): if issubclass(field_info.annotation, input_type): - return widget(field_info) + return widget(field_info, parent) - return BaseInputWidget(field_info) + return BaseInputWidget(field_info, parent) class BaseInputWidget(QtWidgets.QWidget): @@ -64,7 +66,7 @@ def __init__(self, field_info: FieldInfo, parent=None): layout = QtWidgets.QVBoxLayout() layout.addWidget(self.editor) - layout.setContentsMargins(5, 0, 0, 0) + layout.setContentsMargins(0, 0, 0, 0) self.setLayout(layout) self.setSizePolicy(QtWidgets.QSizePolicy.Policy.Expanding, QtWidgets.QSizePolicy.Policy.Fixed) @@ -125,7 +127,7 @@ def create_editor(self, field_info: FieldInfo) -> QtWidgets.QWidget: if hasattr(item, attr): editor.setMaximum(getattr(item, attr)) # if no default exists, field_info.default is PydanticUndefined not a nonexistent attribute - if isinstance(field_info.default, (int, float)) and field_info.default > 0: + if isinstance(field_info.default, (int, float)) and 0 < field_info.default < float("inf"): # set default decimals to order of magnitude of default value editor.setDecimals(-floor(log10(abs(field_info.default)))) @@ -161,8 +163,15 @@ def create_editor(self, field_info: FieldInfo) -> QtWidgets.QWidget: class AdaptiveDoubleSpinBox(QtWidgets.QDoubleSpinBox): """A double spinbox which adapts to given numbers of decimals.""" + MIN_DECIMALS = 2 + def __init__(self, parent=None): super().__init__(parent) + + # default max and min are 99.99 and 0.0 + self.setMaximum(float("inf")) + self.setMinimum(-float("inf")) + self.setStepType(self.StepType.AdaptiveDecimalStepType) self.setKeyboardTracking(False) @@ -182,7 +191,50 @@ def textFromValue(self, value): The string displayed on the spinbox. """ - return f"{round(value, self.decimals()):.{self.decimals()}g}" + if value == float("inf"): + return "inf" + if value == -float("inf"): + return "-inf" + return f"{round(value, 12):.4g}" + + def valueFromText(self, text: str) -> float: + """Set the underlying value of the spinbox from the text input.""" + if text == "inf": + return float("inf") + if text == "-inf": + return -float("inf") + return float(text) + + def setValue(self, value: float): + """Hook into setValue that sets the decimals when the value is manually set. + + Parameters + ---------- + value : float + The value to set the spinbox to. + """ + state, text, _ = self.validate(str(value), 0) + if state == QtGui.QValidator.State.Acceptable: + value = float(text) + super().setValue(value) + + def stepBy(self, steps: int): + """Step the value up or down by some amount. + + Override of QtWidgets.QDoubleSpinBox.stepBy to handle infs. + + Parameters + ---------- + steps : int + The number of linesteps to step by. + + """ + if self.value() == float("inf") and steps < 0: + self.setValue(1e12) # largest possible float that doesn't look ugly in the box + if self.value() == -float("inf") and steps > 0: + self.setValue(1e-12) # smallest possible float that pyqt doesn't round to 0 + else: + super().stepBy(steps) def validate(self, input_text, pos) -> tuple[QtGui.QValidator.State, str, int]: """Validate a string written into the spinbox. @@ -202,9 +254,19 @@ def validate(self, input_text, pos) -> tuple[QtGui.QValidator.State, str, int]: The validation state of the input, the input string, and position. """ - if "e" in input_text: + if input_text in "inf" or input_text in "-inf": + if input_text in ["inf", "-inf"]: + return (QtGui.QValidator.State.Acceptable, input_text, pos) + else: + return (QtGui.QValidator.State.Intermediate, input_text, pos) + if "e" in input_text or "E" in input_text: + components = input_text.lower().split("e") + significand = components[0] + significand_decimals = len(significand.split(".")[-1]) + exponent = components[1] try: - self.setDecimals(-int(input_text.split("e")[-1])) + exponent_order = int(exponent) + self.setDecimals(max(significand_decimals - exponent_order, 0)) return (QtGui.QValidator.State.Acceptable, input_text, pos) except ValueError: return (QtGui.QValidator.State.Intermediate, input_text, pos) diff --git a/rascal2/widgets/project.py b/rascal2/widgets/project.py deleted file mode 100644 index 0425e8c..0000000 --- a/rascal2/widgets/project.py +++ /dev/null @@ -1,219 +0,0 @@ -from PyQt6 import QtCore, QtGui, QtWidgets -from RATapi.utils.enums import Calculations, Geometries, LayerModels - -from rascal2.config import path_for - - -class ProjectWidget(QtWidgets.QWidget): - """ - The Project MDI Widget - """ - - def __init__(self, parent): - """ - Initialize widget. - - Parameters - ---------- - parent: MainWindowView - An instance of the MainWindowView - """ - super().__init__(parent) - self.parent = parent - self.presenter = self.parent.presenter - self.model = self.parent.presenter.model - - self.presenter.model.project_updated.connect(self.update_project_view) - - self.create_project_view() - self.create_edit_view() - - self.stacked_widget = QtWidgets.QStackedWidget() - self.stacked_widget.addWidget(self.project_widget) - self.stacked_widget.addWidget(self.edit_project_widget) - - layout = QtWidgets.QVBoxLayout() - layout.addWidget(self.stacked_widget) - self.setLayout(layout) - - def update_project_view(self) -> None: - """Updates the project view.""" - self.modified_project = self.presenter.model.project.model_copy(deep=True) - - self.calculation_type.setText(self.model.project.calculation) - self.model_type.setText(self.model.project.model) - self.geometry_type.setText(self.model.project.geometry) - - self.calculation_combobox.setCurrentText(self.model.project.calculation) - self.model_combobox.setCurrentText(self.model.project.model) - self.geometry_combobox.setCurrentText(self.model.project.geometry) - - self.handle_domains_tab() - - def create_project_view(self) -> None: - """Creates the project (non-edit) veiw""" - self.project_widget = QtWidgets.QWidget() - main_layout = QtWidgets.QGridLayout() - main_layout.setVerticalSpacing(20) - - self.edit_project_button = QtWidgets.QPushButton( - " Edit Project", self, objectName="bluebutton", icon=QtGui.QIcon(path_for("edit.png")) - ) - self.edit_project_button.clicked.connect(self.show_edit_view) - main_layout.addWidget(self.edit_project_button, 0, 5) - - self.calculation_label = QtWidgets.QLabel("Calculation:", self, objectName="boldlabel") - - self.calculation_type = QtWidgets.QLineEdit(self) - self.calculation_type.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) - self.calculation_type.setReadOnly(True) - - main_layout.addWidget(self.calculation_label, 1, 0, 1, 1) - main_layout.addWidget(self.calculation_type, 1, 1, 1, 1) - - self.model_type_label = QtWidgets.QLabel("Model Type:", self, objectName="boldlabel") - - self.model_type = QtWidgets.QLineEdit(self) - self.model_type.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) - self.model_type.setReadOnly(True) - - main_layout.addWidget(self.model_type_label, 1, 2, 1, 1) - main_layout.addWidget(self.model_type, 1, 3, 1, 1) - - self.geometry_label = QtWidgets.QLabel("Geometry:", self, objectName="boldlabel") - - self.geometry_type = QtWidgets.QLineEdit(self) - self.geometry_type.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) - self.geometry_type.setReadOnly(True) - - main_layout.addWidget(self.geometry_label, 1, 4, 1, 1) - main_layout.addWidget(self.geometry_type, 1, 5, 1, 1) - - self.project_tab = QtWidgets.QTabWidget() - - # Replace QtWidgets.QWidget() with methods to create - # the tabs in project view. - self.project_tab.addTab(QtWidgets.QWidget(), "Parameters") - self.project_tab.addTab(QtWidgets.QWidget(), "Backgrounds") - self.project_tab.addTab(QtWidgets.QWidget(), "Experimental Parameters") - self.project_tab.addTab(QtWidgets.QWidget(), "Layers") - self.project_tab.addTab(QtWidgets.QWidget(), "Data") - self.project_tab.addTab(QtWidgets.QWidget(), "Contrasts") - - main_layout.addWidget(self.project_tab, 2, 0, 1, 6) - self.project_widget.setLayout(main_layout) - - def create_edit_view(self) -> None: - """Creates the project edit veiw""" - - self.edit_project_widget = QtWidgets.QWidget() - main_layout = QtWidgets.QVBoxLayout() - main_layout.setSpacing(20) - - self.save_project_button = QtWidgets.QPushButton(" Save Project", self, objectName="greybutton") - self.save_project_button.setIcon(QtGui.QIcon(path_for("save-project.png"))) - self.save_project_button.clicked.connect(self.save_changes) - - self.cancel_button = QtWidgets.QPushButton(" Cancel", self, objectName="redbutton") - self.cancel_button.setIcon(QtGui.QIcon(path_for("cancel-dark.png"))) - self.cancel_button.clicked.connect(self.cancel_changes) - - layout = QtWidgets.QHBoxLayout() - layout.setAlignment(QtCore.Qt.AlignmentFlag.AlignRight) - layout.addWidget(self.save_project_button) - layout.addWidget(self.cancel_button) - main_layout.addLayout(layout) - - self.edit_calculation_label = QtWidgets.QLabel("Calculation:", self, objectName="boldlabel") - - self.calculation_combobox = QtWidgets.QComboBox(self) - self.calculation_combobox.addItems([calc for calc in Calculations]) - - layout = QtWidgets.QHBoxLayout() - layout.addWidget(self.edit_calculation_label) - layout.addWidget(self.calculation_combobox) - - self.edit_model_type_label = QtWidgets.QLabel("Model Type:", self, objectName="boldlabel") - - self.model_combobox = QtWidgets.QComboBox(self) - self.model_combobox.addItems([model for model in LayerModels]) - - layout.addWidget(self.edit_model_type_label) - layout.addWidget(self.model_combobox) - - self.edit_geometry_label = QtWidgets.QLabel("Geometry:", self, objectName="boldlabel") - - self.geometry_combobox = QtWidgets.QComboBox(self) - self.geometry_combobox.addItems([geo for geo in Geometries]) - - self.calculation_combobox.currentTextChanged.connect(lambda s: self.process_combobox_update("calculation", s)) - self.model_combobox.currentTextChanged.connect(lambda s: self.process_combobox_update("model", s)) - self.geometry_combobox.currentTextChanged.connect(lambda s: self.process_combobox_update("geometry", s)) - - layout.addWidget(self.edit_geometry_label) - layout.addWidget(self.geometry_combobox) - main_layout.addLayout(layout) - - self.edit_project_tab = QtWidgets.QTabWidget() - - # Replace QtWidgets.QWidget() with methods to create - # the tabs in edit view. - self.edit_project_tab.addTab(QtWidgets.QWidget(), "Parameters") - self.edit_project_tab.addTab(QtWidgets.QWidget(), "Backgrounds") - self.edit_project_tab.addTab(QtWidgets.QWidget(), "Experimental Parameters") - self.edit_project_tab.addTab(QtWidgets.QWidget(), "Layers") - self.edit_project_tab.addTab(QtWidgets.QWidget(), "Data") - self.edit_project_tab.addTab(QtWidgets.QWidget(), "Contrasts") - - main_layout.addWidget(self.edit_project_tab) - - self.edit_project_widget.setLayout(main_layout) - - def process_combobox_update(self, attr_name: str, selected_value: str) -> None: - """ - Updates the copy of the project. - - Parameters - ---------- - attr_name: str - The attr that needs to be updated. - selected_value: str - The new selected value from the combobox. - """ - setattr(self.modified_project, attr_name, selected_value) - - def handle_domains_tab(self) -> None: - """Displays or hides the domains tab""" - domain_tab_ix = 6 - if ( - self.calculation_type.text() == Calculations.Domains - and self.project_tab.tabText(domain_tab_ix) != "Domains" - and self.edit_project_tab.tabText(domain_tab_ix) != "Domains" - ): - # Replace QtWidgets.QWidget() with methods to create - # the domains tab in project and edit view. - self.project_tab.insertTab(domain_tab_ix, QtWidgets.QWidget(), "Domains") - self.edit_project_tab.insertTab(domain_tab_ix, QtWidgets.QWidget(), "Domains") - elif self.calculation_type.text() != Calculations.Domains: - self.project_tab.removeTab(domain_tab_ix) - self.edit_project_tab.removeTab(domain_tab_ix) - - def show_project_view(self) -> None: - """Show project view""" - self.setWindowTitle("Project") - self.stacked_widget.setCurrentIndex(0) - - def show_edit_view(self) -> None: - """Show edit view""" - self.setWindowTitle("Edit Project") - self.stacked_widget.setCurrentIndex(1) - - def save_changes(self) -> None: - """Save changes to the project.""" - self.presenter.edit_project(self.modified_project) - self.show_project_view() - - def cancel_changes(self) -> None: - """Cancel changes to the project.""" - self.update_project_view() - self.show_project_view() diff --git a/rascal2/widgets/project/__init__.py b/rascal2/widgets/project/__init__.py new file mode 100644 index 0000000..70251c1 --- /dev/null +++ b/rascal2/widgets/project/__init__.py @@ -0,0 +1,3 @@ +from rascal2.widgets.project.project import ProjectWidget + +__all__ = ["ProjectWidget"] diff --git a/rascal2/widgets/project/models.py b/rascal2/widgets/project/models.py new file mode 100644 index 0000000..dbcae04 --- /dev/null +++ b/rascal2/widgets/project/models.py @@ -0,0 +1,298 @@ +"""Models and widgets for project fields.""" + +from enum import Enum + +import pydantic +import RATapi +from PyQt6 import QtCore, QtGui, QtWidgets +from RATapi.utils.enums import Procedures + +from rascal2.config import path_for +from rascal2.widgets.delegates import ValidatedInputDelegate, ValueSpinBoxDelegate + + +class ClassListModel(QtCore.QAbstractTableModel): + """Table model for a project ClassList field. + + Parameters + ---------- + classlist : ClassList + The initial classlist to represent in this model. + field : str + The name of the field represented by this model. + parent : QtWidgets.QWidget + The parent widget for the model. + + """ + + def __init__(self, classlist: RATapi.ClassList, parent: QtWidgets.QWidget): + super().__init__(parent) + self.parent = parent + self.classlist = classlist + self.item_type = classlist._class_handle + if not issubclass(self.item_type, pydantic.BaseModel): + raise NotImplementedError("ClassListModel only works for classlists of Pydantic models!") + self.headers = list(self.item_type.model_fields) + self.edit_mode = False + + def rowCount(self, parent=None) -> int: + return len(self.classlist) + + def columnCount(self, parent=None) -> int: + return len(self.headers) + 1 + + def data(self, index, role=QtCore.Qt.ItemDataRole.DisplayRole): + param = self.index_header(index) + + if param is None: + return None + + data = getattr(self.classlist[index.row()], param) + + if role == QtCore.Qt.ItemDataRole.DisplayRole and self.index_header(index) != "fit": + data = getattr(self.classlist[index.row()], param) + # pyqt can't automatically coerce enums to strings... + if isinstance(data, Enum): + return str(data) + return data + elif role == QtCore.Qt.ItemDataRole.CheckStateRole and self.index_header(index) == "fit": + return QtCore.Qt.CheckState.Checked if data else QtCore.Qt.CheckState.Unchecked + + def setData(self, index, value, role=QtCore.Qt.ItemDataRole.EditRole) -> bool: + if role == QtCore.Qt.ItemDataRole.EditRole or role == QtCore.Qt.ItemDataRole.CheckStateRole: + row = index.row() + param = self.index_header(index) + if self.index_header(index) == "fit": + value = QtCore.Qt.CheckState(value) == QtCore.Qt.CheckState.Checked + if param is not None: + try: + setattr(self.classlist[row], param, value) + except pydantic.ValidationError: + return False + if not self.edit_mode: + self.parent.update_project() + return True + return False + + def headerData(self, section, orientation, role=QtCore.Qt.ItemDataRole.DisplayRole): + if ( + orientation == QtCore.Qt.Orientation.Horizontal + and role == QtCore.Qt.ItemDataRole.DisplayRole + and section != 0 + ): + return self.headers[section - 1].replace("_", " ").title() + return None + + def append_item(self): + """Append an item to the ClassList.""" + self.classlist.append(self.item_type()) + self.endResetModel() + + def delete_item(self, row: int): + """Delete an item in the ClassList. + + Parameters + ---------- + row : int + The row containing the item to delete. + + """ + self.classlist.pop(row) + self.endResetModel() + + def index_header(self, index): + """Get the header for an index. + + Parameters: + ----------- + index : QModelIndex + The model index for the header. + + Returns + ------- + str or None + Either the name of the header, or None if none exists. + + """ + col = index.column() + if col == 0: + return None + return self.headers[col - 1] + + +class ProjectFieldWidget(QtWidgets.QWidget): + """Widget to show a project ClassList. + + Parameters + ---------- + field : str + The field of the project represented by this widget. + parent : ProjectTabWidget + The tab this field belongs to. + + """ + + classlist_model = ClassListModel + + def __init__(self, field: str, parent): + super().__init__(parent) + self.field = field + header = field.replace("_", " ").title() + self.parent = parent + self.table = QtWidgets.QTableView(parent) + self.table.setSizePolicy( + QtWidgets.QSizePolicy.Policy.MinimumExpanding, QtWidgets.QSizePolicy.Policy.MinimumExpanding + ) + + layout = QtWidgets.QVBoxLayout() + topbar = QtWidgets.QHBoxLayout() + topbar.addWidget(QtWidgets.QLabel(header)) + # change to icon: remember to mention that plus.png in the icons is wonky + self.add_button = QtWidgets.QPushButton(f"+ Add new {header[:-1] if header[-1] == 's' else header}") + self.add_button.setHidden(True) + self.add_button.pressed.connect(self.append_item) + topbar.addWidget(self.add_button) + + layout.addLayout(topbar) + layout.addWidget(self.table) + self.setLayout(layout) + + def update_model(self, classlist): + """Update the table model to synchronise with the project field.""" + self.model = self.classlist_model(classlist, self) + + self.table.setModel(self.model) + self.table.hideColumn(0) + self.set_item_delegates() + header = self.table.horizontalHeader() + + header.setSectionResizeMode(self.model.headers.index("name") + 1, QtWidgets.QHeaderView.ResizeMode.Stretch) + header.setSectionResizeMode(0, QtWidgets.QHeaderView.ResizeMode.ResizeToContents) + + def set_item_delegates(self): + """Set item delegates and open persistent editors for the table.""" + for i, header in enumerate(self.model.headers): + self.table.setItemDelegateForColumn( + i + 1, ValidatedInputDelegate(self.model.item_type.model_fields[header], self.table) + ) + + def append_item(self): + """Append an item to the model if the model exists.""" + if self.model is not None: + self.model.append_item() + + # call edit again to recreate delete buttons + self.edit() + + def delete_item(self, index): + """Delete an item at the index if the model exists. + + Parameters + ---------- + index : int + The row to be deleted. + + """ + if self.model is not None: + self.model.delete_item(index) + + # call edit again to recreate delete buttons + self.edit() + + def edit(self): + """Change the widget to be in edit mode.""" + self.model.edit_mode = True + self.add_button.setHidden(False) + self.table.showColumn(0) + self.set_item_delegates() + for i in range(0, self.model.rowCount()): + self.table.setIndexWidget(self.model.index(i, 0), self.make_delete_button(i)) + + def make_delete_button(self, index): + """Make a button that deletes index `index` from the list.""" + button = QtWidgets.QPushButton(icon=QtGui.QIcon(path_for("delete.png"))) + button.resize(button.sizeHint().width(), button.sizeHint().width()) + button.pressed.connect(lambda: self.delete_item(index)) + + return button + + def update_project(self): + """Update the field in the parent Project.""" + presenter = self.parent.parent.parent.presenter + presenter.edit_project({self.field: self.model.classlist}) + + +class ParametersModel(ClassListModel): + """Classlist model for Parameters.""" + + def __init__(self, classlist: RATapi.ClassList, parent: QtWidgets.QWidget): + super().__init__(classlist, parent) + self.headers.insert(0, self.headers.pop(self.headers.index("fit"))) + + self.protected_indices = [] + if self.item_type is RATapi.models.Parameter: + for i, item in enumerate(classlist): + if isinstance(item, RATapi.models.ProtectedParameter): + self.protected_indices.append(i) + + def flags(self, index): + flags = super().flags(index) + header = self.index_header(index) + # disable editing on the delete widget column + # and disable mu, sigma if prior type is not Gaussian + if (index.column() == 0) or ( + self.classlist[index.row()].prior_type != "gaussian" and header in ["mu", "sigma"] + ): + return QtCore.Qt.ItemFlag.NoItemFlags + # never allow name editing for protected parameters, allow everything else to be edited by default + if header == "fit": + flags |= QtCore.Qt.ItemFlag.ItemIsUserCheckable + elif header != "name" or (self.edit_mode and index.row() not in self.protected_indices): + flags |= QtCore.Qt.ItemFlag.ItemIsEditable + + return flags + + +class ParameterFieldWidget(ProjectFieldWidget): + """Subclass of field widgets for parameters.""" + + classlist_model = ParametersModel + + def set_item_delegates(self): + for i, header in enumerate(self.model.headers): + if header in ["min", "value", "max"]: + self.table.setItemDelegateForColumn(i + 1, ValueSpinBoxDelegate(header, self.table)) + else: + self.table.setItemDelegateForColumn( + i + 1, ValidatedInputDelegate(self.model.item_type.model_fields[header], self.table) + ) + + def update_model(self, classlist): + super().update_model(classlist) + header = self.table.horizontalHeader() + header.setSectionResizeMode( + self.model.headers.index("fit") + 1, QtWidgets.QHeaderView.ResizeMode.ResizeToContents + ) + + def handle_bayesian_columns(self, procedure: Procedures): + """Hide or show Bayes-related columns based on procedure. + + Parameters + ---------- + procedure : Procedure + The procedure in Controls. + """ + is_bayesian = procedure in ["ns", "dream"] + bayesian_columns = ["prior_type", "mu", "sigma"] + for item in bayesian_columns: + index = self.model.headers.index(item) + if is_bayesian: + self.table.showColumn(index + 1) + else: + self.table.hideColumn(index + 1) + + def edit(self): + super().edit() + for i in range(0, self.model.rowCount()): + if i in self.model.protected_indices: + self.table.setIndexWidget(self.model.index(i, 0), None) diff --git a/rascal2/widgets/project/project.py b/rascal2/widgets/project/project.py new file mode 100644 index 0000000..14528ed --- /dev/null +++ b/rascal2/widgets/project/project.py @@ -0,0 +1,340 @@ +"""Widget for the Project window.""" + +from copy import deepcopy + +import RATapi +from PyQt6 import QtCore, QtGui, QtWidgets +from RATapi.utils.enums import Calculations, Geometries, LayerModels + +from rascal2.config import path_for +from rascal2.widgets.project.models import ParameterFieldWidget, ProjectFieldWidget + + +class ProjectWidget(QtWidgets.QWidget): + """ + The Project MDI Widget + """ + + def __init__(self, parent): + """ + Initialize widget. + + Parameters + ---------- + parent: MainWindowView + An instance of the MainWindowView + """ + super().__init__(parent) + self.parent = parent + self.parent_model = self.parent.presenter.model + + self.parent_model.project_updated.connect(self.update_project_view) + self.parent_model.controls_updated.connect(self.handle_controls_update) + + self.tabs = { + "Parameters": ["parameters"], + "Experimental Parameters": ["scalefactors", "bulk_in", "bulk_out"], + "Layers": [], + "Data": [], + "Backgrounds": [], + "Contrasts": [], + "Domains": [], + } + + self.view_tabs = {} + self.edit_tabs = {} + self.draft_project = None + + project_view = self.create_project_view() + project_edit = self.create_edit_view() + + self.project_tab.currentChanged.connect(self.edit_project_tab.setCurrentIndex) + self.edit_project_tab.currentChanged.connect(self.project_tab.setCurrentIndex) + + self.stacked_widget = QtWidgets.QStackedWidget() + self.stacked_widget.addWidget(project_view) + self.stacked_widget.addWidget(project_edit) + + layout = QtWidgets.QVBoxLayout() + layout.addWidget(self.stacked_widget) + self.setLayout(layout) + + def create_project_view(self) -> None: + """Creates the project (non-edit) view""" + project_widget = QtWidgets.QWidget() + main_layout = QtWidgets.QGridLayout() + main_layout.setVerticalSpacing(20) + + self.edit_project_button = QtWidgets.QPushButton( + "Edit Project", self, objectName="bluebutton", icon=QtGui.QIcon(path_for("edit.png")) + ) + self.edit_project_button.clicked.connect(self.show_edit_view) + main_layout.addWidget(self.edit_project_button, 0, 5) + + self.calculation_label = QtWidgets.QLabel("Calculation:", self, objectName="boldlabel") + + self.calculation_type = QtWidgets.QLineEdit(self) + self.calculation_type.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) + self.calculation_type.setReadOnly(True) + + main_layout.addWidget(self.calculation_label, 1, 0, 1, 1) + main_layout.addWidget(self.calculation_type, 1, 1, 1, 1) + + self.model_type_label = QtWidgets.QLabel("Model Type:", self, objectName="boldlabel") + + self.model_type = QtWidgets.QLineEdit(self) + self.model_type.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) + self.model_type.setReadOnly(True) + + main_layout.addWidget(self.model_type_label, 1, 2, 1, 1) + main_layout.addWidget(self.model_type, 1, 3, 1, 1) + + self.geometry_label = QtWidgets.QLabel("Geometry:", self, objectName="boldlabel") + + self.geometry_type = QtWidgets.QLineEdit(self) + self.geometry_type.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) + self.geometry_type.setReadOnly(True) + + main_layout.addWidget(self.geometry_label, 1, 4, 1, 1) + main_layout.addWidget(self.geometry_type, 1, 5, 1, 1) + + self.project_tab = QtWidgets.QTabWidget() + + for tab, fields in self.tabs.items(): + widget = self.view_tabs[tab] = ProjectTabWidget(fields, self) + self.project_tab.addTab(widget, tab) + + main_layout.addWidget(self.project_tab, 2, 0, 1, 6) + project_widget.setLayout(main_layout) + + return project_widget + + def create_edit_view(self) -> None: + """Creates the project edit view""" + + edit_project_widget = QtWidgets.QWidget() + main_layout = QtWidgets.QVBoxLayout() + main_layout.setSpacing(20) + + self.save_project_button = QtWidgets.QPushButton("Save Project", self, objectName="greybutton") + self.save_project_button.setIcon(QtGui.QIcon(path_for("save-project.png"))) + self.save_project_button.clicked.connect(self.save_changes) + + self.cancel_button = QtWidgets.QPushButton("Cancel", self, objectName="redbutton") + self.cancel_button.setIcon(QtGui.QIcon(path_for("cancel-dark.png"))) + self.cancel_button.clicked.connect(self.cancel_changes) + + layout = QtWidgets.QHBoxLayout() + layout.setAlignment(QtCore.Qt.AlignmentFlag.AlignRight) + layout.addWidget(self.save_project_button) + layout.addWidget(self.cancel_button) + main_layout.addLayout(layout) + + self.edit_calculation_label = QtWidgets.QLabel("Calculation:", self, objectName="boldlabel") + + self.calculation_combobox = QtWidgets.QComboBox(self) + self.calculation_combobox.addItems([calc for calc in Calculations]) + + layout = QtWidgets.QHBoxLayout() + layout.addWidget(self.edit_calculation_label) + layout.addWidget(self.calculation_combobox) + + self.edit_model_type_label = QtWidgets.QLabel("Model Type:", self, objectName="boldlabel") + + self.model_combobox = QtWidgets.QComboBox(self) + self.model_combobox.addItems([model for model in LayerModels]) + + layout.addWidget(self.edit_model_type_label) + layout.addWidget(self.model_combobox) + + self.edit_geometry_label = QtWidgets.QLabel("Geometry:", self, objectName="boldlabel") + + self.geometry_combobox = QtWidgets.QComboBox(self) + self.geometry_combobox.addItems([geo for geo in Geometries]) + + layout.addWidget(self.edit_geometry_label) + layout.addWidget(self.geometry_combobox) + main_layout.addLayout(layout) + + self.calculation_combobox.currentTextChanged.connect(lambda s: self.update_draft_project({"calculation": s})) + self.calculation_combobox.currentTextChanged.connect(lambda: self.handle_domains_tab()) + self.model_combobox.currentTextChanged.connect(lambda s: self.update_draft_project({"model": s})) + self.geometry_combobox.currentTextChanged.connect(lambda s: self.update_draft_project({"geometry": s})) + self.edit_project_tab = QtWidgets.QTabWidget() + + for tab, fields in self.tabs.items(): + widget = self.edit_tabs[tab] = ProjectTabWidget(fields, self, edit_mode=True) + self.edit_project_tab.addTab(widget, tab) + + main_layout.addWidget(self.edit_project_tab) + + edit_project_widget.setLayout(main_layout) + + return edit_project_widget + + def update_project_view(self) -> None: + """Updates the project view.""" + # draft project is a dict containing all the attributes of the parent model, + # because we don't want validation errors going off while editing the model is in-progress + self.draft_project: dict = create_draft_project(self.parent_model.project) + + self.calculation_type.setText(self.parent_model.project.calculation) + self.model_type.setText(self.parent_model.project.model) + self.geometry_type.setText(self.parent_model.project.geometry) + + self.calculation_combobox.setCurrentText(self.parent_model.project.calculation) + self.model_combobox.setCurrentText(self.parent_model.project.model) + self.geometry_combobox.setCurrentText(self.parent_model.project.geometry) + + for tab in self.tabs: + self.view_tabs[tab].update_model(self.draft_project) + self.edit_tabs[tab].update_model(self.draft_project) + + self.handle_domains_tab() + self.handle_controls_update() + + def update_draft_project(self, new_values: dict) -> None: + """ + Updates the draft project. + + Parameters + ---------- + new_values: dict + A dictionary of new values to update in the draft project. + + """ + self.draft_project.update(new_values) + + def handle_domains_tab(self) -> None: + """Displays or hides the domains tab""" + domain_tab_index = list(self.view_tabs).index("Domains") + is_domains = self.calculation_combobox.currentText() == Calculations.Domains + self.project_tab.setTabVisible(domain_tab_index, is_domains) + self.edit_project_tab.setTabVisible(domain_tab_index, is_domains) + + def handle_controls_update(self): + """Handle updates to Controls that need to be reflected in the project.""" + if self.draft_project is None: + return + + controls = self.parent_model.controls + + for tab in self.tabs: + self.view_tabs[tab].handle_controls_update(controls) + self.edit_tabs[tab].handle_controls_update(controls) + + def show_project_view(self) -> None: + """Show project view""" + self.setWindowTitle("Project") + self.stacked_widget.setCurrentIndex(0) + + def show_edit_view(self) -> None: + """Show edit view""" + self.setWindowTitle("Edit Project") + self.update_project_view() + self.stacked_widget.setCurrentIndex(1) + + def save_changes(self) -> None: + """Save changes to the project.""" + self.parent.presenter.edit_project(self.draft_project) + self.update_project_view() + self.show_project_view() + + def cancel_changes(self) -> None: + """Cancel changes to the project.""" + self.update_project_view() + self.show_project_view() + + +class ProjectTabWidget(QtWidgets.QWidget): + """Widget that combines multiple ProjectFieldWidgets to create a tab of the project window. + + Subclasses must reimplement the function update_model. + + Parameters + ---------- + fields : list[str] + The fields to display in the tab. + parent : QtWidgets.QWidget + The parent to this widget. + + """ + + def __init__(self, fields: list[str], parent, edit_mode: bool = False): + super().__init__(parent) + self.parent = parent + self.fields = fields + self.edit_mode = edit_mode + self.tables = {} + + layout = QtWidgets.QVBoxLayout() + for field in self.fields: + if field in RATapi.project.parameter_class_lists: + self.tables[field] = ParameterFieldWidget(field, self) + else: + self.tables[field] = ProjectFieldWidget(field, self) + layout.addWidget(self.tables[field]) + + scroll_area = QtWidgets.QScrollArea() + # one widget must be given, not a layout, + # or scrolling won't work properly! + tab_widget = QtWidgets.QFrame() + tab_widget.setLayout(layout) + scroll_area.setWidget(tab_widget) + scroll_area.setWidgetResizable(True) + + widget_layout = QtWidgets.QVBoxLayout() + widget_layout.addWidget(scroll_area) + + self.setLayout(widget_layout) + + def update_model(self, new_model): + """Update the model for each table. + + Parameters + ---------- + new_model + The new model data. + + """ + for field, table in self.tables.items(): + classlist = new_model[field] + table.update_model(classlist) + if self.edit_mode: + table.edit() + + def handle_controls_update(self, controls): + """Reflect changes to the Controls object.""" + for field in RATapi.project.parameter_class_lists: + if field in self.tables: + self.tables[field].handle_bayesian_columns(controls.procedure) + + +def create_draft_project(project: RATapi.Project) -> dict: + """Create a draft project (dictionary of project attributes) from a Project. + + Parameters + ---------- + project : RATapi.Project + The project to create a draft from. + + Returns + ------- + dict + The draft project. + + """ + # in an ideal world, we could just copy and dump the Project with something like + # project.model_copy(deep=True).model_dump() + # but some references get shared for some reason: e.g. draft_project['parameters'].append + # will point towards project.parameters.append (???) and so on + + draft_project = {} + for field in RATapi.Project.model_fields: + attr = getattr(project, field) + if isinstance(attr, RATapi.ClassList): + draft_project[field] = RATapi.ClassList(deepcopy(attr.data)) + draft_project[field]._class_handle = getattr(project, field)._class_handle + else: + draft_project[field] = attr + return draft_project diff --git a/tests/widgets/project/test_models.py b/tests/widgets/project/test_models.py new file mode 100644 index 0000000..da98738 --- /dev/null +++ b/tests/widgets/project/test_models.py @@ -0,0 +1,262 @@ +from unittest.mock import MagicMock + +import pydantic +import pytest +import RATapi +from PyQt6 import QtCore, QtWidgets + +import rascal2.widgets.delegates as delegates +import rascal2.widgets.inputs as inputs +from rascal2.widgets.project.models import ( + ClassListModel, + ParameterFieldWidget, + ParametersModel, + ProjectFieldWidget, +) + + +class MockMainWindow(QtWidgets.QMainWindow): + def __init__(self): + super().__init__() + self.presenter = MagicMock() + self.update_project = MagicMock() + + +class DataModel(pydantic.BaseModel, validate_assignment=True): + """A test Pydantic model.""" + + name: str = "Test Model" + value: int = 15 + + +@pytest.fixture +def classlist(): + """A test ClassList.""" + return RATapi.ClassList([DataModel(name="A", value=1), DataModel(name="B", value=6), DataModel(name="C", value=18)]) + + +@pytest.fixture +def table_model(classlist): + """A test ClassListModel.""" + return ClassListModel(classlist, parent) + + +@pytest.fixture +def param_classlist(): + def _classlist(protected_indices): + return RATapi.ClassList( + [ + RATapi.models.ProtectedParameter(name=str(i)) if i in protected_indices else RATapi.models.Parameter() + for i in [0, 1, 2] + ] + ) + + return _classlist + + +@pytest.fixture +def param_model(param_classlist): + def _param_model(protected_indices): + model = ParametersModel(param_classlist(protected_indices), parent) + return model + + return _param_model + + +parent = MockMainWindow() + + +def test_model_init(table_model, classlist): + """Test that initialisation works correctly for ClassListModels.""" + model = table_model + + assert model.classlist == classlist + assert model.item_type == DataModel + assert model.headers == ["name", "value"] + assert not model.edit_mode + + +def test_model_layout_data(table_model): + """Test that the model layout and data is as expected.""" + model = table_model + + assert model.rowCount() == 3 + assert model.columnCount() == 3 + + expected_data = [[None, "A", 1], [None, "B", 6], [None, "C", 18]] + headers = [None, "Name", "Value"] + + for row in [0, 1, 2]: + for column in [0, 1, 2]: + assert model.data(model.index(row, column)) == expected_data[row][column] + + for column in [0, 1, 2]: + assert model.headerData(column, QtCore.Qt.Orientation.Horizontal) == headers[column] + + +def test_model_set_data(table_model): + """Test that data can be set successfully, but is thrown out if it breaks the Pydantic model rules.""" + model = table_model + + assert model.setData(model.index(1, 2), 4) + assert model.classlist[1].value == 4 + + assert model.setData(model.index(1, 1), "D") + assert model.classlist[1].name == "D" + + assert not model.setData(model.index(2, 2), 19.4) + assert model.classlist[2].value == 18 + + +def test_append(table_model): + """Test that append_item successfully adds an item of the relevant type.""" + model = table_model + + model.append_item() + + assert len(model.classlist) == 4 + assert model.classlist[-1].name == "Test Model" + assert model.classlist[-1].value == 15 + + +def test_delete(table_model): + """Test that delete_item deletes the item at the desired index.""" + model = table_model + + model.delete_item(1) + + assert len(model.classlist) == 2 + assert [m.name for m in model.classlist] == ["A", "C"] + assert [m.value for m in model.classlist] == [1, 18] + + +def test_project_field_init(): + """Test that the ProjectFieldWidget is initialised correctly.""" + widget = ProjectFieldWidget("test", parent) + + assert widget.table.model() is None + assert widget.add_button.isHidden() + + +def test_project_field_update_model(classlist): + """Test that the correct changes are made when the model is updated in the ProjectFieldWidget.""" + widget = ProjectFieldWidget("test", parent) + widget.update_model(classlist) + + assert widget.table.isColumnHidden(0) + + assert widget.model.classlist == classlist + assert isinstance( + widget.table.itemDelegateForColumn(1).createEditor(None, None, widget.model.index(1, 1)), + inputs.BaseInputWidget, + ) + assert isinstance( + widget.table.itemDelegateForColumn(2).createEditor(None, None, widget.model.index(1, 2)), + inputs.IntInputWidget, + ) + + +def test_edit_mode(classlist): + """Test that edit mode makes the expected changes.""" + widget = ProjectFieldWidget("test", parent) + widget.update_model(classlist) + widget.edit() + + assert widget.model.edit_mode + assert not widget.add_button.isHidden() + assert not widget.table.isColumnHidden(0) + + for row in [0, 1, 2]: + assert isinstance(widget.table.indexWidget(widget.model.index(row, 0)), QtWidgets.QPushButton) + + +def test_delete_button(classlist): + """Test that delete buttons work as expected.""" + widget = ProjectFieldWidget("Test", parent) + widget.update_model(classlist) + + delete_button = widget.make_delete_button(1) + delete_button.click() + + assert len(widget.model.classlist) == 2 + assert [m.name for m in widget.model.classlist] == ["A", "C"] + assert [m.value for m in widget.model.classlist] == [1, 18] + + +def test_parameter_edit_mode(param_classlist): + """Test that parameter tab edit mode makes the expected changes.""" + widget = ProjectFieldWidget("Test", parent) + widget.update_model(param_classlist([])) + widget.edit() + + assert widget.model.edit_mode + assert not widget.add_button.isHidden() + assert not widget.table.isColumnHidden(0) + + for row in [0, 1, 2]: + assert isinstance(widget.table.indexWidget(widget.model.index(row, 0)), QtWidgets.QPushButton) + + +@pytest.mark.parametrize("protected", ([], [0, 2], [1])) +@pytest.mark.parametrize("prior_type", ("uniform", "gaussian")) +def test_parameter_flags(param_model, prior_type, protected): + """Test that protected parameters are successfully recorded and flagged, and parameter flags are set correctly.""" + model = param_model(protected) + for param in model.classlist: + param.prior_type = prior_type + + assert model.protected_indices == protected + + model.edit_mode = True + + for row in [0, 1, 2]: + for column in range(1, model.columnCount()): + item_flags = model.flags(model.index(row, column)) + match model.headers[column - 1]: + case "name": + if row in protected: + assert not item_flags & QtCore.Qt.ItemFlag.ItemIsEditable + else: + assert item_flags & QtCore.Qt.ItemFlag.ItemIsEditable + case "fit": + assert item_flags & QtCore.Qt.ItemFlag.ItemIsUserCheckable + case "mu" | "sigma": + if prior_type == "uniform": + assert item_flags == QtCore.Qt.ItemFlag.NoItemFlags + else: + assert item_flags & QtCore.Qt.ItemFlag.ItemIsEditable + + +def test_param_item_delegates(param_classlist): + """Test that parameter models have the expected item delegates.""" + widget = ParameterFieldWidget("Test", parent) + widget.parent = MagicMock() + widget.update_model(param_classlist([])) + + for column, header in enumerate(widget.model.headers, start=1): + if header in ["min", "value", "max"]: + assert isinstance(widget.table.itemDelegateForColumn(column), delegates.ValueSpinBoxDelegate) + else: + assert isinstance(widget.table.itemDelegateForColumn(column), delegates.ValidatedInputDelegate) + + +def test_hidden_bayesian_columns(param_classlist): + """Test that Bayes columns are hidden when procedure is not Bayesian.""" + widget = ParameterFieldWidget("Test", parent) + widget.parent = MagicMock() + widget.update_model(param_classlist([])) + mock_controls = widget.parent.parent.parent_model.controls = MagicMock() + mock_controls.procedure = "calculate" + bayesian_columns = ["prior_type", "mu", "sigma"] + + widget.handle_bayesian_columns("calculate") + + for item in bayesian_columns: + index = widget.model.headers.index(item) + assert widget.table.isColumnHidden(index + 1) + + widget.handle_bayesian_columns("dream") + + for item in bayesian_columns: + index = widget.model.headers.index(item) + assert not widget.table.isColumnHidden(index + 1) diff --git a/tests/test_project.py b/tests/widgets/project/test_project.py similarity index 61% rename from tests/test_project.py rename to tests/widgets/project/test_project.py index e9881a4..8d4bac0 100644 --- a/tests/test_project.py +++ b/tests/widgets/project/test_project.py @@ -1,18 +1,30 @@ from unittest.mock import MagicMock +import pydantic import pytest +import RATapi from PyQt6 import QtCore, QtWidgets -from RATapi import Project from RATapi.utils.enums import Calculations, Geometries, LayerModels -from rascal2.widgets.project import ProjectWidget +from rascal2.widgets.project.models import ( + ClassListModel, + ParameterFieldWidget, + ParametersModel, + ProjectFieldWidget, +) +from rascal2.widgets.project.project import ( + ProjectTabWidget, + ProjectWidget, +) class MockModel(QtCore.QObject): def __init__(self): super().__init__() - self.project = Project() + self.project = RATapi.Project() + self.controls = MagicMock() self.project_updated = MagicMock() + self.controls_updated = MagicMock() class MockPresenter(QtWidgets.QMainWindow): @@ -28,6 +40,28 @@ def __init__(self): self.presenter = MockPresenter() +class DataModel(pydantic.BaseModel, validate_assignment=True): + """A test Pydantic model.""" + + name: str = "Test Model" + value: int = 15 + + +parent = MockMainWindow() + + +@pytest.fixture +def classlist(): + """A test ClassList.""" + return RATapi.ClassList([DataModel(name="A", value=1), DataModel(name="B", value=6), DataModel(name="C", value=18)]) + + +@pytest.fixture +def table_model(classlist): + """A test ClassListModel.""" + return ClassListModel(classlist, parent) + + @pytest.fixture def setup_project_widget(): parent = MockMainWindow() @@ -36,6 +70,28 @@ def setup_project_widget(): return project_widget +@pytest.fixture +def param_classlist(): + def _classlist(protected_indices): + return RATapi.ClassList( + [ + RATapi.models.ProtectedParameter(name=str(i)) if i in protected_indices else RATapi.models.Parameter() + for i in [0, 1, 2] + ] + ) + + return _classlist + + +@pytest.fixture +def param_model(param_classlist): + def _param_model(protected_indices): + model = ParametersModel(param_classlist(protected_indices), parent) + return model + + return _param_model + + def test_project_widget_initial_state(setup_project_widget): """ Tests the inital state of the ProjectWidget class. @@ -46,7 +102,7 @@ def test_project_widget_initial_state(setup_project_widget): assert project_widget.stacked_widget.currentIndex() == 0 assert project_widget.edit_project_button.isEnabled() - assert project_widget.edit_project_button.text() == " Edit Project" + assert project_widget.edit_project_button.text() == "Edit Project" assert project_widget.calculation_label.text() == "Calculation:" assert project_widget.calculation_type.text() == Calculations.NonPolarised @@ -62,10 +118,10 @@ def test_project_widget_initial_state(setup_project_widget): # Check the layout of the edit view assert project_widget.save_project_button.isEnabled() - assert project_widget.save_project_button.text() == " Save Project" + assert project_widget.save_project_button.text() == "Save Project" assert project_widget.cancel_button.isEnabled() - assert project_widget.cancel_button.text() == " Cancel" + assert project_widget.cancel_button.text() == "Cancel" assert project_widget.edit_calculation_label.text() == "Calculation:" assert project_widget.calculation_combobox.currentText() == Calculations.NonPolarised @@ -82,7 +138,7 @@ def test_project_widget_initial_state(setup_project_widget): for ix, geometry in enumerate(Geometries): assert project_widget.geometry_combobox.itemText(ix) == geometry - for ix, tab in enumerate(["Parameters", "Backgrounds", "Experimental Parameters", "Layers", "Data", "Contrasts"]): + for ix, tab in enumerate(project_widget.tabs): assert project_widget.project_tab.tabText(ix) == tab assert project_widget.edit_project_tab.tabText(ix) == tab @@ -124,12 +180,12 @@ def test_save_changes_to_model_project(setup_project_widget): project_widget.geometry_combobox.setCurrentText(Geometries.SubstrateLiquid) project_widget.model_combobox.setCurrentText(LayerModels.CustomXY) - assert project_widget.modified_project.geometry == Geometries.SubstrateLiquid - assert project_widget.modified_project.model == LayerModels.CustomXY - assert project_widget.modified_project.calculation == Calculations.Domains + assert project_widget.draft_project["geometry"] == Geometries.SubstrateLiquid + assert project_widget.draft_project["model"] == LayerModels.CustomXY + assert project_widget.draft_project["calculation"] == Calculations.Domains - project_widget.save_project_button.click() - assert project_widget.presenter.edit_project.call_count == 1 + project_widget.save_changes() + assert project_widget.parent.presenter.edit_project.call_count == 1 def test_cancel_changes_to_model_project(setup_project_widget): @@ -145,12 +201,12 @@ def test_cancel_changes_to_model_project(setup_project_widget): project_widget.geometry_combobox.setCurrentText(Geometries.SubstrateLiquid) project_widget.model_combobox.setCurrentText(LayerModels.CustomXY) - assert project_widget.modified_project.geometry == Geometries.SubstrateLiquid - assert project_widget.modified_project.model == LayerModels.CustomXY - assert project_widget.modified_project.calculation == Calculations.Domains + assert project_widget.draft_project["geometry"] == Geometries.SubstrateLiquid + assert project_widget.draft_project["model"] == LayerModels.CustomXY + assert project_widget.draft_project["calculation"] == Calculations.Domains project_widget.cancel_button.click() - assert project_widget.presenter.edit_project.call_count == 0 + assert project_widget.parent.presenter.edit_project.call_count == 0 assert project_widget.calculation_combobox.currentText() == Calculations.NonPolarised assert project_widget.calculation_type.text() == Calculations.NonPolarised @@ -167,16 +223,39 @@ def test_domains_tab(setup_project_widget): project_widget = setup_project_widget project_widget.edit_project_button.click() project_widget.calculation_combobox.setCurrentText(Calculations.Domains) - assert project_widget.modified_project.calculation == Calculations.Domains - project_widget.presenter.model.project.calculation = Calculations.Domains - project_widget.calculation_type.setText(Calculations.Domains) + assert project_widget.draft_project["calculation"] == Calculations.Domains project_widget.handle_domains_tab() - for ix, tab in enumerate( - ["Parameters", "Backgrounds", "Experimental Parameters", "Layers", "Data", "Contrasts", "Domains"] - ): - assert project_widget.project_tab.tabText(ix) == tab - assert project_widget.edit_project_tab.tabText(ix) == tab + domains_tab_index = 5 + assert project_widget.project_tab.isTabVisible(domains_tab_index) + assert project_widget.edit_project_tab.isTabVisible(domains_tab_index) - assert project_widget.project_tab.currentIndex() == 0 - assert project_widget.edit_project_tab.currentIndex() == 0 + +def test_project_tab_init(): + """Test that the project tab correctly creates field widgets.""" + fields = ["my_field", "parameters", "bulk_in"] + + tab = ProjectTabWidget(fields, parent) + + for field in fields: + if field in RATapi.project.parameter_class_lists: + assert isinstance(tab.tables[field], ParameterFieldWidget) + else: + assert isinstance(tab.tables[field], ProjectFieldWidget) + + +@pytest.mark.parametrize("edit_mode", [True, False]) +def test_project_tab_update_model(classlist, param_classlist, edit_mode): + """Test that updating a ProjectTabEditWidget produces the desired models.""" + + new_model = {"my_field": classlist, "parameters": param_classlist([])} + + tab = ProjectTabWidget(list(new_model), parent, edit_mode=edit_mode) + # change the parent to a mock to avoid spec issues + for table in tab.tables.values(): + table.parent = MagicMock() + tab.update_model(new_model) + + for field in new_model: + assert tab.tables[field].model.classlist == new_model[field] + assert tab.tables[field].model.edit_mode == edit_mode diff --git a/tests/test_controls.py b/tests/widgets/test_controls.py similarity index 100% rename from tests/test_controls.py rename to tests/widgets/test_controls.py diff --git a/tests/test_inputs.py b/tests/widgets/test_inputs.py similarity index 97% rename from tests/test_inputs.py rename to tests/widgets/test_inputs.py index ac787e6..8d4b203 100644 --- a/tests/test_inputs.py +++ b/tests/widgets/test_inputs.py @@ -34,7 +34,7 @@ def test_editor_type(field_info, expected_type, example_data): assert widget.get_data() == example_data -@pytest.mark.parametrize(("value", "decimals"), [("10.", 0), ("1e-5", 5), ("0.01144661", 8)]) +@pytest.mark.parametrize(("value", "decimals"), [("10.", 0), ("1e-5", 6), ("0.01144661", 8)]) def test_adaptive_spinbox(value, decimals): spinbox = AdaptiveDoubleSpinBox() spinbox.validate(value, 0) diff --git a/tests/test_terminal.py b/tests/widgets/test_terminal.py similarity index 100% rename from tests/test_terminal.py rename to tests/widgets/test_terminal.py diff --git a/tests/test_widgets.py b/tests/widgets/test_widgets.py similarity index 100% rename from tests/test_widgets.py rename to tests/widgets/test_widgets.py