Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up search operations for SeqOrView #325

Merged
merged 4 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ docs/site/
Manifest.toml

.DS_Store

LocalPreferences.toml

TODO.md
4 changes: 3 additions & 1 deletion src/alphabet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ If `s` cannot be encoded to the given alphabet, throw an `EncodeError`
return y === nothing ? throw(EncodeError(A, s)) : y
end

tryencode(A::Alphabet, s::BioSymbol) = throw(EncodeError(A, s))
tryencode(A::Alphabet, s::Any) = nothing

"""
tryencode(::Alphabet, x::S)
Expand Down Expand Up @@ -387,3 +387,5 @@ function guess_alphabet(v::AbstractVector{UInt8})
end
end
guess_alphabet(s::AbstractString) = guess_alphabet(codeunits(s))

const KNOWN_ALPHABETS = Union{DNAAlphabet, RNAAlphabet, AminoAcidAlphabet}
24 changes: 12 additions & 12 deletions src/biosequence/find.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@

# Finding specific symbols

Base.findnext(x::DNA, seq::BioSequence{<:DNAAlphabet}, start::Integer) = Base.findnext(isequal(x), seq, start)
Base.findnext(x::RNA, seq::BioSequence{<:RNAAlphabet}, start::Integer) = Base.findnext(isequal(x), seq, start)
Base.findnext(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}, start::Integer) = Base.findnext(isequal(x), seq, start)
Base.findnext(x::DNA, seq::BioSequence{<:DNAAlphabet}, start::Integer) = findnext(isequal(x), seq, start)
Base.findnext(x::RNA, seq::BioSequence{<:RNAAlphabet}, start::Integer) = findnext(isequal(x), seq, start)
Base.findnext(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}, start::Integer) = findnext(isequal(x), seq, start)

Check warning on line 42 in src/biosequence/find.jl

View check run for this annotation

Codecov / codecov/patch

src/biosequence/find.jl#L41-L42

Added lines #L41 - L42 were not covered by tests

Base.findprev(x::DNA, seq::BioSequence{<:DNAAlphabet}, start::Integer) = Base.findprev(isequal(x), seq, start)
Base.findprev(x::RNA, seq::BioSequence{<:RNAAlphabet}, start::Integer) = Base.findprev(isequal(x), seq, start)
Base.findprev(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}, start::Integer) = Base.findprev(isequal(x), seq, start)
Base.findprev(x::DNA, seq::BioSequence{<:DNAAlphabet}, start::Integer) = findprev(isequal(x), seq, start)
Base.findprev(x::RNA, seq::BioSequence{<:RNAAlphabet}, start::Integer) = findprev(isequal(x), seq, start)
Base.findprev(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}, start::Integer) = findprev(isequal(x), seq, start)

Check warning on line 46 in src/biosequence/find.jl

View check run for this annotation

Codecov / codecov/patch

src/biosequence/find.jl#L44-L46

Added lines #L44 - L46 were not covered by tests

Base.findfirst(x::DNA, seq::BioSequence{<:DNAAlphabet}) = Base.findfirst(isequal(x), seq)
Base.findfirst(x::RNA, seq::BioSequence{<:RNAAlphabet}) = Base.findfirst(isequal(x), seq)
Base.findfirst(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}) = Base.findfirst(isequal(x), seq)
Base.findfirst(x::DNA, seq::BioSequence{<:DNAAlphabet}) = findfirst(isequal(x), seq)
Base.findfirst(x::RNA, seq::BioSequence{<:RNAAlphabet}) = findfirst(isequal(x), seq)
Base.findfirst(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}) = findfirst(isequal(x), seq)

Check warning on line 50 in src/biosequence/find.jl

View check run for this annotation

Codecov / codecov/patch

src/biosequence/find.jl#L48-L50

Added lines #L48 - L50 were not covered by tests

Base.findlast(x::DNA, seq::BioSequence{<:DNAAlphabet}) = Base.findlast(isequal(x), seq)
Base.findlast(x::RNA, seq::BioSequence{<:RNAAlphabet}) = Base.findlast(isequal(x), seq)
Base.findlast(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}) = Base.findlast(isequal(x), seq)
Base.findlast(x::DNA, seq::BioSequence{<:DNAAlphabet}) = findlast(isequal(x), seq)
Base.findlast(x::RNA, seq::BioSequence{<:RNAAlphabet}) = findlast(isequal(x), seq)
Base.findlast(x::AminoAcid, seq::BioSequence{AminoAcidAlphabet}) = findlast(isequal(x), seq)

Check warning on line 54 in src/biosequence/find.jl

View check run for this annotation

Codecov / codecov/patch

src/biosequence/find.jl#L52-L54

Added lines #L52 - L54 were not covered by tests
2 changes: 0 additions & 2 deletions src/counting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# Instead, we could emit the head and tail chunk seperately, then an iterator of
# all the full chunks loaded as single elements from the underlying vector.

const KNOWN_ALPHABETS = Union{DNAAlphabet, RNAAlphabet, AminoAcidAlphabet}

trunc_seq(x::LongSequence, len::Int) = typeof(x)(x.data, len % UInt)
trunc_seq(x::LongSubSeq, len::Int) = typeof(x)(x.data, first(x.part):first(x.part)+len-1)

Expand Down
75 changes: 75 additions & 0 deletions src/longsequences/chunk_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,78 @@
b = iter_inbounds(it.b, last(state))
((first(a), first(b)), (last(a), last(b)))
end

# This returns (head, body, tail), where:
# - head and tail are Tuple{UInt64, UInt8}, with a coding element and the number
# of coding bits in that element. Head is the partial coding element before any
# full elements, and tail is the partial after any coding elements.
# If head or tail is empty, the UInt8 is set to zero. By definition, it can be
# at most set to 63.
# If the sequence is composed of only one partial element, tail is nonempty
# and head is empty.
# - body is a Tuple{UInt, UInt} with the (start, stop) indices of coding elements.
# If stop < start, there are no such elements.
# TODO: The body should probably be a MemoryView in 1.11
function parts(seq::LongSequence)

Check warning on line 145 in src/longsequences/chunk_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/chunk_iterator.jl#L145

Added line #L145 was not covered by tests
# LongSequence never has coding bits before the first chunks
head = (zero(UInt64), zero(UInt8))
len = length(seq)

Check warning on line 148 in src/longsequences/chunk_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/chunk_iterator.jl#L147-L148

Added lines #L147 - L148 were not covered by tests
# Shortcut to prevent annoying edge cases in the rest of the code
if iszero(len)
return (head, (UInt(1), UInt(0)), (zero(UInt64), zero(UInt8)))

Check warning on line 151 in src/longsequences/chunk_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/chunk_iterator.jl#L150-L151

Added lines #L150 - L151 were not covered by tests
end
lastbitindex(seq)
bits_in_tail = (offset(bitindex(seq, len + 1)) % UInt8) & 0x3f
lbi = bitindex(seq, len)
lbii = index(lbi)
tail = if iszero(bits_in_tail)
head

Check warning on line 158 in src/longsequences/chunk_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/chunk_iterator.jl#L153-L158

Added lines #L153 - L158 were not covered by tests
else
(@inbounds(seq.data[lbii]), bits_in_tail)

Check warning on line 160 in src/longsequences/chunk_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/chunk_iterator.jl#L160

Added line #L160 was not covered by tests
end
# If we have bits in the tail, then clearly those bits means the last bitindex
# points to one past the last full chunk
body = (UInt(1), (lbii - !iszero(bits_in_tail)) % UInt)
(head, body, tail)

Check warning on line 165 in src/longsequences/chunk_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/chunk_iterator.jl#L164-L165

Added lines #L164 - L165 were not covered by tests
end

function parts(seq::LongSubSeq)
data = seq.data
zero_end = (zero(UInt64), zero(UInt8))
len = length(seq)
# Again: Avoid annoying edge cases later
if iszero(len)
return (zero_end, (UInt(1), UInt(0)), zero_end)

Check warning on line 174 in src/longsequences/chunk_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/chunk_iterator.jl#L174

Added line #L174 was not covered by tests
end
lastbitindex(seq)
lbi = bitindex(seq, len)
lbii = index(lbi)
fbi = firstbitindex(seq)
fbii = index(fbi)
bits_in_head_naive = (((64 - offset(fbi)) % UInt8) & 0x3f)
# If first and last chunk index is the same, there are actually zero
# bits in head, as they are all in the tail
bits_in_head = bits_in_head_naive * (lbii != fbii)
# For the head, there are some uncoding lower bits. We need to shift
# the head right with this number.
head_shift = ((0x40 - bits_in_head_naive) & 0x3f)
head = if iszero(bits_in_head)
zero_end
else
chunk = @inbounds(data[fbii]) >> head_shift
(chunk, bits_in_head)
end
# However, if last and first chunk index is the same, there is no head
# chunk, and thus no head chunk to shift, but the TAIL chunk may not have coding bits at the lowest
# position.
tail_shift = (head_shift * (lbii == fbii)) & 63
bits_in_tail = (offset(bitindex(seq, len + 1)) % UInt8) & 0x3f
bits_in_tail -= tail_shift % UInt8
tail = if iszero(bits_in_tail)
zero_end
else
(@inbounds(data[lbii]) >> tail_shift, bits_in_tail)
end
body = (fbii + !iszero(bits_in_head), lbii - !iszero(bits_in_tail))
(head, body, tail)
end
137 changes: 137 additions & 0 deletions src/longsequences/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,141 @@
end

return true
end

## Search

# We only dispatch on known alphabets, because new alphabets may implement == in surprising ways
function Base.findnext(
cmp::Base.Fix2{<:Union{typeof(==), typeof(isequal)}},
seq::SeqOrView{<:KNOWN_ALPHABETS},
i::Integer,
)
i = max(Int(i)::Int, 1)
i > length(seq) && return nothing
symbol = cmp.x
enc = tryencode(Alphabet(seq), symbol)
enc === nothing && return nothing
vw = @inbounds view(seq, i:lastindex(seq))
res = _findfirst(vw, enc)
res === nothing ? nothing : res + i - 1
end

function _findfirst(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64)
data = seq.data
enc *= encoding_expansion(BitsPerSymbol(seq))
((head, head_bits), (body_i, body_stop), (tail, tail_bits)) = parts(seq)
symbols_in_head = div(head_bits, bits_per_symbol(Alphabet(seq))) % Int
# The idea here is that we xor with the coding elements, then check for the first
# occurrence of a zerod symbol, if any.
if !iszero(head_bits)
tu = trailing_unsets(BitsPerSymbol(seq), head ⊻ enc)
tu < symbols_in_head && return tu + 1
end
i = symbols_in_head + 1
while body_i ≤ body_stop
chunk = @inbounds data[body_i] ⊻ enc
ze = set_zero_encoding(BitsPerSymbol(seq), chunk)
if !iszero(ze)
return i + div(trailing_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int
end
body_i += 1
i += symbols_per_data_element(seq)
end
if !iszero(tail_bits)
tu = trailing_unsets(BitsPerSymbol(seq), tail ⊻ enc)
tu < div(tail_bits, bits_per_symbol(Alphabet(seq))) && return tu + i
end
nothing

Check warning on line 283 in src/longsequences/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/operators.jl#L283

Added line #L283 was not covered by tests
end

function Base.findprev(
cmp::Base.Fix2{<:Union{typeof(==), typeof(isequal)}},
seq::SeqOrView{<:KNOWN_ALPHABETS},
i::Integer,
)
i = Int(i)::Int
i < 1 && return nothing
symbol = cmp.x
enc = tryencode(Alphabet(seq), symbol)
enc === nothing && return nothing
vw = @inbounds view(seq, 1:i)
_findlast(vw, enc)
end

# See comments in findfirst
function _findlast(seq::SeqOrView{<:KNOWN_ALPHABETS}, enc::UInt64)
data = seq.data
enc *= encoding_expansion(BitsPerSymbol(seq))
((head, head_bits), (body_stop, body_i), (tail, tail_bits)) = parts(seq)
i = lastindex(seq)
# This part is slightly different, because the coding bits are shifted to the right,
# but we need to count the leading bits.
# So, we need to mask off the top bits by OR'ing them with a bunch of 1's,
# and then ignore the number of symbols we've masked off when counting the number
# of leading nonzero symbols un the encoding
if !iszero(tail_bits)
symbols_in_tail = div(tail_bits, bits_per_symbol(Alphabet(seq))) % Int
tail = (tail ⊻ enc) | ~(UInt64(1) << (tail_bits & 0x3f) - 1)
masked_unsets = div((0x40 - tail_bits), bits_per_symbol(Alphabet(seq)))
lu = leading_unsets(BitsPerSymbol(seq), tail) - masked_unsets
lu < symbols_in_tail && return (i - lu) % Int
i -= lu
end
while body_i ≥ body_stop
chunk = @inbounds data[body_i] ⊻ enc
ze = set_zero_encoding(BitsPerSymbol(seq), chunk)
if !iszero(ze)
return i - div(leading_zeros(ze) % UInt, bits_per_symbol(Alphabet(seq))) % Int
end
body_i -= 1
i -= symbols_per_data_element(seq)
end
if !iszero(head_bits)
symbols_in_head = div(head_bits, bits_per_symbol(Alphabet(seq))) % Int
head = (head ⊻ enc) | ~(UInt64(1) << (head_bits & 0x3f) - 1)
masked_unsets = div((0x40 - head_bits), bits_per_symbol(Alphabet(seq)))
lu = leading_unsets(BitsPerSymbol(seq), head) - masked_unsets
lu < symbols_in_head && return (i - lu) % Int
end
nothing

Check warning on line 335 in src/longsequences/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/longsequences/operators.jl#L335

Added line #L335 was not covered by tests
end

encoding_expansion(::BitsPerSymbol{8}) = 0x0101010101010101
encoding_expansion(::BitsPerSymbol{4}) = 0x1111111111111111
encoding_expansion(::BitsPerSymbol{2}) = 0x5555555555555555

# For every 8-bit chunk, if the chunk is all zeros, set the lowest bit in the chunk,
# else, zero the chunk.
# E.g. 0x_0a_b0_0c_00_fe_00_ff_4e -> 0x_00_00_00_01_00_01_00_00
function set_zero_encoding(B::BitsPerSymbol{8}, enc::UInt64)
enc = ~enc
enc &= enc >> 4
enc &= enc >> 2
enc &= enc >> 1
enc & encoding_expansion(B)
end

function set_zero_encoding(B::BitsPerSymbol{4}, enc::UInt64)
enc = ~enc
enc &= enc >> 2
enc &= enc >> 1
enc & encoding_expansion(B)
end

function set_zero_encoding(B::BitsPerSymbol{2}, enc::UInt64)
enc = ~enc
enc &= enc >> 1
enc & encoding_expansion(B)
end

# Count how many trailing chunks of B bits in encoding that are not all zeros
function trailing_unsets(::BitsPerSymbol{B}, enc::UInt64) where B
u = set_zero_encoding(BitsPerSymbol{B}(), enc)
div(trailing_zeros(u) % UInt, B) % Int
end

function leading_unsets(::BitsPerSymbol{B}, enc::UInt64) where B
u = set_zero_encoding(BitsPerSymbol{B}(), enc)
div(leading_zeros(u) % UInt, B) % Int
end
14 changes: 7 additions & 7 deletions src/longsequences/seqview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,25 @@ function LongSubSeq{A}(seq::LongSubSeq{A}) where A
return LongSubSeq{A}(seq.data, seq.part)
end

function LongSubSeq{A}(seq::LongSequence{A}, part::AbstractUnitRange{<:Integer}) where A
Base.@propagate_inbounds function LongSubSeq{A}(seq::LongSequence{A}, part::AbstractUnitRange{<:Integer}) where A
@boundscheck checkbounds(seq, part)
return LongSubSeq{A}(seq.data, UnitRange{Int}(part))
end

function LongSubSeq{A}(seq::LongSubSeq{A}, part::AbstractUnitRange{<:Integer}) where A
Base.@propagate_inbounds function LongSubSeq{A}(seq::LongSubSeq{A}, part::AbstractUnitRange{<:Integer}) where A
@boundscheck checkbounds(seq, part)
newpart = first(part) + first(seq.part) - 1 : last(part) + first(seq.part) - 1
return LongSubSeq{A}(seq.data, newpart)
end

function LongSubSeq(seq::SeqOrView{A}, i) where A
Base.@propagate_inbounds function LongSubSeq(seq::SeqOrView{A}, i) where A
return LongSubSeq{A}(seq, i)
end

LongSubSeq(seq::SeqOrView, ::Colon) = LongSubSeq(seq, 1:lastindex(seq))
LongSubSeq(seq::BioSequence{A}) where A = LongSubSeq{A}(seq)
Base.@propagate_inbounds LongSubSeq(seq::SeqOrView, ::Colon) = LongSubSeq(seq, 1:lastindex(seq))
Base.@propagate_inbounds LongSubSeq(seq::BioSequence{A}) where A = LongSubSeq{A}(seq)

Base.view(seq::SeqOrView, part::AbstractUnitRange) = LongSubSeq(seq, part)
Base.@propagate_inbounds Base.view(seq::SeqOrView, part::AbstractUnitRange) = LongSubSeq(seq, part)

function (::Type{T})(seq::SeqOrView{<:NucleicAcidAlphabet{2}}) where
{T<:LongSequence{<:NucleicAcidAlphabet{4}}}
Expand Down Expand Up @@ -145,7 +145,7 @@ function (::Type{T})(seq::LongSequence{<:NucleicAcidAlphabet{N}}) where
T(seq.data, 1:length(seq))
end

function (::Type{T})(seq::LongSequence{<:NucleicAcidAlphabet{N}}, part::AbstractUnitRange{<:Integer}) where
Base.@propagate_inbounds function (::Type{T})(seq::LongSequence{<:NucleicAcidAlphabet{N}}, part::AbstractUnitRange{<:Integer}) where
{N, T<:LongSubSeq{<:NucleicAcidAlphabet{N}}}
@boundscheck checkbounds(seq, part)
T(seq.data, UnitRange{Int}(part))
Expand Down
4 changes: 2 additions & 2 deletions test/alphabet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ end
@test tryencode(DNAAlphabet{2}(), DNA_M) === nothing
@test tryencode(DNAAlphabet{2}(), DNA_N) === nothing
@test tryencode(DNAAlphabet{2}(), DNA_Gap) === nothing
@test_throws EncodeError tryencode(DNAAlphabet{2}(), RNA_G)
@test tryencode(DNAAlphabet{2}(), RNA_G) === nothing

# 4 bits
for nt in BioSymbols.alphabet(DNA)
Expand All @@ -154,7 +154,7 @@ end
@test tryencode(RNAAlphabet{2}(), RNA_M) === nothing
@test tryencode(RNAAlphabet{2}(), RNA_N) === nothing
@test tryencode(RNAAlphabet{2}(), RNA_Gap) === nothing
@test_throws EncodeError tryencode(RNAAlphabet{2}(), DNA_G)
@test tryencode(RNAAlphabet{2}(), DNA_G) === nothing

# 4 bits
for nt in BioSymbols.alphabet(RNA)
Expand Down
Loading
Loading