Skip to content

Commit

Permalink
Update lbwsg risk effect test
Browse files Browse the repository at this point in the history
  • Loading branch information
albrja committed Jan 10, 2025
1 parent bd1dbe8 commit 82cb274
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.2.0 - 1/10/25**

- Bugfix: Fix bug in LBWSGRiskEffect where relative risk pipeline was not properly created

**3.1.4 - 11/22/24**

- Feature: Enable initializing a population of all newborns
Expand Down
4 changes: 2 additions & 2 deletions src/vivarium_public_health/risks/effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def setup(self, builder: Builder) -> None:
self.exposure = self.get_risk_exposure(builder)

self._relative_risk_source = self.get_relative_risk_source(builder)
self.relative_risk = self.get_relative_risk(builder)
self.relative_risk = self.get_relative_risk_pipeline(builder)

self.register_target_modifier(builder)
self.register_paf_modifier(builder)
Expand Down Expand Up @@ -297,7 +297,7 @@ def generate_relative_risk(index: pd.Index) -> pd.Series:

return generate_relative_risk

def get_relative_risk(self, builder: Builder) -> Pipeline:
def get_relative_risk_pipeline(self, builder: Builder) -> Pipeline:
return builder.value.register_value_producer(
f"{self.risk.name}_on_{self.target.name}.relative_risk",
self._relative_risk_source,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,18 +361,11 @@ def get_population_attributable_fraction_source(
paf_data = builder.data.load(paf_key)
return paf_data, builder.data.value_columns()(paf_key)

def get_target_modifier(
self, builder: Builder
) -> Callable[[pd.Index, pd.Series], pd.Series]:
def adjust_target(index: pd.Index, target: pd.Series) -> pd.Series:
return target * self.relative_risk(index)

return adjust_target

def register_target_modifier(self, builder: Builder) -> None:
builder.value.register_value_modifier(
self.target_pipeline_name,
modifier=self.target_modifier,
modifier=self.adjust_target,
component=self,
requires_values=[self.relative_risk_pipeline_name],
)

Expand All @@ -392,11 +385,12 @@ def get_age_intervals(self, builder: Builder) -> dict[str, pd.Interval]:
for age_start in exposed_age_group_starts
}

def get_relative_risk(self, builder: Builder) -> Pipeline:
def get_relative_risk_pipeline(self, builder: Builder) -> Pipeline:
return builder.value.register_value_producer(
self.relative_risk_pipeline_name,
source=self.get_relative_risk_source,
requires_columns=["age"] + self.rr_column_names,
source=self._relative_risk_source,
component=self,
required_resources=["age"] + self.rr_column_names,
)

def get_interpolator(self, builder: Builder) -> pd.Series:
Expand Down Expand Up @@ -469,7 +463,7 @@ def get_relative_risk_for_age_group(age_group: str) -> pd.Series:
# Pipeline sources and modifiers #
##################################

def get_relative_risk_source(self, index: pd.Index) -> pd.Series:
def _get_relative_risk(self, index: pd.Index) -> pd.Series:
pop = self.population_view.get(index)
relative_risk = pd.Series(1.0, index=index, name=self.relative_risk_pipeline_name)

Expand All @@ -479,3 +473,6 @@ def get_relative_risk_source(self, index: pd.Index) -> pd.Series:
age_group_mask, self.relative_risk_column_name(age_group)
]
return relative_risk

def get_relative_risk_source(self, builder: Builder) -> Callable[[pd.Index], pd.Series]:
return self._get_relative_risk
51 changes: 51 additions & 0 deletions tests/data/rr_interpolator.csv

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions tests/risks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,11 @@ def coverage_gap():
cg_data["affected_risk_factors"] = ["test_risk"]
cg_data["distribution"] = "dichotomous"
return Risk(f"coverage_gap.{cg}"), cg_data


@pytest.fixture
def mock_rr_interpolators() -> pd.DataFrame:
rr_interpolators = pd.read_csv("tests/data/rr_interpolators.csv")
idx_cols = [col for col in rr_interpolators.columns if "draw" not in col]
rr_interpolators = rr_interpolators.rename(columns={"draw_0": "value"})
return rr_interpolators
106 changes: 106 additions & 0 deletions tests/risks/test_low_birth_weight_and_short_gestation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import numpy as np
import pandas as pd
import pytest

