Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of inverse trigamma #415

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/SpecialFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export
invdigamma,
polygamma,
trigamma,
invtrigamma,
gamma_inc,
beta_inc,
beta_inc_inv,
Expand Down Expand Up @@ -93,7 +94,8 @@ include("chainrules.jl")
include("deprecated.jl")

for f in (:digamma, :erf, :erfc, :erfcinv, :erfcx, :erfi, :erfinv, :logerfc, :logerfcx,
:eta, :gamma, :invdigamma, :logfactorial, :lgamma, :trigamma, :ellipk, :ellipe)
:eta, :gamma, :invdigamma, :invtrigamma, :logfactorial, :lgamma, :trigamma,
:ellipk, :ellipe)
@eval $(f)(::Missing) = missing
end
for f in (:beta, :lbeta)
Expand Down
5 changes: 4 additions & 1 deletion src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ ChainRulesCore.@scalar_rule(
inv(trigamma(invdigamma(x))),
)
ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x))

ChainRulesCore.@scalar_rule(
invtrigamma(x),
inv(polygamma(2, invtrigamma(x))),
)
# Bessel functions
ChainRulesCore.@scalar_rule(
besselj(ν, x),
Expand Down
43 changes: 43 additions & 0 deletions src/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,49 @@ function _invdigamma(y::Float64)
return x_new
end

"""
invtrigamma(x)
Compute the inverse [`trigamma`](@ref) function of `x`.
Comment on lines +402 to +403
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
invtrigamma(x)
Compute the inverse [`trigamma`](@ref) function of `x`.
invtrigamma(x)
Compute the inverse of the [`trigamma`](@ref) function at the positive real value `x`.
This is the solution `y` to the equation `trigamma(y) = x`.

The line break is for consistency with other docstrings. I realize the text is modified from that of invdigamma but in this case I think it's worth noting the restriction on the domain of x. The added line is just for some extra clarity.

"""
invtrigamma(y::Number) = _invtrigamma(float(y))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
invtrigamma(y::Number) = _invtrigamma(float(y))
invtrigamma(x::Number) = _invtrigamma(float(x))

It's kind of breaking my brain that what we're calling x and y are the reverse of how they're used in the paper you linked.


function _invtrigamma(y::Float64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function _invtrigamma(y::Float64)
function _invtrigamma(x::Float64)

(See above)

# Implementation of Newton algorithm described in
# "Linear Models and Empirical Bayes Methods for Assessing
# Differential Expression in Microarray Experiments"
# (Appendix "Inversion of Trigamma Function")
# by Gordon K. Smyth, 2004

if y <= 0
throw(DomainError(y, "Only positive `y` supported."))
end

if y > 1e7
return inv(sqrt(y))
elseif y < 1e-6
return inv(y)
end
Comment on lines +414 to +422
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if y <= 0
throw(DomainError(y, "Only positive `y` supported."))
end
if y > 1e7
return inv(sqrt(y))
elseif y < 1e-6
return inv(y)
end
if x <= 0
throw(DomainError(x, "`x` must be positive."))
elseif x > 1e7
return inv(sqrt(x))
elseif x < 1e-6
return inv(x)
end

This brings the error message more in line with the text used in other DomainError messages in this package, e.g. from the Bessel functions. Condensing the conditional is just for brevity.


x_old = inv(y) + 0.5
x_new = x_old

# Newton iteration
δ = Inf
iteration = 0
while δ > 1e-8 && iteration <= 25
iteration += 1
f_x_old = trigamma(x_old)
δx = f_x_old*(1-f_x_old/y) / polygamma(2, x_old)
x_new = x_old + δx
δ = - δx / x_new
x_old = x_new
end

return x_new
Comment on lines +424 to +439
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x_old = inv(y) + 0.5
x_new = x_old
# Newton iteration
δ = Inf
iteration = 0
while δ > 1e-8 && iteration <= 25
iteration += 1
f_x_old = trigamma(x_old)
δx = f_x_old*(1-f_x_old/y) / polygamma(2, x_old)
x_new = x_old + δx
δ = - δx / x_new
x_old = x_new
end
return x_new
# Newton iteration
invx = inv(x)
y_new = y_old = invx + 0.5
for _ in 1:26
ψ′ = trigamma(y_old)
δ = ψ′ * (1 - ψ′ * invx) / polygamma(2, y_old)
y_new = y_old + δ
-δ / y_new < 1e-8 && break
y_old = y_new
end
return y_new

AFAICT this is equivalent and more directly maps to the paper, as it avoids introducing a second step size variable. You can also avoid dividing by the input at every iteration by inverting once then multiplying by the inverse.

I assume the number of iterations was chosen to match invdigamma. Do you know whether the algorithm generally converges within the given number of iterations and under what circumstances it may not? When it doesn't, how inaccurate is the result?

end


Comment on lines +441 to +442
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Just some excess space


"""
zeta(s)

Expand Down
4 changes: 4 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
test_scalar(invdigamma, x)
end

if x isa Real && x > 0
test_scalar(invtrigamma, x)
end

if x isa Real && 0 < x < 1
test_scalar(erfinv, x)
test_scalar(erfcinv, x)
Expand Down
13 changes: 13 additions & 0 deletions test/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,19 @@
@test abs(invdigamma(2)) == abs(invdigamma(2.))
end

@testset "invtrigamma" begin
for val in [0.001, 0.01, 0.1, 1.0, 10.0]
@test invtrigamma(trigamma(val)) ≈ val
end

for val in [1e-8, 0.001, 0.01, 0.1, 1.0, 10.0, 1e7, 1e9]
@test trigamma(invtrigamma(val)) ≈ val
end

@test_throws DomainError invtrigamma(-1.0)
@test invtrigamma(2) == invtrigamma(2.)
end

@testset "polygamma" begin
@test polygamma(20, 7.) ≈ -4.644616027240543262561198814998587152547
@test polygamma(20, Float16(7.)) ≈ -4.644616027240543262561198814998587152547
Expand Down
4 changes: 2 additions & 2 deletions test/other_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ end

@testset "missing data" begin
for f in (digamma, erf, erfc, erfcinv, erfcx, erfi, erfinv, eta, gamma,
invdigamma, logfactorial, trigamma)
invdigamma, invtrigamma, logfactorial, trigamma)
@test f(missing) === missing
end
@test beta(1.0, missing) === missing
Expand All @@ -90,7 +90,7 @@ end
for n in numbers
@test abs(n) == SpecialFunctions.fastabs(n)
end

numbers = [1im, 2 + 2im, 0 + 100im, 1e3 + 1e-10im]
for n in numbers
@test abs(real(n)) + abs(imag(n)) == SpecialFunctions.fastabs(n)
Expand Down