Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v1.9 #9

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,8 @@ example_scripts/**/*
!example_scripts/*
!example_scripts/**/*.py
configs/tests.yaml
build_pip.sh
.DS_Store
**/*pkl
**/*yaml
.vscode/settings.json
30 changes: 17 additions & 13 deletions grad_june/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ data_path: '@grad_june/test/data/data.pkl'

save_path: ./example
age_bins_to_save: [0, 18, 65, 100]
store_differentiable_deaths: true
store_cases_by_age: true

timer:
total_days: 15
Expand Down Expand Up @@ -43,31 +45,33 @@ timer:
- household

infection_seed:
log_fraction_initial_cases: -1
type: InfectionSeedByFraction
params:
log_fraction: -1.5

networks:
household:
log_beta: -0.4
log_beta: 0.4
company:
log_beta: -0.3
log_beta: 0.3
school:
log_beta: -0.3
log_beta: 0.3
pub:
log_beta: -1.2
log_beta: 0.2
gym:
log_beta: -1.2
log_beta: 0.2
grocery:
log_beta: -1.2
log_beta: 0.2
visit:
log_beta: -1.2
log_beta: 0.2
cinema:
log_beta: -1.2
log_beta: 0.2
university:
log_beta: -0.5
log_beta: 0.5
care_visit:
log_beta: -0.4
log_beta: 0.4
care_home:
log_beta: -0.4
log_beta: 0.4

policies:
interaction:
Expand Down Expand Up @@ -112,7 +116,7 @@ transmission:
scale: 0.03
shift:
dist: Normal
loc: -2.12
loc: 2.12
scale: 0.1

symptoms:
Expand Down
63 changes: 63 additions & 0 deletions grad_june/demographics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from torch_geometric.data import HeteroData

def store_differentiable_deaths(data: HeteroData, dead_idx: int):
"""
Returns differentiable deaths. The results are stored
in data["results"]
"""
symptoms = data["agent"].symptoms
#dead_idx = self.model.symptoms_updater.stages_ids[-1]
deaths = (
(symptoms["current_stage"] == dead_idx)
* symptoms["current_stage"]
/ dead_idx
)
if data["results"]["deaths_per_timestep"] is not None:
data["results"]["deaths_per_timestep"] = torch.hstack(
(data["results"]["deaths_per_timestep"], deaths.sum())
)
else:
data["results"]["deaths_per_timestep"] = deaths.sum()

def get_cases_by_age(data: HeteroData, age_bins: torch.Tensor):
device = age_bins.device
ret = torch.zeros(age_bins.shape[0] - 1, device=device)
for i in range(1, age_bins.shape[0]):
mask1 = data["agent"].age < age_bins[i]
mask2 = data["agent"].age > age_bins[i - 1]
mask = mask1 * mask2
ret[i - 1] = (data["agent"].is_infected * mask).sum()
return ret

def get_people_by_age(ages: torch.Tensor, age_bins: torch.Tensor):
ret = {}
for i in range(1, age_bins.shape[0]):
mask1 = ages < age_bins[i]
mask2 = ages > age_bins[i - 1]
mask = mask1 * mask2
ret[int(age_bins[i].item())] = mask.sum()
return ret

def get_cases_by_ethnicity(data: HeteroData, ethnicities):
device = ethnicities.device
ret = torch.zeros(len(ethnicities), device=device)
for i, ethnicity in enumerate(ethnicities):
mask = torch.tensor(
data["agent"].ethnicity == ethnicity, device=device
)
ret[i] = (mask * data["agent"].is_infected).sum()
return ret

def get_people_per_area(agent_ids: torch.Tensor, area_ids: torch.Tensor):
"""Gets people ids in each area.

**Arguments:**

- `agent_ids`: Ids of all agents.
- `area_ids`: Area ids of all agents.
"""
people_per_area = {}
for area_id in torch.unique(area_ids):
people_per_area[area_id.item()] = agent_ids[area_ids == area_id]
return people_per_area
53 changes: 15 additions & 38 deletions grad_june/infection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import torch
from torch_geometric.data import HeteroData

from grad_june.demographics import get_people_per_area


class IsInfectedSampler(torch.nn.Module):
def forward(self, not_infected_probs):
Expand All @@ -17,47 +21,20 @@ def forward(self, not_infected_probs):
is_infected = 1.0 - infection[0, :]
return is_infected

def infect_people(data: HeteroData, time: int, new_infected: torch.Tensor):
"""
Sets the `new_infected` individuals to infected at time `time`.

**Arguments:**

def infect_people(data, timer, new_infected):
- `data`: the graph data
- `time`: the time step at which the infection happens
- `new_infected`: a tensor of size [N] where N is the number of agents.
"""
data["agent"].susceptibility = torch.clamp(
data["agent"].susceptibility - new_infected, min=0.0
)
data["agent"].is_infected = data["agent"].is_infected + new_infected
data["agent"].infection_time = data["agent"].infection_time + new_infected * (
timer.now - data["agent"].infection_time
)


def infect_fraction_of_people(
data, timer, symptoms_updater, fraction, device
):
n_infections = data["agent"].susceptibility.shape[0]
n_agents = data["agent"].id.shape[0]
probs = fraction * torch.ones(n_agents, device=device)
sampler = IsInfectedSampler()
new_infected = sampler(
1.0 - probs
) # sampler takes not inf probs
infect_people(data, timer, new_infected)
return new_infected


