Skip to content

Commit

Permalink
Add ChangesOfVariables definitions and extend tests
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Dec 10, 2024
1 parent da5130f commit 5f1d99d
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 8 deletions.
14 changes: 14 additions & 0 deletions ext/LogExpFunctionsChangesOfVariablesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,25 @@ function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1pexp), x::Real)
y = log1pexp(x)
return y, x - y
end
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(softplus), x::Real)
return ChangesOfVariables.with_logabsdet_jacobian(log1pexp, x)
end
function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(softplus),<:Real}, x::Real)
y = f(x)
return y, f.x * (x - y)
end

function ChangesOfVariables.with_logabsdet_jacobian(::typeof(logexpm1), x::Real)
y = logexpm1(x)
return y, x - y
end
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(invsoftplus), x::Real)
return ChangesOfVariables.with_logabsdet_jacobian(logexpm1, x)
end
function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(invsoftplus),<:Real}, x::Real)
y = f(x)
return y, f.x * (x - y)
end

function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1mexp), x::Real)
y = log1mexp(x)
Expand Down
7 changes: 7 additions & 0 deletions ext/LogExpFunctionsInverseFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ InverseFunctions.inverse(::typeof(log1mlogistic)) = logit1mexp
InverseFunctions.inverse(::typeof(logit1mexp)) = log1mlogistic

InverseFunctions.inverse(::typeof(softplus)) = invsoftplus
function InverseFunctions.inverse(f::Base.Fix2{typeof(softplus),<:Real})
Base.Fix2(invsoftplus, f.x)
end

InverseFunctions.inverse(::typeof(invsoftplus)) = softplus
function InverseFunctions.inverse(f::Base.Fix2{typeof(invsoftplus),<:Real})
Base.Fix2(softplus, f.x)
end

end # module
22 changes: 17 additions & 5 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,15 @@ end
end

@testset "softplus" begin
@test softplus(2) log1pexp(2)
@test softplus(2, 1) log1pexp(2)
@test softplus(2, 10) < log1pexp(2)
@test invsoftplus(softplus(2), 1) 2
for T in (Int, Float64, Float32, Float16)
@test @inferred(softplus(T(2))) === log1pexp(T(2))
@test @inferred(softplus(T(2), 1)) isa float(T)
@test @inferred(softplus(T(2), 1)) softplus(T(2))
@test @inferred(softplus(T(2), 5)) softplus(5 * T(2)) / 5
@test @inferred(softplus(T(2), 10)) softplus(10 * T(2)) / 10
end
end


@testset "log1mexp" begin
for T in (Float64, Float32, Float16)
@test @inferred(log1mexp(-T(1))) isa T
Expand All @@ -194,6 +196,16 @@ end
end
end

@testset "invsoftplus" begin
for T in (Int, Float64, Float32, Float16)
@test @inferred(invsoftplus(T(2))) === logexpm1(T(2))
@test @inferred(invsoftplus(T(2), 1)) isa float(T)
@test @inferred(invsoftplus(T(2), 1)) invsoftplus(T(2))
@test @inferred(invsoftplus(T(2), 5)) invsoftplus(5 * T(2)) / 5
@test @inferred(invsoftplus(T(2), 10)) invsoftplus(10 * T(2)) / 10
end
end

@testset "log1pmx" begin
@test iszero(log1pmx(0.0))
@test log1pmx(1.0) log(2.0) - 1.0
Expand Down
8 changes: 5 additions & 3 deletions test/inverse.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
@testset "inverse.jl" begin
InverseFunctions.test_inverse(log1pexp, randn())
InverseFunctions.test_inverse(softplus, randn())
InverseFunctions.test_inverse(Base.Fix2(softplus, randexp()), randn())

InverseFunctions.test_inverse(logexpm1, randexp())
InverseFunctions.test_inverse(invsoftplus, randexp())
InverseFunctions.test_inverse(Base.Fix2(invsoftplus, randexp()), randexp())

InverseFunctions.test_inverse(log1mexp, -randexp())

Expand All @@ -17,7 +22,4 @@

InverseFunctions.test_inverse(log1mlogistic, randexp())
InverseFunctions.test_inverse(logit1mexp, -randexp())

InverseFunctions.test_inverse(softplus, randn())
InverseFunctions.test_inverse(invsoftplus, randexp())
end
11 changes: 11 additions & 0 deletions test/with_logabsdet_jacobian.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
@testset "with_logabsdet_jacobian" begin
derivative(f, x) = ChainRulesTestUtils.frule((ChainRulesTestUtils.NoTangent(), 1), f, x)[2]
derivative(::typeof(softplus), x) = derivative(log1pexp, x)
derivative(f::Base.Fix2{typeof(softplus),<:Real}, x) = derivative(log1pexp, f.x * x)
derivative(::typeof(invsoftplus), x) = derivative(logexpm1, x)
derivative(f::Base.Fix2{typeof(invsoftplus),<:Real}, x) = derivative(logexpm1, f.x * x)

x = randexp()
y = randexp()

ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, -x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(softplus, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(softplus, -x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), -x, derivative)

ChangesOfVariables.test_with_logabsdet_jacobian(logexpm1, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(invsoftplus, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(invsoftplus, y), x, derivative)

ChangesOfVariables.test_with_logabsdet_jacobian(log1mexp, -x, derivative)

Expand Down

0 comments on commit 5f1d99d

Please sign in to comment.