Skip to content

Commit

Permalink
Albrja/mic-5546/mypy framework time (#539)
Browse files Browse the repository at this point in the history
Albrja/mic-5546/mypy framework time

Fix mypy erros in framework/time.py
- *Category*: Other
- *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-5546
  • Loading branch information
albrja authored Nov 14, 2024
1 parent 13c1fbd commit c71111b
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 51 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
**3.2.1 - 11/13/24**

- Fix mypy errors in vivarium/framework/results/context.py
- Fix mypy errors in vivarium/framework/time.py
- Modernize type hinting
- Remove unnecessary "from future import annotation" imports

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ exclude = [
'src/vivarium/framework/results/manager.py',
'src/vivarium/framework/results/observer.py',
'src/vivarium/framework/state_machine.py',
'src/vivarium/framework/time.py',
'src/vivarium/interface/cli.py',
'src/vivarium/interface/interactive.py',
'src/vivarium/testing_utilities.py',
Expand Down
104 changes: 58 additions & 46 deletions src/vivarium/framework/time.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: ignore-errors
"""
====================
The Simulation Clock
Expand All @@ -12,21 +11,24 @@
"""

from __future__ import annotations

import math
from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd

from vivarium.types import ClockStepSize, ClockTime, NumberLike
from vivarium.types import ClockStepSize, ClockTime

if TYPE_CHECKING:
from vivarium.framework.engine import Builder
from vivarium.framework.population.population_view import PopulationView
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData
from vivarium.framework.values import ValuesManager

from vivarium.framework.values import list_combiner
from vivarium.manager import Interface, Manager
Expand All @@ -36,7 +38,7 @@ class SimulationClock(Manager):
"""A base clock that includes global clock and a pandas series of clocks for each simulant"""

@property
def name(self):
def name(self) -> str:
return "simulation_clock"

@property
Expand Down Expand Up @@ -89,32 +91,37 @@ def step_size(self) -> ClockStepSize:
@property
def event_time(self) -> ClockTime:
"Convenience method for event time, or clock + step"
return self.time + self.step_size
return self.time + self.step_size # type: ignore [operator]

@property
def time_steps_remaining(self) -> int:
return math.ceil((self.stop_time - self.time) / self.step_size)

def __init__(self):
self._clock_time: ClockTime = None
self._stop_time: ClockTime = None
self._minimum_step_size: ClockStepSize = None
self._standard_step_size: ClockStepSize = None
self._clock_step_size: ClockStepSize = None
self._individual_clocks: PopulationView = None
number_steps_remaining = (self.stop_time - self.time) / self.step_size # type: ignore [operator]
if not isinstance(number_steps_remaining, (float, int)):
raise ValueError("Invalid type for number of steps remaining")
return math.ceil(number_steps_remaining)

def __init__(self) -> None:
self._clock_time: ClockTime | None = None
self._stop_time: ClockTime | None = None
self._minimum_step_size: ClockStepSize | None = None
self._standard_step_size: ClockStepSize | None = None
self._clock_step_size: ClockStepSize | None = None
self._individual_clocks: PopulationView | None = None
self._pipeline_name = "simulant_step_size"
# TODO: Delegate this functionality to "tracked" or similar when appropriate
self._simulants_to_snooze = pd.Index([])

def setup(self, builder: "Builder"):
def setup(self, builder: "Builder") -> None:
self._step_size_pipeline = builder.value.register_value_producer(
self._pipeline_name,
source=lambda idx: [pd.Series(np.nan, index=idx).astype("timedelta64[ns]")],
preferred_combiner=list_combiner,
preferred_post_processor=self.step_size_post_processor,
)
self.register_step_modifier = partial(
builder.value.register_value_modifier, self._pipeline_name
builder.value.register_value_modifier,
self._pipeline_name,
component=self,
)
builder.population.initializes_simulants(self, creates_columns=self.columns_created)
builder.event.register_listener("post_setup", self.on_post_setup)
Expand All @@ -140,15 +147,15 @@ def on_initialize_simulants(self, pop_data: "SimulantData") -> None:
)
self._individual_clocks.update(clocks_to_initialize)

def simulant_next_event_times(self, index: pd.Index) -> pd.Series:
def simulant_next_event_times(self, index: pd.Index[int]) -> pd.Series[ClockTime]:
"""The next time each simulant will be updated."""
if not self._individual_clocks:
return pd.Series(self.event_time, index=index)
return self._individual_clocks.subview(["next_event_time", "tracked"]).get(index)[
"next_event_time"
]

def simulant_step_sizes(self, index: pd.Index) -> pd.Series:
def simulant_step_sizes(self, index: pd.Index[int]) -> pd.Series[ClockStepSize]:
"""The step size for each simulant."""
if not self._individual_clocks:
return pd.Series(self.step_size, index=index)
Expand All @@ -158,41 +165,43 @@ def simulant_step_sizes(self, index: pd.Index) -> pd.Series:

def step_backward(self) -> None:
"""Rewinds the clock by the current step size."""
self._clock_time -= self.step_size
if self._clock_time is None:
raise ValueError("No start time provided")
self._clock_time -= self.step_size # type: ignore [operator]

def step_forward(self, index: pd.Index) -> None:
def step_forward(self, index: pd.Index[int]) -> None:
"""Advances the clock by the current step size, and updates aligned simulant clocks."""
self._clock_time += self.step_size
if self._individual_clocks and index.any():
self._clock_time += self.step_size # type: ignore [assignment, operator]
if self._individual_clocks and not index.empty:
update_index = self.get_active_simulants(index, self.time)
clocks_to_update = self._individual_clocks.get(update_index)
if not clocks_to_update.empty:
clocks_to_update["step_size"] = self._step_size_pipeline(update_index)
# Simulants that were flagged to get moved to the end should have a next event time
# of stop time + 1 minimum timestep
clocks_to_update.loc[self._simulants_to_snooze, "step_size"] = (
self.stop_time + self.minimum_step_size - self.time
self.stop_time + self.minimum_step_size - self.time # type: ignore [operator]
)
# TODO: Delegate this functionality to "tracked" or similar when appropriate
self._simulants_to_snooze = pd.Index([])
clocks_to_update["next_event_time"] = (
self.time + clocks_to_update["step_size"]
)
self._individual_clocks.update(clocks_to_update)
self._clock_step_size = self.simulant_next_event_times(index).min() - self.time
self._clock_step_size = self.simulant_next_event_times(index).min() - self.time # type: ignore [operator]

def get_active_simulants(self, index: pd.Index, time: ClockTime) -> pd.Index:
def get_active_simulants(self, index: pd.Index[int], time: ClockTime) -> pd.Index[int]:
"""Gets population that is aligned with global clock"""
if index.empty or not self._individual_clocks:
return index
next_event_times = self.simulant_next_event_times(index)
return next_event_times[next_event_times <= time].index

def move_simulants_to_end(self, index: pd.Index) -> None:
if self._individual_clocks and index.any():
def move_simulants_to_end(self, index: pd.Index[int]) -> None:
if self._individual_clocks and not index.empty:
self._simulants_to_snooze = self._simulants_to_snooze.union(index)

def step_size_post_processor(self, values: list[NumberLike], _) -> pd.Series:
def step_size_post_processor(self, value: Any, manager: ValuesManager) -> Any:
"""Computes the largest feasible step size for each simulant.
This is the smallest component-modified step size (rounded down to increments
Expand All @@ -209,10 +218,10 @@ def step_size_post_processor(self, values: list[NumberLike], _) -> pd.Series:
The largest feasible step size for each simulant
"""

min_modified = pd.DataFrame(values).min(axis=0).fillna(self.standard_step_size)
min_modified = pd.DataFrame(value).min(axis=0).fillna(self.standard_step_size)
## Rescale pipeline values to global minimum step size
discretized_step_sizes = (
np.floor(min_modified / self.minimum_step_size).replace(0, 1)
np.floor(min_modified / self.minimum_step_size).replace(0, 1) # type: ignore [attr-defined, operator]
* self.minimum_step_size
)
## Make sure we don't get zero
Expand All @@ -232,10 +241,10 @@ class SimpleClock(SimulationClock):
}

