From 91c2a305f3155e41b25b100f4fbca4afa408912c Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 29 Jan 2024 15:18:45 +0200 Subject: [PATCH] Add support for alternative filters into connections Alternative filters are executed in a single thread unlike scenario filters. Filter types can also be enabled or disabled which allows excluding scenario filters while including alternative filters and vice-versa. Re spine-tools/Spine-Toolbox#2147 --- spine_engine/project_item/connection.py | 134 +++++++++++++++++------- spine_engine/spine_engine.py | 10 +- tests/project_item/test_connection.py | 23 +++- 3 files changed, 121 insertions(+), 46 deletions(-) diff --git a/spine_engine/project_item/connection.py b/spine_engine/project_item/connection.py index 9f7bebf4..982e2843 100644 --- a/spine_engine/project_item/connection.py +++ b/spine_engine/project_item/connection.py @@ -8,10 +8,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### -""" -Provides connection classes for linking project items. - -""" +""" Provides connection classes for linking project items. """ from dataclasses import asdict, dataclass, field import os import subprocess @@ -21,6 +18,7 @@ from datapackage import Package from spinedb_api import DatabaseMapping, SpineDBAPIError, SpineDBVersionError from spinedb_api.filters.scenario_filter import SCENARIO_FILTER_TYPE +from spinedb_api.filters.alternative_filter import ALTERNATIVE_FILTER_TYPE from spinedb_api.purge import purge_url from spine_engine.project_item.project_item_resource import ( file_resource, @@ -37,6 +35,9 @@ from spine_engine.utils.queue_logger import QueueLogger +SUPPORTED_FILTER_TYPES = {ALTERNATIVE_FILTER_TYPE, SCENARIO_FILTER_TYPE} + + class ConnectionBase: """Base class for connections between two project items.""" @@ -163,6 +164,9 @@ def emit_flash(self): self._logger.flash.emit() +DEFAULT_ENABLED_FILTER_TYPES = {ALTERNATIVE_FILTER_TYPE: False, SCENARIO_FILTER_TYPE: True} + + @dataclass class FilterSettings: """Filter settings for resource converting connections.""" @@ -171,6 +175,17 @@ class FilterSettings: """mapping from resource labels and filter types to filter online statuses""" auto_online: bool = True """if True, set unknown filters automatically online""" + enabled_filter_types: dict = field(default_factory=DEFAULT_ENABLED_FILTER_TYPES.copy) + + def __post_init__(self): + for resource, online_filters in self.known_filters.items(): + supported_filters = { + filter_type: online + for filter_type, online in online_filters.items() + if filter_type in SUPPORTED_FILTER_TYPES + } + if supported_filters: + self.known_filters[resource] = supported_filters def has_filters(self): """Tests if there are filters. @@ -179,8 +194,8 @@ def has_filters(self): bool: True if filters of any type exists, False otherwise """ for filters_by_type in self.known_filters.values(): - for filters in filters_by_type.values(): - if filters: + for filter_type, filters in filters_by_type.items(): + if self.enabled_filter_types[filter_type] and filters: return True return False @@ -191,8 +206,8 @@ def has_any_filter_online(self): bool: True if any filter is online, False otherwise """ for filters_by_type in self.known_filters.values(): - for filters in filters_by_type.values(): - if any(filters.values()): + for filter_type, filters in filters_by_type.items(): + if self.enabled_filter_types[filter_type] and any(filters.values()): return True return False @@ -205,6 +220,8 @@ def has_filter_online(self, filter_type): Returns: bool: True if any filter of filter_type is online, False otherwise """ + if not self.enabled_filter_types[filter_type]: + return False for filters_by_type in self.known_filters.values(): if any(filters_by_type.get(filter_type, {}).values()): return True @@ -306,19 +323,32 @@ def require_filter_online(self, filter_type): Returns: bool: True if online filters are required, False otherwise """ - return self.options.get("require_" + filter_type, False) + return self._filter_settings.enabled_filter_types[filter_type] and self.options.get( + "require_" + filter_type, False + ) + + def is_filter_type_enabled(self, filter_type): + """Tests if filter type is enabled. + + Args: + filter_type (str): filter type + + Returns: + bool: True if filter type is enabled, False otherwise + """ + return self._filter_settings.enabled_filter_types[filter_type] def notifications(self): """See base class.""" notifications = [] - for filter_type in (SCENARIO_FILTER_TYPE,): + for filter_type in (SCENARIO_FILTER_TYPE, ALTERNATIVE_FILTER_TYPE): filter_settings = self._filter_settings if self.require_filter_online(filter_type) and ( not filter_settings.has_filter_online(filter_type) if filter_settings.has_filters() else not filter_settings.auto_online ): - filter_name = {SCENARIO_FILTER_TYPE: "scenario"}[filter_type] + filter_name = {SCENARIO_FILTER_TYPE: "scenario", ALTERNATIVE_FILTER_TYPE: "alternative"}[filter_type] notifications.append(f"At least one {filter_name} filter must be active.") return notifications @@ -413,7 +443,7 @@ def _apply_use_datapackage(self, resources): def ready_to_execute(self): """See base class.""" - for filter_type in (SCENARIO_FILTER_TYPE,): + for filter_type in (SCENARIO_FILTER_TYPE, ALTERNATIVE_FILTER_TYPE): if self.require_filter_online(filter_type) and not self._filter_settings.has_filter_online(filter_type): return False return True @@ -427,7 +457,10 @@ def to_dict(self): d = super().to_dict() if self.options: d["options"] = self.options.copy() - if self._filter_settings.has_filters(): + if ( + self._filter_settings.has_filters() + or self._filter_settings.enabled_filter_types != DEFAULT_ENABLED_FILTER_TYPES + ): d["filter_settings"] = self._filter_settings.to_dict() return d @@ -463,7 +496,7 @@ def __init__( filter_settings (FilterSettings, optional): filter settings """ super().__init__(source_name, source_position, destination_name, destination_position, options, filter_settings) - self._enabled_filter_names = None + self._enabled_filter_values = None self._source_visited = False def visit_source(self): @@ -485,36 +518,67 @@ def enabled_filters(self, resource_label): Returns: dict: mapping from filter type to list of online filter names """ - if self._enabled_filter_names is None: - self._prepare_enabled_filter_names() - return self._enabled_filter_names.get(resource_label) + if self._enabled_filter_values is None: + self._prepare_enabled_filter_values() + return self._enabled_filter_values.get(resource_label) - def _prepare_enabled_filter_names(self): + def _prepare_enabled_filter_values(self): """Reads filter information from database.""" - self._enabled_filter_names = {} + self._enabled_filter_values = {} for resource in self._resources: url = resource.url if not url: continue try: - db_map = DatabaseMapping(url) + with DatabaseMapping(url) as db_map: + known_filters = self._filter_settings.known_filters.get(resource.label, {}) + enabled_filter_values = self._enabled_filter_values.setdefault(resource.label, {}) + if self._filter_settings.enabled_filter_types[SCENARIO_FILTER_TYPE]: + enabled_scenarios = self._fetch_scenario_filter_values(db_map, known_filters) + if enabled_scenarios: + enabled_filter_values[SCENARIO_FILTER_TYPE] = enabled_scenarios + if self._filter_settings.enabled_filter_types[ALTERNATIVE_FILTER_TYPE]: + enabled_alternatives = self._fetch_alternative_filter_values(db_map, known_filters) + if enabled_alternatives: + enabled_filter_values[ALTERNATIVE_FILTER_TYPE] = enabled_alternatives except (SpineDBAPIError, SpineDBVersionError): continue - try: - scenario_filter_settings = self._filter_settings.known_filters.get(resource.label, {}).get( - SCENARIO_FILTER_TYPE, {} - ) - available_scenarios = {row.name for row in db_map.query(db_map.scenario_sq)} - enabled_scenarios = set() - for name in available_scenarios: - if scenario_filter_settings.get(name, self._filter_settings.auto_online): - enabled_scenarios.add(name) - if enabled_scenarios: - self._enabled_filter_names.setdefault(resource.label, {})[SCENARIO_FILTER_TYPE] = sorted( - list(enabled_scenarios) - ) - finally: - db_map.close() + + def _fetch_scenario_filter_values(self, db_map, known_filters): + """Fetches scenario names from database and picks the ones that are enabled by filter settings. + + Args: + db_map (DatabaseMapping): database mapping + known_filters (dict): mapping from filter type to filter settings + + Returns: + list of str: scenario filter values + """ + filter_settings = known_filters.get(SCENARIO_FILTER_TYPE, {}) + available_scenarios = {row.name for row in db_map.query(db_map.scenario_sq)} + enabled_scenarios = set() + for scenario_name in available_scenarios: + if filter_settings.get(scenario_name, self._filter_settings.auto_online): + enabled_scenarios.add(scenario_name) + return sorted(enabled_scenarios) + + def _fetch_alternative_filter_values(self, db_map, known_filters): + """Fetches enabled alternative names from database. + + Args: + db_map (DatabaseMapping): database mapping + known_filters (dict): mapping from filter type to filter settings + + Returns: + list of list of str: alternative filter values + """ + filter_settings = known_filters.get(ALTERNATIVE_FILTER_TYPE, {}) + available_alternatives = {row.name for row in db_map.query(db_map.alternative_sq)} + enabled_alternatives = set() + for alternative_name in available_alternatives: + if filter_settings.get(alternative_name, self._filter_settings.auto_online): + enabled_alternatives.add(alternative_name) + return [list(enabled_alternatives)] if enabled_alternatives else [] @classmethod def from_dict(cls, connection_dict): diff --git a/spine_engine/spine_engine.py b/spine_engine/spine_engine.py index 87468451..dc2246aa 100644 --- a/spine_engine/spine_engine.py +++ b/spine_engine/spine_engine.py @@ -8,11 +8,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - -""" -Contains the SpineEngine class for running Spine Toolbox DAGs. - -""" +""" Contains the SpineEngine class for running Spine Toolbox DAGs. """ from enum import Enum, unique import os import threading @@ -735,8 +731,8 @@ def _filter_stacks(self, item_name, provider_name, resource_label): if filters is None: return [] filter_configs_list = [] - for filter_type, names in filters.items(): - filter_configs = [filter_config(filter_type, name) for name in names] + for filter_type, values in filters.items(): + filter_configs = [filter_config(filter_type, value) for value in values] if not filter_configs: continue filter_configs_list.append(filter_configs) diff --git a/tests/project_item/test_connection.py b/tests/project_item/test_connection.py index e19835c5..707c6897 100644 --- a/tests/project_item/test_connection.py +++ b/tests/project_item/test_connection.py @@ -8,10 +8,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### -""" -Uni tests for the ``connection`` module. - -""" +""" Uni tests for the ``connection`` module. """ import os.path from tempfile import TemporaryDirectory import unittest @@ -191,6 +188,13 @@ def test_has_filters_returns_true_when_filters_exist(self): settings = FilterSettings({"database@Data Store": {SCENARIO_FILTER_TYPE: {"scenario_1": True}}}) self.assertTrue(settings.has_filters()) + def test_has_filters_returns_false_when_filter_type_is_disabled(self): + settings = FilterSettings( + {"database@Data Store": {SCENARIO_FILTER_TYPE: {"scenario_1": True}}}, + enabled_filter_types={SCENARIO_FILTER_TYPE: False}, + ) + self.assertFalse(settings.has_filters()) + def test_has_filters_online_returns_false_when_no_filters_exist(self): settings = FilterSettings() self.assertFalse(settings.has_filter_online(SCENARIO_FILTER_TYPE)) @@ -207,6 +211,17 @@ def test_has_filter_online_works_when_there_are_no_known_filters(self): settings = FilterSettings() self.assertFalse(settings.has_filter_online(SCENARIO_FILTER_TYPE)) + def test_has_any_filter_online_returns_false_when_no_filters_exist(self): + settings = FilterSettings() + self.assertFalse(settings.has_any_filter_online()) + + def test_has_any_filter_online_returns_false_when_filter_type_is_disabled(self): + settings = FilterSettings( + {"database@Data Store": {SCENARIO_FILTER_TYPE: {"scenario_1": True}}}, + enabled_filter_types={SCENARIO_FILTER_TYPE: False}, + ) + self.assertFalse(settings.has_any_filter_online()) + def test_has_any_filter_online_returns_true_when_filters_are_online(self): settings = FilterSettings( {"database@Data Store": {SCENARIO_FILTER_TYPE: {"scenario_1": False, "scenario_2": True}}}