Skip to content

Commit

Permalink
Merge pull request #85 from DominiqueMakowski/master
Browse files Browse the repository at this point in the history
Add "a" parameter to softplus()  #83
  • Loading branch information
tpapp authored Dec 11, 2024
2 parents 289114f + 5f1d99d commit 76a23a7
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ logcosh
logabssinh
log1psq
log1pexp
softplus
invsoftplus
log1mexp
log2mexp
logexpm1
Expand Down
14 changes: 14 additions & 0 deletions ext/LogExpFunctionsChangesOfVariablesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,25 @@ function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1pexp), x::Real)
y = log1pexp(x)
return y, x - y
end
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(softplus), x::Real)
return ChangesOfVariables.with_logabsdet_jacobian(log1pexp, x)
end
function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(softplus),<:Real}, x::Real)
y = f(x)
return y, f.x * (x - y)
end

function ChangesOfVariables.with_logabsdet_jacobian(::typeof(logexpm1), x::Real)
y = logexpm1(x)
return y, x - y
end
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(invsoftplus), x::Real)
return ChangesOfVariables.with_logabsdet_jacobian(logexpm1, x)
end
function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(invsoftplus),<:Real}, x::Real)
y = f(x)
return y, f.x * (x - y)
end

function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1mexp), x::Real)
y = log1mexp(x)
Expand Down
10 changes: 10 additions & 0 deletions ext/LogExpFunctionsInverseFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,14 @@ InverseFunctions.inverse(::typeof(logitexp)) = loglogistic
InverseFunctions.inverse(::typeof(log1mlogistic)) = logit1mexp
InverseFunctions.inverse(::typeof(logit1mexp)) = log1mlogistic

InverseFunctions.inverse(::typeof(softplus)) = invsoftplus
function InverseFunctions.inverse(f::Base.Fix2{typeof(softplus),<:Real})
Base.Fix2(invsoftplus, f.x)
end

InverseFunctions.inverse(::typeof(invsoftplus)) = softplus
function InverseFunctions.inverse(f::Base.Fix2{typeof(invsoftplus),<:Real})
Base.Fix2(softplus, f.x)
end

end # module
26 changes: 24 additions & 2 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ Return `log(1+exp(x))` evaluated carefully for largish `x`.
This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref).
This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
transformation (in its default parametrization, see [`softplus`](@ref)), being a smooth approximation to `max(0,x)`.
See:
* Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf)
"""
Expand Down Expand Up @@ -257,8 +260,27 @@ Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse o
logexpm1(x::Real) = x <= 18.0 ? log(_expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x)
logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x)

const softplus = log1pexp
const invsoftplus = logexpm1
"""
$(SIGNATURES)
The generalized `softplus` function (Wiemann et al., 2024) takes an additional optional parameter `a` that control
the approximation error with respect to the linear spline. It defaults to `a=1.0`, in which case the softplus is
equivalent to [`log1pexp`](@ref).
See:
* Wiemann, P. F., Kneib, T., & Hambuckers, J. (2024). Using the softplus function to construct alternative link functions in generalized linear models and beyond. Statistical Papers, 65(5), 3155-3180.
"""
softplus(x::Real) = log1pexp(x)
softplus(x::Real, a::Real) = log1pexp(a * x) / a

"""
$(SIGNATURES)
The inverse generalized `softplus` function (Wiemann et al., 2024). See [`softplus`](@ref).
"""
invsoftplus(y::Real) = logexpm1(y)
invsoftplus(y::Real, a::Real) = logexpm1(a * y) / a


"""
$(SIGNATURES)
Expand Down
20 changes: 20 additions & 0 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ end
end
end

@testset "softplus" begin
for T in (Int, Float64, Float32, Float16)
@test @inferred(softplus(T(2))) === log1pexp(T(2))
@test @inferred(softplus(T(2), 1)) isa float(T)
@test @inferred(softplus(T(2), 1)) softplus(T(2))
@test @inferred(softplus(T(2), 5)) softplus(5 * T(2)) / 5
@test @inferred(softplus(T(2), 10)) softplus(10 * T(2)) / 10
end
end

@testset "log1mexp" begin
for T in (Float64, Float32, Float16)
@test @inferred(log1mexp(-T(1))) isa T
Expand All @@ -186,6 +196,16 @@ end
end
end

@testset "invsoftplus" begin
for T in (Int, Float64, Float32, Float16)
@test @inferred(invsoftplus(T(2))) === logexpm1(T(2))
@test @inferred(invsoftplus(T(2), 1)) isa float(T)
@test @inferred(invsoftplus(T(2), 1)) invsoftplus(T(2))
@test @inferred(invsoftplus(T(2), 5)) invsoftplus(5 * T(2)) / 5
@test @inferred(invsoftplus(T(2), 10)) invsoftplus(10 * T(2)) / 10
end
end

@testset "log1pmx" begin
@test iszero(log1pmx(0.0))
@test log1pmx(1.0) log(2.0) - 1.0
Expand Down
5 changes: 5 additions & 0 deletions test/inverse.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
@testset "inverse.jl" begin
InverseFunctions.test_inverse(log1pexp, randn())
InverseFunctions.test_inverse(softplus, randn())
InverseFunctions.test_inverse(Base.Fix2(softplus, randexp()), randn())

InverseFunctions.test_inverse(logexpm1, randexp())
InverseFunctions.test_inverse(invsoftplus, randexp())
InverseFunctions.test_inverse(Base.Fix2(invsoftplus, randexp()), randexp())

InverseFunctions.test_inverse(log1mexp, -randexp())

Expand Down
11 changes: 11 additions & 0 deletions test/with_logabsdet_jacobian.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
@testset "with_logabsdet_jacobian" begin
derivative(f, x) = ChainRulesTestUtils.frule((ChainRulesTestUtils.NoTangent(), 1), f, x)[2]
derivative(::typeof(softplus), x) = derivative(log1pexp, x)
derivative(f::Base.Fix2{typeof(softplus),<:Real}, x) = derivative(log1pexp, f.x * x)
derivative(::typeof(invsoftplus), x) = derivative(logexpm1, x)
derivative(f::Base.Fix2{typeof(invsoftplus),<:Real}, x) = derivative(logexpm1, f.x * x)

x = randexp()
y = randexp()

ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, -x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(softplus, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(softplus, -x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), -x, derivative)

ChangesOfVariables.test_with_logabsdet_jacobian(logexpm1, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(invsoftplus, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(invsoftplus, y), x, derivative)

ChangesOfVariables.test_with_logabsdet_jacobian(log1mexp, -x, derivative)

Expand Down

0 comments on commit 76a23a7

Please sign in to comment.