From 881d6b8df66a3b1400b617b01ca26d2d520211ae Mon Sep 17 00:00:00 2001 From: Steve Bachmeier <23350991+stevebachmeier@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:30:02 -0700 Subject: [PATCH] Sbachmei/mic 5549/mypy results context (#538) --- CHANGELOG.rst | 4 ++ docs/source/concepts/results.rst | 26 ++++----- pyproject.toml | 1 - src/vivarium/framework/results/context.py | 54 +++++++++---------- src/vivarium/framework/results/manager.py | 5 +- src/vivarium/framework/results/observation.py | 17 +++--- src/vivarium/framework/results/observer.py | 2 +- .../framework/results/stratification.py | 7 +-- src/vivarium/types.py | 4 ++ 9 files changed, 65 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c25fbcd9e..982bc0291 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**3.2.1 - TBD/TBD/TBD** + + - Fix mypy errors in vivarium/framework/results/context.py + **3.2.0 - 11/12/24** - Feature: Supports passing callables directly when building lookup tables diff --git a/docs/source/concepts/results.rst b/docs/source/concepts/results.rst index 3581aed59..4463a7fb8 100644 --- a/docs/source/concepts/results.rst +++ b/docs/source/concepts/results.rst @@ -303,7 +303,7 @@ A couple other more specific and commonly used observations are provided as well that gathers new results and concatenates them to any existing results. Ideally, all concrete classes should inherit from the -:class:`BaseObservation ` +:class:`Observation ` abstract base class, which contains the common attributes between observation types: .. list-table:: **Common Observation Attributes** @@ -312,40 +312,40 @@ abstract base class, which contains the common attributes between observation ty * - Attribute - Description - * - | :attr:`name ` + * - | :attr:`name ` - | Name of the observation. It will also be the name of the output results file | for this particular observation. - * - | :attr:`pop_filter ` + * - | :attr:`pop_filter ` - | A Pandas query filter string to filter the population down to the simulants | who should be considered for the observation. - * - | :attr:`when ` + * - | :attr:`when ` - | Name of the lifecycle phase the observation should happen. Valid values are: | "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". - * - | :attr:`results_initializer ` + * - | :attr:`results_initializer ` - | Method or function that initializes the raw observation results | prior to starting the simulation. This could return, for example, an empty | DataFrame or one with a complete set of stratifications as the index and | all values set to 0.0. - * - | :attr:`results_gatherer ` + * - | :attr:`results_gatherer ` - | Method or function that gathers the new observation results. - * - | :attr:`results_updater ` + * - | :attr:`results_updater ` - | Method or function that updates existing raw observation results with newly | gathered results. - * - | :attr:`results_formatter ` + * - | :attr:`results_formatter ` - | Method or function that formats the raw observation results. - * - | :attr:`stratifications ` + * - | :attr:`stratifications ` - | Optional tuple of column names for the observation to stratify by. - * - | :attr:`to_observe ` + * - | :attr:`to_observe ` - | Method or function that determines whether to perform an observation on this Event. -The **BaseObservation** also contains the -:meth:`observe ` +The **Observation** also contains the +:meth:`observe ` method which is called at each :ref:`event ` and :ref:`time step ` to determine whether or not the observation should be recorded, and if so, gathers the results and stores them in the results system. .. note:: - All four observation types discussed above inherit from the **BaseObservation** + All four observation types discussed above inherit from the **Observation** abstract base class. What differentiates them are the assigned attributes (e.g. defining the **results_updater** to be an adding method for the **AddingObservation**) or adding other attributes as necessary (e.g. diff --git a/pyproject.toml b/pyproject.toml index 3f89b0c98..ccc81506c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ exclude = [ 'src/vivarium/framework/lookup/manager.py', 'src/vivarium/framework/population/manager.py', 'src/vivarium/framework/population/population_view.py', - 'src/vivarium/framework/results/context.py', 'src/vivarium/framework/results/interface.py', 'src/vivarium/framework/results/manager.py', 'src/vivarium/framework/results/observer.py', diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index 353cecbae..987c5777d 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -1,4 +1,3 @@ -# mypy: ignore-errors """ =============== Results Context @@ -6,8 +5,11 @@ """ +from __future__ import annotations + from collections import defaultdict -from typing import Callable, Generator, List, Optional, Tuple, Type, Union +from collections.abc import Callable, Generator +from typing import Any import pandas as pd from pandas.core.groupby.generic import DataFrameGroupBy @@ -15,13 +17,13 @@ from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.results.exceptions import ResultsConfigurationError -from vivarium.framework.results.observation import BaseObservation +from vivarium.framework.results.observation import Observation from vivarium.framework.results.stratification import ( Stratification, get_mapped_col_name, get_original_col_name, ) -from vivarium.types import ScalarValue +from vivarium.types import ScalarMapper, VectorMapper class ResultsContext: @@ -52,10 +54,12 @@ class ResultsContext: """ def __init__(self) -> None: - self.default_stratifications: List[str] = [] - self.stratifications: List[Stratification] = [] + self.default_stratifications: list[str] = [] + self.stratifications: list[Stratification] = [] self.excluded_categories: dict[str, list[str]] = {} - self.observations: defaultdict = defaultdict(lambda: defaultdict(list)) + self.observations: defaultdict[ + str, defaultdict[tuple[str, tuple[str, ...] | None], list[Observation]] + ] = defaultdict(lambda: defaultdict(list)) @property def name(self) -> str: @@ -73,7 +77,7 @@ def setup(self, builder: Builder) -> None: ) # noinspection PyAttributeOutsideInit - def set_default_stratifications(self, default_grouping_columns: List[str]) -> None: + def set_default_stratifications(self, default_grouping_columns: list[str]) -> None: """Set the default stratifications to be used by stratified observations. Parameters @@ -96,15 +100,10 @@ def set_default_stratifications(self, default_grouping_columns: List[str]) -> No def add_stratification( self, name: str, - sources: List[str], - categories: List[str], - excluded_categories: Optional[List[str]], - mapper: Optional[ - Union[ - Callable[[Union[pd.Series, pd.DataFrame]], pd.Series], - Callable[[ScalarValue], str], - ] - ], + sources: list[str], + categories: list[str], + excluded_categories: list[str] | None, + mapper: VectorMapper | ScalarMapper | None, is_vectorized: bool, ) -> None: """Add a stratification to the results context. @@ -187,11 +186,11 @@ def add_stratification( def register_observation( self, - observation_type: Type[BaseObservation], + observation_type: type[Observation], name: str, pop_filter: str, when: str, - **kwargs, + **kwargs: Any, ) -> None: """Add an observation to the results context. @@ -242,10 +241,10 @@ def register_observation( def gather_results( self, population: pd.DataFrame, lifecycle_phase: str, event: Event ) -> Generator[ - Tuple[ - Optional[pd.DataFrame], - Optional[str], - Optional[Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]], + tuple[ + pd.DataFrame | None, + str | None, + Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] | None, ], None, None, @@ -302,6 +301,7 @@ def gather_results( if filtered_pop.empty: yield None, None, None else: + pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str] if stratification_names is None: pop = filtered_pop else: @@ -317,7 +317,7 @@ def _filter_population( self, population: pd.DataFrame, pop_filter: str, - stratification_names: Optional[tuple[str, ...]], + stratification_names: tuple[str, ...] | None, ) -> pd.DataFrame: """Filter out simulants not to observe.""" pop = population.query(pop_filter) if pop_filter else population.copy() @@ -334,8 +334,8 @@ def _filter_population( @staticmethod def _get_groups( - stratifications: Tuple[str, ...], filtered_pop: pd.DataFrame - ) -> DataFrameGroupBy: + stratifications: tuple[str, ...], filtered_pop: pd.DataFrame + ) -> DataFrameGroupBy[tuple[str, ...] | str]: """Group the population by stratification. Notes @@ -356,7 +356,7 @@ def _get_groups( ) else: pop_groups = filtered_pop.groupby(lambda _: "all") - return pop_groups + return pop_groups # type: ignore[return-value] def _rename_stratification_columns(self, results: pd.DataFrame) -> None: """Convert the temporary stratified mapped index names back to their original names.""" diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index 6c18d279d..ee6864760 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -8,12 +8,13 @@ from collections import defaultdict from enum import Enum -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union import pandas as pd from vivarium.framework.event import Event from vivarium.framework.results.context import ResultsContext +from vivarium.framework.results.observation import Observation from vivarium.framework.values import Pipeline from vivarium.manager import Manager from vivarium.types import ScalarValue @@ -301,7 +302,7 @@ def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series: def register_observation( self, - observation_type, + observation_type: Type[Observation], is_stratified: bool, name: str, pop_filter: str, diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index f400900d4..8037573f4 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -6,7 +6,7 @@ An observation is a class object that records simulation results; they are responsible for initializing, gathering, updating, and formatting results. -The provided :class:`BaseObservation` class is an abstract base class that should +The provided :class:`Observation` class is an abstract base class that should be subclassed by concrete observations. While there are no required abstract methods to define when subclassing, the class does provide common attributes as well as an `observe` method that determines whether to observe results for a given event. @@ -24,7 +24,6 @@ from abc import ABC from collections.abc import Callable from dataclasses import dataclass -from typing import Any import pandas as pd from pandas.api.types import CategoricalDtype @@ -37,7 +36,7 @@ @dataclass -class BaseObservation(ABC): +class Observation(ABC): """An abstract base dataclass to be inherited by concrete observations. This class includes an :meth:`observe ` method that determines whether @@ -60,7 +59,8 @@ class BaseObservation(ABC): DataFrame or one with a complete set of stratifications as the index and all values set to 0.0.""" results_gatherer: Callable[ - [pd.DataFrame | DataFrameGroupBy[str], tuple[str, ...] | None], pd.DataFrame + [pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], tuple[str, ...] | None], + pd.DataFrame, ] """Method or function that gathers the new observation results.""" results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] @@ -76,7 +76,7 @@ class BaseObservation(ABC): def observe( self, event: Event, - df: pd.DataFrame | DataFrameGroupBy[str], + df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], stratifications: tuple[str, ...] | None, ) -> pd.DataFrame | None: """Determine whether to observe the given event, and if so, gather the results. @@ -100,7 +100,7 @@ def observe( return self.results_gatherer(df, stratifications) -class UnstratifiedObservation(BaseObservation): +class UnstratifiedObservation(Observation): """Concrete class for observing results that are not stratified. The parent class `stratifications` are set to None and the `results_initializer` @@ -139,7 +139,8 @@ def __init__( to_observe: Callable[[Event], bool] = lambda event: True, ): def _wrap_results_gatherer( - df: pd.DataFrame | DataFrameGroupBy[str], _: tuple[str, ...] | None + df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], + _: tuple[str, ...] | None, ) -> pd.DataFrame: if isinstance(df, DataFrameGroupBy): raise TypeError( @@ -181,7 +182,7 @@ def create_empty_df( return pd.DataFrame() -class StratifiedObservation(BaseObservation): +class StratifiedObservation(Observation): """Concrete class for observing stratified results. The parent class `results_initializer` and `results_gatherer` methods are diff --git a/src/vivarium/framework/results/observer.py b/src/vivarium/framework/results/observer.py index 7a37a5db6..a04c71b2b 100644 --- a/src/vivarium/framework/results/observer.py +++ b/src/vivarium/framework/results/observer.py @@ -5,7 +5,7 @@ ========= An observer is a component that is responsible for registering -:class:`observations ` +:class:`observations ` to the simulation. The provided :class:`Observer` class is an abstract base class that should be subclassed diff --git a/src/vivarium/framework/results/stratification.py b/src/vivarium/framework/results/stratification.py index 7be52813b..e0d4d1f7a 100644 --- a/src/vivarium/framework/results/stratification.py +++ b/src/vivarium/framework/results/stratification.py @@ -4,19 +4,20 @@ =============== """ + from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable +from typing import Any import pandas as pd from pandas.api.types import CategoricalDtype +from vivarium.types import ScalarMapper, VectorMapper + STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values" # TODO: Parameterizing pandas objects fails below python 3.12 -VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg] -ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg] @dataclass diff --git a/src/vivarium/types.py b/src/vivarium/types.py index 89da4389b..0de4ea498 100644 --- a/src/vivarium/types.py +++ b/src/vivarium/types.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from datetime import datetime, timedelta from numbers import Number from typing import Union @@ -25,3 +26,6 @@ float, int, ] + +VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg] +ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg]