Skip to content

Commit

Permalink
caching numba compiled code
Browse files Browse the repository at this point in the history
  • Loading branch information
GertjanBisschop committed Jan 30, 2023
1 parent 60079b9 commit ca68fa3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 24 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Changelog

## [0.0.2] - 2023-XX-XX

- Caching `numba` compiled code. ({issue}`3`, {pr}`5`, {user}`GertjanBisschop`)

- Solving bug in episilon-based logic for whenever `ZeroDivisionErrors` occur during graph evaluation. ({issue}`2`, {pr}`4`, {user}`GertjanBisschop`)
48 changes: 24 additions & 24 deletions agemo/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

# numerical compensation algorithms
# algorithm from Ogita et al. 2005. Accurate sum and dot product. Journal of Scientific Computing
@numba.njit(numba.types.UniTuple(numba.float64, 2)(numba.float64, numba.float64))
@numba.njit(numba.types.UniTuple(numba.float64, 2)(numba.float64, numba.float64), cache=True)
def two_sum(a, b):
x = a + b
y = x - a
e = a - (x - y) + (b - y)
return x, e


@numba.njit(numba.float64(numba.float64[:]))
@numba.njit(numba.float64(numba.float64[:]), cache=True)
def casc_sum(arr):
s, t = 0, 0
for x in arr:
Expand All @@ -32,7 +32,7 @@ def casc_sum(arr):
numba.float64(numba.int16[:], numba.float64[:]),
numba.float64(numba.int32[:], numba.float64[:]),
numba.float64(numba.int64[:], numba.float64[:]),
]
], cache=True
)
def casc_dot_product(A, B):
s, t = 0, 0
Expand All @@ -53,7 +53,7 @@ def casc_dot_product(A, B):
numba.float64(numba.int16[:], numba.float64[:]),
numba.float64(numba.int32[:], numba.float64[:]),
numba.float64(numba.int64[:], numba.float64[:]),
]
], cache=True
)
def simple_dot_product(A, B):
m = A.size
Expand All @@ -74,7 +74,7 @@ def simple_dot_product(A, B):
numba.float64(numba.int16[:], numba.float64[:], numba.int64),
numba.float64(numba.int32[:], numba.float64[:], numba.int64),
numba.float64(numba.int64[:], numba.float64[:], numba.int64),
]
], cache=True
)
def simple_dot_product_setback(A, B, setback):
m = A.size
Expand All @@ -84,7 +84,7 @@ def simple_dot_product_setback(A, B, setback):
return s


