Skip to content

Commit

Permalink
Fix some bugs
Browse files Browse the repository at this point in the history
* Make sure == doesn't dispatch to memcmp if the mem's eltype is a union, since
  in that case the array's bits can be the same but its content differ due to
  the type metadata array. Also, test for this
* However, ensure that =='ing immutable and mutable arrays does dispatch to
  memchr, even if the two arrays type differ in mutability
* Fix an edge case in split_unaligned when no part of the array is aligned
* Minor refactor, using the internal truncate functions more. This change allows
  some branches to be omitted.
  • Loading branch information
jakobnissen committed Dec 12, 2024
1 parent 83b4218 commit 04e38c6
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 16 deletions.
29 changes: 23 additions & 6 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function Base.similar(::MemoryView{T1, M}, ::Type{T2}, dims::Tuple{Int}) where {
MemoryView{T2, M}(unsafe, memoryref(memory), len)
end

function Base.empty(mem::MemoryView{T1, M}, ::Type{T2}) where {T1, T2, M}
function Base.empty(::MemoryView{T1, M}, ::Type{T2}) where {T1, T2, M}
MemoryView{T2, M}(unsafe, memoryref(Memory{T2}()), 0)
end

Expand Down Expand Up @@ -83,7 +83,7 @@ function Base.getindex(v::MemoryView, idx::AbstractUnitRange)
end

Base.getindex(v::MemoryView, ::Colon) = v
Base.view(v::MemoryView, idx::AbstractUnitRange) = v[idx]
Base.@propagate_inbounds Base.view(v::MemoryView, idx::AbstractUnitRange) = v[idx]

function truncate(mem::MemoryView, include_last::Integer)
lst = Int(include_last)::Int
Expand All @@ -102,6 +102,15 @@ function truncate_start_nonempty(mem::MemoryView, from::Integer)
typeof(mem)(unsafe, newref, length(mem) - frm + 1)
end

function truncate_start(mem::MemoryView, from::Integer)
frm = Int(from)::Int
@boundscheck if ((frm - 1) % UInt) > length(mem) % UInt
throw(BoundsError(mem, frm))
end
newref = @inbounds memoryref(mem.ref, frm - (from == length(mem) + 1))
typeof(mem)(unsafe, newref, length(mem) - frm + 1)
end

function Base.unsafe_copyto!(dst::MutableMemoryView{T}, src::MemoryView{T}) where {T}
iszero(length(src)) && return dst
@inbounds unsafe_copyto!(dst.ref, src.ref, length(src))
Expand Down Expand Up @@ -243,12 +252,20 @@ function memrchr(mem::ImmutableMemoryView{T}, byte::T) where {T <: Union{Int8, U
p == C_NULL ? nothing : (p - ptr) % Int + 1
end

const Bits =
Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, UInt128, Char}
const BitsTypes =
(Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, UInt128, Char)
const Bits = Union{BitsTypes...}
const BitMemory = Union{map(T -> MemoryView{T}, BitsTypes)...}

# This dispatch makes sure that, if they have the same element bitstype, but the views
# are of different types due to mutability, we still dispatch to the correct methpd.
Base.:(==)(a::ImmutableMemoryView, b::MutableMemoryView) = a == ImmutableMemoryView(b)
Base.:(==)(a::MutableMemoryView, b::ImmutableMemoryView) = ImmutableMemoryView(a) == b

function Base.:(==)(a::MemoryView{T}, b::MemoryView{T}) where {T <: Bits}
# Make sure to only dispatch if it's the exact same memory type.
function Base.:(==)(a::Mem, b::Mem) where {Mem <: BitMemory}
length(a) == length(b) || return false
(T === Union{} || Base.issingletontype(T)) && return true
(eltype(a) === Union{} || Base.issingletontype(eltype(a))) && return true
a.ref === b.ref && return true
GC.@preserve a b begin
aptr = Ptr{Nothing}(pointer(a))
Expand Down
14 changes: 4 additions & 10 deletions src/experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ ERROR: BoundsError: attempt to access 0-element MutableMemoryView{UInt8} at inde
"""
function split_first(v::MemoryView)
@boundscheck checkbounds(v, 1)
newref = @inbounds memoryref(v.ref, 1 + (length(v) > 1))
fst = @inbounds v[1]
(fst, typeof(v)(unsafe, newref, length(v) - 1))
(@inbounds(v[1]), @inbounds(truncate_start(v, 2)))
end

"""
Expand Down Expand Up @@ -59,8 +57,7 @@ ERROR: BoundsError: attempt to access 0-element MutableMemoryView{UInt8} at inde
"""
function split_last(v::MemoryView)
@boundscheck checkbounds(v, 1)
lst = @inbounds v[end]
(lst, typeof(v)(unsafe, v.ref, length(v) - 1))
(@inbounds(v[end]), @inbounds(truncate(v, length(v) - 1)))
end

"""
Expand All @@ -82,10 +79,7 @@ julia> split_at(MemoryView(Int8[1, 2, 3]), 4)
"""
function split_at(v::MemoryView, i::Int)
@boundscheck checkbounds(1:(lastindex(v) + 1), i)
fst = typeof(v)(unsafe, v.ref, i - 1)
ref = i > lastindex(v) ? v.ref : @inbounds memoryref(v.ref, i)
lst = typeof(v)(unsafe, ref, length(v) - i + 1)
(fst, lst)
(@inbounds(truncate(v, i - 1)), @inbounds(truncate_start(v, i)))
end

"""
Expand Down Expand Up @@ -118,6 +112,6 @@ function split_unaligned(v::MemoryView, ::Val{A}) where {A}
# this will be compiled away
iszero(sz) && return (typeof(v)(unsafe, v.ref, 0), v)
unaligned_bytes = ((alignment - (UInt(pointer(v)) & mask)) & mask)
n_elements = div(unaligned_bytes, sz % UInt) % Int
n_elements = min(length(v), div(unaligned_bytes, sz % UInt) % Int)
@inbounds split_at(v, n_elements + 1)
end
9 changes: 9 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,10 @@ end
@test split_unaligned(v, Val(8)) == split_at(v, 8)
@test split_unaligned(v, Val(16)) == split_at(v, 16)

v = v[2:4]
@test split_unaligned(v, Val(16)) == split_at(v, length(v) + 1)
@test split_unaligned(v, Val(8)) == split_at(v, length(v) + 1)

v = MemoryView(collect(0x0000:0x003f))[3:end]
@test split_unaligned(v, Val(1)) == split_at(v, 1)
@test split_unaligned(v, Val(4)) == split_at(v, 1)
Expand Down Expand Up @@ -544,6 +548,11 @@ end
@test m1 != m2
m2 = m2[1:(end - 1)]
@test m1 == m2

# These only differ in the type metadata, make sure they are distinguished
m1 = MemoryView(Union{Int, UInt}[-1])
m2 = MemoryView(Union{Int, UInt}[typemax(UInt)])
@test m1 != m2
end

@testset "MemoryKind" begin
Expand Down

0 comments on commit 04e38c6

Please sign in to comment.