-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #113 from broadinstitute/irwin-math
Use math.erf, math.gamma, and math.lgamma to simplify computations
- Loading branch information
Showing
1 changed file
with
20 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
'''A few pure-python statistical tools to avoid the need to install scipy. ''' | ||
from __future__ import division # Division of integers with / should never round! | ||
from math import exp, log, pi, sqrt | ||
from math import exp, log, pi, sqrt, gamma, lgamma, erf | ||
import itertools | ||
|
||
__author__ = "[email protected], [email protected]" | ||
|
@@ -99,6 +99,8 @@ def fisher_exact(contingencyTable) : | |
raise ValueError('Not all rows have the same length') | ||
if any(x != int(x) for row in contingencyTable for x in row) : | ||
raise ValueError('Some table entry is not an integer') | ||
if any(x < 0 for row in contingencyTable for x in row) : | ||
raise ValueError('Some table entry is negative') | ||
|
||
# Eliminate rows and columns with 0 sum | ||
colSums = [sum(row[col] for row in contingencyTable) | ||
|
@@ -150,65 +152,27 @@ def prob_of_table(firstRow) : | |
|
||
return result | ||
|
||
def log_stirling(n) : | ||
"""Return Stirling's approximation for log(n!) using up to n^7 term. | ||
Provides exact right answer (when rounded to int) for 16! and lower. | ||
Correct to 10 digits for 5! and above.""" | ||
n2 = n * n | ||
n3 = n * n2 | ||
n5 = n3 * n2 | ||
n7 = n5 * n2 | ||
return n * log(n) - n + 0.5 * log(2 * pi * n) + \ | ||
1 / 12.0 / n - 1 / 360.0 / n3 + 1 / 1260.0 / n5 - 1 / 1680.0 / n7 | ||
|
||
def log_choose(n, k) : | ||
k = min(k, n - k) | ||
if k <= 10 : | ||
result = 0.0 | ||
for ii in range(1, k + 1) : | ||
result += log(n - ii + 1) - log(ii) | ||
else : | ||
result = log_stirling(n) - log_stirling(k) - log_stirling(n-k) | ||
return result | ||
# Return log(n choose k). Compute using lgamma(x + 1) = log(x!) | ||
if not (0 <= k <=n) : | ||
raise ValueError('%d is negative or more than %d' % (k, n)) | ||
return lgamma(n + 1) - lgamma(k + 1) - lgamma(n - k + 1) | ||
|
||
def factorial(n) : | ||
"""Return n factorial exactly up to n = 16, otherwise approximate to 14 digits. | ||
n must be a non-negative integer.""" | ||
if n < 0 or n != int(n) : | ||
raise ValueError('%s is not a non-negative integer' % n) | ||
return int(round(exp(log_stirling(n)))) if n > 0 else 1 | ||
|
||
def gamma(s) : | ||
""" Gamma function = integral from 0 to infinity of t ** (s-1) exp(-t) dt. | ||
Implemented only for s >= 0, | ||
""" | ||
# Accurate to better than 1 in 1e11 | ||
# scipy equivalent: scipy.special.gamma(s) | ||
if s <= 0 : | ||
raise ValueError('%s is not positive' % s) | ||
if s == int(s) : | ||
return factorial(int(s - 1)) | ||
elif 2 * s == int(2 * s) : | ||
return sqrt(pi) * factorial(int(2 * s - 1)) / factorial(int(s - 0.5)) /\ | ||
4 ** (s - 0.5) | ||
else : | ||
# stirling is more accurate for larger values, so call it for a | ||
# larger value of s and use gamma recursion to get back to lower value. | ||
return exp(log_stirling(s + 9)) / product(s + i for i in range(10)) | ||
|
||
def gammainc(s, x) : | ||
def gammainc_halfint(s, x) : | ||
""" Lower incomplete gamma function = | ||
integral from 0 to x of t ** (s-1) exp(-t) dt divided by gamma(s), | ||
i.e., the fraction of gamma that you get if you integrate only until | ||
x instead of all the way to infinity. | ||
Implemented only for s > 0. | ||
Implemented here only if s is a positive multiple of 0.5. | ||
""" | ||
# scipy equivalent: scipy.special.gammainc(s,x) | ||
|
||
if s <= 0 : | ||
raise ValueError('%s is not positive' % s) | ||
if x < 0 : | ||
raise ValueError('%s < 0' % x) | ||
if s * 2 != int(s * 2) : | ||
raise NotImplementedError('%s is not a multiple of 0.5' % s) | ||
|
||
# Handle integers analytically | ||
if s == int(s) : | ||
|
@@ -219,26 +183,20 @@ def gammainc(s, x) : | |
total += term | ||
return 1 - exp(-x) * total | ||
|
||
# Otherwise use infinite series: | ||
# gammainc(s,x) = x ** s * exp(-x) / s / gamma(s) * | ||
# sum_k=0_to_infinity(x ** k / product_j=1_to_k(s + j)) | ||
# which follows from the recursion formula: | ||
# gammainc(s, x) = gammainc(s - 1, x) - x ** (s - 1) * exp(-x) / gamma(s) | ||
relTol = 1e-15 | ||
term = 1 | ||
total = 1 | ||
for k in itertools.count(1) : | ||
term *= x / (s + k) | ||
total += term | ||
if term <= relTol * total : | ||
break | ||
return min(1.0, total * x ** s * exp(-x) / s / gamma(s)) | ||
# Otherwise s is integer + 0.5. Decrease to 0.5 using recursion formula: | ||
result = 0.0 | ||
while s > 1 : | ||
result -= x ** (s - 1) * exp(-x) / gamma(s) | ||
s = s - 1 | ||
# Then use gammainc(0.5, x) = erf(sqrt(x)) | ||
result += erf(sqrt(x)) | ||
return result | ||
|
||
def pchisq(x, k) : | ||
"Cumulative distribution function of chi squared with k degrees of freedom." | ||
if k < 1 or k != int(k) : | ||
raise ValueError('%s is not a positive integer' % k) | ||
if x < 0 : | ||
raise ValueError('%s < 0' % x) | ||
return gammainc(k / 2, x / 2) | ||
return gammainc_halfint(k / 2, x / 2) | ||
|