Skip to content

Commit

Permalink
make faster BigFloats
Browse files Browse the repository at this point in the history
We can coalesce the two required allocations for the MFPR BigFloat API
design into one allocation, hopefully giving a easy performance boost.
It would have been slightly easier and more efficient if MPFR BigFloat
was already a VLA instead of containing a pointer here, but that does
not prevent the optimization.
  • Loading branch information
vtjnash committed Sep 27, 2024
1 parent 6e33dfb commit 36ee135
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 90 deletions.
1 change: 0 additions & 1 deletion base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ end
include("hashing.jl")
include("rounding.jl")
include("div.jl")
include("rawbigints.jl")
include("float.jl")
include("twiceprecision.jl")
include("complex.jl")
Expand Down
156 changes: 104 additions & 52 deletions base/mpfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ import
setrounding, maxintfloat, widen, significand, frexp, tryparse, iszero,
isone, big, _string_n, decompose, minmax, _precision_with_base_2,
sinpi, cospi, sincospi, tanpi, sind, cosd, tand, asind, acosd, atand,
uinttype, exponent_max, exponent_min, ieee754_representation, significand_mask,
RawBigIntRoundingIncrementHelper, truncated, RawBigInt

uinttype, exponent_max, exponent_min, ieee754_representation, significand_mask

using .Base.Libc
import ..Rounding:
import ..Rounding: Rounding,
rounding_raw, setrounding_raw, rounds_to_nearest, rounds_away_from_zero,
tie_breaker_is_to_even, correct_rounding_requires_increment

Expand All @@ -39,7 +37,6 @@ else
const libmpfr = "libmpfr.so.6"
end


version() = VersionNumber(unsafe_string(ccall((:mpfr_get_version,libmpfr), Ptr{Cchar}, ())))
patches() = split(unsafe_string(ccall((:mpfr_get_patches,libmpfr), Ptr{Cchar}, ())),' ')

Expand Down Expand Up @@ -120,69 +117,124 @@ const mpfr_special_exponent_zero = typemin(Clong) + true
const mpfr_special_exponent_nan = mpfr_special_exponent_zero + true
const mpfr_special_exponent_inf = mpfr_special_exponent_nan + true

struct BigFloatLayout
prec::Clong
sign::Cint
exp::Clong
d::Ptr{Limb}
p::Limb # Tuple{Vararg{Limb}}
end
const offset_prec = fieldoffset(BigFloatLayout, 1)
const offset_sign = fieldoffset(BigFloatLayout, 2)
const offset_exp = fieldoffset(BigFloatLayout, 3)
const offset_d = fieldoffset(BigFloatLayout, 4)
const offset_p = fieldoffset(BigFloatLayout, 5)
const offset_p_limbs = (offset_p ÷ sizeof(Limb)) % Int

"""
BigFloat <: AbstractFloat
Arbitrary precision floating point number type.
"""
mutable struct BigFloat <: AbstractFloat
prec::Clong
sign::Cint
exp::Clong
d::Ptr{Limb}
# _d::Buffer{Limb} # Julia gc handle for memory @ d
_d::String # Julia gc handle for memory @ d (optimized)
struct BigFloat <: AbstractFloat
d::Memory{Limb}

# Not recommended for general use:
# used internally by, e.g. deepcopy
global function _BigFloat(prec::Clong, sign::Cint, exp::Clong, d::String)
# ccall-based version, inlined below
#z = new(zero(Clong), zero(Cint), zero(Clong), C_NULL, d)
#ccall((:mpfr_custom_init,libmpfr), Cvoid, (Ptr{Limb}, Clong), d, prec) # currently seems to be a no-op in mpfr
#NAN_KIND = Cint(0)
#ccall((:mpfr_custom_init_set,libmpfr), Cvoid, (Ref{BigFloat}, Cint, Clong, Ptr{Limb}), z, NAN_KIND, prec, d)
#return z
return new(prec, sign, exp, pointer(d), d)
end
global _BigFloat(d::Memory{Limb}) = new(d)