@numba.njit
@numba.njit(cache=True)
def taylor_coeff_inverse_polynomial_legacy(
denom, var_array, diff_array, num_branchtypes, dot_product
):
Expand All @@ -102,7 +102,7 @@ def taylor_coeff_inverse_polynomial_legacy(
return (-1) ** (total_diff_count) * nomd / denomd


@numba.njit
@numba.njit(cache=True)
def taylor_coeff_inverse_polynomial(
denom, theta, diff_array, num_branchtypes, dot_product, mutypes_shape
):
Expand All @@ -125,7 +125,7 @@ def taylor_coeff_inverse_polynomial(
return (-1) ** (total_diff_count) * nomd / denomd


@numba.njit
@numba.njit(cache=True)
def taylor_coeff_exponential_legacy(
c, f, exponential_part, diff_array, num_branchtypes
):
Expand All @@ -139,7 +139,7 @@ def taylor_coeff_exponential_legacy(
return p1 * exponential_part / fact


@numba.njit
@numba.njit(cache=True)
def taylor_coeff_exponential(
c, f, dot_product, diff_array, num_branchtypes, theta, mutypes_shape
):
Expand All @@ -160,7 +160,7 @@ def taylor_coeff_exponential(


# combining taylor series
@numba.njit
@numba.njit(cache=True)
def series_product_legacy(arr1, arr2, subsetdict):
# arr1*arr2
shape = arr1.shape
Expand All @@ -174,7 +174,7 @@ def series_product_legacy(arr1, arr2, subsetdict):
return result


@numba.njit
@numba.njit(cache=True)
def series_product(arr1, arr2, subsetdict):
# arr1*arr2
size = arr1.size
Expand All @@ -185,7 +185,7 @@ def series_product(arr1, arr2, subsetdict):
return result


@numba.njit
@numba.njit(cache=True)
def series_quotient_legacy(arr1, arr2, subsetdict):
# arr1/arr2
shape = arr1.shape
Expand All @@ -206,7 +206,7 @@ def series_quotient_legacy(arr1, arr2, subsetdict):
[
numba.int64(numba.int64[:], numba.int64[:]),
numba.int64(numba.uint64[:], numba.uint64[:]),
]
], cache=True
)
def ravel_multi_index(multi_index, shape):
shape_prod = np.cumprod(shape[:0:-1])[::-1]
Expand Down Expand Up @@ -238,7 +238,7 @@ def return_strictly_smaller_than_idx(idx, shape):


# making subsetdict with marginals
@numba.njit
@numba.njit(cache=True)
def increment_marginal(arr, idx, max_value, reset_value):
if idx < 0:
return -1
Expand All @@ -254,7 +254,7 @@ def increment_marginal(arr, idx, max_value, reset_value):
return result


@numba.njit
@numba.njit(cache=True)
def return_smaller_than_idx_marg(start, max_value, shape):
reset_value = start.copy()
yield ravel_multi_index(start, shape)
Expand Down Expand Up @@ -327,15 +327,15 @@ def quotient_subsetdict(shape):


# deconstructing equations:
@numba.njit
@numba.njit(cache=True)
def quotient_f_g(subsetdict, f, g):
result = np.zeros_like(f)
for idx, (fs, gs) in enumerate(zip(f, g)):
result[idx] = series_quotient_legacy(fs, gs, subsetdict)
return result


@numba.njit
@numba.njit(cache=True)
def product_f(subsetdict, f):
if len(f) == 1:
return f[0]
Expand All @@ -346,7 +346,7 @@ def product_f(subsetdict, f):
return result


@numba.njit
@numba.njit(cache=True)
def product_f_g(subsetdict, f, g, signs):
if g.shape[0] == 0:
return signs * f
Expand All @@ -357,7 +357,7 @@ def product_f_g(subsetdict, f, g, signs):
return result


@numba.njit
@numba.njit(cache=True)
def all_polynomials_legacy(eq_matrix, shape, var_array, num_branchtypes, mutype_shape):
num_equations = eq_matrix.shape[0]
if num_equations == 0:
Expand Down Expand Up @@ -407,7 +407,7 @@ def all_polynomials(
return result


@numba.njit
@numba.njit(cache=True)
def all_exponentials_legacy(
eq_matrix, shape, var_array, time, num_branchtypes, mutype_shape
):
Expand Down Expand Up @@ -448,7 +448,7 @@ def all_exponentials(
return result


@numba.njit
@numba.njit(cache=True)
def product_pairwise_diff_inverse_polynomial(polynomial_f, shape, combos, subsetdict):
if shape[0] <= 1:
return polynomial_f
Expand Down Expand Up @@ -881,7 +881,7 @@ def taylor_to_probability_coeffs(mutype_array, mutype_shape, include_marginals=F
return temp


@numba.njit
@numba.njit(cache=True)
def taylor_to_probability(precomp, theta):
max_idx = np.max(precomp) + 1
temp = np.zeros(max_idx, dtype=np.float64)
Expand Down Expand Up @@ -960,7 +960,7 @@ def iterate_graph(sequence, graph, adjacency_matrix, evaluated_eqs, subsetdict):
return node_values


@numba.njit
@numba.njit(cache=True)
def iterate_eq_graph_legacy(sequence, graph, evaluated_eqs, subsetdict):
shape = evaluated_eqs[0].shape
num_nodes = len(sequence)
Expand All @@ -981,7 +981,7 @@ def iterate_eq_graph_legacy(sequence, graph, evaluated_eqs, subsetdict):
return node_values[0]


@numba.njit
@numba.njit(cache=True)
def iterate_eq_graph(sequence, graph, evaluated_eqs, subsetdict):
size = len(evaluated_eqs[0])
num_nodes = len(sequence)
Expand Down

0 comments on commit ca68fa3

Please sign in to comment.