Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sumlog #48

Open
wants to merge 40 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
581b9df
sumlog
cscherrer May 2, 2022
5725aa9
Update src/sumlog.jl
cscherrer May 2, 2022
77aa3d9
Update src/sumlog.jl
cscherrer May 2, 2022
88d6fb1
Update src/sumlog.jl
cscherrer May 2, 2022
9ecf589
Update src/sumlog.jl
cscherrer May 2, 2022
9db732f
Update src/sumlog.jl
cscherrer May 2, 2022
76533e1
fall-back method
cscherrer May 2, 2022
5747205
more tests
cscherrer May 2, 2022
0f5a927
bump version
cscherrer May 2, 2022
977723d
cast to floating point when possible
cscherrer May 2, 2022
4d488cd
docstring fixes
cscherrer May 2, 2022
afa5d94
performance fix
cscherrer May 2, 2022
e400483
inline _sumlog
cscherrer May 2, 2022
cc1aaac
qualify IrrationalConstants.logtwo
cscherrer May 3, 2022
07809b7
update comment
cscherrer May 3, 2022
16ee153
Update src/sumlog.jl
cscherrer May 3, 2022
1af518b
bugfix
cscherrer May 3, 2022
0eaf8d2
Make it work (and be fast) for Tuples and NamedTuples
cscherrer May 3, 2022
1f478d0
add sumlog to docs
cscherrer May 3, 2022
0807f7a
tests
cscherrer May 3, 2022
2a0004d
comment that `eltype` of a `Base.Generator` returns `Any`
cscherrer May 3, 2022
e5809d1
saturday
mcabbott May 7, 2022
eb1b524
Merge pull request #1 from mcabbott/iterate
cscherrer May 7, 2022
6fe8bb1
Update src/sumlog.jl
cscherrer May 7, 2022
0def97d
change to `logprod`
cscherrer May 10, 2022
3ad95b2
fix sign bit
cscherrer May 10, 2022
a0a9348
Update docs/src/index.md
cscherrer May 10, 2022
fa667ec
Update src/LogExpFunctions.jl
cscherrer May 10, 2022
989a111
Update src/logprod.jl
cscherrer May 10, 2022
a54a024
Update src/LogExpFunctions.jl
cscherrer May 10, 2022
dc48433
Update src/logprod.jl
cscherrer May 10, 2022
39ca989
cleaning up
cscherrer May 10, 2022
207fce2
Merge branch 'master' of https://github.com/cscherrer/LogExpFunctions.jl
cscherrer May 10, 2022
55d125e
Update src/logprod.jl
cscherrer May 10, 2022
bef4728
Update src/logprod.jl
cscherrer May 10, 2022
9572e48
Update src/logprod.jl
cscherrer May 10, 2022
3848848
Update src/logprod.jl
cscherrer May 10, 2022
c4c3e89
Update src/logprod.jl
cscherrer May 10, 2022
e0f410e
Update src/logprod.jl
cscherrer May 10, 2022
23b5bf1
Update src/logprod.jl
cscherrer May 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LogExpFunctions"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
version = "0.3.14"
version = "0.3.15"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ logaddexp
logsubexp
logsumexp
logsumexp!
sumlog
cscherrer marked this conversation as resolved.
Show resolved Hide resolved
softmax!
softmax
```
4 changes: 3 additions & 1 deletion src/LogExpFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ import LinearAlgebra

export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,
softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax,
softmax!, logcosh
softmax!, logcosh, sumlog
cscherrer marked this conversation as resolved.
Show resolved Hide resolved

include("basicfuns.jl")
include("logsumexp.jl")
include("chainrules.jl")
include("inverse.jl")
include("with_logabsdet_jacobian.jl")
# include("sumlog.jl")
cscherrer marked this conversation as resolved.
Show resolved Hide resolved
include("logprod.jl")

end # module
78 changes: 78 additions & 0 deletions src/logprod.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
logprod(X::AbstractArray{T}; dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logprod(X::AbstractArray{T}; dims)
logprod(x)


Compute `sum(log.(X))` with a single `log` evaluation,
provided `float(T) <: AbstractFloat`.
cscherrer marked this conversation as resolved.
Show resolved Hide resolved

This is faster than computing `sum(log, X)`, especially for large `X`.
cscherrer marked this conversation as resolved.
Show resolved Hide resolved
It works by representing the `j`th element of `x` as ``x_j = a_j 2^{b_j}`,
allowing us to write
```math
\\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j
cscherrer marked this conversation as resolved.
Show resolved Hide resolved
```
"""
logprod(x) = first(logabsprod(x))
cscherrer marked this conversation as resolved.
Show resolved Hide resolved

export logabsprod

Comment on lines +19 to +20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
export logabsprod

function logabsprod(x::AbstractArray{T}) where {T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not work in general, e.g., if T === Int. We need something like

Suggested change
function logabsprod(x::AbstractArray{T}) where {T}
function logabsprod(x::AbstractArray{<:Number})
T = float(eltype(x))
if !(T <: AbstractFloat)
y = prod(x)
return log(abs(y)), sign(y)
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a docstring?

sig, ex = mapreduce(_logabsprod_op, x; init=frexp(one(T))) do xj
float_xj = float(xj)
frexp(float_xj)
end
sgn = signbit(sig) ? -one(T) : one(T)
return (log(abs(sig)) + IrrationalConstants.logtwo * T(ex), sgn)
cscherrer marked this conversation as resolved.
Show resolved Hide resolved
end

@inline function _logabsprod_op((sig1, ex1), (sig2, ex2))
sig = sig1 * sig2
# sig = ifelse(sig2<0, sig2, sig1 * sig2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# sig = ifelse(sig2<0, sig2, sig1 * sig2)

ex = ex1 + ex2
# Significands are in the range [1,2), so multiplication will eventually overflow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No? According to the docstring of frexp:

  frexp(val)

  Return (x,exp) such that x has a magnitude in the interval [1/2, 1) or 0, and val is equal to x \times 2^{exp}.

if sig > floatmax(typeof(sig)) / 2
(new_sig, Δex) = frexp(sig)
ex += Δex
sig = new_sig
cscherrer marked this conversation as resolved.
Show resolved Hide resolved
end
return sig, ex
end

"""
logprod(x)
logprod(f, x, ys...)