function BigFloat(; precision::Integer=_precision_with_base_2(BigFloat))
precision < 1 && throw(DomainError(precision, "`precision` cannot be less than 1."))
nb = ccall((:mpfr_custom_get_size,libmpfr), Csize_t, (Clong,), precision)
nb = (nb + Core.sizeof(Limb) - 1) ÷ Core.sizeof(Limb) # align to number of Limb allocations required for this
#d = Vector{Limb}(undef, nb)
d = _string_n(nb * Core.sizeof(Limb))
EXP_NAN = mpfr_special_exponent_nan
return _BigFloat(Clong(precision), one(Cint), EXP_NAN, d) # +NAN
nl = (nb + Core.sizeof(BigFloatLayout) - 1) ÷ Core.sizeof(Limb) # align to number of Limb allocations required for this
d = Memory{Limb}(undef, nl % Int)
# ccall-based version, inlined below
z = _BigFloat(d) # initialize to +NAN
#ccall((:mpfr_custom_init,libmpfr), Cvoid, (Ptr{Limb}, Clong), BigFloatData(d), prec) # currently seems to be a no-op in mpfr
#NAN_KIND = Cint(0)
#ccall((:mpfr_custom_init_set,libmpfr), Cvoid, (Ref{BigFloat}, Cint, Clong, Ptr{Limb}), z, NAN_KIND, prec, BigFloatData(d))
z.prec = Clong(precision)
z.sign = one(Cint)
z.exp = mpfr_special_exponent_nan
return z
end
end

# The rounding mode here shouldn't matter.
significand_limb_count(x::BigFloat) = div(sizeof(x._d), sizeof(Limb), RoundToZero)
"""
Segment of raw words of bits interpreted as a big integer. Less
significant words come first. Each word is in machine-native bit-order.
"""
struct BigFloatData{Limb}
d::Memory{Limb}
end

# BigFloat interface
@inline function Base.getproperty(x::BigFloat, s::Symbol)
d = getfield(x, :d)
p = Base.unsafe_convert(Ptr{Limb}, d)
if s === :prec
return GC.@preserve d unsafe_load(Ptr{Clong}(p) + offset_prec)
elseif s === :sign
return GC.@preserve d unsafe_load(Ptr{Cint}(p) + offset_sign)
elseif s === :exp
return GC.@preserve d unsafe_load(Ptr{Clong}(p) + offset_exp)
elseif s === :d
return BigFloatData(d)
else
return throw(FieldError(typeof(x), s))
end
end

@inline function Base.setproperty!(x::BigFloat, s::Symbol, v)
d = getfield(x, :d)
p = Base.unsafe_convert(Ptr{Limb}, d)
if s === :prec
return GC.@preserve d unsafe_store!(Ptr{Clong}(p) + offset_prec, v)
elseif s === :sign
return GC.@preserve d unsafe_store!(Ptr{Cint}(p) + offset_sign, v)
elseif s === :exp
return GC.@preserve d unsafe_store!(Ptr{Clong}(p) + offset_exp, v)
#elseif s === :d # not mutable
else
return throw(FieldError(x, s))
end
end

# Ref interface: make sure the conversion to C is done properly
Base.unsafe_convert(::Type{Ref{BigFloat}}, x::Ptr{BigFloat}) = error("not compatible with mpfr")
Base.unsafe_convert(::Type{Ref{BigFloat}}, x::Ref{BigFloat}) = error("not compatible with mpfr")
Base.cconvert(::Type{Ref{BigFloat}}, x::BigFloat) = x.d # BigFloatData is the Ref type for BigFloat
function Base.unsafe_convert(::Type{Ref{BigFloat}}, x::BigFloatData)
d = getfield(x, :d)
p = Base.unsafe_convert(Ptr{Limb}, d)
GC.@preserve d unsafe_store!(Ptr{Ptr{Limb}}(p) + offset_d, p + offset_p, :monotonic) # :monotonic ensure that TSAN knows that this isn't a data race
return Ptr{BigFloat}(p)
end
Base.unsafe_convert(::Type{Ptr{Limb}}, fd::BigFloatData) = Base.unsafe_convert(Ptr{Limb}, getfield(fd, :d)) + offset_p
function Base.setindex!(fd::BigFloatData, v, i)
getfield(fd, :d)[i + offset_p_limbs] = v
return fd
end
function Base.getindex(fd::BigFloatData, i)
return getfield(fd, :d)[i + offset_p_limbs]
end
Base.length(fd::BigFloatData) = length(getfield(fd, :d)) - offset_p_limbs
Base.copyto!(fd::BigFloatData, limbs) = copyto!(getfield(fd, :d), offset_p_limbs + 1, limbs) # for Random