@property
def name(self):
def name(self) -> str:
return "simple_clock"

def setup(self, builder):
def setup(self, builder: Builder) -> None:
super().setup(builder)
time = builder.configuration.time
self._clock_time = time.start
Expand All @@ -246,11 +255,11 @@ def setup(self, builder):
)
self._clock_step_size = self._standard_step_size

def __repr__(self):
def __repr__(self) -> str:
return "SimpleClock()"


def get_time_stamp(time):
def get_time_stamp(time: dict[str, int]) -> pd.Timestamp:
return pd.Timestamp(time["year"], time["month"], time["day"])


Expand All @@ -271,10 +280,10 @@ class DateTimeClock(SimulationClock):
}

@property
def name(self):
def name(self) -> str:
return "datetime_clock"

def setup(self, builder):
def setup(self, builder: Builder) -> None:
super().setup(builder)
time = builder.configuration.time
self._clock_time = get_time_stamp(time.start)
Expand All @@ -291,12 +300,12 @@ def setup(self, builder):
)
self._clock_step_size = self._minimum_step_size

def __repr__(self):
def __repr__(self) -> str:
return "DateTimeClock()"


class TimeInterface(Interface):
def __init__(self, manager: SimulationClock):
def __init__(self, manager: SimulationClock) -> None:
self._manager = manager

