-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SFTs for the values in the AgeBasedHINT matrix.
- Loading branch information
1 parent
d88931c
commit 914c408
Showing
2 changed files
with
200 additions
and
0 deletions.
There are no files selected for viewing
135 changes: 135 additions & 0 deletions
135
tests/sft_tests/hint_tests/test_age_base_hint_values/ep4_dir/dtk_post_process.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
#!/usr/bin/python | ||
import os.path | ||
import json | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
|
||
from idm_test.dtk_test.sft_class import arg_parser, SFT | ||
|
||
|
||
class AgeBaseHINTValueTest(SFT): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
def get_age_base_hint_matrix(self): | ||
""" | ||
load the demographics json and get the age base HINT related data | ||
""" | ||
with open('Assets/demographics.json', 'r') as f: | ||
demog = json.load(f) | ||
age_bin_edges_in_years = demog['Defaults']['IndividualProperties'][0]['Age_Bin_Edges_In_Years'] | ||
transmission_matrix = demog['Defaults']['IndividualProperties'][0]['TransmissionMatrix'] | ||
contact = transmission_matrix['contact']['Matrix'] | ||
environmental = transmission_matrix['environmental']['Matrix'] | ||
return age_bin_edges_in_years, contact, environmental | ||
|
||
# overwrite the test method | ||
def test(self): | ||
self.success = True | ||
|
||
# Load Age Base HINT parameters from Demographics file | ||
age_bin_edges_in_years, contact, environmental = self.get_age_base_hint_matrix() | ||
|
||
# Replace the last age bin edge -1 with a large year of age. It will be used to generate the IntervalIndex for | ||
# dataframe later. | ||
age_bin_edges_in_years[-1] = 125 | ||
|
||
with open(self.report_name, "w") as outfile: | ||
if contact != environmental: | ||
self.success = False | ||
outfile.write( | ||
f"BAD: Expected the same transmission matix for both contact and environmental routes. Found:" | ||
f" contact = {contact}, environmental = {environmental}.\n") | ||
else: | ||
transmission_matrix = contact | ||
for transmission_rate in transmission_matrix: | ||
for idx, x in enumerate(transmission_matrix[0]): | ||
if transmission_rate[idx] != x: | ||
self.success = False | ||
outfile.write( | ||
f"BAD: Expected the same transmission rate for the same age bin. please check the " | ||
f"TransmissionMatrix.\n") | ||
break | ||
if self.success: | ||
with open(os.path.join(self.output_folder, "ReportEventRecorder.csv"), 'r') as infile: | ||
df = pd.read_csv(infile) | ||
df.columns = df.columns.to_series().apply(lambda x: x.strip()) | ||
# Filter the dataframe for NewInfectionEvent | ||
new_infected_df = df.loc[(df['Event_Name'] == 'NewInfectionEvent')] | ||
# Add a new 'Age_Year' column to the DataFrame by converting 'Age' to years. | ||
new_infected_df['Age_Year'] = new_infected_df['Age'] / 365.0 | ||
# Add a new 'Age_Bin' column to the DataFrame based on age bin edges. | ||
bins = pd.IntervalIndex.from_breaks(age_bin_edges_in_years) | ||
new_infected_df['Age_Bin'] = pd.cut(new_infected_df['Age_Year'], bins=bins, right=False).astype(str) | ||
|
||
# Save dataframes to csv for debugging | ||
new_infected_df.to_csv("ReportEventRecorder_AgeBin.csv") | ||
|
||
# ignore the first timestep where we have the outbreak | ||
new_infected_df = new_infected_df.loc[(new_infected_df['Year'] != 2005)] | ||
|
||
# Count NewInfectionEvent by Age_Bin | ||
count_dict = {age_bin: count for age_bin, count in new_infected_df.groupby('Age_Bin').size().items()} | ||
|
||
# Plot the dictionary into histogram plot for debugging | ||
if len(count_dict): | ||
self.plot_bar(count_dict) | ||
outfile.write(f'{count_dict}\n') | ||
|
||
for idx, rate in enumerate(transmission_matrix[0]): | ||
left = age_bin_edges_in_years[idx] | ||
right = age_bin_edges_in_years[idx + 1] | ||
# age_bin_interval = str(pd.Interval(left=left, right=right, closed='right')) | ||
# workaround for old Pandas library | ||
age_bin_interval = f'({left:.1f}, {right:.1f}]' | ||
if rate == 0: | ||
if age_bin_interval in count_dict: | ||
self.success = False | ||
outfile.write( | ||
f"BAD: Expected no transmission in age bin {age_bin_interval} since " | ||
f"transmission rate is 0.\n") | ||
else: | ||
outfile.write( | ||
f"GOOD: There is no transmission in age bin {age_bin_interval} since " | ||
f"transmission rate is 0.\n") | ||
else: | ||
if age_bin_interval in count_dict and count_dict[age_bin_interval] != 0: | ||
outfile.write( | ||
f"GOOD: There are {count_dict[age_bin_interval]} transmission in age bin " | ||
f"{age_bin_interval} since transmission rate is {rate}.\n") | ||
else: | ||
self.success = False | ||
outfile.write( | ||
f"BAD: There is no transmission in age bin {age_bin_interval} but " | ||
f"transmission rate is {rate} != 0.\n") | ||
return self.success | ||
|
||
def plot_bar(self, count_dict): | ||
# Extract keys and values from the dictionary | ||
keys = list(count_dict.keys()) | ||
values = list(count_dict.values()) | ||
# Create a histogram | ||
plt.bar(keys, values) | ||
# Set labels for the axes | ||
plt.xlabel('Age bin') | ||
plt.ylabel('New Infections') | ||
# Set the title of the graph | ||
plt.title('Histogram of New Infection Values') | ||
plt.savefig('new_infection_count_by_age_bin.png') | ||
plt.close() | ||
|
||
|
||
def application(output_folder="output", my_arg=None): | ||
if not my_arg: | ||
my_sft = AgeBaseHINTValueTest(stdout='stdout.txt') | ||
else: | ||
my_sft = AgeBaseHINTValueTest( | ||
output=my_arg.output, stdout='stdout.txt', json_report=my_arg.json_report, event_csv=my_arg.event_csv, | ||
config=my_arg.config, campaign=my_arg.campaign, report_name=my_arg.report_name, debug=my_arg.debug) | ||
my_sft.run() | ||
|
||
|
||
if __name__ == "__main__": | ||
# execute only if run as a script | ||
my_arg = arg_parser() | ||
application(my_arg=my_arg) |
65 changes: 65 additions & 0 deletions
65
tests/sft_tests/hint_tests/test_age_base_hint_values/test_age_base_hint_values.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os | ||
import sys | ||
from functools import partial | ||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
sys.path.append('../') | ||
from hint_test_helper import TestAgeBaseHint, build_camp, build_demog, set_param_fn, Age_Bin_Edges_In_Years, \ | ||
num_age_bin, build_demog_target_all_age_bin | ||
|
||
|
||
def build_demog_age_bins(): | ||
""" | ||
Build a demographics input file for the DTK using emod_api with AgeDependentTransmission matrix. The transmission | ||
matrix will have a constant value for one group and 0 for the other groups | ||
""" | ||
demog = build_demog() | ||
TransmissionMatrix = [[0] * num_age_bin for _ in range(num_age_bin)] | ||
|
||
group_indexes = [0, 1, 2, 3] | ||
values = [2, 1, 0.5, 0.1] | ||
|
||
for row in TransmissionMatrix: | ||
for group_index in group_indexes: | ||
row[group_index] = values[group_index] | ||
|
||
demog.AddAgeDependentTransmission( | ||
Age_Bin_Edges_In_Years=Age_Bin_Edges_In_Years.copy(), | ||
TransmissionMatrix=TransmissionMatrix | ||
) | ||
|
||
return demog | ||
|
||
|
||
class TestAgeBaseHintValues(TestAgeBaseHint): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.test_name = os.path.basename(__file__) | ||
|
||
def test_age_base_hint_values(self): | ||
self.age_base_hint_test(build_demog_age_bins) | ||
|
||
def test_age_base_hint_values_baseline(self): | ||
# This is a baseline simulation that is run without the AgeBadeHINT feature. | ||
self.age_base_hint_test(build_demog) | ||
|
||
def test_age_base_hint_values_all_0s(self): | ||
build_demog_age_bins_all_0s = partial(build_demog_target_all_age_bin, value=0) | ||
self.age_base_hint_test(build_demog_age_bins_all_0s) | ||
|
||
def test_age_base_hint_values_all_1s(self): | ||
build_demog_age_bins_all_1s = partial(build_demog_target_all_age_bin, value=1) | ||
self.age_base_hint_test(build_demog_age_bins_all_1s) | ||
|
||
def test_age_base_hint_values_all_2s(self): | ||
build_demog_age_bins_all_2s = partial(build_demog_target_all_age_bin, value=2) | ||
self.age_base_hint_test(build_demog_age_bins_all_2s) | ||
|
||
def test_age_base_hint_values_all_halves(self): | ||
build_demog_age_bins_all_halves = partial(build_demog_target_all_age_bin, value=0.5) | ||
self.age_base_hint_test(build_demog_age_bins_all_halves) | ||
|
||
|
||
if __name__ == '__main__': | ||
import unittest | ||
unittest.main() |