include("rawbigfloats.jl")

rounding_raw(::Type{BigFloat}) = something(Base.ScopedValues.get(CURRENT_ROUNDING_MODE), ROUNDING_MODE[])
setrounding_raw(::Type{BigFloat}, r::MPFRRoundingMode) = ROUNDING_MODE[]=r
function setrounding_raw(f::Function, ::Type{BigFloat}, r::MPFRRoundingMode)
Base.ScopedValues.@with(CURRENT_ROUNDING_MODE => r, f())
end


rounding(::Type{BigFloat}) = convert(RoundingMode, rounding_raw(BigFloat))
setrounding(::Type{BigFloat}, r::RoundingMode) = setrounding_raw(BigFloat, convert(MPFRRoundingMode, r))
setrounding(f::Function, ::Type{BigFloat}, r::RoundingMode) =
setrounding_raw(f, BigFloat, convert(MPFRRoundingMode, r))


# overload the definition of unsafe_convert to ensure that `x.d` is assigned
# it may have been dropped in the event that the BigFloat was serialized
Base.unsafe_convert(::Type{Ref{BigFloat}}, x::Ptr{BigFloat}) = x
@inline function Base.unsafe_convert(::Type{Ref{BigFloat}}, x::Ref{BigFloat})
x = x[]
if x.d == C_NULL
x.d = pointer(x._d)
end
return convert(Ptr{BigFloat}, Base.pointer_from_objref(x))
end

