Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sumlog #48

Open
wants to merge 40 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
581b9df
sumlog
cscherrer May 2, 2022
5725aa9
Update src/sumlog.jl
cscherrer May 2, 2022
77aa3d9
Update src/sumlog.jl
cscherrer May 2, 2022
88d6fb1
Update src/sumlog.jl
cscherrer May 2, 2022
9ecf589
Update src/sumlog.jl
cscherrer May 2, 2022
9db732f
Update src/sumlog.jl
cscherrer May 2, 2022
76533e1
fall-back method
cscherrer May 2, 2022
5747205
more tests
cscherrer May 2, 2022
0f5a927
bump version
cscherrer May 2, 2022
977723d
cast to floating point when possible
cscherrer May 2, 2022
4d488cd
docstring fixes
cscherrer May 2, 2022
afa5d94
performance fix
cscherrer May 2, 2022
e400483
inline _sumlog
cscherrer May 2, 2022
cc1aaac
qualify IrrationalConstants.logtwo
cscherrer May 3, 2022
07809b7
update comment
cscherrer May 3, 2022
16ee153
Update src/sumlog.jl
cscherrer May 3, 2022
1af518b
bugfix
cscherrer May 3, 2022
0eaf8d2
Make it work (and be fast) for Tuples and NamedTuples
cscherrer May 3, 2022
1f478d0
add sumlog to docs
cscherrer May 3, 2022
0807f7a
tests
cscherrer May 3, 2022
2a0004d
comment that `eltype` of a `Base.Generator` returns `Any`
cscherrer May 3, 2022
e5809d1
saturday
mcabbott May 7, 2022
eb1b524
Merge pull request #1 from mcabbott/iterate
cscherrer May 7, 2022
6fe8bb1
Update src/sumlog.jl
cscherrer May 7, 2022
0def97d
change to `logprod`
cscherrer May 10, 2022
3ad95b2
fix sign bit
cscherrer May 10, 2022
a0a9348
Update docs/src/index.md
cscherrer May 10, 2022
fa667ec
Update src/LogExpFunctions.jl
cscherrer May 10, 2022
989a111
Update src/logprod.jl
cscherrer May 10, 2022
a54a024
Update src/LogExpFunctions.jl
cscherrer May 10, 2022
dc48433
Update src/logprod.jl
cscherrer May 10, 2022
39ca989
cleaning up
cscherrer May 10, 2022
207fce2
Merge branch 'master' of https://github.com/cscherrer/LogExpFunctions.jl
cscherrer May 10, 2022
55d125e
Update src/logprod.jl
cscherrer May 10, 2022
bef4728
Update src/logprod.jl
cscherrer May 10, 2022
9572e48
Update src/logprod.jl
cscherrer May 10, 2022
3848848
Update src/logprod.jl
cscherrer May 10, 2022
c4c3e89
Update src/logprod.jl
cscherrer May 10, 2022
e0f410e
Update src/logprod.jl
cscherrer May 10, 2022
23b5bf1
Update src/logprod.jl
cscherrer May 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LogExpFunctions"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
version = "0.3.14"
version = "0.3.15"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
3 changes: 2 additions & 1 deletion src/LogExpFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ import LinearAlgebra

export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,
softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax,
softmax!, logcosh
softmax!, logcosh, sumlog
cscherrer marked this conversation as resolved.
Show resolved Hide resolved

include("basicfuns.jl")
include("logsumexp.jl")
include("chainrules.jl")
include("inverse.jl")
include("with_logabsdet_jacobian.jl")
include("sumlog.jl")

end # module
42 changes: 42 additions & 0 deletions src/sumlog.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
$(SIGNATURES)

Compute `sum(log.(X))` with a single `log` evaluation.

This is faster than computing `sum(log.(X))` or even `sum(log, X)`, in
particular as the size of `X` increases.

