Skip to content

Commit

Permalink
Change API regarding convergence information (JuliaMath#33)
Browse files Browse the repository at this point in the history
There is now a single function `lambertw` for computing the Lambert W
function. It takes a keyword argument, `info`. If `info` is false, the
default, then only the result of computation is returned. If it is `true`
then a triple giving the result and info on convergence is returned.

In neither case is a warning or error explicitly raised.
  • Loading branch information
jlapeyre authored Oct 2, 2024
1 parent 6617072 commit b6c186e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 38 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ also called the omega function or product logarithm.
```julia
lambertw(z,k) # Lambert W function for argument z and branch index k
lambertw(z) # the same as lambertw(z,0)
lambertw_check_convergence(z, k=0) # The same as above but throw an error if the computation failed to converge
lambertw(z; info=true) # Return a 3-tuple that includes convergence information.
```

`z` may be Complex or Real. `k` must be an integer. For Real
Expand All @@ -36,7 +36,7 @@ julia> lambertw(-pi/2 + 0im) / pi
4.6681174759251105e-18 + 0.5im
```

#### Note on `lambertw_check_convergence`
#### Note on `info=true`

You can use this for extra safety. But I have been unable to find any input for which the root finding fails to
converge quickly.
Expand Down
63 changes: 29 additions & 34 deletions src/LambertW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ module LambertW

import IrrationalConstants

export lambertw, lambertwbp, lambertw_check_convergence
export lambertw, lambertwbp

const omega_const_bf_ = Ref{BigFloat}()

Expand Down Expand Up @@ -44,20 +44,23 @@ julia> lambertw(LambertW.lambertwbranchpoint, -1)
### Lambert W function

"""
lambertw(z, k::Integer=0, maxits::Integer=1000)
lambertw(z, k::Integer=0, maxits::Integer=1000; info::Bool=false)
Compute the `k`th branch of the Lambert W function of `z`.
If `z` is real, `k` must be either `0` or `-1`. For `Real` `z`, the domain of the branch
`k = -1` is `[-1/e, 0]` and the domain of the branch `k = 0` is `[-1/e, Inf]`. For
`Complex` `z`, and all `k`, the domain is the complex plane.
The result is computed via a root-finding loop. If the number of iterations exceeds
`maxits`, then the loop exits early, returning a result without warning about the failure
to converge. This will probably never happen. However, if you want to be more careful,
call `lambertw_check_convergence` instead. The latter function returns the result if
`maxits` was not reached, and otherwise throws an error.
If `info` is `false` then `lambertw` returns just the result of the computation.
If `info` is `true`, then it returns a 3-tuple. The first item is the the result of the
computation. The second item is `true` if the root-finding compution converged in fewer
than `maxits` iterations, and otherwise is `false`. The third item is the number of
iterations performed.
I have been unable to find a value of `z` for which the root-finding fails to converge
within ten iterations.
```jldoctest
julia> lambertw(-1/MathConstants.e, -1)
-1.0
Expand All @@ -75,24 +78,11 @@ julia> lambertw(Complex(-10.0, 3.0), 4)
-0.9274337508660128 + 26.37693445371142im
```
"""
lambertw(z, k::Integer=0, maxits::Integer=1000) = _lambertw(float(z), k, maxits)[1]

"""
lambertw_check_convergence(z, k::Integer=0, maxits::Integer=1000)
This is the same as `lambertw` except that if the root finding fails to converge in `maxits` iterations,
an error is thrown.
"""
function lambertw_check_convergence(z, k::Integer=0, maxits::Integer=1000)
(w, converged) = _lambertw(float(z), k, maxits)
if ! converged
error("lambertw failed to converge in $maxits iterations")
end
w
function lambertw(z, k::Integer=0, maxits::Integer=1000; info::Bool=false)
result = _lambertw(float(z), k, maxits)
info ? result : result[1]
end

#lambertw(z, k::Integer=0, maxits::Integer=1000) = _lambertw(float(z), k, maxits)

# lambertw(e + 0im, k) is ok for all k

### Real z
Expand All @@ -103,16 +93,19 @@ function _lambertw(x::Real, k, maxits)
throw(DomainError(k, "lambertw: real x must have branch k == 0 or k == -1"))
end

# If we don't run root finding at all, return `true` for success with zero iterations.
_no_loop(w) = (w, true, 0)

# Real x, k = 0
# This appears to be inferrable with T=Float64 and T=BigFloat, including if x=Inf.
# There is a magic number here. It could be noted, or possibly removed.
# In particular, the fancy initial condition selection does not seem to help speed.
function lambertw_branch_zero(x::T, maxits) where T<:Real
isnan(x) && return(NaN)
x == Inf && return Inf # appears to return convert(BigFloat, Inf) for x == BigFloat(Inf)
isnan(x) && return _no_loop(NaN)
x == Inf && return _no_loop(Inf) # appears to return convert(BigFloat, Inf) for x == BigFloat(Inf)
one_t = one(T)
oneoe = -one_t / convert(T, MathConstants.e) # The branch point
x == oneoe && return -one_t
x == oneoe && return _no_loop(-one_t)
oneoe <= x || throw(DomainError(x))
itwo_t = 1 / convert(T, 2)
if x > one_t
Expand All @@ -128,9 +121,9 @@ end
# Real x, k = -1
function lambertw_branch_one(x::T, maxits) where T<:Real
oneoe = -one(T) / convert(T, MathConstants.e)
x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above
x == oneoe && return _no_loop(-one(T)) # W approaches -1 as x -> -1/e from above
oneoe <= x || throw(DomainError(x)) # branch domain exludes x < -1/e
x == zero(T) && return -convert(T, Inf) # W decreases w/o bound as x -> 0 from below
x == zero(T) && return _no_loop(-convert(T, Inf)) # W decreases w/o bound as x -> 0 from below
x < zero(T) || throw(DomainError(x))
return lambertw_root_finding(x, log(-x), maxits)
end
Expand All @@ -143,8 +136,8 @@ function _lambertw(z::Complex{T}, k::Integer, maxits::Integer) where T<:Real
pointseven = 7//10
if abs(z) <= one_t/convert(T, MathConstants.e)
if z == 0
k == 0 && return z
return complex(-convert(T, Inf), zero(T))
k == 0 && return _no_loop(z)
return _no_loop(complex(-convert(T, Inf), zero(T)))
end
if k == 0
w = z
Expand All @@ -158,10 +151,10 @@ function _lambertw(z::Complex{T}, k::Integer, maxits::Integer) where T<:Real
w = abs(z+ 1//2) < 1//10 ? imag(z) > 0 ? complex(pointseven, pointseven) : complex(pointseven, -pointseven) : z
else
if real(z) == convert(T, Inf)
k == 0 && return z
k == 0 && return _no_loop(z)
return z + complex(0, 2*k*pi)
end
real(z) == -convert(T, Inf) && return -z + complex(0, (2*k+1)*pi)
real(z) == -convert(T, Inf) && return _no_loop(-z + complex(0, (2*k+1)*pi))
w = log(z)
k != 0 ? w += complex(0, 2*k*pi) : nothing
end
Expand All @@ -178,20 +171,22 @@ function lambertw_root_finding(z::T, x0::T, maxits) where T <: Number
lastx = x
lastdiff = zero(T)
converged::Bool = false
for _ in 1:maxits
num_iters = 0
for iter_count in 1:maxits
ex = exp(x)
xexz = x * ex - z
x1 = x + 1
x -= xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ))
xdiff = abs(lastx - x)
if xdiff <= 3 * eps(abs(lastx)) || lastdiff == xdiff # second condition catches two-value cycle
converged = true
num_iters = iter_count
break
end
lastx = x
lastdiff = xdiff
end
return (x, converged)
return (x, converged, num_iters)
end

### Inverse of Lambert W function
Expand Down
13 changes: 11 additions & 2 deletions test/lambertw_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,15 @@ end
@test string(LambertW.Omega()) == "ω"
end

@testset "lambertw_check_convergence" begin
@test lambertw_check_convergence(1.0) == lambertw(1.0)
@testset "lambertw info" begin
result = lambertw(1.0; info=true)
@test result[1] == lambertw(1.0)
@test result[2]
@test result[3] > 1 && result[3] < 10

for z in (10., complex(10), lambertwbranchpoint)
res = lambertw(1.0; info=true)
@test res[2]
@test length(res) == 3
end
end

0 comments on commit b6c186e

Please sign in to comment.