def clock(self) -> Callable[[], ClockTime]:
Expand All @@ -307,24 +316,24 @@ def step_size(self) -> Callable[[], ClockStepSize]:
"""Gets a callable that returns the current simulation step size."""
return lambda: self._manager.step_size

def simulant_next_event_times(self) -> Callable[[pd.Index], pd.Series]:
def simulant_next_event_times(self) -> Callable[[pd.Index[int]], pd.Series[ClockTime]]:
"""Gets a callable that returns the next event times for simulants."""
return self._manager.simulant_next_event_times

def simulant_step_sizes(self) -> Callable[[pd.Index], pd.Series]:
def simulant_step_sizes(self) -> Callable[[pd.Index[int]], pd.Series[ClockStepSize]]:
"""Gets a callable that returns the simulant step sizes."""
return self._manager.simulant_step_sizes

def move_simulants_to_end(self) -> Callable[[pd.Index], None]:
def move_simulants_to_end(self) -> Callable[[pd.Index[int]], None]:
"""Gets a callable that moves simulants to the end of the simulation"""
return self._manager.move_simulants_to_end

def register_step_size_modifier(
self,
modifier: Callable[[pd.Index], pd.Series],
requires_columns: list[str] = (),
requires_values: list[str] = (),
requires_streams: list[str] = (),
modifier: Callable[[pd.Index[int]], pd.Series[ClockStepSize]],
requires_columns: list[str] = [],
requires_values: list[str] = [],
requires_streams: list[str] = [],
) -> None:
"""Registers a step size modifier.
Expand All @@ -344,5 +353,8 @@ def register_step_size_modifier(
A list of the randomness streams that need to be properly sourced
before the modifier is called."""
return self._manager.register_step_modifier(
modifier, requires_columns, requires_values, requires_streams
modifier=modifier,
requires_columns=requires_columns,
requires_values=requires_values,
requires_streams=requires_streams,
)
4 changes: 2 additions & 2 deletions src/vivarium/framework/values/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def register_value_modifier(
value_name: str,
modifier: Callable[..., Any],
# TODO [MIC-5452]: all calls should have a component
component: Component | None = None,
component: Component | Manager | None = None,
requires_columns: Iterable[str] = (),
requires_values: Iterable[str] = (),
requires_streams: Iterable[str] = (),
Expand Down Expand Up @@ -371,7 +371,7 @@ def register_value_modifier(
value_name: str,
modifier: Callable[..., Any],
# TODO [MIC-5452]: all calls should have a component
component: Component | None = None,
component: Component | Manager | None = None,
requires_columns: Iterable[str] = (),
requires_values: Iterable[str] = (),
requires_streams: Iterable[str] = (),
Expand Down
5 changes: 3 additions & 2 deletions src/vivarium/framework/values/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vivarium import Component
from vivarium.framework.resource import Resource
from vivarium.framework.values.exceptions import DynamicValueError
from vivarium.manager import Manager

if TYPE_CHECKING:
from vivarium.framework.values.combiners import ValueCombiner
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(
self,
pipeline: Pipeline,
modifier: Callable[..., Any],
component: Component | None,
component: Component | Manager | None,
) -> None:
mutator_name = self._get_modifier_name(modifier)
mutator_index = len(pipeline.mutators) + 1
Expand Down Expand Up @@ -190,7 +191,7 @@ def __hash__(self) -> int:
return hash(self.name)

def get_value_modifier(
self, modifier: Callable[..., Any], component: Component | None
self, modifier: Callable[..., Any], component: Component | Manager | None
) -> ValueModifier:
"""Add a value modifier to the pipeline and return it.
Expand Down

0 comments on commit c71111b

Please sign in to comment.