This works by representing the `j`th element of `X` as ``x_j = a_j 2^{b_j}``,
allowing us to write
```math
\\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j
```
Since ``\\log{2}`` is constant, `sumlog` only requires a single `log`
evaluation.
"""
function sumlog(x)
T = float(eltype(x))
_sumlog(T, values(x))
end

@inline function _sumlog(::Type{T}, x) where {T<:AbstractFloat}
sig, ex = mapreduce(_sumlog_op, x; init=(one(T), zero(exponent(one(T))))) do xj
float_xj = float(xj)
significand(float_xj), exponent(float_xj)
end
return log(sig) + IrrationalConstants.logtwo * ex
end

@inline function _sumlog_op((sig1, ex1), (sig2, ex2))
sig = sig1 * sig2
ex = ex1 + ex2
# Significands are in the range [1,2), so multiplication will eventually overflow
if sig > floatmax(typeof(sig)) / 2
ex += exponent(sig)
sig = significand(sig)
end
return sig, ex
end

# `float(T)` is not always `isa AbstractFloat`, e.g. dual numbers or symbolics
@inline _sumlog(::Type{T}, x) where {T} = sum(log, x)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ include("basicfuns.jl")
include("chainrules.jl")
include("inverse.jl")
include("with_logabsdet_jacobian.jl")
include("sumlog.jl")
7 changes: 7 additions & 0 deletions test/sumlog.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
@testset "sumlog" begin
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some surprises:

julia> sumlog([1,2,-0.1,-0.2])
-3.2188758248682

julia> sumlog([1,2,NaN])
ERROR: DomainError with NaN:
Cannot be NaN or Inf.
Stacktrace:
  [1] (::Base.Math.var"#throw1#5")(x::Float64)
    @ Base.Math ./math.jl:845
  [2] exponent
    @ ./math.jl:848 [inlined]

julia> sumlog([-0.0])
ERROR: DomainError with -0.0:
Cannot be ±0.0.

julia> sum(log, -0.0)
-Inf

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling Base.Math._exponent_finite_nonzero works around the NaN problem (although not for BigFloats).

Adding xj < 0 && Base.Math.throw_complex_domainerror(:log, xj) doesn't seem to cost much speed.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, note that tests right now only test Float64

for T in [Int, Float16, Float32, Float64, BigFloat]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that you removed the type restriction. Thus we should extend the tests and eg. check more general iterables (also with different types, abstract eltype etc since sum(log, x) would work for them) and also complex numbers.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eltype doesn't work well for Base.Generators. Usually this is when I'd turn to something like

julia> Core.Compiler.return_type(gen.f, Tuple{eltype(gen.iter)})
Float64

We could instead have it fall back on the default, but I'd guess that will sacrifice performance.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I bet you could write an equally fast version which explicitly calls iterate, and widens if the type changes. (But usually the compiler will prove that it won't.)

One reason to keep mapreduce for arrays is that you can give it dims.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should check more carefully, but this appears to work & is as fast as current version:

function sumlog(x)
    iter = iterate(x)
    if isnothing(iter)
        return eltype(x) <: Number ? zero(float(eltype(x))) : 0.0
    end
    x1 = float(iter[1])
    x1 isa AbstractFloat || return sum(log, x)
    sig, ex = significand(x1), exponent(x1)
    iter = iterate(x, iter[2])
    while iter !== nothing
        xj = float(iter[1])
        x1 isa AbstractFloat || return sum(log, x)  # maybe not ideal, re-starts iterator
        sig, ex = _sumlog_op((sig, ex), (significand(xj), exponent(xj)))
        iter = iterate(x, iter[2])
    end
    return log(sig) + IrrationalConstants.logtwo * ex
end

sumlog(f, x) = sumlog(Iterators.map(f, x))
sumlog(f, x, ys...) = sumlog(f(xy...) for xy in zip(x, ys...))

And for dims:

sumlog(x::AbstractArray{T}; dims=:) where T = _sumlog(float(T), dims, x)

function _sumlog(::Type{T}, ::Colon, x) where {T<:AbstractFloat}
    sig, ex = mapreduce(_sumlog_op, x; init=(one(T), zero(exponent(one(T))))) do xj
        float_xj = float(xj)
        significand(float_xj), exponent(float_xj) 
    end
    return log(sig) + IrrationalConstants.logtwo * ex
end

function _sumlog(::Type{T}, dims, x) where {T<:AbstractFloat}
    sig_ex = mapreduce(_sumlog_op, x; dims=dims, init=(one(T), zero(exponent(one(T))))) do xj
        float_xj = float(xj)
        significand(float_xj), exponent(float_xj) 
    end
    map(sig_ex) do (sig, ex)
        log(sig) + IrrationalConstants.logtwo * ex
    end
end

Should I make a PR to the PR?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cscherrer#1 is a tidier version of the above.

for x in [10 .* rand(1000), repeat([nextfloat(1.0)], 1000), repeat([prevfloat(2.0)], 1000)]
@test (@inferred sumlog(x)) ≈ sum(log, x)
end
end
end