From 5f87c31d928a75b47db3c24a70a36f06beb3a43b Mon Sep 17 00:00:00 2001 From: Niko Aarnio Date: Tue, 3 Dec 2024 16:14:08 +0200 Subject: [PATCH] feat(wizard-history): save and display validation metrics in modeling history --- .../modeling/machine_learning/training.py | 8 +++-- eis_qgis_plugin/eis_wizard/wizard_history.py | 32 +++++++++++++++---- .../resources/ui/wizard_model_history.ui | 12 +++++++ 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/eis_qgis_plugin/eis_wizard/modeling/machine_learning/training.py b/eis_qgis_plugin/eis_wizard/modeling/machine_learning/training.py index da71435b..2d76f4ec 100644 --- a/eis_qgis_plugin/eis_wizard/modeling/machine_learning/training.py +++ b/eis_qgis_plugin/eis_wizard/modeling/machine_learning/training.py @@ -125,7 +125,7 @@ def on_algorithm_executor_finished(self, result, execution_time): **self.get_common_parameter_values(), **self.get_validation_settings(as_str = True) } - self.save_info(model_parameters_as_str, execution_time) + self.save_info(result, model_parameters_as_str, execution_time) self.training_feedback.pushInfo(f"\nTraining time: {execution_time}") else: self.training_feedback.report_failed_run() @@ -260,7 +260,7 @@ def cancel(self): self.executor.cancel() - def save_info(self, model_parameters: dict, execution_time: Optional[float] = None): + def save_info(self, results: dict, model_parameters: dict, execution_time: Optional[float] = None): """Save model info with ModelManager.""" model_info = MLModelInfo( model_instance_name=self.train_model_instance_name.text(), @@ -273,6 +273,10 @@ def save_info(self, model_parameters: dict, execution_time: Optional[float] = No evidence_data=[(layer.name(), layer.source()) for layer in self.train_evidence_data.get_layers()], label_data=(self.get_training_label_layer().name(), self.get_training_label_layer().source()), parameters=model_parameters, + validation_metrics={ + key: float(value) for key, value in results.items() + if key.capitalize() in self.model_main.get_valid_metrics() + } ) self.model_main.get_model_manager().save_model_info(model_info) diff --git a/eis_qgis_plugin/eis_wizard/wizard_history.py b/eis_qgis_plugin/eis_wizard/wizard_history.py index ddb8e2a6..6b861385 100644 --- a/eis_qgis_plugin/eis_wizard/wizard_history.py +++ b/eis_qgis_plugin/eis_wizard/wizard_history.py @@ -52,6 +52,8 @@ def __init__(self, parent=None, model_manager: ModelManager = None) -> None: self.label_data_box: QGroupBox self.parameters_box: QGroupBox self.parameters_layout: QFormLayout + self.validation_metrics_box: QGroupBox + self.validation_metrics_layout: QFormLayout self.label_layer_name: QLabel self.label_filepath: QLabel @@ -108,6 +110,7 @@ def update_viewed_model(self, model_id: str): self.clear_evidence_data() self.clear_label_data() self.clear_parameter_data() + self.clear_validation_metrics_data() else: if not info.check_model_file(): self.model_file_label.setText("Model file (MISSING!)") @@ -117,6 +120,7 @@ def update_viewed_model(self, model_id: str): self.load_evidence_data(info) self.load_label_data(info) self.load_parameter_data(info) + self.load_validation_metrics_data(info) def update_model_file(self): @@ -145,16 +149,28 @@ def load_label_data(self, info: MLModelInfo): self.label_layer_name.setText(info.label_data[0]) self.label_filepath.setText(info.label_data[1]) + + def _create_label_and_value_widgets(self, name: str, value: int) -> tuple[QLabel, QLineEdit]: + name_label = QLabel() + name_label.setText(name) + value_widget = QLineEdit() + value_widget.setText(str(value)) + value_widget.setReadOnly(True) + return name_label, value_widget + def load_parameter_data(self, info: MLModelInfo): self.clear_parameter_data() for parameter_name, parameter_value in info.parameters.items(): - name_label = QLabel() - name_label.setText(parameter_name) - value_widget = QLineEdit() - value_widget.setText(str(parameter_value)) - value_widget.setReadOnly(True) - self.parameters_layout.addRow(name_label, value_widget) + self.parameters_layout.addRow(*self._create_label_and_value_widgets(parameter_name, parameter_value)) + + + def load_validation_metrics_data(self, info: MLModelInfo): + self.clear_validation_metrics_data() + if not info.validation_metrics: + return + for metric_name, metric_value in info.validation_metrics.items(): + self.validation_metrics_layout.addRow(*self._create_label_and_value_widgets(metric_name, metric_value)) def clear_summary_data(self): @@ -177,6 +193,10 @@ def clear_parameter_data(self): clear_layout(self.parameters_layout) + def clear_validation_metrics_data(self): + clear_layout(self.validation_metrics_layout) + + def _on_export_clicked(self): EISMessageManager().show_message("Model history exporting not implemented yet!", "invalid") diff --git a/eis_qgis_plugin/resources/ui/wizard_model_history.ui b/eis_qgis_plugin/resources/ui/wizard_model_history.ui index ffa89616..5fa18789 100644 --- a/eis_qgis_plugin/resources/ui/wizard_model_history.ui +++ b/eis_qgis_plugin/resources/ui/wizard_model_history.ui @@ -225,6 +225,18 @@ + + + + Validation metrics + + + + + + + +