diff --git a/src/idmlaser_cholera/cholera.py b/src/idmlaser_cholera/cholera.py index ba40a57..361e09f 100644 --- a/src/idmlaser_cholera/cholera.py +++ b/src/idmlaser_cholera/cholera.py @@ -120,7 +120,6 @@ class Model: from idmlaser_cholera.mods import init_prev from idmlaser_cholera.mods import transmission from idmlaser_cholera.mods import intrahost -from idmlaser_cholera.mods import maternal_immunity as mi from idmlaser_cholera.mods import ri from idmlaser_cholera.mods import sia from idmlaser_cholera.mods import fertility @@ -275,7 +274,7 @@ def check_for_cached(): #intrahost.step2, # type: ignore transmission.step, # type: ignore #ri.step, # type: ignore - mi.step, # type: ignore + immunity.step, # type: ignore #sia.step, # type: ignore ] @@ -292,7 +291,7 @@ def check_for_cached(): """ #""" - if tick == 40: + if tick == 40: # outbreak/seeding init_prev.init( model ) #""" metrics = [tick] diff --git a/src/idmlaser_cholera/mods/immunity.py b/src/idmlaser_cholera/mods/immunity.py index 2d3cd2a..5ae3173 100644 --- a/src/idmlaser_cholera/mods/immunity.py +++ b/src/idmlaser_cholera/mods/immunity.py @@ -1,6 +1,11 @@ import numpy as np import numba as nb +global use_nb +use_nb = True +global lib +lib = None + # initialize susceptibility based on age @nb.njit((nb.uint32, nb.int32[:], nb.uint8[:], nb.uint16[:]), parallel=True) def initialize_susceptibility(count, dob, susceptibility, susceptibility_timer): @@ -20,4 +25,67 @@ def initialize_susceptibility(count, dob, susceptibility, susceptibility_timer): return def init( model ): - return initialize_susceptibility( model.population.count, model.population.dob, model.population.susceptibility, model.population.susceptibility_timer ) + initialize_susceptibility( model.population.count, model.population.dob, model.population.susceptibility, model.population.susceptibility_timer ) + try: + # Load the shared library + shared_lib_path = resource_filename('idmlaser_cholera', 'mods/libmi.so') + lib = ctypes.CDLL(shared_lib_path) + + # Define the function prototype + lib.update_susceptibility_based_on_sus_timer.argtypes = [ + ctypes.c_int32, + np.ctypeslib.ndpointer(dtype=np.uint16, ndim=1, flags='C_CONTIGUOUS'), # susceptibility_timer + np.ctypeslib.ndpointer(dtype=np.uint8, ndim=1, flags='C_CONTIGUOUS'), # susceptibility + ] + + lib.update_susceptibility_based_on_sus_timer.restype = None + + lib.update_susceptibility_timer_strided_shards.argtypes = [ + ctypes.c_int32, + np.ctypeslib.ndpointer(dtype=np.uint16, ndim=1, flags='C_CONTIGUOUS'), # susceptibility_timer + np.ctypeslib.ndpointer(dtype=np.uint8, ndim=1, flags='C_CONTIGUOUS'), # susceptibility + ctypes.c_int32, + ctypes.c_int32, + ] + use_nb = False + print( "maternal immunity component initialized. Will use compiled C." ) + except Exception as ex: + print( "Failed to load libmi.so. Will use numba." ) + + +# Define the function to decrement susceptibility_timer and update susceptibility +@nb.njit((nb.uint32, nb.uint16[:], nb.uint8[:], nb.uint8, nb.uint8 ), parallel=True) +def _update_susceptibility_based_on_sus_timer_nb(count, susceptibility_timer, susceptibility, tick, delta): + shard_size = count // delta + + # Determine the start and end indices for the current shard + shard_index = tick % delta + start_index = shard_index * shard_size + end_index = start_index + shard_size + + # Handle the case where the last shard might be slightly larger due to rounding + if shard_index == delta - 1: + end_index = count + + # Loop through the current shard + for i in nb.prange(start_index, end_index): + if susceptibility_timer[i] > 0: + susceptibility_timer[i] = max(0, susceptibility_timer[i] - delta) + if susceptibility_timer[i] <= 0: + susceptibility[i] = 1 + +delta = 8 +def step(model, tick): + + global lib, use_nb + if use_nb: + _update_susceptibility_based_on_sus_timer_nb(model.population.count, model.population.susceptibility_timer, model.population.susceptibility, tick, delta) + else: + lib.update_susceptibility_timer_strided_shards( + model.population.count, + model.population.susceptibility_timer, + model.population.susceptibility, + delta, + tick + ) + return