Skip to content

Commit

Permalink
do_transmission_update now works more efficiently. First it calls rep…
Browse files Browse the repository at this point in the history
…ort, which does the SEIRW census but also creates the node-to-susceptibles map since it makes no sense to do a second susceptible census. That map is now a static varible in the ctypes extension module and used in the tx_inner_loop, and that function is now shorter and simpler. The calculate_new_infections is still done in between but we don't recalculate the number of infections per node since we already have that.
  • Loading branch information
Jonathan Bloedow committed Oct 7, 2024
1 parent 3e94b2d commit 4166d9d
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 47 deletions.
132 changes: 101 additions & 31 deletions nnmm/tx.c
Original file line number Diff line number Diff line change
Expand Up @@ -172,20 +172,22 @@ void tx_inner_nodes_v1(
}
}

//static std::vector<std::vector<int>> local_node2sus(num_nodes);
//This is clearly a nasty hardcoding. TBD, don't do this.
//We are making this static so the container can be written in report and read in tx_inner_nodes.
static std::vector<std::vector<int>> local_node2sus(419);

// This function now assumes that report has been called first. But there is no check for that yet.
// report is what populates local_node2sus so we don't have to do a second census of susceptibles.
void tx_inner_nodes(
uint32_t count,
unsigned int num_nodes,
uint16_t * agent_node,
uint8_t * susceptibility,
uint8_t * incubation_timer,
uint8_t * infection_timer,
uint16_t * new_infections_array,
float incubation_period_constant,
uint32_t * infected_ids // Output: an array of arrays for storing infected IDs
) {
// Local maps for each thread
std::vector<std::vector<int>> local_node2sus(num_nodes);

uint32_t offsets[num_nodes]; // To store starting index for each node

// Calculate offsets
Expand All @@ -194,39 +196,18 @@ void tx_inner_nodes(
offsets[node] = offsets[node - 1] + new_infections_array[node - 1];
}

// First pass: gather susceptible individuals by node in parallel
#pragma omp parallel
{
// Thread-local buffers to collect susceptible indices by node
std::vector<std::vector<int>> thread_local_node2sus(num_nodes);

#pragma omp for nowait
for (unsigned long int i = 0; i < count; ++i) {
if (susceptibility[i] == 1) {
int node = agent_node[i];
thread_local_node2sus[node].push_back(i);
}
}

// Combine thread-local results
#pragma omp critical
{
for (unsigned int node = 0; node < num_nodes; ++node) {
local_node2sus[node].insert(local_node2sus[node].end(),
thread_local_node2sus[node].begin(),
thread_local_node2sus[node].end());
}
}
}

// Second pass: Infect individuals by node in parallel
#pragma omp parallel for schedule(dynamic)
for (unsigned int node = 0; node < num_nodes; ++node) {
unsigned int new_infections = new_infections_array[node];
//printf( "Finding %d new infections in node %d\n", new_infections, node );

if (new_infections > 0) {
std::vector<int> &susceptible_indices = local_node2sus[node];
int num_susceptible = susceptible_indices.size();
/*if( num_susceptible == 0 ) {
printf( "WARNING: 0 susceptibles in node!\n" );
}*/
int step = (new_infections >= num_susceptible) ? 1 : num_susceptible / new_infections;

// Get the starting index for this node's infections
Expand All @@ -238,12 +219,101 @@ void tx_inner_nodes(
susceptibility[selected_id] = 0;
selected_count++;
// Write the infected ID into the pre-allocated array
//printf( "Writing new infected id to index %d.\n", start_index + selected_count );
//printf( "Writing new infected id to index %d for node %d.\n", start_index + selected_count, node );
infected_ids[start_index + selected_count] = selected_id;
}
}
}
}

