Skip to content

Commit

Permalink
Merge pull request #113 from broadinstitute/irwin-math
Browse files Browse the repository at this point in the history
Use math.erf, math.gamma, and math.lgamma to simplify computations
  • Loading branch information
dpark01 committed Mar 14, 2015
2 parents 0f21fae + 587576a commit adf364e
Showing 1 changed file with 20 additions and 62 deletions.
82 changes: 20 additions & 62 deletions util/stats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'''A few pure-python statistical tools to avoid the need to install scipy. '''
from __future__ import division # Division of integers with / should never round!
from math import exp, log, pi, sqrt
from math import exp, log, pi, sqrt, gamma, lgamma, erf
import itertools

__author__ = "[email protected], [email protected]"
Expand Down Expand Up @@ -99,6 +99,8 @@ def fisher_exact(contingencyTable) :
raise ValueError('Not all rows have the same length')
if any(x != int(x) for row in contingencyTable for x in row) :
raise ValueError('Some table entry is not an integer')
if any(x < 0 for row in contingencyTable for x in row) :
raise ValueError('Some table entry is negative')

# Eliminate rows and columns with 0 sum
colSums = [sum(row[col] for row in contingencyTable)
Expand Down Expand Up @@ -150,65 +152,27 @@ def prob_of_table(firstRow) :

return result

def log_stirling(n) :
"""Return Stirling's approximation for log(n!) using up to n^7 term.
Provides exact right answer (when rounded to int) for 16! and lower.
Correct to 10 digits for 5! and above."""
n2 = n * n
n3 = n * n2
n5 = n3 * n2
n7 = n5 * n2
return n * log(n) - n + 0.5 * log(2 * pi * n) + \
1 / 12.0 / n - 1 / 360.0 / n3 + 1 / 1260.0 / n5 - 1 / 1680.0 / n7

def log_choose(n, k) :
k = min(k, n - k)
if k <= 10 :
result = 0.0
for ii in range(1, k + 1) :
result += log(n - ii + 1) - log(ii)
else :
result = log_stirling(n) - log_stirling(k) - log_stirling(n-k)
return result
# Return log(n choose k). Compute using lgamma(x + 1) = log(x!)
if not (0 <= k <=n) :
raise ValueError('%d is negative or more than %d' % (k, n))
return lgamma(n + 1) - lgamma(k + 1) - lgamma(n - k + 1)

def factorial(n) :
"""Return n factorial exactly up to n = 16, otherwise approximate to 14 digits.
n must be a non-negative integer."""
if n < 0 or n != int(n) :
raise ValueError('%s is not a non-negative integer' % n)
return int(round(exp(log_stirling(n)))) if n > 0 else 1

def gamma(s) :
""" Gamma function = integral from 0 to infinity of t ** (s-1) exp(-t) dt.
Implemented only for s >= 0,
"""
# Accurate to better than 1 in 1e11
# scipy equivalent: scipy.special.gamma(s)
if s <= 0 :
raise ValueError('%s is not positive' % s)
if s == int(s) :
return factorial(int(s - 1))
elif 2 * s == int(2 * s) :
return sqrt(pi) * factorial(int(2 * s - 1)) / factorial(int(s - 0.5)) /\
4 ** (s - 0.5)
else :
# stirling is more accurate for larger values, so call it for a
# larger value of s and use gamma recursion to get back to lower value.
return exp(log_stirling(s + 9)) / product(s + i for i in range(10))

def gammainc(s, x) :
def gammainc_halfint(s, x) :
""" Lower incomplete gamma function =
integral from 0 to x of t ** (s-1) exp(-t) dt divided by gamma(s),
i.e., the fraction of gamma that you get if you integrate only until
x instead of all the way to infinity.
Implemented only for s > 0.
Implemented here only if s is a positive multiple of 0.5.
"""
# scipy equivalent: scipy.special.gammainc(s,x)

if s <= 0 :
raise ValueError('%s is not positive' % s)
if x < 0 :
raise ValueError('%s < 0' % x)
if s * 2 != int(s * 2) :
raise NotImplementedError('%s is not a multiple of 0.5' % s)

# Handle integers analytically
if s == int(s) :
Expand All @@ -219,26 +183,20 @@ def gammainc(s, x) :
total += term
return 1 - exp(-x) * total

# Otherwise use infinite series:
# gammainc(s,x) = x ** s * exp(-x) / s / gamma(s) *
# sum_k=0_to_infinity(x ** k / product_j=1_to_k(s + j))
# which follows from the recursion formula:
# gammainc(s, x) = gammainc(s - 1, x) - x ** (s - 1) * exp(-x) / gamma(s)
relTol = 1e-15
term = 1
total = 1
for k in itertools.count(1) :
term *= x / (s + k)
total += term
if term <= relTol * total :
break
return min(1.0, total * x ** s * exp(-x) / s / gamma(s))
# Otherwise s is integer + 0.5. Decrease to 0.5 using recursion formula:
result = 0.0
while s > 1 :
result -= x ** (s - 1) * exp(-x) / gamma(s)
s = s - 1
# Then use gammainc(0.5, x) = erf(sqrt(x))
result += erf(sqrt(x))
return result

def pchisq(x, k) :
"Cumulative distribution function of chi squared with k degrees of freedom."
if k < 1 or k != int(k) :
raise ValueError('%s is not a positive integer' % k)
if x < 0 :
raise ValueError('%s < 0' % x)
return gammainc(k / 2, x / 2)
return gammainc_halfint(k / 2, x / 2)

0 comments on commit adf364e

Please sign in to comment.