Skip to content

Commit

Permalink
Made use of numba vs c a bit more dynamic and automatic.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathan Bloedow committed Aug 19, 2024
1 parent 1bf4c7f commit bbd0d7a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 24 deletions.
32 changes: 19 additions & 13 deletions nnmm/mods/maternal_immunity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,46 @@
# # Maternal Immunity (Waning)
# All newborns come into the world with susceptibility=0. They call get a 6month timer. When that timer hits 0, they become susceptible.

use_nb = True
def init( model, istart, iend ):
# enable this after adding susceptibility property to the population (see cells below)
model.population.susceptibility[istart:iend] = 0 # newborns have maternal immunity
model.population.susceptibility_timer[istart:iend] = int(0.5*365) # 6 months

"""
# Define the function to decrement susceptibility_timer and update susceptibility
@nb.njit((nb.uint32, nb.uint8[:], nb.uint8[:]), parallel=True)
def _update_susceptibility_based_on_sus_timer(count, susceptibility_timer, susceptibility):
def _update_susceptibility_based_on_sus_timer_nb(count, susceptibility_timer, susceptibility):
for i in nb.prange(count):
if susceptibility_timer[i] > 0:
susceptibility_timer[i] -= 1
if susceptibility_timer[i] == 0:
susceptibility[i] = 1
"""


try:
# Load the shared library
lib = ctypes.CDLL('./libmi.so')
lib = ctypes.CDLL('./libmi.so')

# Define the function prototype
lib.update_susceptibility_based_on_sus_timer.argtypes = [ctypes.c_uint32,
ctypes.POINTER(ctypes.c_uint8),
ctypes.POINTER(ctypes.c_uint8)]
lib.update_susceptibility_based_on_sus_timer.restype = None
lib.update_susceptibility_based_on_sus_timer.argtypes = [ctypes.c_uint32,
ctypes.POINTER(ctypes.c_uint8),
ctypes.POINTER(ctypes.c_uint8)]
lib.update_susceptibility_based_on_sus_timer.restype = None
use_nb = False
except Exception as ex:
print( "Failed to load libmi.so. Will use numba." )

# Example usage

def _update_susceptibility_based_on_sus_timer(count, susceptibility_timer, susceptibility):
lib.update_susceptibility_based_on_sus_timer(count,
"""
def _update_susceptibility_based_on_sus_timer_c(count, susceptibility_timer, susceptibility):
lib._update_susceptibility_based_on_sus_timer(count,
susceptibility_timer.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)),
susceptibility.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)))
"""

def do_susceptibility_decay(model, tick):
_update_susceptibility_based_on_sus_timer(model.population.count, model.population.susceptibility_timer, model.population.susceptibility)
if use_nb:
_update_susceptibility_based_on_sus_timer_nb(model.population.count, model.population.susceptibility_timer, model.population.susceptibility)
else:
lib._update_susceptibility_based_on_sus_timer_c(model.population.count, model.population.susceptibility_timer, model.population.susceptibility)
return
29 changes: 18 additions & 11 deletions nnmm/mods/ri.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,22 @@
# ### "Step-Function"
# Timers get counted down each timestep and when they reach 0, susceptibility is set to 0.

lib = ctypes.CDLL('./libri.so')
use_nb = True
try:
lib = ctypes.CDLL('./libri.so')

# Define the argument types for the C function
lib.update_susceptibility_based_on_ri_timer.argtypes = [
ctypes.c_uint32, # count
np.ctypeslib.ndpointer(dtype=np.uint16, ndim=1, flags='C_CONTIGUOUS'), # ri_timer
np.ctypeslib.ndpointer(dtype=np.uint8, ndim=1, flags='C_CONTIGUOUS'), # susceptibility
#np.ctypeslib.ndpointer(dtype=np.uint16, ndim=1, flags='C_CONTIGUOUS'), # age_at_vax
np.ctypeslib.ndpointer(dtype=np.int32, ndim=1, flags='C_CONTIGUOUS'), # dob
ctypes.c_int64 # tick
]
lib.update_susceptibility_based_on_ri_timer.argtypes = [
ctypes.c_uint32, # count
np.ctypeslib.ndpointer(dtype=np.uint16, ndim=1, flags='C_CONTIGUOUS'), # ri_timer
np.ctypeslib.ndpointer(dtype=np.uint8, ndim=1, flags='C_CONTIGUOUS'), # susceptibility
#np.ctypeslib.ndpointer(dtype=np.uint16, ndim=1, flags='C_CONTIGUOUS'), # age_at_vax
np.ctypeslib.ndpointer(dtype=np.int32, ndim=1, flags='C_CONTIGUOUS'), # dob
ctypes.c_int64 # tick
]
use_nb = False
except Exception as ex:
print( "Failed to load libri.so. Will use numba." )


def add(model, count_births, istart, iend):
Expand Down Expand Up @@ -117,7 +122,9 @@ def _update_susceptibility_based_on_ri_timer(count, ri_timer, susceptibility, do
#_update_susceptibility_based_on_ri_timer(count, ri_timer, susceptibility, dob, tick)

def do_ri(model, tick):
#lib.update_susceptibility_based_on_ri_timer(count, ri_timer, susceptibility, dob, tick)
_update_susceptibility_based_on_ri_timer(model.population.count, model.population.ri_timer, model.population.susceptibility, model.population.dob, tick)
if use_nb:
_update_susceptibility_based_on_ri_timer(model.population.count, model.population.ri_timer, model.population.susceptibility, model.population.dob, tick)
else:
lib.update_susceptibility_based_on_ri_timer(count, ri_timer, susceptibility, dob, tick)
return

0 comments on commit bbd0d7a

Please sign in to comment.