void report(
unsigned long int count,
int num_nodes,
int32_t *age,
uint16_t *node,
unsigned char *infectious_timer, // max 255
unsigned char *incubation_timer, // max 255
bool *susceptibility, // yes or no
unsigned char *susceptibility_timer, // max 255
int *dod, // sim day
uint32_t *susceptible_count,
uint32_t *incubating_count,
uint32_t *infectious_count,
uint32_t *waning_count,
uint32_t *recovered_count,
unsigned int delta = 1
) {
//printf( "%s: count=%ld, num_nodes=%d", __FUNCTION__, count, num_nodes );
#pragma omp parallel
{
std::vector<std::vector<int>> thread_local_node2sus(num_nodes);
for (unsigned int node = 0; node < num_nodes; ++node) {
local_node2sus[node].clear(); // Clear before inserting new data
}

// Thread-local buffers
int *local_infectious_count = (int*) calloc(num_nodes, sizeof(int));
int *local_incubating_count = (int*) calloc(num_nodes, sizeof(int));
int *local_recovered_count = (int*) calloc(num_nodes, sizeof(int));
int *local_susceptible_count = (int*) calloc(num_nodes, sizeof(int));
int *local_waning_count = (int*) calloc(num_nodes, sizeof(int));

#pragma omp for
for (size_t i = 0; i <= count; i++) {
// Collect report
if (dod[i]>0) {
int node_id = node[i];
//printf( "Found live person at node %d: etimer=%d, itimer=%d, sus=%d.\n", node_id, incubation_timer[i], infectious_timer[i], susceptibility[i] );
if (incubation_timer[i] > 0) {
//printf( "Found E in node %d.\n", node_id );
local_incubating_count[node_id]++;
} else if (infectious_timer[i] > 0) {
//printf( "Found I in node %d.\n", node_id );
local_infectious_count[node_id]++;
} else if (susceptibility[i]==0) {
//printf( "Found R in node %d.\n", node_id );
if (susceptibility_timer[i]>0) {
local_waning_count[node_id]++;
} else {
//printf( "ERROR? recording %lu as recovered: susceptibility_timer = %d.\n", i, susceptibility_timer[i] );
local_recovered_count[node_id]++;
}
} else {
//printf( "Found S in node %d.\n", node_id );
local_susceptible_count[node_id]++;
thread_local_node2sus[node_id].push_back(i);
}
}
}

// Combine thread-local results
#pragma omp critical
{
for (unsigned int node = 0; node < num_nodes; ++node) {
local_node2sus[node].insert(local_node2sus[node].end(),
thread_local_node2sus[node].begin(),
thread_local_node2sus[node].end());
}
}

// Combine local counts into global counts
#pragma omp critical
{
for (int j = 0; j < num_nodes; ++j) {
susceptible_count[j] += local_susceptible_count[j];
incubating_count[j] += local_incubating_count[j];
infectious_count[j] += local_infectious_count[j];
waning_count[j] += local_waning_count[j];
recovered_count[j] += local_recovered_count[j];
}
}

// Free local buffers
free(local_susceptible_count);
free(local_incubating_count);
free(local_infectious_count);
free(local_waning_count);
free(local_recovered_count);
}
}
} // extern C (for C++)
70 changes: 54 additions & 16 deletions src/idmlaser_cholera/mods/transmission.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#infected_ids_type = ctypes.POINTER(ctypes.c_uint32)

# Define the maximum number of infections you expect
MAX_INFECTIONS = 10000000 # Adjust this to your expected maximum
MAX_INFECTIONS = 100000000 # Adjust this to your expected maximum

# Allocate a flat array for infected IDs
infected_ids_buffer = (ctypes.c_uint32 * (MAX_INFECTIONS))()
Expand Down Expand Up @@ -101,17 +101,29 @@ def init( model ):
lib.tx_inner_nodes.argtypes = [
ctypes.c_uint32, # count
ctypes.c_uint32, # num_nodes
np.ctypeslib.ndpointer(dtype=np.uint16, ndim=1, flags='C_CONTIGUOUS'), # nodeids
np.ctypeslib.ndpointer(dtype=np.uint8, ndim=1, flags='C_CONTIGUOUS'), # susceptibility
np.ctypeslib.ndpointer(dtype=np.uint8, ndim=1, flags='C_CONTIGUOUS'), # itimers
np.ctypeslib.ndpointer(dtype=np.uint8, ndim=1, flags='C_CONTIGUOUS'), # etimers
np.ctypeslib.ndpointer(dtype=np.uint16, ndim=1, flags='C_CONTIGUOUS'), # new_infections,
ctypes.c_float, # exp_mean
#ctypes.c_float, # exp_std
#np.ctypeslib.ndpointer(dtype=np.uint32, ndim=1, flags='C_CONTIGUOUS'), # new_ids_out,
#ctypes.POINTER(ctypes.POINTER(ctypes.c_uint32)), # new_ids_out
ctypes.POINTER(ctypes.c_uint32) # new_ids_out (pointer to uint32)
]
lib.report.argtypes = [
ctypes.c_int64, # count
ctypes.c_int, # num_nodes
np.ctypeslib.ndpointer(dtype=np.int32, flags='C_CONTIGUOUS'), # age
np.ctypeslib.ndpointer(dtype=np.uint16, flags='C_CONTIGUOUS'), # node
np.ctypeslib.ndpointer(dtype=np.uint8, flags='C_CONTIGUOUS'), # infectious_timer
np.ctypeslib.ndpointer(dtype=np.uint8, flags='C_CONTIGUOUS'), # incubation_timer
np.ctypeslib.ndpointer(dtype=np.uint8, flags='C_CONTIGUOUS'), # immunity
np.ctypeslib.ndpointer(dtype=np.uint16, flags='C_CONTIGUOUS'), # susceptibility_timer
np.ctypeslib.ndpointer(dtype=np.int32, flags='C_CONTIGUOUS'), # expected_lifespan
np.ctypeslib.ndpointer(dtype=np.uint32, flags='C_CONTIGUOUS'), # infectious_count
np.ctypeslib.ndpointer(dtype=np.uint32, flags='C_CONTIGUOUS'), # incubating_count
np.ctypeslib.ndpointer(dtype=np.uint32, flags='C_CONTIGUOUS'), # susceptible_count
np.ctypeslib.ndpointer(dtype=np.uint32, flags='C_CONTIGUOUS'), # waning_count
np.ctypeslib.ndpointer(dtype=np.uint32, flags='C_CONTIGUOUS'), # recovered_count
ctypes.c_int # delta
]
global use_nb
use_nb = False
except Exception as ex:
Expand Down Expand Up @@ -234,7 +246,8 @@ def get_enviro_beta_from_psi( beta_env0, psi ):
beta_env = beta_env0 * (1 + (psi - psi_avg[-1]) / psi_avg[-1])
return beta_env


