From 13e231a0a22e716426b73cb87ff3b8b24e33aaf1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 24 Jan 2022 22:52:28 +0100 Subject: [PATCH] Promote arguments of `norminvcdf` and `norminvccdf` (#132) * Promote arguments of `norminvcdf` and `norminvccdf` * Simplify promotions * Improve clarity of tests Co-authored-by: Seth Axen * Simplify tests * Add method with `rtol` * Simplification Co-authored-by: Seth Axen --- Project.toml | 2 +- src/distrs/norm.jl | 26 +++++++++++++++----------- test/rmath.jl | 12 +++++++++++- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 8a93cb8..175e92a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StatsFuns" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "0.9.14" +version = "0.9.15" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/distrs/norm.jl b/src/distrs/norm.jl index 7461dcf..9184e02 100644 --- a/src/distrs/norm.jl +++ b/src/distrs/norm.jl @@ -39,15 +39,15 @@ function normlogpdf(μ::Real, σ::Real, x::Number) z = zval(μ, σ, x) end normlogpdf(z) - log(σ) -end +end # cdf normcdf(z::Number) = erfc(-z * invsqrt2)/2 function normcdf(μ::Real, σ::Real, x::Number) if iszero(σ) && x == μ z = zval(zero(μ), σ, one(x)) - else - z = zval(μ, σ, x) + else + z = zval(μ, σ, x) end normcdf(z) end @@ -56,8 +56,8 @@ normccdf(z::Number) = erfc(z * invsqrt2)/2 function normccdf(μ::Real, σ::Real, x::Number) if iszero(σ) && x == μ z = zval(zero(μ), σ, one(x)) - else - z = zval(μ, σ, x) + else + z = zval(μ, σ, x) end normccdf(z) end @@ -69,8 +69,8 @@ normlogcdf(z::Number) = z < -1.0 ? function normlogcdf(μ::Real, σ::Real, x::Number) if iszero(σ) && x == μ z = zval(zero(μ), σ, one(x)) - else - z = zval(μ, σ, x) + else + z = zval(μ, σ, x) end normlogcdf(z) end @@ -82,17 +82,21 @@ normlogccdf(z::Number) = z > 1.0 ? function normlogccdf(μ::Real, σ::Real, x::Number) if iszero(σ) && x == μ z = zval(zero(μ), σ, one(x)) - else - z = zval(μ, σ, x) + else + z = zval(μ, σ, x) end normlogccdf(z) end norminvcdf(p::Real) = -erfcinv(2*p) * sqrt2 -norminvcdf(μ::Real, σ::Real, p::Real) = xval(μ, σ, norminvcdf(p)) +# Promote to ensure that we don't compute erfcinv in low precision and then promote +norminvcdf(μ::Real, σ::Real, p::Real) = norminvcdf(promote(μ, σ, p)...) +norminvcdf(μ::T, σ::T, p::T) where {T<:Real} = xval(μ, σ, norminvcdf(p)) norminvccdf(p::Real) = erfcinv(2*p) * sqrt2 -norminvccdf(μ::Real, σ::Real, p::Real) = xval(μ, σ, norminvccdf(p)) +# Promote to ensure that we don't compute erfcinv in low precision and then promote +norminvccdf(μ::Real, σ::Real, p::Real) = norminvccdf(promote(μ, σ, p)...) +norminvccdf(μ::T, σ::T, p::T) where {T<:Real} = xval(μ, σ, norminvccdf(p)) # invlogcdf. Fixme! Support more precisions than Float64 norminvlogcdf(lp::Union{Float16,Float32}) = convert(typeof(lp), _norminvlogcdf_impl(Float64(lp))) diff --git a/test/rmath.jl b/test/rmath.jl index cf22c47..2034b3a 100644 --- a/test/rmath.jl +++ b/test/rmath.jl @@ -17,7 +17,15 @@ end get_statsfun(fname) = eval(Symbol(fname)) get_rmathfun(fname) = eval(Meta.parse(string("RFunctions.", fname))) -function rmathcomp(basename, params, X::AbstractArray, rtol=100eps(float(one(eltype(X))))) +function rmathcomp(basename, params, X::AbstractArray) + # compute default tolerance: + # has to take into account `params` as well since otherwise e.g. `X::Array{<:Rational}` + # always uses a tolerance based on `eps(one(Float64))` even when parameters are of type + # Float32 + rtol = 100 * eps(float(one(promote_type(Base.promote_typeof(params...), eltype(X))))) + rmathcomp(basename, params, X, rtol) +end +function rmathcomp(basename, params, X::AbstractArray, rtol) # tackle pdf specially has_pdf = true if basename == "srdist" @@ -264,6 +272,8 @@ end ((0.0, 0.0), -6.0:0.1:6.0), ((0f0, 1f0), -6f0:0.01f0:6f0), ((0.0, 1.0), -6f0:0.01f0:6f0), + ((0, 2), -6//1:1//2:6//1), + ((0f0, 2f0), -6//1:1//2:6//1), # Fail since `SpecialFunctions.erfcx` is not implemented for `Float16` #((Float16(0), Float16(1)), -Float16(6):Float16(0.01):Float16(6)), #((0f0, 1f0), -Float16(6):Float16(0.01):Float16(6)),