Skip to content

Commit

Permalink
Speed up search operations for SeqOrView (#325)
Browse files Browse the repository at this point in the history
This commit adds new methods for findnext and findprev for SeqOrView with known
alphabets, which use bitparallel operations. This in turns speeds up most search
ops which are defined in terms of these.
The new code is 4-20 times faster depending on circumstances.

It's only implemented for known alphabets because new alphabets may overload ==
in surprising ways, which makes the bitparallel ops invalid.

The commit also introduces a new internal abstraction, the `parts` function,
which may be useful for other operations down the line. It's similar to the
existing chunk iterators, but may be more efficient for subsequences, and can
be reversed.

There is also some minor cleanup that could have been its own PR, but whatever.
  • Loading branch information
jakobnissen authored Oct 25, 2024
1 parent 295ba89 commit 1473594
Show file tree
Hide file tree
Showing 9 changed files with 322 additions and 48 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ docs/site/
Manifest.toml

.DS_Store

LocalPreferences.toml

TODO.md
4 changes: 3 additions & 1 deletion src/alphabet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ If `s` cannot be encoded to the given alphabet, throw an `EncodeError`
return y === nothing ? throw(EncodeError(A, s)) : y
end

tryencode(A::Alphabet, s::BioSymbol) = throw(EncodeError(A, s))
tryencode(A::Alphabet, s::Any) = nothing

"""
tryencode(::Alphabet, x::S)
Expand Down Expand Up @@ -387,3 +387,5 @@ function guess_alphabet(v::AbstractVector{UInt8})
end
end
guess_alphabet(s::AbstractString) = guess_alphabet(codeunits(s))

const KNOWN_ALPHABETS = Union{DNAAlphabet, RNAAlphabet, AminoAcidAlphabet}
24 changes: 12 additions & 12 deletions src/biosequence/find.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ Base.findlast(f::Function, seq::BioSequence) = findprev(f, seq, lastindex(seq))

# Finding specific symbols

Base.findnext(x::DNA, seq::BioSequence{<:DNAAlphabet}, start::Integer) = Base.findnext(isequal(x), seq, start)
Base.findnext(x::RNA, seq::BioSequence{<:RNAAlphabet}, start::Integer) = Base.findnext(isequal(x), seq, start)
Base.findnext(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}, start::Integer) = Base.findnext(isequal(x), seq, start)
Base.findnext(x::DNA, seq::BioSequence{<:DNAAlphabet}, start::Integer) = findnext(isequal(x), seq, start)
Base.findnext(x::RNA, seq::BioSequence{<:RNAAlphabet}, start::Integer) = findnext(isequal(x), seq, start)
Base.findnext(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}, start::Integer) = findnext(isequal(x), seq, start)

Base.findprev(x::DNA, seq::BioSequence{<:DNAAlphabet}, start::Integer) = Base.findprev(isequal(x), seq, start)
Base.findprev(x::RNA, seq::BioSequence{<:RNAAlphabet}, start::Integer) = Base.findprev(isequal(x), seq, start)
Base.findprev(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}, start::Integer) = Base.findprev(isequal(x), seq, start)
Base.findprev(x::DNA, seq::BioSequence{<:DNAAlphabet}, start::Integer) = findprev(isequal(x), seq, start)
Base.findprev(x::RNA, seq::BioSequence{<:RNAAlphabet}, start::Integer) = findprev(isequal(x), seq, start)
Base.findprev(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}, start::Integer) = findprev(isequal(x), seq, start)

Base.findfirst(x::DNA, seq::BioSequence{<:DNAAlphabet}) = Base.findfirst(isequal(x), seq)
Base.findfirst(x::RNA, seq::BioSequence{<:RNAAlphabet}) = Base.findfirst(isequal(x), seq)
Base.findfirst(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}) = Base.findfirst(isequal(x), seq)
Base.findfirst(x::DNA, seq::BioSequence{<:DNAAlphabet}) = findfirst(isequal(x), seq)
Base.findfirst(x::RNA, seq::BioSequence{<:RNAAlphabet}) = findfirst(isequal(x), seq)
Base.findfirst(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}) = findfirst(isequal(x), seq)

Base.findlast(x::DNA, seq::BioSequence{<:DNAAlphabet}) = Base.findlast(isequal(x), seq)
Base.findlast(x::RNA, seq::BioSequence{<:RNAAlphabet}) = Base.findlast(isequal(x), seq)
Base.findlast(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}) = Base.findlast(isequal(x), seq)
Base.findlast(x::DNA, seq::BioSequence{<:DNAAlphabet}) = findlast(isequal(x), seq)
Base.findlast(x::RNA, seq::BioSequence{<:RNAAlphabet}) = findlast(isequal(x), seq)
Base.findlast(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}) = findlast(isequal(x), seq)
2 changes: 0 additions & 2 deletions src/counting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# Instead, we could emit the head and tail chunk seperately, then an iterator of
# all the full chunks loaded as single elements from the underlying vector.

const KNOWN_ALPHABETS = Union{DNAAlphabet, RNAAlphabet, AminoAcidAlphabet}

trunc_seq(x::LongSequence, len::Int) = typeof(x)(x.data, len % UInt)
trunc_seq(x::LongSubSeq, len::Int) = typeof(x)(x.data, first(x.part):first(x.part)+len-1)

Expand Down
75 changes: 75 additions & 0 deletions src/longsequences/chunk_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,78 @@ Base.eltype(::Type{<:PairedChunkIterator}) = NTuple{2, UInt64}
b = iter_inbounds(it.b, last(state))
((first(a), first(b)), (last(a), last(b)))
end

# This returns (head, body, tail), where:
# - head and tail are Tuple{UInt64, UInt8}, with a coding element and the number
# of coding bits in that element. Head is the partial coding element before any
# full elements, and tail is the partial after any coding elements.
# If head or tail is empty, the UInt8 is set to zero. By definition, it can be
# at most set to 63.
# If the sequence is composed of only one partial element, tail is nonempty
# and head is empty.
# - body is a Tuple{UInt, UInt} with the (start, stop) indices of coding elements.
# If stop < start, there are no such elements.
# TODO: The body should probably be a MemoryView in 1.11
function parts(seq::LongSequence)
# LongSequence never has coding bits before the first chunks
head = (zero(UInt64), zero(UInt8))
len = length(seq)
# Shortcut to prevent annoying edge cases in the rest of the code
if iszero(len)
return (head, (UInt(1), UInt(0)), (zero(UInt64), zero(UInt8)))
end
lastbitindex(seq)
bits_in_tail = (offset(bitindex(seq, len + 1)) % UInt8) & 0x3f
lbi = bitindex(seq, len)
lbii = index(lbi)
tail = if iszero(bits_in_tail)
head
else
(@inbounds(seq.data[lbii]), bits_in_tail)
end
# If we have bits in the tail, then clearly those bits means the last bitindex
# points to one past the last full chunk
body = (UInt(1), (lbii - !iszero(bits_in_tail)) % UInt)
(head, body, tail)
end

function parts(seq::LongSubSeq)
data = seq.data
zero_end = (zero(UInt64), zero(UInt8))
len = length(seq)
# Again: Avoid annoying edge cases later
if iszero(len)
return (zero_end, (UInt(1), UInt(0)), zero_end)
end
lastbitindex(seq)
lbi = bitindex(seq, len)
lbii = index(lbi)
fbi = firstbitindex(seq)
fbii = index(fbi)
bits_in_head_naive = (((64 - offset(fbi)) % UInt8) & 0x3f)
# If first and last chunk index is the same, there are actually zero
# bits in head, as they are all in the tail
bits_in_head = bits_in_head_naive * (lbii != fbii)
# For the head, there are some uncoding lower bits. We need to shift
# the head right with this number.
head_shift = ((0x40 - bits_in_head_naive) & 0x3f)
head = if iszero(bits_in_head)
zero_end
else
chunk = @inbounds(data[fbii]) >> head_shift
(chunk, bits_in_head)
end
# However, if last and first chunk index is the same, there is no head
# chunk, and thus no head chunk to shift, but the TAIL chunk may not have coding bits at the lowest
# position.
tail_shift = (head_shift * (lbii == fbii)) & 63
bits_in_tail = (offset(bitindex(seq, len + 1)) % UInt8) & 0x3f
bits_in_tail -= tail_shift % UInt8
tail = if iszero(bits_in_tail)
zero_end
else
(@inbounds(data[lbii]) >> tail_shift, bits_in_tail)
end
body = (fbii + !iszero(bits_in_head), lbii - !iszero(bits_in_tail))
(head, body, tail)
end
137 changes: 137 additions & 0 deletions src/longsequences/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,141 @@ function Base.:(==)(seq1::LongSequence{A}, seq2::LongSequence{A}) where {A <: Al
end

return true
end

## Search

# We only dispatch on known alphabets, because new alphabets may implement == in surprising ways
function Base.findnext(
cmp::Base.Fix2{<:Union{typeof(==), typeof(isequal)}},
seq::SeqOrView{<:KNOWN_ALPHABETS},
i::Integer,
)
i = max(Int(i)::Int, 1)
i > length(seq) && return nothing
symbol = cmp.x
enc = tryencode(Alphabet(seq), symbol)
enc === nothing && return nothing
vw = @inbounds view(seq, i:lastindex(seq))
res = _findfirst(vw, enc)
res === nothing ? nothing : res + i - 1
end

function _findfirst(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64)
data = seq.data
enc *= encoding_expansion(BitsPerSymbol(seq))
((head, head_bits), (body_i, body_stop), (tail, tail_bits)) = parts(seq)
symbols_in_head = div(head_bits, bits_per_symbol(Alphabet(seq))) % Int
# The idea here is that we xor with the coding elements, then check for the first
# occurrence of a zerod symbol, if any.
if !iszero(head_bits)
tu = trailing_unsets(BitsPerSymbol(seq), head enc)
tu < symbols_in_head && return tu + 1
end
i = symbols_in_head + 1
while body_i body_stop
chunk = @inbounds data[body_i] enc
ze = set_zero_encoding(BitsPerSymbol(seq), chunk)
if !iszero(ze)
return i + div(trailing_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int
end
body_i += 1
i += symbols_per_data_element(seq)
end
if !iszero(tail_bits)
tu = trailing_unsets(BitsPerSymbol(seq), tail enc)
tu < div(tail_bits, bits_per_symbol(Alphabet(seq))) && return tu + i
end
nothing
end

function Base.findprev(
cmp::Base.Fix2{<:Union{typeof(==), typeof(isequal)}},
seq::SeqOrView{<:KNOWN_ALPHABETS},
i::Integer,
)
i = Int(i)::Int
i < 1 && return nothing
symbol = cmp.x
enc = tryencode(Alphabet(seq), symbol)
enc === nothing && return nothing
vw = @inbounds view(seq, 1:i)
_findlast(vw, enc)
end

# See comments in findfirst
function _findlast(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64)
data = seq.data
enc *= encoding_expansion(BitsPerSymbol(seq))
((head, head_bits), (body_stop, body_i), (tail, tail_bits)) = parts(seq)
i = lastindex(seq)
# This part is slightly different, because the coding bits are shifted to the right,
# but we need to count the leading bits.
# So, we need to mask off the top bits by OR'ing them with a bunch of 1's,
# and then ignore the number of symbols we've masked off when counting the number
# of leading nonzero symbols un the encoding
if !iszero(tail_bits)
symbols_in_tail = div(tail_bits, bits_per_symbol(Alphabet(seq))) % Int
tail = (tail enc) | ~(UInt64(1) << (tail_bits & 0x3f) - 1)
masked_unsets = div((0x40 - tail_bits), bits_per_symbol(Alphabet(seq)))
lu = leading_unsets(BitsPerSymbol(seq), tail) - masked_unsets
lu < symbols_in_tail && return (i - lu) % Int
i -= lu
end
while body_i body_stop
chunk = @inbounds data[body_i] enc
ze = set_zero_encoding(BitsPerSymbol(seq), chunk)
if !iszero(ze)
return i - div(leading_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int
end
body_i -= 1
i -= symbols_per_data_element(seq)
end
if !iszero(head_bits)
symbols_in_head = div(head_bits, bits_per_symbol(Alphabet(seq))) % Int
head = (head enc) | ~(UInt64(1) << (head_bits & 0x3f) - 1)
masked_unsets = div((0x40 - head_bits), bits_per_symbol(Alphabet(seq)))
lu = leading_unsets(BitsPerSymbol(seq), head) - masked_unsets
lu < symbols_in_head && return (i - lu) % Int
end
nothing
end

encoding_expansion(::BitsPerSymbol{8}) = 0x0101010101010101
encoding_expansion(::BitsPerSymbol{4}) = 0x1111111111111111
encoding_expansion(::BitsPerSymbol{2}) = 0x5555555555555555

# For every 8-bit chunk, if the chunk is all zeros, set the lowest bit in the chunk,
# else, zero the chunk.
# E.g. 0x_0a_b0_0c_00_fe_00_ff_4e -> 0x_00_00_00_01_00_01_00_00
function set_zero_encoding(B::BitsPerSymbol{8}, enc::UInt64)
enc = ~enc
enc &= enc >> 4
enc &= enc >> 2
enc &= enc >> 1
enc & encoding_expansion(B)
end

function set_zero_encoding(B::BitsPerSymbol{4}, enc::UInt64)
enc = ~enc
enc &= enc >> 2
enc &= enc >> 1
enc & encoding_expansion(B)
end

function set_zero_encoding(B::BitsPerSymbol{2}, enc::UInt64)
enc = ~enc
enc &= enc >> 1
enc & encoding_expansion(B)
end

# Count how many trailing chunks of B bits in encoding that are not all zeros
function trailing_unsets(::BitsPerSymbol{B}, enc::UInt64) where B
u = set_zero_encoding(BitsPerSymbol{B}(), enc)
div(trailing_zeros(u) % UInt, B) % Int
end

function leading_unsets(::BitsPerSymbol{B}, enc::UInt64) where B
u = set_zero_encoding(BitsPerSymbol{B}(), enc)
div(leading_zeros(u) % UInt, B) % Int
end
14 changes: 7 additions & 7 deletions src/longsequences/seqview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,25 @@ function LongSubSeq{A}(seq::LongSubSeq{A}) where A
return LongSubSeq{A}(seq.data, seq.part)
end

function LongSubSeq{A}(seq::LongSequence{A}, part::AbstractUnitRange{<:Integer}) where A
Base.@propagate_inbounds function LongSubSeq{A}(seq::LongSequence{A}, part::AbstractUnitRange{<:Integer}) where A
@boundscheck checkbounds(seq, part)
return LongSubSeq{A}(seq.data, UnitRange{Int}(part))
end

function LongSubSeq{A}(seq::LongSubSeq{A}, part::AbstractUnitRange{<:Integer}) where A
Base.@propagate_inbounds function LongSubSeq{A}(seq::LongSubSeq{A}, part::AbstractUnitRange{<:Integer}) where A
@boundscheck checkbounds(seq, part)
newpart = first(part) + first(seq.part) - 1 : last(part) + first(seq.part) - 1
return LongSubSeq{A}(seq.data, newpart)
end

function LongSubSeq(seq::SeqOrView{A}, i) where A
Base.@propagate_inbounds function LongSubSeq(seq::SeqOrView{A}, i) where A
return LongSubSeq{A}(seq, i)
end

LongSubSeq(seq::SeqOrView, ::Colon) = LongSubSeq(seq, 1:lastindex(seq))
LongSubSeq(seq::BioSequence{A}) where A = LongSubSeq{A}(seq)
Base.@propagate_inbounds LongSubSeq(seq::SeqOrView, ::Colon) = LongSubSeq(seq, 1:lastindex(seq))
Base.@propagate_inbounds LongSubSeq(seq::BioSequence{A}) where A = LongSubSeq{A}(seq)

Base.view(seq::SeqOrView, part::AbstractUnitRange) = LongSubSeq(seq, part)
Base.@propagate_inbounds Base.view(seq::SeqOrView, part::AbstractUnitRange) = LongSubSeq(seq, part)

function (::Type{T})(seq::SeqOrView{<:NucleicAcidAlphabet{2}}) where
{T<:LongSequence{<:NucleicAcidAlphabet{4}}}
Expand Down Expand Up @@ -145,7 +145,7 @@ function (::Type{T})(seq::LongSequence{<:NucleicAcidAlphabet{N}}) where
T(seq.data, 1:length(seq))
end

function (::Type{T})(seq::LongSequence{<:NucleicAcidAlphabet{N}}, part::AbstractUnitRange{<:Integer}) where
Base.@propagate_inbounds function (::Type{T})(seq::LongSequence{<:NucleicAcidAlphabet{N}}, part::AbstractUnitRange{<:Integer}) where
{N, T<:LongSubSeq{<:NucleicAcidAlphabet{N}}}
@boundscheck checkbounds(seq, part)
T(seq.data, UnitRange{Int}(part))
Expand Down
4 changes: 2 additions & 2 deletions test/alphabet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ end
@test tryencode(DNAAlphabet{2}(), DNA_M) === nothing
@test tryencode(DNAAlphabet{2}(), DNA_N) === nothing
@test tryencode(DNAAlphabet{2}(), DNA_Gap) === nothing
@test_throws EncodeError tryencode(DNAAlphabet{2}(), RNA_G)
@test tryencode(DNAAlphabet{2}(), RNA_G) === nothing

# 4 bits
for nt in BioSymbols.alphabet(DNA)
Expand All @@ -154,7 +154,7 @@ end
@test tryencode(RNAAlphabet{2}(), RNA_M) === nothing
@test tryencode(RNAAlphabet{2}(), RNA_N) === nothing
@test tryencode(RNAAlphabet{2}(), RNA_Gap) === nothing
@test_throws EncodeError tryencode(RNAAlphabet{2}(), DNA_G)
@test tryencode(RNAAlphabet{2}(), DNA_G) === nothing

# 4 bits
for nt in BioSymbols.alphabet(RNA)
Expand Down
Loading

0 comments on commit 1473594

Please sign in to comment.