Skip to content

Commit

Permalink
Support for alternative filter (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
soininen authored Jan 29, 2024
2 parents 5826a0d + 91c2a30 commit b8cdc10
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 46 deletions.
134 changes: 99 additions & 35 deletions spine_engine/project_item/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""
Provides connection classes for linking project items.
"""
""" Provides connection classes for linking project items. """
from dataclasses import asdict, dataclass, field
import os
import subprocess
Expand All @@ -21,6 +18,7 @@
from datapackage import Package
from spinedb_api import DatabaseMapping, SpineDBAPIError, SpineDBVersionError
from spinedb_api.filters.scenario_filter import SCENARIO_FILTER_TYPE
from spinedb_api.filters.alternative_filter import ALTERNATIVE_FILTER_TYPE
from spinedb_api.purge import purge_url
from spine_engine.project_item.project_item_resource import (
file_resource,
Expand All @@ -37,6 +35,9 @@
from spine_engine.utils.queue_logger import QueueLogger


SUPPORTED_FILTER_TYPES = {ALTERNATIVE_FILTER_TYPE, SCENARIO_FILTER_TYPE}


class ConnectionBase:
"""Base class for connections between two project items."""

Expand Down Expand Up @@ -163,6 +164,9 @@ def emit_flash(self):
self._logger.flash.emit()


DEFAULT_ENABLED_FILTER_TYPES = {ALTERNATIVE_FILTER_TYPE: False, SCENARIO_FILTER_TYPE: True}


@dataclass
class FilterSettings:
"""Filter settings for resource converting connections."""
Expand All @@ -171,6 +175,17 @@ class FilterSettings:
"""mapping from resource labels and filter types to filter online statuses"""
auto_online: bool = True
"""if True, set unknown filters automatically online"""
enabled_filter_types: dict = field(default_factory=DEFAULT_ENABLED_FILTER_TYPES.copy)

def __post_init__(self):
for resource, online_filters in self.known_filters.items():
supported_filters = {
filter_type: online
for filter_type, online in online_filters.items()
if filter_type in SUPPORTED_FILTER_TYPES
}
if supported_filters:
self.known_filters[resource] = supported_filters

def has_filters(self):
"""Tests if there are filters.
Expand All @@ -179,8 +194,8 @@ def has_filters(self):
bool: True if filters of any type exists, False otherwise
"""
for filters_by_type in self.known_filters.values():
for filters in filters_by_type.values():
if filters:
for filter_type, filters in filters_by_type.items():
if self.enabled_filter_types[filter_type] and filters:
return True
return False

Expand All @@ -191,8 +206,8 @@ def has_any_filter_online(self):
bool: True if any filter is online, False otherwise
"""
for filters_by_type in self.known_filters.values():
for filters in filters_by_type.values():
if any(filters.values()):
for filter_type, filters in filters_by_type.items():
if self.enabled_filter_types[filter_type] and any(filters.values()):
return True
return False

Expand All @@ -205,6 +220,8 @@ def has_filter_online(self, filter_type):
Returns:
bool: True if any filter of filter_type is online, False otherwise
"""
if not self.enabled_filter_types[filter_type]:
return False
for filters_by_type in self.known_filters.values():
if any(filters_by_type.get(filter_type, {}).values()):
return True
Expand Down Expand Up @@ -306,19 +323,32 @@ def require_filter_online(self, filter_type):
Returns:
bool: True if online filters are required, False otherwise
"""
return self.options.get("require_" + filter_type, False)
return self._filter_settings.enabled_filter_types[filter_type] and self.options.get(
"require_" + filter_type, False
)

def is_filter_type_enabled(self, filter_type):
"""Tests if filter type is enabled.
Args:
filter_type (str): filter type
Returns:
bool: True if filter type is enabled, False otherwise
"""
return self._filter_settings.enabled_filter_types[filter_type]

def notifications(self):
"""See base class."""
notifications = []
for filter_type in (SCENARIO_FILTER_TYPE,):
for filter_type in (SCENARIO_FILTER_TYPE, ALTERNATIVE_FILTER_TYPE):
filter_settings = self._filter_settings
if self.require_filter_online(filter_type) and (
not filter_settings.has_filter_online(filter_type)
if filter_settings.has_filters()
else not filter_settings.auto_online
):
filter_name = {SCENARIO_FILTER_TYPE: "scenario"}[filter_type]
filter_name = {SCENARIO_FILTER_TYPE: "scenario", ALTERNATIVE_FILTER_TYPE: "alternative"}[filter_type]
notifications.append(f"At least one {filter_name} filter must be active.")
return notifications

Expand Down Expand Up @@ -413,7 +443,7 @@ def _apply_use_datapackage(self, resources):

def ready_to_execute(self):
"""See base class."""
for filter_type in (SCENARIO_FILTER_TYPE,):
for filter_type in (SCENARIO_FILTER_TYPE, ALTERNATIVE_FILTER_TYPE):
if self.require_filter_online(filter_type) and not self._filter_settings.has_filter_online(filter_type):
return False
return True
Expand All @@ -427,7 +457,10 @@ def to_dict(self):
d = super().to_dict()
if self.options:
d["options"] = self.options.copy()
if self._filter_settings.has_filters():
if (
self._filter_settings.has_filters()
or self._filter_settings.enabled_filter_types != DEFAULT_ENABLED_FILTER_TYPES
):
d["filter_settings"] = self._filter_settings.to_dict()
return d

Expand Down Expand Up @@ -463,7 +496,7 @@ def __init__(
filter_settings (FilterSettings, optional): filter settings
"""
super().__init__(source_name, source_position, destination_name, destination_position, options, filter_settings)
self._enabled_filter_names = None
self._enabled_filter_values = None
self._source_visited = False

def visit_source(self):
Expand All @@ -485,36 +518,67 @@ def enabled_filters(self, resource_label):
Returns:
dict: mapping from filter type to list of online filter names
"""
if self._enabled_filter_names is None:
self._prepare_enabled_filter_names()
return self._enabled_filter_names.get(resource_label)
if self._enabled_filter_values is None:
self._prepare_enabled_filter_values()
return self._enabled_filter_values.get(resource_label)

def _prepare_enabled_filter_names(self):
def _prepare_enabled_filter_values(self):
"""Reads filter information from database."""
self._enabled_filter_names = {}
self._enabled_filter_values = {}
for resource in self._resources:
url = resource.url
if not url:
continue
try:
db_map = DatabaseMapping(url)
with DatabaseMapping(url) as db_map:
known_filters = self._filter_settings.known_filters.get(resource.label, {})
enabled_filter_values = self._enabled_filter_values.setdefault(resource.label, {})
if self._filter_settings.enabled_filter_types[SCENARIO_FILTER_TYPE]:
enabled_scenarios = self._fetch_scenario_filter_values(db_map, known_filters)
if enabled_scenarios:
enabled_filter_values[SCENARIO_FILTER_TYPE] = enabled_scenarios
if self._filter_settings.enabled_filter_types[ALTERNATIVE_FILTER_TYPE]:
enabled_alternatives = self._fetch_alternative_filter_values(db_map, known_filters)
if enabled_alternatives:
enabled_filter_values[ALTERNATIVE_FILTER_TYPE] = enabled_alternatives
except (SpineDBAPIError, SpineDBVersionError):
continue
try:
scenario_filter_settings = self._filter_settings.known_filters.get(resource.label, {}).get(
SCENARIO_FILTER_TYPE, {}
)
available_scenarios = {row.name for row in db_map.query(db_map.scenario_sq)}
enabled_scenarios = set()
for name in available_scenarios:
if scenario_filter_settings.get(name, self._filter_settings.auto_online):
enabled_scenarios.add(name)
if enabled_scenarios:
self._enabled_filter_names.setdefault(resource.label, {})[SCENARIO_FILTER_TYPE] = sorted(
list(enabled_scenarios)
)
finally:
db_map.close()

def _fetch_scenario_filter_values(self, db_map, known_filters):
"""Fetches scenario names from database and picks the ones that are enabled by filter settings.
Args:
db_map (DatabaseMapping): database mapping
known_filters (dict): mapping from filter type to filter settings
Returns:
list of str: scenario filter values
"""
filter_settings = known_filters.get(SCENARIO_FILTER_TYPE, {})
available_scenarios = {row.name for row in db_map.query(db_map.scenario_sq)}
enabled_scenarios = set()
for scenario_name in available_scenarios:
if filter_settings.get(scenario_name, self._filter_settings.auto_online):
enabled_scenarios.add(scenario_name)
return sorted(enabled_scenarios)

def _fetch_alternative_filter_values(self, db_map, known_filters):
"""Fetches enabled alternative names from database.
Args:
db_map (DatabaseMapping): database mapping
known_filters (dict): mapping from filter type to filter settings
Returns:
list of list of str: alternative filter values
"""
filter_settings = known_filters.get(ALTERNATIVE_FILTER_TYPE, {})
available_alternatives = {row.name for row in db_map.query(db_map.alternative_sq)}
enabled_alternatives = set()
for alternative_name in available_alternatives:
if filter_settings.get(alternative_name, self._filter_settings.auto_online):
enabled_alternatives.add(alternative_name)
return [list(enabled_alternatives)] if enabled_alternatives else []

@classmethod
def from_dict(cls, connection_dict):
Expand Down
10 changes: 3 additions & 7 deletions spine_engine/spine_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################

"""
Contains the SpineEngine class for running Spine Toolbox DAGs.
"""
""" Contains the SpineEngine class for running Spine Toolbox DAGs. """
from enum import Enum, unique
import os
import threading
Expand Down Expand Up @@ -735,8 +731,8 @@ def _filter_stacks(self, item_name, provider_name, resource_label):
if filters is None:
return []
filter_configs_list = []
for filter_type, names in filters.items():
filter_configs = [filter_config(filter_type, name) for name in names]
for filter_type, values in filters.items():
filter_configs = [filter_config(filter_type, value) for value in values]
if not filter_configs:
continue
filter_configs_list.append(filter_configs)
Expand Down
23 changes: 19 additions & 4 deletions tests/project_item/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""
Uni tests for the ``connection`` module.
"""
""" Uni tests for the ``connection`` module. """
import os.path
from tempfile import TemporaryDirectory
import unittest
Expand Down Expand Up @@ -191,6 +188,13 @@ def test_has_filters_returns_true_when_filters_exist(self):
settings = FilterSettings({"database@Data Store": {SCENARIO_FILTER_TYPE: {"scenario_1": True}}})
self.assertTrue(settings.has_filters())

def test_has_filters_returns_false_when_filter_type_is_disabled(self):
settings = FilterSettings(
{"database@Data Store": {SCENARIO_FILTER_TYPE: {"scenario_1": True}}},
enabled_filter_types={SCENARIO_FILTER_TYPE: False},
)
self.assertFalse(settings.has_filters())

def test_has_filters_online_returns_false_when_no_filters_exist(self):
settings = FilterSettings()
self.assertFalse(settings.has_filter_online(SCENARIO_FILTER_TYPE))
Expand All @@ -207,6 +211,17 @@ def test_has_filter_online_works_when_there_are_no_known_filters(self):
settings = FilterSettings()
self.assertFalse(settings.has_filter_online(SCENARIO_FILTER_TYPE))

def test_has_any_filter_online_returns_false_when_no_filters_exist(self):
settings = FilterSettings()
self.assertFalse(settings.has_any_filter_online())

def test_has_any_filter_online_returns_false_when_filter_type_is_disabled(self):
settings = FilterSettings(
{"database@Data Store": {SCENARIO_FILTER_TYPE: {"scenario_1": True}}},
enabled_filter_types={SCENARIO_FILTER_TYPE: False},
)
self.assertFalse(settings.has_any_filter_online())

def test_has_any_filter_online_returns_true_when_filters_are_online(self):
settings = FilterSettings(
{"database@Data Store": {SCENARIO_FILTER_TYPE: {"scenario_1": False, "scenario_2": True}}}
Expand Down

0 comments on commit b8cdc10

Please sign in to comment.