-
Notifications
You must be signed in to change notification settings - Fork 1
/
prevalence_analyzer.py
66 lines (52 loc) · 2.9 KB
/
prevalence_analyzer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import starsim as ss
import numpy as np
import sciris as sc
class PrevalenceAnalyzer(ss.Analyzer):
""" Generalized analyzer to calculate disease prevalence over time by age group and sex """
def __init__(self, prevalence_data, diseases, *args, **kwargs):
super().__init__(*args, **kwargs)
self.name = 'prevalence_analyzer'
self.prevalence_data = prevalence_data
self.diseases = diseases # List of disease names like ['HIV', 'depression', 'diabetes', ...]
# Initialize age bins for each disease
self.age_bins = {}
self.age_groups = {}
# Iterate over each disease and assign age bins
for disease in self.diseases:
self.age_bins[disease] = list(prevalence_data[disease]['male'].keys())
self.age_bins[disease].sort() # Ensure age bins are sorted
# Create age groups with "inf" for the last bin (80+)
self.age_groups[disease] = list(zip(self.age_bins[disease][:-1], self.age_bins[disease][1:])) + [(self.age_bins[disease][-1], float('inf'))]
self.results = sc.odict()
def init_pre(self, sim):
super().init_pre(sim)
npts = sim.npts # Number of time points in the simulation
# Initialize result arrays for each disease: time x age groups
for disease in self.diseases:
self.results[f'{disease}_prevalence_male'] = np.zeros((npts, len(self.age_groups[disease])))
self.results[f'{disease}_prevalence_female'] = np.zeros((npts, len(self.age_groups[disease])))
print(f"Initialized prevalence array with {npts} time points for {self.diseases}.")
return
def apply(self, sim):
print(f"Applying analyzer at time step {sim.ti}")
ages = sim.people.age
females = sim.people.female
for disease in self.diseases:
disease_obj = getattr(sim.diseases, disease.lower())
if disease == 'HIV':
status_attr = 'infected'
else:
status_attr = 'affected'
for sex, label in zip([0, 1], ['male', 'female']):
prevalence_by_age_group = np.zeros(len(self.age_groups[disease]))
for i, (start, end) in enumerate(self.age_groups[disease]):
if end == float('inf'):
age_mask = (ages >= start) & (females == sex)
else:
age_mask = (ages >= start) & (ages < end) & (females == sex)
status_array = getattr(disease_obj, status_attr)
if np.sum(age_mask) > 0:
prevalence_by_age_group[i] = np.mean(status_array[age_mask])
disease_key = f'{disease}_prevalence_{label}'
# print(f"Storing data for {disease_key} at time {sim.ti}") # Add this to confirm data is stored
self.results[disease_key][sim.ti, :] = prevalence_by_age_group