# Sometimes I think it might be faster not numba-ing this function but I
# want to try a compiled C version of it at some point.
@nb.njit(
nb.float32[:](
nb.float32[:],
Expand Down Expand Up @@ -382,13 +395,31 @@ def calculate_new_infections_by_node(total_forces, susceptibles):

def do_transmission_update(model, tick) -> None:

delta = 1
nodes = model.nodes
population = model.population

global lib
lib.report(
len(population),
len(nodes),
model.population.age,
model.population.nodeid,
model.population.itimer,
model.population.etimer,
model.population.susceptibility,
model.population.susceptibility_timer,
model.population.dod,
model.nodes.S[tick],
model.nodes.E[tick],
model.nodes.I[tick],
model.nodes.W[tick],
model.nodes.R[tick],
delta
)

contagion = nodes.cases[:, tick].astype(np.float32) # we will accumulate current infections into this array
nodeids = population.nodeid[:population.count] # just look at the active agent indices
itimers = population.itimer[:population.count] # just look at the active agent indices
np.add.at(contagion, nodeids[itimers > 0], 1) # increment by the number of active agents with non-zero itimer
contagion += model.nodes.I[tick]

network = nodes.network
transfer = (contagion * network).round().astype(np.uint32)
Expand Down Expand Up @@ -433,13 +464,14 @@ def do_transmission_update(model, tick) -> None:
#total_forces = forces

new_infections = calculate_new_infections_by_node(total_forces, model.nodes.S[tick])


total_infections = np.sum(new_infections)
#print( f"total new infections={total_infections}" )
if total_infections > MAX_INFECTIONS:
raise ValueError( f"Number of new infections ({total_infections}) > than allocated array size (MAX_INFECTIONS)!" )
raise ValueError( f"Number of new infections ({total_infections}) > than allocated array size ({MAX_INFECTIONS})!" )

if use_nb:
#calculated_incidence =
tx_inner_nodes(
population.susceptibility,
population.nodeid,
Expand All @@ -453,15 +485,11 @@ def do_transmission_update(model, tick) -> None:
)
else:
num_nodes = len(new_infections) # Assume number of nodes is the length of new_infections_by_node

global lib
lib.tx_inner_nodes(
population.count,
num_nodes,
population.nodeid, # uint32_t * agent_node,
population.susceptibility,# uint8_t *susceptibility,
population.etimer,# unsigned char * incubation_timer,
population.itimer,# unsigned char * infection_timer,
new_infections, # int * new_infections_array,
model.params.exp_mean, # unsigned char incubation_period_constant
infected_ids_buffer
Expand All @@ -484,3 +512,13 @@ def report_linelist():

return

"""
Latest design thought:
- Skip current census reporting in ages code.
- Do census to get number of active infections by node, and number of susceptibles by node. Probably in C. save this somewhere.
- Calculate foi in Python
- Calculate number of new infections by node. Python.
- Do new infections.
- See if this can all be time-sharded. Divide population into 7 or 8. Process 1/7 of population each day of the week.
"""

0 comments on commit 4166d9d

Please sign in to comment.