from tests.risks.test_effect import _setup_risk_effect_simulation
from tests.test_utilities import make_age_bins
from vivarium_public_health.risks.implementations.low_birth_weight_and_short_gestation import (
LBWSGDistribution,
LBWSGRisk,
LBWSGRiskEffect,
)
from vivarium_public_health.utilities import to_snake_case


@pytest.mark.parametrize(
Expand Down Expand Up @@ -32,3 +39,102 @@ def test_parsing_lbwsg_descriptions(description, expected_weight_values, expecte
assert weight_interval.right == expected_weight_values[1]
assert age_interval.left == expected_age_values[0]
assert age_interval.right == expected_age_values[1]


def test_lbwsg_risk_effect_rr_pipeline(
base_config, base_plugins, mocker, mock_rr_interpolators
):

risk = LBWSGRisk()
lbwsg_effect = LBWSGRiskEffect("cause.test_cause.cause_specific_mortality_rate")

# Add mock data to artifact
categories = {
"cat81": "Neonatal preterm and LBWSG (estimation years) - [28, 30) wks, [2500, 3000) g",
"cat82": "Neonatal preterm and LBWSG (estimation years) - [28, 30) wks, [3000, 3500) g",
}
# Create exposure with matching demograph index as age_bins
age_bins = make_age_bins()
agees = age_bins.drop(columns="age_group_name")
exposure_data = make_categorical_exposure_data(agees)

# Add data dict to add to artifact
data = {
f"{risk.name}.exposure": exposure_data,
f"{risk.name}.population_attributable_fraction": 0,
f"{risk.name}.categories": categories,
f"{risk.name}.relative_risk_interpolator": mock_rr_interpolators,
}

# Only have neontal age groups
age_start = 0.0
age_end = 28 / 365.0
base_config.update(
{
"population": {
"initialization_age_start": age_start,
"initialization_age_max": age_end,
}
}
)
sim = _setup_risk_effect_simulation(base_config, base_plugins, risk, lbwsg_effect, data)
pop = sim.get_population()

expected_pipeline_name = (
f"effect_of_{lbwsg_effect.risk.name}_on_{lbwsg_effect.target.name}.relative_risk"
)
assert expected_pipeline_name in sim.list_values()

# Get age group names to lookup rr interpolator later
def map_age_groups(value):
for i, row in age_bins.iterrows():
if row["age_start"] <= value <= row["age_end"]:
return row["age_group_name"]
return None

mapped_age_groups = pop["age"].apply(map_age_groups)
mapped_age_groups = mapped_age_groups.apply(to_snake_case)
sim_data = pop[["sex", "birth_weight_exposure", "gestational_age_exposure"]].copy()
sim_data["age_group_name"] = mapped_age_groups

# Test the 4 different demographic groups
for sex in ["Male", "Female"]:
for age_group_name in ["early_neonatal", "late_neonatal"]:
interpolator = lbwsg_effect.interpolator[sex, age_group_name]
demo_idx = sim_data.index[
(sim_data["sex"] == sex) & (sim_data["age_group_name"] == age_group_name)
]
sub_pop = sim_data.loc[demo_idx]
actual_rr = sim.get_value(expected_pipeline_name)(demo_idx)
sub_pop["expected_rr"] = np.exp(
interpolator(
sub_pop["gestational_age_exposure"],
sub_pop["birth_weight_exposure"],
grid=False,
)
)
assert (actual_rr == sub_pop["expected_rr"]).all()


def make_categorical_exposure_data(data: pd.DataFrame) -> pd.DataFrame:
# Takes age gropus and adds sex, years, categories, and values
exposure_dfs = []
for year in range(1990, 2017):
tmp = data.copy()
tmp["year_start"] = year
tmp["year_end"] = year + 1
p_81 = tmp.copy()
p_81["parameter"] = "cat81"
p_81["value"] = 0.75
p_82 = tmp.copy()
p_82["parameter"] = "cat82"
p_82["value"] = 0.25
categories_df = pd.concat([p_81, p_82])
male_tmp = categories_df.copy()
male_tmp["sex"] = "Male"
female_tmp = categories_df.copy()
female_tmp["sex"] = "Female"
age_sex_df = pd.concat([male_tmp, female_tmp])
exposure_dfs.append(age_sex_df)

return pd.concat(exposure_dfs)

0 comments on commit 82cb274

Please sign in to comment.