For any iterator which produces `AbstractFloat` elements,
this can use `logprod`'s fast reduction strategy.

Signature with `f` is equivalent to `sum(log, map(f, x, ys...))`
or `mapreduce(log∘f, +, x, ys...)`, without intermediate allocations.

Does not accept a `dims` keyword.
"""
logprod(f, x) = logprod(Iterators.map(f, x))
logprod(f, x, ys...) = logprod(f(xy...) for xy in zip(x, ys...))
Comment on lines +43 to +56
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
logprod(x)
logprod(f, x, ys...)
For any iterator which produces `AbstractFloat` elements,
this can use `logprod`'s fast reduction strategy.
Signature with `f` is equivalent to `sum(log, map(f, x, ys...))`
or `mapreduce(log∘f, +, x, ys...)`, without intermediate allocations.
Does not accept a `dims` keyword.
"""
logprod(f, x) = logprod(Iterators.map(f, x))
logprod(f, x, ys...) = logprod(f(xy...) for xy in zip(x, ys...))


# Iterator version, uses the same `_logprod_op`, should be the same speed.
function logprod(x)
cscherrer marked this conversation as resolved.
Show resolved Hide resolved
iter = iterate(x)
if isnothing(iter)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isnothing(iter)
if iter === nothing

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion can you help me understand why you prefer this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it's more efficient in older Julia versions (IIRC it doesn't matter in more recent versions, probably >= 1.6?).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T = Base._return_type(first, Tuple{typeof(x)})
return T <: Number ? zero(float(T)) : 0.0
cscherrer marked this conversation as resolved.
Show resolved Hide resolved
end
x1 = float(iter[1])
x1 isa AbstractFloat || return sum(log, x)
cscherrer marked this conversation as resolved.
Show resolved Hide resolved
x1 < 0 && Base.Math.throw_complex_domainerror(:log, x1)
cscherrer marked this conversation as resolved.
Show resolved Hide resolved
sig, ex = significand(x1), _exponent(x1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sig, ex = significand(x1), _exponent(x1)
sig, ex = frexp(x1)

nonfloat = zero(x1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nonfloat = zero(x1)

iter = iterate(x, iter[2])
while iter !== nothing
xj = float(iter[1])
if xj isa AbstractFloat
sig, ex = _logprod_op((sig, ex), (significand(xj), _exponent(xj)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sig, ex = _logprod_op((sig, ex), (significand(xj), _exponent(xj)))
sig, ex = _logabsprod_op((sig, ex), frexp(xj))

else
nonfloat += log(xj)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nonfloat += log(xj)
y = prod(x)
return log(abs(y)), sign(y)

end
iter = iterate(x, iter[2])
end
return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex) + nonfloat
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex) + nonfloat
return (log(abs(sig)) + IrrationalConstants.logtwo * oftype(sig, ex), sign(sig))

end
94 changes: 94 additions & 0 deletions src/sumlog.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
sumlog(X::AbstractArray{T}; dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sumlog(X::AbstractArray{T}; dims)
sumlog(x::AbstractArray{<:Real}; dims=:)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with x instead of X. I tried to match the original docs, but now I see they use lowercase. I may have been looking at an old version.

I think in general it makes sense to match the Julia docs for sum. For arrays, that looks like sum(f, A::AbstractArray; dims).

We've already discussed that we don't need <:Real for dispatch, and it disallows some types that would be useful to include.


Compute `sum(log.(X))` with a single `log` evaluation,
provided `float(T) <: AbstractFloat`.
Comment on lines +4 to +5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds as if the function requires float(T) <: AbstractFloat which is not the case.

Suggested change
Compute `sum(log.(X))` with a single `log` evaluation,
provided `float(T) <: AbstractFloat`.
Compute `sum(log, x; dims=dims)`.
If `float(eltype(x)) <: AbstractFloat` the computation is performed with a single `log` evaluation.


This is faster than computing `sum(log, X)`, especially for large `X`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This is faster than computing `sum(log, X)`, especially for large `X`.
This is faster than computing `sum(log, x; dims=dims)`, especially for large `x`.

It works by representing the `j`th element of `x` as ``x_j = a_j 2^{b_j}`,
allowing us to write
```math
\\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j
```
"""
sumlog(x::AbstractArray{T}; dims=:) where T = _sumlog(float(T), dims, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to keep this restricted to Real as other methods in LogExpFunctions?

In any case, the type parameter is not necessary it seems and the dispatch on dims is not needed (see below).

Suggested change
sumlog(x::AbstractArray{T}; dims=:) where T = _sumlog(float(T), dims, x)
sumlog(x::AbstractArray{<:Real}; dims=:) = _sumlog(float(eltype(x)), x; dims=dims)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Real would be fine. The tests use complex numbers only as examples of things for which float(x) isn't an AbstractFloat. (Which in reality could be weird numbers like Dual, etc.)


function _sumlog(::Type{T}, ::Colon, x) where {T<:AbstractFloat}
sig, ex = mapreduce(_sumlog_op, x; init=(one(T), 0)) do xj
xj < 0 && Base.Math.throw_complex_domainerror(:log, xj)
float_xj = float(xj)
significand(float_xj), _exponent(float_xj)
end
return log(sig) + IrrationalConstants.logtwo * T(ex)
end

function _sumlog(::Type{T}, dims, x) where {T<:AbstractFloat}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function _sumlog(::Type{T}, dims, x) where {T<:AbstractFloat}
function _sumlog(::Type{T}, x; dims) where {T<:AbstractFloat}

sig_ex = mapreduce(_sumlog_op, x; dims=dims, init=(one(T), 0)) do xj
xj < 0 && Base.Math.throw_complex_domainerror(:log, xj)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to not rely on some internal error function in base, it seems simple enough to throw a custom error message.

Suggested change
xj < 0 && Base.Math.throw_complex_domainerror(:log, xj)
xj < zero(xj) && throw(DomainError(xj, "log requires a non-negative argument"))

float_xj = float(xj)
significand(float_xj), _exponent(float_xj)
end
map(sig_ex) do (sig, ex)
log(sig) + IrrationalConstants.logtwo * T(ex)
end
end

# Fallback: `float(T)` is not always `<: AbstractFloat`, e.g. complex, dual numbers or symbolics
_sumlog(::Type, dims, x) = sum(log, x; dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_sumlog(::Type, dims, x) = sum(log, x; dims)
_sumlog(::Type, x; dims) = sum(log, x; dims=dims)


@inline function _sumlog_op((sig1, ex1), (sig2, ex2))
sig = sig1 * sig2
# sig = ifelse(sig2<0, sig2, sig1 * sig2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# sig = ifelse(sig2<0, sig2, sig1 * sig2)

Copy link

@mcabbott mcabbott May 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This, BTW, was my alternative attempt at error checking. Instead of checking on every iteration, if you ensure that sig ends up negative, then the final log will throw the right error automatically. I wasn't able to make this as fast, which I'm surprised by, maybe it can be done.

I think maybe focusing on the big questions first might be better than immediately nitpicking.

The other big one is how to handle Float16. At the moment it's super-inaccurate. Maybe accumulation should happen in higher precision. Maybe that should happen for Float32 too for accuracy. But I have not run any careful accuracy tests.

ex = ex1 + ex2
# Significands are in the range [1,2), so multiplication will eventually overflow
if sig > floatmax(typeof(sig)) / 2
ex += _exponent(sig)
sig = significand(sig)
end
return sig, ex
end

# The exported `exponent(x)` checks for `NaN` etc, this function doesn't, which is fine as `sig` keeps track.
_exponent(x::Base.IEEEFloat) = Base.Math._exponent_finite_nonzero(x)
Base.@assume_effects :nothrow _exponent(x::AbstractFloat) = Int(exponent(x)) # e.g. for BigFloat
Comment on lines +51 to +53
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm strongly against using any internal non-exported functions in Base in a package such as LogExpFunctions. I really think we should just use exponent.


"""
sumlog(x)
sumlog(f, x, ys...)

For any iterator which produces `AbstractFloat` elements,
this can use `sumlog`'s fast reduction strategy.

Signature with `f` is equivalent to `sum(log, map(f, x, ys...))`
or `mapreduce(log∘f, +, x, ys...)`, without intermediate allocations.

Does not accept a `dims` keyword.
"""
sumlog(f, x) = sumlog(Iterators.map(f, x))
sumlog(f, x, ys...) = sumlog(f(xy...) for xy in zip(x, ys...))
Comment on lines +55 to +68
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed? Users could just create these iterators themselves and call sumlog(x).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets you use a do block, which is tidier than writing a generator yourself.


# Iterator version, uses the same `_sumlog_op`, should be the same speed.
function sumlog(x)
iter = iterate(x)
if isnothing(iter)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In older Julia versions

Suggested change
if isnothing(iter)
if iter === nothing

should be better

T = Base._return_type(first, Tuple{typeof(x)})
return T <: Number ? zero(float(T)) : 0.0
Comment on lines +74 to +75
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I don't think we should use such internals.
We could just return

Suggested change
T = Base._return_type(first, Tuple{typeof(x)})
return T <: Number ? zero(float(T)) : 0.0
return zero(log(sum(x)))

or

Suggested change
T = Base._return_type(first, Tuple{typeof(x)})
return T <: Number ? zero(float(T)) : 0.0
return zero(float(sum(x)))

Unfortunately, sum(log, x) doesn't seem to work for empty arrays - but this would even allow us to throw an error for empty arrays to be consistent with sum(log, x).

end
x1 = float(iter[1])
x1 isa AbstractFloat || return sum(log, x)
x1 < 0 && Base.Math.throw_complex_domainerror(:log, x1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, no internal function:

Suggested change
x1 < 0 && Base.Math.throw_complex_domainerror(:log, x1)
x1 < zero(x1) && throw(DomainError(x1, "log requires a non-negative argument"))

sig, ex = significand(x1), _exponent(x1)
nonfloat = zero(x1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Suggested change
nonfloat = zero(x1)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there is no guarantee that half way through the iterator, you won't encounter one non-Float. There's a test for this exact case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure but in this case returning sum(log, x) seems sufficient?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's avoiding re-starting the iterator.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why is it not restarted in the case of the first element? In general, reiterating is not guaranteed to yield the same values.

Maybe we should just error if float(xi) is not an AbstractFloat.

iter = iterate(x, iter[2])
while iter !== nothing
xj = float(iter[1])
if xj isa AbstractFloat
xj < 0 && Base.Math.throw_complex_domainerror(:log, xj)
sig, ex = _sumlog_op((sig, ex), (significand(xj), _exponent(xj)))
else
nonfloat += log(xj)
end
Comment on lines +85 to +90
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if xj isa AbstractFloat
xj < 0 && Base.Math.throw_complex_domainerror(:log, xj)
sig, ex = _sumlog_op((sig, ex), (significand(xj), _exponent(xj)))
else
nonfloat += log(xj)
end
xj isa AbstractFloat || return sum(log, x)
xj < zero(xj) && throw(DomainError(xj, "log requires a non-negative argument"))
sig, ex = _sumlog_op((sig, ex), (significand(xj), _exponent(xj)))

iter = iterate(x, iter[2])
end
return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex) + nonfloat
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex) + nonfloat
log_sig = log(sig)
return log_sig + IrrationalConstants.logtwo * oftype(log_sig, ex)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You already know sig is an AbstractFloat.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if we assume that typeof(log(sig)) === typeof(sig) this could be simplified to

Suggested change
return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex) + nonfloat
return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex)

The assumption should hold in almost all cases, but it didn't seem harmful to not rely on this fact.

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ include("basicfuns.jl")
include("chainrules.jl")
include("inverse.jl")
include("with_logabsdet_jacobian.jl")
include("sumlog.jl")
63 changes: 63 additions & 0 deletions test/sumlog.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
@testset "sumlog" begin
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some surprises:

julia> sumlog([1,2,-0.1,-0.2])
-3.2188758248682

julia> sumlog([1,2,NaN])
ERROR: DomainError with NaN:
Cannot be NaN or Inf.
Stacktrace:
  [1] (::Base.Math.var"#throw1#5")(x::Float64)
    @ Base.Math ./math.jl:845
  [2] exponent
    @ ./math.jl:848 [inlined]

julia> sumlog([-0.0])
ERROR: DomainError with -0.0:
Cannot be ±0.0.

julia> sum(log, -0.0)
-Inf

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling Base.Math._exponent_finite_nonzero works around the NaN problem (although not for BigFloats).

Adding xj < 0 && Base.Math.throw_complex_domainerror(:log, xj) doesn't seem to cost much speed.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, note that tests right now only test Float64

@testset for T in [Float16, Float32, Float64, BigFloat]
for x in (
T[1,2,3],
10 .* rand(T, 1000),
fill(nextfloat(T(1.0)), 1000),
fill(prevfloat(T(2.0)), 1000),
)
@test sumlog(x) isa T

@test (@inferred sumlog(x)) ≈ sum(log, x)

y = @view x[1:min(end, 100)]
@test (@inferred sumlog(y')) ≈ sum(log, y)

tup = tuple(y...)
@test (@inferred sumlog(tup)) ≈ sum(log, tup)
#
# gen = (sqrt(a) for a in y)
# # `eltype` of a `Base.Generator` returns `Any`
# @test_broken (@inferred sumlog(gen)) ≈ sum(log, gen)

# nt = NamedTuple{tuple(Symbol.(1:100)...)}(tup)
# @test (@inferred sumlog(y)) ≈ sum(log, y)

z = x .+ im .* Random.shuffle(x)
@test (@inferred sumlog(z)) ≈ sum(log, z)
end

# With dims
m = 1 .+ rand(T, 10, 10)
sumlog(m; dims=1) ≈ sum(log, m; dims=1)
sumlog(m; dims=2) ≈ sum(log, m; dims=2)

# Iterator
@test sumlog(x^2 for x in m) ≈ sumlog(abs2, m) ≈ sumlog(*, m, m) ≈ sum(log.(m.^2))
@test sumlog(x for x in Any[1, 2, 3+im, 4]) ≈ sum(log, Any[1, 2, 3+im, 4])

# NaN, Inf
if T != BigFloat # exponent fails here
@test isnan(sumlog(T[1, 2, NaN]))
@test isinf(sumlog(T[1, 2, Inf]))
@test sumlog(T[1, 2, 0.0]) == -Inf
@test sumlog(T[1, 2, -0.0]) == -Inf
end

# Empty
@test sumlog(T[]) isa T
@test eltype(sumlog(T[]; dims=1)) == T
@test sumlog(x for x in T[]) isa T

# Negative
@test_throws DomainError sumlog(T[1, -2, 3]) # easy
@test_throws DomainError sumlog(T[1, -2, -3]) # harder

end
@testset "Int" begin
@test sumlog([1,2,3]) isa Float64
@test sumlog([1,2,3]) ≈ sum(log, [1,2,3])
@test sumlog([1 2; 3 4]; dims=1) ≈ sum(log, [1 2; 3 4]; dims=1)
@test sumlog(Int(x) for x in Float64[1,2,3]) ≈ sum(log, [1,2,3])
end
end