def infect_people_at_indices(data, indices, device="cpu"):
susc = data["agent"]["susceptibility"].cpu().numpy()
is_inf = data["agent"]["is_infected"].cpu().numpy()
inf_t = data["agent"]["infection_time"].cpu().numpy()
next_stage = data["agent"]["symptoms"]["next_stage"].cpu().numpy()
current_stage = data["agent"]["symptoms"]["current_stage"].cpu().numpy()
susc[indices] = 0.0
is_inf[indices] = 1.0
inf_t[indices] = 0.0
next_stage[indices] = 2
current_stage[indices] = 1
data["agent"]["susceptibility"] = torch.tensor(susc, device=device)
data["agent"]["is_infected"] = torch.tensor(is_inf, device=device)
data["agent"]["infection_time"] = torch.tensor(inf_t, device=device)
data["agent"]["symptoms"]["next_stage"] = torch.tensor(next_stage, device=device)
data["agent"]["symptoms"]["current_stage"] = torch.tensor(
current_stage, device=device
)
return data
time - data["agent"].infection_time
)
23 changes: 13 additions & 10 deletions grad_june/infection_networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class InfectionNetwork(MessagePassing):
def __init__(self, log_beta, device="cpu"):
super().__init__( aggr="add", node_dim=-1)
super().__init__(aggr="add", node_dim=-1)
self.device = device
if type(log_beta) != torch.nn.Parameter:
self.log_beta = torch.tensor(float(log_beta))
Expand All @@ -33,16 +33,22 @@ def _get_edge_index(self, data):
def _get_reverse_edge_index(self, data):
return data["rev_attends_" + self.name].edge_index

def _get_beta_factor(self, data):
if not hasattr(data["region"], "beta_factor"):
return 1.0
return data["region"].beta_factor[data[self.name].region]

def _get_beta(self, policies, timer, data):
beta_factor = self._get_beta_factor(data)
interaction_policies = policies.interaction_policies
beta = 10.0**self.log_beta
if interaction_policies:
beta = interaction_policies.apply(beta=beta, name=self.name, timer=timer)
beta = beta * torch.ones(len(data[self.name]["id"]), device=self.device)
return beta
beta = beta * beta_factor
return beta

def _get_people_per_group(self, data):
return data[self.name]["people"]
def _get_people_per_group(self, data, timer):
return data[self.name].people

def _get_transmissions(self, data, policies, timer):
if policies.quarantine_policies:
Expand All @@ -60,7 +66,7 @@ def _get_susceptibilities(self, data, policies, timer):

def forward(self, data, timer, policies):
beta = self._get_beta(policies=policies, timer=timer, data=data)
people_per_group = self._get_people_per_group(data)
people_per_group = self._get_people_per_group(data, timer)
p_contact = torch.maximum(
torch.minimum(
1.0 / (people_per_group - 1), torch.tensor(1.0, device=self.device)
Expand Down Expand Up @@ -134,7 +140,7 @@ def forward(
network = self.networks[activity]
trans_susc += network(data=data, timer=timer, policies=policies)
trans_susc = torch.clamp(
trans_susc, min=1e-6, max = 100
trans_susc, min=1e-6, max=100
) # this is necessary to avoid gradient nans
not_infected_probs = torch.exp(-trans_susc * delta_time)
not_infected_probs = torch.clamp(not_infected_probs, min=0.0, max=1.0)
Expand All @@ -148,8 +154,6 @@ def _get_transmissions(self, data, policies, timer):
def _get_susceptibilities(self, data, policies, timer):
return data["agent"].susceptibility

pass


class CareHomeNetwork(InfectionNetwork):
pass
Expand All @@ -165,4 +169,3 @@ class CompanyNetwork(InfectionNetwork):

class UniversityNetwork(InfectionNetwork):
pass

22 changes: 19 additions & 3 deletions grad_june/infection_networks/leisure_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,32 @@ def _get_edge_index(self, data):
def _get_reverse_edge_index(self, data):
return data["rev_attends_leisure"].edge_index

def _get_beta_factor(self, data):
if not hasattr(data["region"], "beta_factor"):
return 1.0
return data["region"].beta_factor[data["leisure"].region]

def _get_beta(self, policies, timer, data):
beta_factor = self._get_beta_factor(data)
interaction_policies = policies.interaction_policies
beta = 10.0**self.log_beta
if interaction_policies:
beta = interaction_policies.apply(beta=beta, name=self.name, timer=timer)
beta = beta * torch.ones(len(data["leisure"]["id"]), device=self.device)
beta = beta * beta_factor
return beta

def _get_people_per_group(self, data):
return data["leisure"]["people"]
def _get_people_per_group(self, data, timer):
if self.weekday_probabilities is None:
self.initialize_leisure_probabilities(data)
if timer.day_type == "weekday":
leisure_mask = self.weekday_probabilities
else:
leisure_mask = self.weekend_probabilities
aux = torch.ones(len(data["leisure"]["id"]), device=self.device)
prob_leisure = leisure_mask
edge_index = self._get_edge_index(data)
people_per_group = self.propagate(edge_index, x=prob_leisure, y=aux)
return people_per_group

def _get_transmissions(self, data, policies, timer):
if self.weekday_probabilities is None:
Expand Down
Loading
Loading