From 13c1fbd8a2c5314e42f3c5fc3072c85f04042a62 Mon Sep 17 00:00:00 2001 From: Hussain Jafari Date: Wed, 13 Nov 2024 15:51:49 -0800 Subject: [PATCH] drop python 3.9 changes (#540) Category: CI JIRA issue: https://jira.ihme.washington.edu/browse/MIC-5536 Changes and notes Modernize type hinting Remove unnecessary "from future import annotation" imports Testing Ran pytest. --- CHANGELOG.rst | 4 +- docs/source/concepts/results.rst | 4 +- src/vivarium/component.py | 34 +++++----- src/vivarium/examples/boids/forces.py | 4 +- .../examples/disease_model/disease.py | 2 - .../examples/disease_model/intervention.py | 6 +- .../examples/disease_model/mortality.py | 6 +- .../examples/disease_model/observer.py | 6 +- .../examples/disease_model/population.py | 6 +- src/vivarium/examples/disease_model/risk.py | 8 +-- src/vivarium/framework/artifact/artifact.py | 2 - src/vivarium/framework/components/manager.py | 39 ++++++------ src/vivarium/framework/components/parser.py | 18 +++--- src/vivarium/framework/configuration.py | 2 - src/vivarium/framework/engine.py | 25 ++++---- src/vivarium/framework/event.py | 1 - src/vivarium/framework/lifecycle.py | 1 - src/vivarium/framework/logging/manager.py | 1 - .../framework/lookup/interpolation.py | 5 +- src/vivarium/framework/lookup/manager.py | 3 +- src/vivarium/framework/plugins.py | 2 - src/vivarium/framework/results/interface.py | 62 +++++++++---------- src/vivarium/framework/results/manager.py | 48 +++++++------- src/vivarium/framework/results/observer.py | 4 +- .../framework/results/stratification.py | 1 - src/vivarium/framework/state_machine.py | 16 ++--- src/vivarium/framework/time.py | 15 ++--- src/vivarium/framework/utilities.py | 12 ++-- src/vivarium/interface/interactive.py | 15 ++--- src/vivarium/interface/utilities.py | 2 - src/vivarium/testing_utilities.py | 10 +-- src/vivarium/types.py | 13 ++-- tests/conftest.py | 2 + tests/framework/artifact/test_manager.py | 4 +- tests/framework/components/mocks.py | 4 +- tests/framework/components/test_component.py | 2 + tests/framework/components/test_manager.py | 4 +- tests/framework/population/test_manager.py | 2 - .../population/test_population_view.py | 3 +- tests/framework/randomness/test_crn.py | 5 +- tests/framework/randomness/test_stream.py | 4 +- .../framework/resource/test_resource_group.py | 2 - tests/framework/results/helpers.py | 3 +- tests/framework/test_engine.py | 5 +- tests/framework/test_time.py | 3 +- tests/helpers.py | 18 +++--- 46 files changed, 206 insertions(+), 232 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 982bc0291..277087296 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,8 @@ -**3.2.1 - TBD/TBD/TBD** +**3.2.1 - 11/13/24** - Fix mypy errors in vivarium/framework/results/context.py + - Modernize type hinting + - Remove unnecessary "from future import annotation" imports **3.2.0 - 11/12/24** diff --git a/docs/source/concepts/results.rst b/docs/source/concepts/results.rst index 4463a7fb8..6ac81471c 100644 --- a/docs/source/concepts/results.rst +++ b/docs/source/concepts/results.rst @@ -66,7 +66,7 @@ to the existing number of people who have died from previous time steps. .. testcode:: - from typing import Any, Optional + from typing import Any import pandas as pd @@ -84,7 +84,7 @@ to the existing number of people who have died from previous time steps. } @property - def columns_required(self) -> Optional[list[str]]: + def columns_required(self) -> list[str] | None: return ["age", "alive"] def register_observations(self, builder: Builder) -> None: diff --git a/src/vivarium/component.py b/src/vivarium/component.py index 76f2caf1b..cf0b7cd70 100644 --- a/src/vivarium/component.py +++ b/src/vivarium/component.py @@ -13,12 +13,12 @@ import re import warnings from abc import ABC -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import datetime, timedelta from importlib import import_module from inspect import signature from numbers import Number -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any import pandas as pd from layered_config_tree import ConfigurationError, LayeredConfigTree @@ -87,7 +87,7 @@ class Component(ABC): """ - CONFIGURATION_DEFAULTS: Dict[str, Any] = {} + CONFIGURATION_DEFAULTS: dict[str, Any] = {} """A dictionary containing the defaults for any configurations managed by this component. An empty dictionary indicates no managed configurations. """ @@ -187,7 +187,7 @@ def population_view(self) -> PopulationView: return self._population_view @property - def sub_components(self) -> List["Component"]: + def sub_components(self) -> list["Component"]: """Provide components managed by this component. Returns @@ -198,7 +198,7 @@ def sub_components(self) -> List["Component"]: return self._sub_components @property - def configuration_defaults(self) -> Dict[str, Any]: + def configuration_defaults(self) -> dict[str, Any]: """Provides a dictionary containing the defaults for any configurations managed by this component. @@ -213,7 +213,7 @@ def configuration_defaults(self) -> Dict[str, Any]: return self.CONFIGURATION_DEFAULTS @property - def columns_created(self) -> List[str]: + def columns_created(self) -> list[str]: """Provides names of columns created by the component. Returns @@ -224,7 +224,7 @@ def columns_created(self) -> List[str]: return [] @property - def columns_required(self) -> Optional[List[str]]: + def columns_required(self) -> list[str] | None: """Provides names of columns required by the component. Returns @@ -244,7 +244,7 @@ def initialization_requirements( return [] @property - def population_view_query(self) -> Optional[str]: + def population_view_query(self) -> str | None: """Provides a query to use when filtering the component's `PopulationView`. Returns @@ -334,14 +334,12 @@ def __init__(self) -> None: """ self._repr: str = "" self._name: str = "" - self._sub_components: List["Component"] = [] - self.logger: Optional[Logger] = None - self.get_value_columns: Optional[ - Callable[[Union[str, pd.DataFrame]], List[str]] - ] = None - self.configuration: Optional[LayeredConfigTree] = None - self._population_view: Optional[PopulationView] = None - self.lookup_tables: Dict[str, LookupTable] = {} + self._sub_components: list["Component"] = [] + self.logger: Logger | None = None + self.get_value_columns: Callable[[str | pd.DataFrame], list[str]] | None = None + self.configuration: LayeredConfigTree | None = None + self._population_view: PopulationView | None = None + self.lookup_tables: dict[str, LookupTable] = {} def setup_component(self, builder: "Builder") -> None: """Sets up the component for a Vivarium simulation. @@ -501,7 +499,7 @@ def on_simulation_end(self, event: Event) -> None: # Helper methods # ################## - def get_initialization_parameters(self) -> Dict[str, Any]: + def get_initialization_parameters(self) -> dict[str, Any]: """Retrieves the values of all parameters specified in the `__init__` that have an attribute with the same name. @@ -521,7 +519,7 @@ def get_initialization_parameters(self) -> Dict[str, Any]: if hasattr(self, parameter_name) } - def get_configuration(self, builder: "Builder") -> Optional[LayeredConfigTree]: + def get_configuration(self, builder: "Builder") -> LayeredConfigTree | None: """Retrieves the configuration for this component from the builder. This method retrieves the configuration for this component from the diff --git a/src/vivarium/examples/boids/forces.py b/src/vivarium/examples/boids/forces.py index cc71e1f07..c960502bd 100644 --- a/src/vivarium/examples/boids/forces.py +++ b/src/vivarium/examples/boids/forces.py @@ -1,6 +1,6 @@ # mypy: ignore-errors from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import Any import numpy as np import pandas as pd @@ -14,7 +14,7 @@ class Force(Component, ABC): # Properties # ############## @property - def configuration_defaults(self) -> Dict[str, Any]: + def configuration_defaults(self) -> dict[str, Any]: return { self.__class__.__name__.lower(): { "max_force": 0.03, diff --git a/src/vivarium/examples/disease_model/disease.py b/src/vivarium/examples/disease_model/disease.py index 770619e41..619c67fe6 100644 --- a/src/vivarium/examples/disease_model/disease.py +++ b/src/vivarium/examples/disease_model/disease.py @@ -1,6 +1,4 @@ # mypy: ignore-errors -from __future__ import annotations - import pandas as pd from vivarium import Component diff --git a/src/vivarium/examples/disease_model/intervention.py b/src/vivarium/examples/disease_model/intervention.py index f09cc593e..0395c3239 100644 --- a/src/vivarium/examples/disease_model/intervention.py +++ b/src/vivarium/examples/disease_model/intervention.py @@ -1,5 +1,5 @@ # mypy: ignore-errors -from typing import Any, Dict +from typing import Any import pandas as pd @@ -8,7 +8,7 @@ class TreatmentIntervention(Component): - CONFIGURATION_DEFAULTS: Dict[str, Any] = { + CONFIGURATION_DEFAULTS: dict[str, Any] = { "intervention": { "effect_size": 0.5, } @@ -19,7 +19,7 @@ class TreatmentIntervention(Component): ############## @property - def configuration_defaults(self) -> Dict[str, Any]: + def configuration_defaults(self) -> dict[str, Any]: return {self.intervention: self.CONFIGURATION_DEFAULTS["intervention"]} ##################### diff --git a/src/vivarium/examples/disease_model/mortality.py b/src/vivarium/examples/disease_model/mortality.py index 191188828..806d09dd7 100644 --- a/src/vivarium/examples/disease_model/mortality.py +++ b/src/vivarium/examples/disease_model/mortality.py @@ -1,5 +1,5 @@ # mypy: ignore-errors -from typing import Any, Dict, List, Optional +from typing import Any import numpy as np import pandas as pd @@ -15,7 +15,7 @@ class Mortality(Component): ############## @property - def configuration_defaults(self) -> Dict[str, Any]: + def configuration_defaults(self) -> dict[str, Any]: """A set of default configuration values for this component. These can be overwritten in the simulation model specification or by @@ -28,7 +28,7 @@ def configuration_defaults(self) -> Dict[str, Any]: } @property - def columns_required(self) -> Optional[List[str]]: + def columns_required(self) -> list[str] | None: return ["tracked", "alive"] ##################### diff --git a/src/vivarium/examples/disease_model/observer.py b/src/vivarium/examples/disease_model/observer.py index 35b143fcc..bdad69181 100644 --- a/src/vivarium/examples/disease_model/observer.py +++ b/src/vivarium/examples/disease_model/observer.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import pandas as pd @@ -14,7 +14,7 @@ class DeathsObserver(Observer): ############## @property - def columns_required(self) -> Optional[list[str]]: + def columns_required(self) -> list[str] | None: return ["alive"] ################# @@ -39,7 +39,7 @@ class YllsObserver(Observer): ############## @property - def columns_required(self) -> Optional[list[str]]: + def columns_required(self) -> list[str] | None: return ["age", "alive"] @property diff --git a/src/vivarium/examples/disease_model/population.py b/src/vivarium/examples/disease_model/population.py index 5717a3ba4..9edc42221 100644 --- a/src/vivarium/examples/disease_model/population.py +++ b/src/vivarium/examples/disease_model/population.py @@ -1,5 +1,5 @@ # mypy: ignore-errors -from typing import Any, Dict, List +from typing import Any import pandas as pd @@ -17,7 +17,7 @@ class BasePopulation(Component): ############## @property - def configuration_defaults(self) -> Dict[str, Any]: + def configuration_defaults(self) -> dict[str, Any]: """A set of default configuration values for this component. These can be overwritten in the simulation model specification or by @@ -33,7 +33,7 @@ def configuration_defaults(self) -> Dict[str, Any]: } @property - def columns_created(self) -> List[str]: + def columns_created(self) -> list[str]: return ["age", "sex", "alive", "entrance_time"] ##################### diff --git a/src/vivarium/examples/disease_model/risk.py b/src/vivarium/examples/disease_model/risk.py index 0969813f9..931c33164 100644 --- a/src/vivarium/examples/disease_model/risk.py +++ b/src/vivarium/examples/disease_model/risk.py @@ -1,7 +1,7 @@ # mypy: ignore-errors from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Any import pandas as pd @@ -24,11 +24,11 @@ class Risk(Component): ############## @property - def configuration_defaults(self) -> Dict[str, Any]: + def configuration_defaults(self) -> dict[str, Any]: return {self.risk: self.CONFIGURATION_DEFAULTS["risk"]} @property - def columns_created(self) -> List[str]: + def columns_created(self) -> list[str]: return [self.propensity_column] @property @@ -91,7 +91,7 @@ class RiskEffect(Component): ############## @property - def configuration_defaults(self) -> Dict[str, Any]: + def configuration_defaults(self) -> dict[str, Any]: return {self.risk: self.CONFIGURATION_DEFAULTS["risk_effect"]} ##################### diff --git a/src/vivarium/framework/artifact/artifact.py b/src/vivarium/framework/artifact/artifact.py index 934645e01..1a2ff0f69 100644 --- a/src/vivarium/framework/artifact/artifact.py +++ b/src/vivarium/framework/artifact/artifact.py @@ -10,8 +10,6 @@ archive file for convenient access and inspection. """ -from __future__ import annotations - import re import warnings from collections import defaultdict diff --git a/src/vivarium/framework/components/manager.py b/src/vivarium/framework/components/manager.py index 4abb8aa9f..2c2567093 100644 --- a/src/vivarium/framework/components/manager.py +++ b/src/vivarium/framework/components/manager.py @@ -18,7 +18,8 @@ """ import inspect -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Sequence, Tuple, Union +from collections.abc import Iterator, Sequence +from typing import TYPE_CHECKING, Any from layered_config_tree import ( ConfigurationError, @@ -49,8 +50,8 @@ class OrderedComponentSet: """ - def __init__(self, *args: Union[Component, Manager]): - self.components: List[Union[Component, Manager]] = [] + def __init__(self, *args: Component | Manager): + self.components: list[Component | Manager] = [] if args: self.update(args) @@ -63,21 +64,21 @@ def add(self, component: Component) -> None: def update( self, - components: Union[List[Union[Component, Manager]], Tuple[Union[Component, Manager]]], + components: list[Component | Manager] | tuple[Component | Manager], ) -> None: for c in components: self.add(c) - def pop(self) -> Union[Component, Manager]: + def pop(self) -> Component | Manager: component = self.components.pop(0) return component - def __contains__(self, component: Union[Component, Manager]) -> bool: + def __contains__(self, component: Component | Manager) -> bool: if not hasattr(component, "name"): raise ComponentConfigError(f"Component {component} has no name attribute") return component.name in [c.name for c in self.components] - def __iter__(self) -> Iterator[Union[Component, Manager]]: + def __iter__(self) -> Iterator[Component | Manager]: return iter(self.components) def __len__(self) -> int: @@ -149,7 +150,7 @@ def setup( self.list_components, restrict_during=["initialization"] ) - def add_managers(self, managers: Union[List[Manager], Tuple[Manager]]) -> None: + def add_managers(self, managers: list[Manager] | tuple[Manager]) -> None: """Registers new managers with the component manager. Managers are configured and setup before components. @@ -163,7 +164,7 @@ def add_managers(self, managers: Union[List[Manager], Tuple[Manager]]) -> None: self.apply_configuration_defaults(m) self._managers.add(m) - def add_components(self, components: Union[List[Component], Tuple[Component]]) -> None: + def add_components(self, components: list[Component] | tuple[Component]) -> None: """Register new components with the component manager. Components are configured and setup after managers. @@ -178,8 +179,8 @@ def add_components(self, components: Union[List[Component], Tuple[Component]]) - self._components.add(c) def get_components_by_type( - self, component_type: Union[type, Sequence[type]] - ) -> List[Component]: + self, component_type: type | Sequence[type] + ) -> list[Component]: """Get all components that are an instance of ``component_type``. Parameters @@ -218,7 +219,7 @@ def get_component(self, name: str) -> Component: return c raise ValueError(f"No component found with name {name}") - def list_components(self) -> Dict[str, Component]: + def list_components(self) -> dict[str, Component]: """Get a mapping of component names to components held by the manager. Returns @@ -244,7 +245,7 @@ def setup_components(self, builder: "Builder") -> None: """ self._setup_components(builder, self._managers + self._components) - def apply_configuration_defaults(self, component: Union[Component, Manager]) -> None: + def apply_configuration_defaults(self, component: Component | Manager) -> None: try: self.configuration.update( component.configuration_defaults, @@ -271,7 +272,7 @@ def apply_configuration_defaults(self, component: Union[Component, Manager]) -> ) @staticmethod - def _get_file(component: Union[Component, Manager]) -> str: + def _get_file(component: Component | Manager) -> str: if component.__module__ == "__main__": # This is defined directly in a script or notebook so there's no # file to attribute it to. @@ -280,9 +281,7 @@ def _get_file(component: Union[Component, Manager]) -> str: return inspect.getfile(component.__class__) @staticmethod - def _flatten( - components: List[Union[Component, Manager]] - ) -> List[Union[Component, Manager]]: + def _flatten(components: list[Component | Manager]) -> list[Component | Manager]: out = [] components = components[::-1] while components: @@ -340,8 +339,8 @@ def get_component(self, name: str) -> Component: return self._manager.get_component(name) def get_components_by_type( - self, component_type: Union[type, Sequence[type]] - ) -> List[Component]: + self, component_type: type | Sequence[type] + ) -> list[Component]: """Get all components that are an instance of ``component_type``. Parameters @@ -356,7 +355,7 @@ def get_components_by_type( """ return self._manager.get_components_by_type(component_type) - def list_components(self) -> Dict[str, Component]: + def list_components(self) -> dict[str, Component]: """Get a mapping of component names to components held by the manager. Returns diff --git a/src/vivarium/framework/components/parser.py b/src/vivarium/framework/components/parser.py index 529184c09..45eea8fe9 100644 --- a/src/vivarium/framework/components/parser.py +++ b/src/vivarium/framework/components/parser.py @@ -22,8 +22,6 @@ """ -from typing import Dict, List, Tuple, Union - from layered_config_tree.main import LayeredConfigTree from vivarium.framework.utilities import import_by_path @@ -62,8 +60,8 @@ class ComponentConfigurationParser: """ def get_components( - self, component_config: Union[LayeredConfigTree, List[str]] - ) -> List[Component]: + self, component_config: LayeredConfigTree | list[str] + ) -> list[Component]: """Extracts component specifications from configuration information and returns initialized components. @@ -95,7 +93,7 @@ def get_components( ] return component_list - def parse_component_config(self, component_config: LayeredConfigTree) -> List[Component]: + def parse_component_config(self, component_config: LayeredConfigTree) -> list[Component]: """ Helper function for parsing a ``LayeredConfigTree`` into a flat list of Components. @@ -119,8 +117,8 @@ def parse_component_config(self, component_config: LayeredConfigTree) -> List[Co return self.process_level(component_config.to_dict(), []) def process_level( - self, level: Union[str, List[str], Dict[str, Union[Dict, List]]], prefix: List[str] - ) -> List[Component]: + self, level: str | list[str] | dict[str, dict | list], prefix: list[str] + ) -> list[Component]: """Helper function for parsing hierarchical component configuration into a flat list of Components. @@ -208,7 +206,7 @@ def create_component_from_string(self, component_string: str) -> Component: component = self.import_and_instantiate_component(component_path, args) return component - def prep_component(self, component_string: str) -> Tuple[str, Tuple]: + def prep_component(self, component_string: str) -> tuple[str, tuple]: """Transform component description string into a tuple of component paths and required arguments. @@ -226,7 +224,7 @@ def prep_component(self, component_string: str) -> Tuple[str, Tuple]: return path, cleaned_args @staticmethod - def _clean_args(args: List, path: str) -> Tuple: + def _clean_args(args: list, path: str) -> tuple: """Transform component arguments into a tuple, validating that each argument is a string. @@ -257,7 +255,7 @@ def _clean_args(args: List, path: str) -> Tuple: return tuple(out) @staticmethod - def import_and_instantiate_component(component_path: str, args: Tuple[str]) -> Component: + def import_and_instantiate_component(component_path: str, args: tuple[str]) -> Component: """Transform a tuple representing a Component into an actual instantiated component object. diff --git a/src/vivarium/framework/configuration.py b/src/vivarium/framework/configuration.py index f21def6e4..da3d6246e 100644 --- a/src/vivarium/framework/configuration.py +++ b/src/vivarium/framework/configuration.py @@ -8,8 +8,6 @@ :term:`configurations `. """ -from __future__ import annotations - from pathlib import Path from typing import Any diff --git a/src/vivarium/framework/engine.py b/src/vivarium/framework/engine.py index 90acf77f5..ce32f5aeb 100644 --- a/src/vivarium/framework/engine.py +++ b/src/vivarium/framework/engine.py @@ -23,7 +23,6 @@ from pathlib import Path from pprint import pformat from time import time -from typing import Dict, List, Optional, Set, Union import dill import numpy as np @@ -51,10 +50,10 @@ class SimulationContext: - _created_simulation_contexts: Set[str] = set() + _created_simulation_contexts: set[str] = set() @staticmethod - def _get_context_name(sim_name: Union[str, None]) -> str: + def _get_context_name(sim_name: str | None) -> str: """Get a unique name for a simulation context. Parameters @@ -99,11 +98,11 @@ def _clear_context_cache(): def __init__( self, - model_specification: Optional[Union[str, Path, LayeredConfigTree]] = None, - components: Optional[Union[List[Component], Dict, LayeredConfigTree]] = None, - configuration: Optional[Union[Dict, LayeredConfigTree]] = None, - plugin_configuration: Optional[Union[Dict, LayeredConfigTree]] = None, - sim_name: Optional[str] = None, + model_specification: str | Path | LayeredConfigTree | None = None, + components: list[Component] | dict | LayeredConfigTree | None = None, + configuration: dict | LayeredConfigTree | None = None, + plugin_configuration: dict | LayeredConfigTree | None = None, + sim_name: str | None = None, logging_verbosity: int = 1, ): self._name = self._get_context_name(sim_name) @@ -112,7 +111,7 @@ def __init__( component_configuration = ( components if isinstance(components, (dict, LayeredConfigTree)) else None ) - self._additional_components = components if isinstance(components, List) else [] + self._additional_components = components if isinstance(components, list) else [] self.model_specification = build_model_specification( model_specification, component_configuration, configuration, plugin_configuration ) @@ -214,7 +213,7 @@ def current_time(self) -> ClockTime: """Returns the current simulation time.""" return self._clock.time - def get_results(self) -> Dict[str, pd.DataFrame]: + def get_results(self) -> dict[str, pd.DataFrame]: """Return the formatted results.""" return self._results.get_results() @@ -268,8 +267,8 @@ def step(self) -> None: def run( self, - backup_path: Optional[Path] = None, - backup_freq: Optional[Union[int, float]] = None, + backup_path: Path | None = None, + backup_freq: int | float = None, ) -> None: if backup_freq: time_to_save = time() + backup_freq @@ -339,7 +338,7 @@ def get_performance_metrics(self) -> pd.DataFrame: performance_metrics = pd.DataFrame(records) return performance_metrics - def add_components(self, component_list: List[Component]) -> None: + def add_components(self, component_list: list[Component]) -> None: """Adds new components to the simulation.""" self._component_manager.add_components(component_list) diff --git a/src/vivarium/framework/event.py b/src/vivarium/framework/event.py index 075f2d816..802df1600 100644 --- a/src/vivarium/framework/event.py +++ b/src/vivarium/framework/event.py @@ -26,7 +26,6 @@ :ref:`concept note `. """ - from __future__ import annotations from collections.abc import Callable diff --git a/src/vivarium/framework/lifecycle.py b/src/vivarium/framework/lifecycle.py index 8db550356..0ddcd3e76 100644 --- a/src/vivarium/framework/lifecycle.py +++ b/src/vivarium/framework/lifecycle.py @@ -30,7 +30,6 @@ The tools here also allow for introspection of the simulation life cycle. """ - from __future__ import annotations import functools diff --git a/src/vivarium/framework/logging/manager.py b/src/vivarium/framework/logging/manager.py index c8f6fb1c9..0e4427610 100644 --- a/src/vivarium/framework/logging/manager.py +++ b/src/vivarium/framework/logging/manager.py @@ -4,7 +4,6 @@ ===================== """ - from __future__ import annotations import loguru diff --git a/src/vivarium/framework/lookup/interpolation.py b/src/vivarium/framework/lookup/interpolation.py index bf34b0d63..076f5891d 100644 --- a/src/vivarium/framework/lookup/interpolation.py +++ b/src/vivarium/framework/lookup/interpolation.py @@ -7,15 +7,12 @@ simulations. """ -from __future__ import annotations - from collections.abc import Hashable, Sequence -from typing import Union import numpy as np import pandas as pd -_SubTablesType = list[tuple[Union[tuple[Hashable, ...], Hashable, None], pd.DataFrame]] +_SubTablesType = list[tuple[tuple[Hashable, ...] | Hashable | None, pd.DataFrame]] class Interpolation: diff --git a/src/vivarium/framework/lookup/manager.py b/src/vivarium/framework/lookup/manager.py index cc1c5b78c..8284517fa 100644 --- a/src/vivarium/framework/lookup/manager.py +++ b/src/vivarium/framework/lookup/manager.py @@ -13,9 +13,10 @@ """ +from collections.abc import Sequence from datetime import datetime, timedelta from numbers import Number -from typing import TYPE_CHECKING, List, Sequence, Tuple, Union +from typing import TYPE_CHECKING import pandas as pd diff --git a/src/vivarium/framework/plugins.py b/src/vivarium/framework/plugins.py index 9edcd5f3f..893fc5939 100644 --- a/src/vivarium/framework/plugins.py +++ b/src/vivarium/framework/plugins.py @@ -9,8 +9,6 @@ """ -from __future__ import annotations - from dataclasses import dataclass from layered_config_tree.main import LayeredConfigTree diff --git a/src/vivarium/framework/results/interface.py b/src/vivarium/framework/results/interface.py index 16d4fb922..6a105b4f2 100644 --- a/src/vivarium/framework/results/interface.py +++ b/src/vivarium/framework/results/interface.py @@ -10,7 +10,8 @@ """ -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING import pandas as pd @@ -72,17 +73,14 @@ def name(self) -> str: def register_stratification( self, name: str, - categories: List[str], - excluded_categories: Optional[List[str]] = None, - mapper: Optional[ - Union[ - Callable[[Union[pd.Series, pd.DataFrame]], pd.Series], - Callable[[ScalarValue], str], - ] - ] = None, + categories: list[str], + excluded_categories: list[str] | None = None, + mapper: Callable[[pd.Series | pd.DataFrame], pd.Series] + | Callable[[ScalarValue], str] + | None = None, is_vectorized: bool = False, - requires_columns: List[str] = [], - requires_values: List[str] = [], + requires_columns: list[str] = [], + requires_values: list[str] = [], ) -> None: """Registers a stratification that can be used by stratified observations. @@ -125,11 +123,11 @@ def register_binned_stratification( self, target: str, binned_column: str, - bin_edges: List[Union[int, float]] = [], - labels: List[str] = [], - excluded_categories: Optional[List[str]] = None, + bin_edges: list[int | float] = [], + labels: list[str] = [], + excluded_categories: list[str] | None = None, target_type: str = "column", - **cut_kwargs: Dict, + **cut_kwargs: dict, ) -> None: """Registers a binned stratification that can be used by stratified observations. @@ -173,18 +171,18 @@ def register_stratified_observation( name: str, pop_filter: str = "tracked==True", when: str = "collect_metrics", - requires_columns: List[str] = [], - requires_values: List[str] = [], + requires_columns: list[str] = [], + requires_values: list[str] = [], results_updater: Callable[ [pd.DataFrame, pd.DataFrame], pd.DataFrame ] = _required_function_placeholder, results_formatter: Callable[ [str, pd.DataFrame], pd.DataFrame ] = lambda measure, results: results, - additional_stratifications: List[str] = [], - excluded_stratifications: List[str] = [], - aggregator_sources: Optional[List[str]] = None, - aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]] = len, + additional_stratifications: list[str] = [], + excluded_stratifications: list[str] = [], + aggregator_sources: list[str] | None = None, + aggregator: Callable[[pd.DataFrame], float | pd.Series] = len, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers a stratified observation to the results system. @@ -249,8 +247,8 @@ def register_unstratified_observation( name: str, pop_filter: str = "tracked==True", when: str = "collect_metrics", - requires_columns: List[str] = [], - requires_values: List[str] = [], + requires_columns: list[str] = [], + requires_values: list[str] = [], results_gatherer: Callable[ [pd.DataFrame], pd.DataFrame ] = _required_function_placeholder, @@ -317,15 +315,15 @@ def register_adding_observation( name: str, pop_filter: str = "tracked==True", when: str = "collect_metrics", - requires_columns: List[str] = [], - requires_values: List[str] = [], + requires_columns: list[str] = [], + requires_values: list[str] = [], results_formatter: Callable[ [str, pd.DataFrame], pd.DataFrame ] = lambda measure, results: results.reset_index(), - additional_stratifications: List[str] = [], - excluded_stratifications: List[str] = [], - aggregator_sources: Optional[List[str]] = None, - aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]] = len, + additional_stratifications: list[str] = [], + excluded_stratifications: list[str] = [], + aggregator_sources: list[str] | None = None, + aggregator: Callable[[pd.DataFrame], float | pd.Series] = len, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: """Registers an adding observation to the results system. @@ -388,8 +386,8 @@ def register_concatenating_observation( name: str, pop_filter: str = "tracked==True", when: str = "collect_metrics", - requires_columns: List[str] = [], - requires_values: List[str] = [], + requires_columns: list[str] = [], + requires_values: list[str] = [], results_formatter: Callable[ [str, pd.DataFrame], pd.DataFrame ] = lambda measure, results: results, @@ -440,7 +438,7 @@ def register_concatenating_observation( @staticmethod def _check_for_required_callables( - observation_name: str, required_callables: Dict[str, Callable] + observation_name: str, required_callables: dict[str, Callable] ) -> None: """Raises a ValueError if any required callable arguments are missing.""" missing = [] diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index ee6864760..afeb19f18 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -7,8 +7,9 @@ """ from collections import defaultdict +from collections.abc import Callable from enum import Enum -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING import pandas as pd @@ -55,7 +56,7 @@ def __init__(self) -> None: def name(self) -> str: return self._name - def get_results(self) -> Dict[str, pd.DataFrame]: + def get_results(self) -> dict[str, pd.DataFrame]: """Return the measure-specific formatted results in a dictionary. Notes @@ -187,17 +188,14 @@ def set_default_stratifications(self, builder: "Builder") -> None: def register_stratification( self, name: str, - categories: List[str], - excluded_categories: Optional[List[str]], - mapper: Optional[ - Union[ - Callable[[Union[pd.Series, pd.DataFrame]], pd.Series], - Callable[[ScalarValue], str], - ] - ], + categories: list[str], + excluded_categories: list[str] | None, + mapper: Callable[[pd.Series | pd.DataFrame], pd.Series] + | Callable[[ScalarValue], str] + | None, is_vectorized: bool, - requires_columns: List[str] = [], - requires_values: List[str] = [], + requires_columns: list[str] = [], + requires_values: list[str] = [], ) -> None: """Manager-level stratification registration. @@ -242,9 +240,9 @@ def register_binned_stratification( self, target: str, binned_column: str, - bin_edges: List[Union[int, float]], - labels: List[str], - excluded_categories: Optional[List[str]], + bin_edges: list[int | float], + labels: list[str], + excluded_categories: list[str] | None, target_type: str, **cut_kwargs, ) -> None: @@ -273,7 +271,7 @@ def register_binned_stratification( Keyword arguments for :meth: pandas.cut. """ - def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series: + def _bin_data(data: pd.Series | pd.DataFrame) -> pd.Series: """Use pandas.cut to bin continuous values""" data = data.squeeze() if not isinstance(data, pd.Series): @@ -302,13 +300,13 @@ def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series: def register_observation( self, - observation_type: Type[Observation], + observation_type: type[Observation], is_stratified: bool, name: str, pop_filter: str, when: str, - requires_columns: List[str], - requires_values: List[str], + requires_columns: list[str], + requires_values: list[str], **kwargs, ) -> None: """Manager-level observation registration. @@ -375,10 +373,10 @@ def register_observation( def _get_stratifications( self, - stratifications: List[str] = [], - additional_stratifications: List[str] = [], - excluded_stratifications: List[str] = [], - ) -> Tuple[str, ...]: + stratifications: list[str] = [], + additional_stratifications: list[str] = [], + excluded_stratifications: list[str] = [], + ) -> tuple[str, ...]: """Resolve the stratifications required for the observation.""" stratifications = list( set( @@ -391,7 +389,7 @@ def _get_stratifications( # Makes sure measure identifiers have fields in the same relative order. return tuple(sorted(stratifications)) - def _add_resources(self, target: List[str], target_type: SourceType) -> None: + def _add_resources(self, target: list[str], target_type: SourceType) -> None: """Add required resources to the manager's list of required columns and values.""" if len(target) == 0: return # do nothing on empty lists @@ -416,7 +414,7 @@ def _prepare_population(self, event: Event) -> pd.DataFrame: return population def _warn_check_stratifications( - self, additional_stratifications: List[str], excluded_stratifications: List[str] + self, additional_stratifications: list[str], excluded_stratifications: list[str] ) -> None: """Check additional and excluded stratifications if they'd not affect stratifications (i.e., would be NOP), and emit warning.""" diff --git a/src/vivarium/framework/results/observer.py b/src/vivarium/framework/results/observer.py index a04c71b2b..9cab55a62 100644 --- a/src/vivarium/framework/results/observer.py +++ b/src/vivarium/framework/results/observer.py @@ -15,7 +15,7 @@ """ from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import Any from vivarium import Component from vivarium.framework.engine import Builder @@ -35,7 +35,7 @@ def __init__(self) -> None: self.results_dir = None @property - def configuration_defaults(self) -> Dict[str, Any]: + def configuration_defaults(self) -> dict[str, Any]: return { "stratification": { self.get_configuration_name(): { diff --git a/src/vivarium/framework/results/stratification.py b/src/vivarium/framework/results/stratification.py index e0d4d1f7a..933c5c470 100644 --- a/src/vivarium/framework/results/stratification.py +++ b/src/vivarium/framework/results/stratification.py @@ -4,7 +4,6 @@ =============== """ - from __future__ import annotations from dataclasses import dataclass diff --git a/src/vivarium/framework/state_machine.py b/src/vivarium/framework/state_machine.py index c44d7b6f7..5a4c11a64 100644 --- a/src/vivarium/framework/state_machine.py +++ b/src/vivarium/framework/state_machine.py @@ -9,9 +9,9 @@ """ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Iterable from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -67,8 +67,8 @@ def _next_state( def _groupby_new_state( - index: pd.Index, outputs: List, decisions: pd.Series -) -> List[Tuple[str, pd.Index]]: + index: pd.Index, outputs: list, decisions: pd.Series +) -> list[tuple[str, pd.Index]]: """Groups the simulants in the index by their new output state. Parameters @@ -363,7 +363,7 @@ def setup(self, builder: Builder) -> None: # Public methods # ################## - def choose_new_state(self, index: pd.Index) -> Tuple[List, pd.Series]: + def choose_new_state(self, index: pd.Index) -> tuple[list, pd.Series]: """Chooses a new state for each simulant in the index. Parameters @@ -488,7 +488,7 @@ def sub_components(self): return self.states @property - def columns_created(self) -> List[str]: + def columns_created(self) -> list[str]: return [self.state_column] @property @@ -598,7 +598,7 @@ def cleanup(self, index: pd.Index, event_time: ClockTime) -> None: if not affected.empty: state.cleanup_effect(affected.index, event_time) - def _get_state_pops(self, index: pd.Index) -> List[Tuple[State, pd.DataFrame]]: + def _get_state_pops(self, index: pd.Index) -> list[tuple[State, pd.DataFrame]]: population = self.population_view.get(index) return [ (state, population[population[self.state_column] == state.state_id]) @@ -609,7 +609,7 @@ def _get_state_pops(self, index: pd.Index) -> List[Tuple[State, pd.DataFrame]]: # Helper methods # ################## - def get_initialization_parameters(self) -> Dict[str, Any]: + def get_initialization_parameters(self) -> dict[str, Any]: """ Gets the values of the state column specified in the __init__`. diff --git a/src/vivarium/framework/time.py b/src/vivarium/framework/time.py index 6e894f318..bb1e621ad 100644 --- a/src/vivarium/framework/time.py +++ b/src/vivarium/framework/time.py @@ -13,8 +13,9 @@ """ import math +from collections.abc import Callable from functools import partial -from typing import TYPE_CHECKING, Callable, List +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -39,11 +40,11 @@ def name(self): return "simulation_clock" @property - def columns_created(self) -> List[str]: + def columns_created(self) -> list[str]: return ["next_event_time", "step_size"] @property - def columns_required(self) -> List[str]: + def columns_required(self) -> list[str]: return ["tracked"] @property @@ -191,7 +192,7 @@ def move_simulants_to_end(self, index: pd.Index) -> None: if self._individual_clocks and index.any(): self._simulants_to_snooze = self._simulants_to_snooze.union(index) - def step_size_post_processor(self, values: List[NumberLike], _) -> pd.Series: + def step_size_post_processor(self, values: list[NumberLike], _) -> pd.Series: """Computes the largest feasible step size for each simulant. This is the smallest component-modified step size (rounded down to increments @@ -321,9 +322,9 @@ def move_simulants_to_end(self) -> Callable[[pd.Index], None]: def register_step_size_modifier( self, modifier: Callable[[pd.Index], pd.Series], - requires_columns: List[str] = (), - requires_values: List[str] = (), - requires_streams: List[str] = (), + requires_columns: list[str] = (), + requires_values: list[str] = (), + requires_streams: list[str] = (), ) -> None: """Registers a step size modifier. diff --git a/src/vivarium/framework/utilities.py b/src/vivarium/framework/utilities.py index bdaa6c500..79cf4dea5 100644 --- a/src/vivarium/framework/utilities.py +++ b/src/vivarium/framework/utilities.py @@ -7,13 +7,11 @@ Collection of utility functions shared by the ``vivarium`` framework. """ -from __future__ import annotations - import functools from bdb import BdbQuit -from collections.abc import Sequence +from collections.abc import Callable, Sequence from importlib import import_module -from typing import Any, Callable, Optional, Union +from typing import Any import numpy as np @@ -28,7 +26,7 @@ def to_yearly(value: NumberLike, time_step: Timedelta) -> NumberLike: return value / (time_step.total_seconds() / (60 * 60 * 24 * 365.0)) -def rate_to_probability(rate: Union[Sequence[float], NumberLike]) -> NumericArray: +def rate_to_probability(rate: Sequence[float] | NumberLike) -> NumericArray: # encountered underflow from rate > 30k # for rates greater than 250, exp(-rate) evaluates to 1e-109 # beware machine-specific floating point issues @@ -39,14 +37,14 @@ def rate_to_probability(rate: Union[Sequence[float], NumberLike]) -> NumericArra return probability -def probability_to_rate(probability: Union[Sequence[float], NumberLike]) -> NumericArray: +def probability_to_rate(probability: Sequence[float] | NumberLike) -> NumericArray: probability = np.array(probability) rate: NumericArray = -np.log(1 - probability) return rate def collapse_nested_dict( - d: dict[str, Any], prefix: Optional[str] = None + d: dict[str, Any], prefix: str | None = None ) -> list[tuple[str, Any]]: results = [] for k, v in d.items(): diff --git a/src/vivarium/interface/interactive.py b/src/vivarium/interface/interactive.py index 51129d861..4a45822a1 100644 --- a/src/vivarium/interface/interactive.py +++ b/src/vivarium/interface/interactive.py @@ -14,8 +14,9 @@ """ +from collections.abc import Callable from math import ceil -from typing import Any, Callable, Dict, List, Optional +from typing import Any import pandas as pd @@ -43,7 +44,7 @@ def setup(self): super().setup() self.initialize_simulants() - def step(self, step_size: Optional[ClockStepSize] = None) -> None: + def step(self, step_size: ClockStepSize | None = None) -> None: """Advance the simulation one step. Parameters @@ -133,7 +134,7 @@ def run_until(self, end_time: ClockTime, with_logging: bool = True) -> int: def take_steps( self, number_of_steps: int = 1, - step_size: Optional[ClockStepSize] = None, + step_size: ClockStepSize | None = None, with_logging: bool = True, ): """Run the simulation for the given number of steps. @@ -174,7 +175,7 @@ def get_population(self, untracked: bool = False) -> pd.DataFrame: """ return self._population.get_population(untracked) - def list_values(self) -> List[str]: + def list_values(self) -> list[str]: """List the names of all pipelines in the simulation.""" return list(self._values.keys()) @@ -182,11 +183,11 @@ def get_value(self, value_pipeline_name: str) -> Pipeline: """Get the value pipeline associated with the given name.""" return self._values.get_value(value_pipeline_name) - def list_events(self) -> List[str]: + def list_events(self) -> list[str]: """List all event types registered with the simulation.""" return self._events.list_events() - def get_listeners(self, event_type: str) -> Dict[int, List[Callable]]: + def get_listeners(self, event_type: str) -> dict[int, list[Callable]]: """Get all listeners of a particular type of event. Available event types can be found by calling @@ -225,7 +226,7 @@ def get_emitter(self, event_type: str) -> Callable: raise ValueError(f"No event {event_type} in system.") return self._events.get_emitter(event_type) - def list_components(self) -> Dict[str, Any]: + def list_components(self) -> dict[str, Any]: """Get a mapping of component names to components currently in the simulation. Returns diff --git a/src/vivarium/interface/utilities.py b/src/vivarium/interface/utilities.py index e6624a8ee..d344d09b6 100644 --- a/src/vivarium/interface/utilities.py +++ b/src/vivarium/interface/utilities.py @@ -7,8 +7,6 @@ interfaces for ``vivarium``. """ -from __future__ import annotations - import functools from collections.abc import Callable, Generator, Sequence from datetime import datetime diff --git a/src/vivarium/testing_utilities.py b/src/vivarium/testing_utilities.py index 36922ff62..9b640ab6f 100644 --- a/src/vivarium/testing_utilities.py +++ b/src/vivarium/testing_utilities.py @@ -10,7 +10,7 @@ from itertools import product from pathlib import Path -from typing import Any, Dict, List +from typing import Any import numpy as np import pandas as pd @@ -33,7 +33,7 @@ class NonCRNTestPopulation(Component): } @property - def columns_created(self) -> List[str]: + def columns_created(self) -> list[str]: return ["age", "sex", "location", "alive", "entrance_time", "exit_time"] def setup(self, builder: Builder) -> None: @@ -152,12 +152,12 @@ def _non_crn_build_population( def build_table( value: Any, - parameter_columns: Dict = { + parameter_columns: dict = { "age": (0, 125), "year": (1990, 2020), }, - key_columns: Dict = {"sex": ("Female", "Male")}, - value_columns: List = ["value"], + key_columns: dict = {"sex": ("Female", "Male")}, + value_columns: list = ["value"], ) -> pd.DataFrame: """ diff --git a/src/vivarium/types.py b/src/vivarium/types.py index 0de4ea498..f625c5bb9 100644 --- a/src/vivarium/types.py +++ b/src/vivarium/types.py @@ -10,13 +10,14 @@ NumericArray = npt.NDArray[np.number[npt.NBitBase]] # todo need to use TypeVars here -Time = Union[pd.Timestamp, datetime] -Timedelta = Union[pd.Timedelta, timedelta] -ClockTime = Union[Time, int] -ClockStepSize = Union[Timedelta, int] +Time = pd.Timestamp | datetime +Timedelta = pd.Timedelta | timedelta +ClockTime = Time | int +ClockStepSize = Timedelta | int + +ScalarValue = Number | Timedelta | Time +LookupTableData = ScalarValue | pd.DataFrame | list[ScalarValue] | tuple[ScalarValue] -ScalarValue = Union[Number, Timedelta, Time] -LookupTableData = Union[ScalarValue, pd.DataFrame, list[ScalarValue], tuple[ScalarValue]] # TODO: For some of the uses of NumberLike, we probably want a TypeVar here instead. NumberLike = Union[ NumericArray, diff --git a/tests/conftest.py b/tests/conftest.py index 7100cacb1..d6a0e3c4d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path import pytest diff --git a/tests/framework/artifact/test_manager.py b/tests/framework/artifact/test_manager.py index e58eff90d..3e7f5c29c 100644 --- a/tests/framework/artifact/test_manager.py +++ b/tests/framework/artifact/test_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import random from pathlib import Path @@ -98,7 +100,7 @@ def test_parse_artifact_path_relative_no_source(base_config): def test_parse_artifact_path_relative(base_config, test_data_dir): base_config.update( {"input_data": {"artifact_path": "../../test_data/artifact.hdf"}}, - **metadata(__file__) + **metadata(__file__), ) assert parse_artifact_path_config(base_config) == str(test_data_dir / "artifact.hdf") diff --git a/tests/framework/components/mocks.py b/tests/framework/components/mocks.py index 39459635b..3883f5677 100644 --- a/tests/framework/components/mocks.py +++ b/tests/framework/components/mocks.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from vivarium import Component from vivarium.framework.engine import Builder @@ -66,7 +66,7 @@ def name(self) -> str: return self._name @property - def configuration_defaults(self) -> Dict[str, Any]: + def configuration_defaults(self) -> dict[str, Any]: return {self.name: self.CONFIGURATION_DEFAULTS["component"]} def __init__(self, name: str): diff --git a/tests/framework/components/test_component.py b/tests/framework/components/test_component.py index 6a9887863..9d340dd92 100644 --- a/tests/framework/components/test_component.py +++ b/tests/framework/components/test_component.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pandas as pd import pytest from layered_config_tree.exceptions import ConfigurationError diff --git a/tests/framework/components/test_manager.py b/tests/framework/components/test_manager.py index e26e1039e..dacdb1964 100644 --- a/tests/framework/components/test_manager.py +++ b/tests/framework/components/test_manager.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any import pytest @@ -223,7 +223,7 @@ def test_apply_configuration_defaults_duplicate(): def test_apply_configuration_defaults_bad_structure(): class BadConfigComponent(MockComponentA): @property - def configuration_defaults(self) -> Dict[str, Any]: + def configuration_defaults(self) -> dict[str, Any]: return {"test_component": "val"} config = build_simulation_configuration() diff --git a/tests/framework/population/test_manager.py b/tests/framework/population/test_manager.py index 5c6eed123..fc3dc4a4f 100644 --- a/tests/framework/population/test_manager.py +++ b/tests/framework/population/test_manager.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import pytest from vivarium import Component diff --git a/tests/framework/population/test_population_view.py b/tests/framework/population/test_population_view.py index 043dfa89c..aafdfd85e 100644 --- a/tests/framework/population/test_population_view.py +++ b/tests/framework/population/test_population_view.py @@ -1,7 +1,6 @@ import itertools import math import random -from typing import Union import pandas as pd import pytest @@ -71,7 +70,7 @@ def update_index(request) -> pd.Index: BASE_POPULATION[COL_NAMES[0]].copy(), ] ) -def population_update(request, update_index) -> Union[pd.Series, pd.DataFrame]: +def population_update(request, update_index) -> pd.Series | pd.DataFrame: return request.param.loc[update_index] diff --git a/tests/framework/randomness/test_crn.py b/tests/framework/randomness/test_crn.py index cb9dda228..f94470b3d 100644 --- a/tests/framework/randomness/test_crn.py +++ b/tests/framework/randomness/test_crn.py @@ -3,8 +3,9 @@ """ +from collections.abc import Iterator from itertools import cycle -from typing import Iterator, List, Type +from typing import Type import numpy as np import pandas as pd @@ -66,7 +67,7 @@ def name(self): return "population" @property - def columns_created(self) -> List[str]: + def columns_created(self) -> list[str]: return ["crn_attr1", "crn_attr2", "other_attr1"] def __init__(self, with_crn: bool, sims_to_add: Iterator = cycle([0])): diff --git a/tests/framework/randomness/test_stream.py b/tests/framework/randomness/test_stream.py index 2633691fd..f4c5b4b4b 100644 --- a/tests/framework/randomness/test_stream.py +++ b/tests/framework/randomness/test_stream.py @@ -1,5 +1,3 @@ -from typing import Dict - import numpy as np import pandas as pd import pytest @@ -157,7 +155,7 @@ def test_sample_from_distribution_bad_args( ], ) def test_sample_from_distribution_using_scipy( - index: pd.Index, distribution: stats.rv_continuous, params: Dict + index: pd.Index, distribution: stats.rv_continuous, params: dict ): randomness_stream = RandomnessStream( "test", lambda: pd.Timestamp(2020, 1, 1), 1, IndexMap() diff --git a/tests/framework/resource/test_resource_group.py b/tests/framework/resource/test_resource_group.py index e7f867903..3f7975aff 100644 --- a/tests/framework/resource/test_resource_group.py +++ b/tests/framework/resource/test_resource_group.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from datetime import datetime import pytest diff --git a/tests/framework/results/helpers.py b/tests/framework/results/helpers.py index 7b99d3b5c..cf7809b5e 100644 --- a/tests/framework/results/helpers.py +++ b/tests/framework/results/helpers.py @@ -1,5 +1,4 @@ import itertools -from typing import List import numpy as np import pandas as pd @@ -55,7 +54,7 @@ class Hogwarts(Component): @property - def columns_created(self) -> List[str]: + def columns_created(self) -> list[str]: return [ "student_id", "student_house", diff --git a/tests/framework/test_engine.py b/tests/framework/test_engine.py index 5146f4e99..dd888ac11 100644 --- a/tests/framework/test_engine.py +++ b/tests/framework/test_engine.py @@ -2,7 +2,6 @@ from itertools import product from pathlib import Path from time import time -from typing import Dict, List import dill import pandas as pd @@ -66,7 +65,7 @@ def log(mocker): return mocker.patch("vivarium.framework.logging.manager.logger") -def test_simulation_with_non_components(SimulationContext, components: List[Component]): +def test_simulation_with_non_components(SimulationContext, components: list[Component]): class NonComponent: def __init__(self): self.name = "non_component" @@ -447,7 +446,7 @@ def test_get_results_formatting(SimulationContext, base_config): #################### # HELPER FUNCTIONS # #################### -def _convert_to_datetime(date_dict: Dict[str, int]) -> pd.Timestamp: +def _convert_to_datetime(date_dict: dict[str, int]) -> pd.Timestamp: return pd.to_datetime( "-".join([str(val) for val in date_dict.values()]), format="%Y-%m-%d" ) diff --git a/tests/framework/test_time.py b/tests/framework/test_time.py index 6b6286e2c..149a823b9 100644 --- a/tests/framework/test_time.py +++ b/tests/framework/test_time.py @@ -1,5 +1,4 @@ import math -from typing import List import numpy as np import pandas as pd @@ -171,7 +170,7 @@ class StepModifierWithUntracking(StepModifierWithRatePipeline): """Add an event step that untracks/tracks even simulants every timestep""" @property - def columns_required(self) -> List[str]: + def columns_required(self) -> list[str]: return ["tracked"] def on_time_step(self, event: Event) -> None: diff --git a/tests/helpers.py b/tests/helpers.py index f9c0a0e25..d4e00c37d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,7 +1,7 @@ # mypy: ignore-errors from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Any import pandas as pd @@ -84,7 +84,7 @@ def name(self) -> str: return self._name @property - def configuration_defaults(self) -> Dict[str, Any]: + def configuration_defaults(self) -> dict[str, Any]: return {self.name: self.CONFIGURATION_DEFAULTS["component"]} def __init__(self, name: str): @@ -142,7 +142,7 @@ def on_simulation_end(self, event: Event) -> None: class ColumnCreator(Component): @property - def columns_created(self) -> List[str]: + def columns_created(self) -> list[str]: return ["test_column_1", "test_column_2", "test_column_3"] def setup(self, builder: Builder) -> None: @@ -193,7 +193,7 @@ class SingleLookupCreator(ColumnCreator): class OrderedColumnsLookupCreator(Component): @property - def columns_created(self) -> List[str]: + def columns_created(self) -> list[str]: return ["foo", "bar"] def on_initialize_simulants(self, pop_data: SimulantData) -> None: @@ -232,17 +232,17 @@ def build_all_lookup_tables(self, builder: "Builder") -> None: class ColumnRequirer(Component): @property - def columns_required(self) -> List[str]: + def columns_required(self) -> list[str]: return ["test_column_1", "test_column_2"] class ColumnCreatorAndRequirer(Component): @property - def columns_required(self) -> List[str]: + def columns_required(self) -> list[str]: return ["test_column_1", "test_column_2"] @property - def columns_created(self) -> List[str]: + def columns_created(self) -> list[str]: return ["test_column_4"] @property @@ -260,13 +260,13 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: class AllColumnsRequirer(Component): @property - def columns_required(self) -> List[str]: + def columns_required(self) -> list[str]: return [] class FilteredPopulationView(ColumnRequirer): @property - def population_view_query(self) -> Optional[str]: + def population_view_query(self) -> str | None: return "test_column_1 == 5"