"""
BigFloat(x::Union{Real, AbstractString} [, rounding::RoundingMode=rounding(BigFloat)]; [precision::Integer=precision(BigFloat)])
Expand Down Expand Up @@ -283,17 +335,18 @@ function BigFloat(x::Float64, r::MPFRRoundingMode=rounding_raw(BigFloat); precis
nlimbs = (precision + 8*Core.sizeof(Limb) - 1) ÷ (8*Core.sizeof(Limb))

# Limb is a CLong which is a UInt32 on windows (thank M$) which makes this more complicated and slower.
zd = z.d
if Limb === UInt64
for i in 1:nlimbs-1
unsafe_store!(z.d, 0x0, i)
setindex!(zd, 0x0, i)
end
unsafe_store!(z.d, val, nlimbs)
setindex!(zd, val, nlimbs)
else
for i in 1:nlimbs-2
unsafe_store!(z.d, 0x0, i)
setindex!(zd, 0x0, i)
end
unsafe_store!(z.d, val % UInt32, nlimbs-1)
unsafe_store!(z.d, (val >> 32) % UInt32, nlimbs)
setindex!(zd, val % UInt32, nlimbs-1)
setindex!(zd, (val >> 32) % UInt32, nlimbs)
end
z
end
Expand Down Expand Up @@ -440,12 +493,12 @@ function to_ieee754(::Type{T}, x::BigFloat, rm) where {T<:AbstractFloat}
ret_u = if is_regular & !rounds_to_inf & !rounds_to_zero
if !exp_is_huge_p
# significand
v = RawBigInt{Limb}(x._d, significand_limb_count(x))
v = x.d::BigFloatData
len = max(ieee_precision + min(exp_diff, 0), 0)::Int
signif = truncated(U, v, len) & significand_mask(T)

# round up if necessary
rh = RawBigIntRoundingIncrementHelper(v, len)
rh = BigFloatDataRoundingIncrementHelper(v, len)
incr = correct_rounding_requires_increment(rh, rm, sb)

# exponent
Expand Down Expand Up @@ -1193,10 +1246,8 @@ set_emin!(x) = check_exponent_err(ccall((:mpfr_set_emin, libmpfr), Cint, (Clong,

function Base.deepcopy_internal(x::BigFloat, stackdict::IdDict)
get!(stackdict, x) do
# d = copy(x._d)
d = x._d
d′ = GC.@preserve d unsafe_string(pointer(d), sizeof(d)) # creates a definitely-new String
y = _BigFloat(x.prec, x.sign, x.exp, d′)
d′ = copy(getfield(x, :d))
y = _BigFloat(d′)
#ccall((:mpfr_custom_move,libmpfr), Cvoid, (Ref{BigFloat}, Ptr{Limb}), y, d) # unnecessary
return y
end::BigFloat
Expand All @@ -1210,7 +1261,8 @@ function decompose(x::BigFloat)::Tuple{BigInt, Int, Int}
s.size = cld(x.prec, 8*sizeof(Limb)) # limbs
b = s.size * sizeof(Limb) # bytes
ccall((:__gmpz_realloc2, libgmp), Cvoid, (Ref{BigInt}, Culong), s, 8b) # bits
memcpy(s.d, x.d, b)
xd = x.d
GC.@preserve xd memcpy(s.d, Base.unsafe_convert(Ptr{Limb}, xd), b)
s, x.exp - 8b, x.sign
end

Expand Down
57 changes: 22 additions & 35 deletions base/rawbigints.jl → base/rawbigfloats.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,47 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

"""
Segment of raw words of bits interpreted as a big integer. Less
significant words come first. Each word is in machine-native bit-order.
"""
struct RawBigInt{T<:Unsigned}
d::String
word_count::Int

function RawBigInt{T}(d::String, word_count::Int) where {T<:Unsigned}
new{T}(d, word_count)
end
end
# Some operations on BigFloat can be done more directly by treating the data portion ("BigFloatData") as a BigInt

elem_count(x::RawBigInt, ::Val{:words}) = x.word_count
elem_count(x::BigFloatData, ::Val{:words}) = length(x)
elem_count(x::Unsigned, ::Val{:bits}) = sizeof(x) * 8
word_length(::RawBigInt{T}) where {T} = elem_count(zero(T), Val(:bits))
elem_count(x::RawBigInt{T}, ::Val{:bits}) where {T} = word_length(x) * elem_count(x, Val(:words))
word_length(::BigFloatData{T}) where {T} = elem_count(zero(T), Val(:bits))
elem_count(x::BigFloatData{T}, ::Val{:bits}) where {T} = word_length(x) * elem_count(x, Val(:words))
reversed_index(n::Int, i::Int) = n - i - 1
reversed_index(x, i::Int, v::Val) = reversed_index(elem_count(x, v), i)::Int
split_bit_index(x::RawBigInt, i::Int) = divrem(i, word_length(x), RoundToZero)
split_bit_index(x::BigFloatData, i::Int) = divrem(i, word_length(x), RoundToZero)

"""
`i` is the zero-based index of the wanted word in `x`, starting from
the less significant words.
"""
function get_elem(x::RawBigInt{T}, i::Int, ::Val{:words}, ::Val{:ascending}) where {T}
# `i` must be non-negative and less than `x.word_count`
d = x.d
(GC.@preserve d unsafe_load(Ptr{T}(pointer(d)), i + 1))::T
Base.@propagate_inbounds function get_elem(x::BigFloatData{T}, i::Int, ::Val{:words}, ::Val{:ascending}) where {T}
return x[i + 1]::T
end

function get_elem(x, i::Int, v::Val, ::Val{:descending})
j = reversed_index(x, i, v)
get_elem(x, j, v, Val(:ascending))
end

word_is_nonzero(x::RawBigInt, i::Int, v::Val) = !iszero(get_elem(x, i, Val(:words), v))
word_is_nonzero(x::BigFloatData, i::Int, v::Val) = !iszero(get_elem(x, i, Val(:words), v))

word_is_nonzero(x::RawBigInt, v::Val) = let x = x
word_is_nonzero(x::BigFloatData, v::Val) = let x = x
i -> word_is_nonzero(x, i, v)
end

"""
Returns a `Bool` indicating whether the `len` least significant words
of `x` are nonzero.
"""
function tail_is_nonzero(x::RawBigInt, len::Int, ::Val{:words})
function tail_is_nonzero(x::BigFloatData, len::Int, ::Val{:words})
any(word_is_nonzero(x, Val(:ascending)), 0:(len - 1))
end

"""
Returns a `Bool` indicating whether the `len` least significant bits of
the `i`-th (zero-based index) word of `x` are nonzero.
"""
function tail_is_nonzero(x::RawBigInt, len::Int, i::Int, ::Val{:word})
function tail_is_nonzero(x::BigFloatData, len::Int, i::Int, ::Val{:word})
!iszero(len) &&
!iszero(get_elem(x, i, Val(:words), Val(:ascending)) << (word_length(x) - len))
end
Expand All @@ -63,7 +50,7 @@ end
Returns a `Bool` indicating whether the `len` least significant bits of
`x` are nonzero.
"""
function tail_is_nonzero(x::RawBigInt, len::Int, ::Val{:bits})
function tail_is_nonzero(x::BigFloatData, len::Int, ::Val{:bits})
if 0 < len
word_count, bit_count_in_word = split_bit_index(x, len)
tail_is_nonzero(x, bit_count_in_word, word_count, Val(:word)) ||
Expand All @@ -83,7 +70,7 @@ end
"""
Returns a `Bool` that is the `i`-th (zero-based index) bit of `x`.
"""
function get_elem(x::RawBigInt, i::Int, ::Val{:bits}, v::Val{:ascending})
function get_elem(x::BigFloatData, i::Int, ::Val{:bits}, v::Val{:ascending})
vb = Val(:bits)
if 0 i < elem_count(x, vb)
word_index, bit_index_in_word = split_bit_index(x, i)
Expand All @@ -98,7 +85,7 @@ end
Returns an integer of type `R`, consisting of the `len` most
significant bits of `x`.
"""
function truncated(::Type{R}, x::RawBigInt, len::Int) where {R<:Integer}
function truncated(::Type{R}, x::BigFloatData, len::Int) where {R<:Integer}
ret = zero(R)
if 0 < len
word_count, bit_count_in_word = split_bit_index(x, len)
Expand All @@ -120,30 +107,30 @@ function truncated(::Type{R}, x::RawBigInt, len::Int) where {R<:Integer}
ret::R
end

struct RawBigIntRoundingIncrementHelper{T<:Unsigned}
n::RawBigInt{T}
struct BigFloatDataRoundingIncrementHelper{T<:Unsigned}
n::BigFloatData{T}
trunc_len::Int

final_bit::Bool
round_bit::Bool

function RawBigIntRoundingIncrementHelper{T}(n::RawBigInt{T}, len::Int) where {T<:Unsigned}
function BigFloatDataRoundingIncrementHelper{T}(n::BigFloatData{T}, len::Int) where {T<:Unsigned}
vals = (Val(:bits), Val(:descending))
f = get_elem(n, len - 1, vals...)
r = get_elem(n, len , vals...)
new{T}(n, len, f, r)
end
end

function RawBigIntRoundingIncrementHelper(n::RawBigInt{T}, len::Int) where {T<:Unsigned}
RawBigIntRoundingIncrementHelper{T}(n, len)
function BigFloatDataRoundingIncrementHelper(n::BigFloatData{T}, len::Int) where {T<:Unsigned}
BigFloatDataRoundingIncrementHelper{T}(n, len)
end

(h::RawBigIntRoundingIncrementHelper)(::Rounding.FinalBit) = h.final_bit
(h::BigFloatDataRoundingIncrementHelper)(::Rounding.FinalBit) = h.final_bit

(h::RawBigIntRoundingIncrementHelper)(::Rounding.RoundBit) = h.round_bit
(h::BigFloatDataRoundingIncrementHelper)(::Rounding.RoundBit) = h.round_bit

function (h::RawBigIntRoundingIncrementHelper)(::Rounding.StickyBit)
function (h::BigFloatDataRoundingIncrementHelper)(::Rounding.StickyBit)
v = Val(:bits)
n = h.n
tail_is_nonzero(n, elem_count(n, v) - h.trunc_len - 1, v)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/Random/src/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ function _rand!(rng::AbstractRNG, z::BigFloat, sp::SamplerBigFloat)
limbs[end] |= Limb_high_bit
end
z.sign = 1
GC.@preserve limbs unsafe_copyto!(z.d, pointer(limbs), sp.nlimbs)
copyto!(z.d, limbs)
randbool
end

Expand Down
Loading

0 comments on commit 36ee135

Please sign in to comment.