diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 27708729..9ab987db 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,7 @@ **3.2.1 - 11/13/24** - Fix mypy errors in vivarium/framework/results/context.py + - Fix mypy errors in vivarium/framework/time.py - Modernize type hinting - Remove unnecessary "from future import annotation" imports diff --git a/pyproject.toml b/pyproject.toml index ccc81506..108d524a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,6 @@ exclude = [ 'src/vivarium/framework/results/manager.py', 'src/vivarium/framework/results/observer.py', 'src/vivarium/framework/state_machine.py', - 'src/vivarium/framework/time.py', 'src/vivarium/interface/cli.py', 'src/vivarium/interface/interactive.py', 'src/vivarium/testing_utilities.py', diff --git a/src/vivarium/framework/time.py b/src/vivarium/framework/time.py index bb1e621a..f24c95d5 100644 --- a/src/vivarium/framework/time.py +++ b/src/vivarium/framework/time.py @@ -1,4 +1,3 @@ -# mypy: ignore-errors """ ==================== The Simulation Clock @@ -12,21 +11,24 @@ """ +from __future__ import annotations + import math from collections.abc import Callable from functools import partial -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd -from vivarium.types import ClockStepSize, ClockTime, NumberLike +from vivarium.types import ClockStepSize, ClockTime if TYPE_CHECKING: from vivarium.framework.engine import Builder from vivarium.framework.population.population_view import PopulationView from vivarium.framework.event import Event from vivarium.framework.population import SimulantData + from vivarium.framework.values import ValuesManager from vivarium.framework.values import list_combiner from vivarium.manager import Interface, Manager @@ -36,7 +38,7 @@ class SimulationClock(Manager): """A base clock that includes global clock and a pandas series of clocks for each simulant""" @property - def name(self): + def name(self) -> str: return "simulation_clock" @property @@ -89,24 +91,27 @@ def step_size(self) -> ClockStepSize: @property def event_time(self) -> ClockTime: "Convenience method for event time, or clock + step" - return self.time + self.step_size + return self.time + self.step_size # type: ignore [operator] @property def time_steps_remaining(self) -> int: - return math.ceil((self.stop_time - self.time) / self.step_size) - - def __init__(self): - self._clock_time: ClockTime = None - self._stop_time: ClockTime = None - self._minimum_step_size: ClockStepSize = None - self._standard_step_size: ClockStepSize = None - self._clock_step_size: ClockStepSize = None - self._individual_clocks: PopulationView = None + number_steps_remaining = (self.stop_time - self.time) / self.step_size # type: ignore [operator] + if not isinstance(number_steps_remaining, (float, int)): + raise ValueError("Invalid type for number of steps remaining") + return math.ceil(number_steps_remaining) + + def __init__(self) -> None: + self._clock_time: ClockTime | None = None + self._stop_time: ClockTime | None = None + self._minimum_step_size: ClockStepSize | None = None + self._standard_step_size: ClockStepSize | None = None + self._clock_step_size: ClockStepSize | None = None + self._individual_clocks: PopulationView | None = None self._pipeline_name = "simulant_step_size" # TODO: Delegate this functionality to "tracked" or similar when appropriate self._simulants_to_snooze = pd.Index([]) - def setup(self, builder: "Builder"): + def setup(self, builder: "Builder") -> None: self._step_size_pipeline = builder.value.register_value_producer( self._pipeline_name, source=lambda idx: [pd.Series(np.nan, index=idx).astype("timedelta64[ns]")], @@ -114,7 +119,9 @@ def setup(self, builder: "Builder"): preferred_post_processor=self.step_size_post_processor, ) self.register_step_modifier = partial( - builder.value.register_value_modifier, self._pipeline_name + builder.value.register_value_modifier, + self._pipeline_name, + component=self, ) builder.population.initializes_simulants(self, creates_columns=self.columns_created) builder.event.register_listener("post_setup", self.on_post_setup) @@ -140,7 +147,7 @@ def on_initialize_simulants(self, pop_data: "SimulantData") -> None: ) self._individual_clocks.update(clocks_to_initialize) - def simulant_next_event_times(self, index: pd.Index) -> pd.Series: + def simulant_next_event_times(self, index: pd.Index[int]) -> pd.Series[ClockTime]: """The next time each simulant will be updated.""" if not self._individual_clocks: return pd.Series(self.event_time, index=index) @@ -148,7 +155,7 @@ def simulant_next_event_times(self, index: pd.Index) -> pd.Series: "next_event_time" ] - def simulant_step_sizes(self, index: pd.Index) -> pd.Series: + def simulant_step_sizes(self, index: pd.Index[int]) -> pd.Series[ClockStepSize]: """The step size for each simulant.""" if not self._individual_clocks: return pd.Series(self.step_size, index=index) @@ -158,12 +165,14 @@ def simulant_step_sizes(self, index: pd.Index) -> pd.Series: def step_backward(self) -> None: """Rewinds the clock by the current step size.""" - self._clock_time -= self.step_size + if self._clock_time is None: + raise ValueError("No start time provided") + self._clock_time -= self.step_size # type: ignore [operator] - def step_forward(self, index: pd.Index) -> None: + def step_forward(self, index: pd.Index[int]) -> None: """Advances the clock by the current step size, and updates aligned simulant clocks.""" - self._clock_time += self.step_size - if self._individual_clocks and index.any(): + self._clock_time += self.step_size # type: ignore [assignment, operator] + if self._individual_clocks and not index.empty: update_index = self.get_active_simulants(index, self.time) clocks_to_update = self._individual_clocks.get(update_index) if not clocks_to_update.empty: @@ -171,7 +180,7 @@ def step_forward(self, index: pd.Index) -> None: # Simulants that were flagged to get moved to the end should have a next event time # of stop time + 1 minimum timestep clocks_to_update.loc[self._simulants_to_snooze, "step_size"] = ( - self.stop_time + self.minimum_step_size - self.time + self.stop_time + self.minimum_step_size - self.time # type: ignore [operator] ) # TODO: Delegate this functionality to "tracked" or similar when appropriate self._simulants_to_snooze = pd.Index([]) @@ -179,20 +188,20 @@ def step_forward(self, index: pd.Index) -> None: self.time + clocks_to_update["step_size"] ) self._individual_clocks.update(clocks_to_update) - self._clock_step_size = self.simulant_next_event_times(index).min() - self.time + self._clock_step_size = self.simulant_next_event_times(index).min() - self.time # type: ignore [operator] - def get_active_simulants(self, index: pd.Index, time: ClockTime) -> pd.Index: + def get_active_simulants(self, index: pd.Index[int], time: ClockTime) -> pd.Index[int]: """Gets population that is aligned with global clock""" if index.empty or not self._individual_clocks: return index next_event_times = self.simulant_next_event_times(index) return next_event_times[next_event_times <= time].index - def move_simulants_to_end(self, index: pd.Index) -> None: - if self._individual_clocks and index.any(): + def move_simulants_to_end(self, index: pd.Index[int]) -> None: + if self._individual_clocks and not index.empty: self._simulants_to_snooze = self._simulants_to_snooze.union(index) - def step_size_post_processor(self, values: list[NumberLike], _) -> pd.Series: + def step_size_post_processor(self, value: Any, manager: ValuesManager) -> Any: """Computes the largest feasible step size for each simulant. This is the smallest component-modified step size (rounded down to increments @@ -209,10 +218,10 @@ def step_size_post_processor(self, values: list[NumberLike], _) -> pd.Series: The largest feasible step size for each simulant """ - min_modified = pd.DataFrame(values).min(axis=0).fillna(self.standard_step_size) + min_modified = pd.DataFrame(value).min(axis=0).fillna(self.standard_step_size) ## Rescale pipeline values to global minimum step size discretized_step_sizes = ( - np.floor(min_modified / self.minimum_step_size).replace(0, 1) + np.floor(min_modified / self.minimum_step_size).replace(0, 1) # type: ignore [attr-defined, operator] * self.minimum_step_size ) ## Make sure we don't get zero @@ -232,10 +241,10 @@ class SimpleClock(SimulationClock): } @property - def name(self): + def name(self) -> str: return "simple_clock" - def setup(self, builder): + def setup(self, builder: Builder) -> None: super().setup(builder) time = builder.configuration.time self._clock_time = time.start @@ -246,11 +255,11 @@ def setup(self, builder): ) self._clock_step_size = self._standard_step_size - def __repr__(self): + def __repr__(self) -> str: return "SimpleClock()" -def get_time_stamp(time): +def get_time_stamp(time: dict[str, int]) -> pd.Timestamp: return pd.Timestamp(time["year"], time["month"], time["day"]) @@ -271,10 +280,10 @@ class DateTimeClock(SimulationClock): } @property - def name(self): + def name(self) -> str: return "datetime_clock" - def setup(self, builder): + def setup(self, builder: Builder) -> None: super().setup(builder) time = builder.configuration.time self._clock_time = get_time_stamp(time.start) @@ -291,12 +300,12 @@ def setup(self, builder): ) self._clock_step_size = self._minimum_step_size - def __repr__(self): + def __repr__(self) -> str: return "DateTimeClock()" class TimeInterface(Interface): - def __init__(self, manager: SimulationClock): + def __init__(self, manager: SimulationClock) -> None: self._manager = manager def clock(self) -> Callable[[], ClockTime]: @@ -307,24 +316,24 @@ def step_size(self) -> Callable[[], ClockStepSize]: """Gets a callable that returns the current simulation step size.""" return lambda: self._manager.step_size - def simulant_next_event_times(self) -> Callable[[pd.Index], pd.Series]: + def simulant_next_event_times(self) -> Callable[[pd.Index[int]], pd.Series[ClockTime]]: """Gets a callable that returns the next event times for simulants.""" return self._manager.simulant_next_event_times - def simulant_step_sizes(self) -> Callable[[pd.Index], pd.Series]: + def simulant_step_sizes(self) -> Callable[[pd.Index[int]], pd.Series[ClockStepSize]]: """Gets a callable that returns the simulant step sizes.""" return self._manager.simulant_step_sizes - def move_simulants_to_end(self) -> Callable[[pd.Index], None]: + def move_simulants_to_end(self) -> Callable[[pd.Index[int]], None]: """Gets a callable that moves simulants to the end of the simulation""" return self._manager.move_simulants_to_end def register_step_size_modifier( self, - modifier: Callable[[pd.Index], pd.Series], - requires_columns: list[str] = (), - requires_values: list[str] = (), - requires_streams: list[str] = (), + modifier: Callable[[pd.Index[int]], pd.Series[ClockStepSize]], + requires_columns: list[str] = [], + requires_values: list[str] = [], + requires_streams: list[str] = [], ) -> None: """Registers a step size modifier. @@ -344,5 +353,8 @@ def register_step_size_modifier( A list of the randomness streams that need to be properly sourced before the modifier is called.""" return self._manager.register_step_modifier( - modifier, requires_columns, requires_values, requires_streams + modifier=modifier, + requires_columns=requires_columns, + requires_values=requires_values, + requires_streams=requires_streams, ) diff --git a/src/vivarium/framework/values/manager.py b/src/vivarium/framework/values/manager.py index cb4bf3f1..9ad65a3a 100644 --- a/src/vivarium/framework/values/manager.py +++ b/src/vivarium/framework/values/manager.py @@ -109,7 +109,7 @@ def register_value_modifier( value_name: str, modifier: Callable[..., Any], # TODO [MIC-5452]: all calls should have a component - component: Component | None = None, + component: Component | Manager | None = None, requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), @@ -371,7 +371,7 @@ def register_value_modifier( value_name: str, modifier: Callable[..., Any], # TODO [MIC-5452]: all calls should have a component - component: Component | None = None, + component: Component | Manager | None = None, requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), diff --git a/src/vivarium/framework/values/pipeline.py b/src/vivarium/framework/values/pipeline.py index 6cc81a65..b7afe8eb 100644 --- a/src/vivarium/framework/values/pipeline.py +++ b/src/vivarium/framework/values/pipeline.py @@ -8,6 +8,7 @@ from vivarium import Component from vivarium.framework.resource import Resource from vivarium.framework.values.exceptions import DynamicValueError +from vivarium.manager import Manager if TYPE_CHECKING: from vivarium.framework.values.combiners import ValueCombiner @@ -52,7 +53,7 @@ def __init__( self, pipeline: Pipeline, modifier: Callable[..., Any], - component: Component | None, + component: Component | Manager | None, ) -> None: mutator_name = self._get_modifier_name(modifier) mutator_index = len(pipeline.mutators) + 1 @@ -190,7 +191,7 @@ def __hash__(self) -> int: return hash(self.name) def get_value_modifier( - self, modifier: Callable[..., Any], component: Component | None + self, modifier: Callable[..., Any], component: Component | Manager | None ) -> ValueModifier: """Add a value modifier to the pipeline and return it.