diff --git a/src/Inflate.jl b/src/Inflate.jl index c6e0008..0ba295c 100644 --- a/src/Inflate.jl +++ b/src/Inflate.jl @@ -33,11 +33,14 @@ module Inflate export inflate, inflate_zlib, inflate_gzip, InflateStream, InflateZlibStream, InflateGzipStream -# Huffman codes are internally represented by Vector{Vector{Int}}, +"integer type for literals" +const LInt = Int + +# Huffman codes are internally represented by Vector{Vector{LInt}}, # where code[k] are a vector of the values with code words of length # k. Codes are assigned in order from shorter to longer codes and in # the order listed. E.g. -# [[], [2, 7], [1, 3, 5], [6, 4]] +# Vector{LInt}[[], [2, 7], [1, 3, 5], [6, 4]] # would be the code # 00 - 2 # 01 - 7 @@ -47,31 +50,31 @@ export inflate, inflate_zlib, inflate_gzip, # 1110 - 6 # 1111 - 4 -const fixed_literal_or_length_table = (Vector{Int})[Int[], - Int[], - Int[], - Int[], - Int[], - Int[], - Int[256:279;], - Int[0:143; 280:287], - Int[144:255;]] - -const fixed_distance_table = (Vector{Int})[Int[], - Int[], - Int[], - Int[], - Int[0:31;]] +const fixed_literal_or_length_table = (Vector{LInt})[[], + [], + [], + [], + [], + [], + [256:279;], + [0:143; 280:287], + [144:255;]] + +const fixed_distance_table = (Vector{LInt})[[], + [], + [], + [], + [0:31;]] abstract type AbstractInflateData end mutable struct InflateData <: AbstractInflateData bytes::Vector{UInt8} - current_byte::Int + current_byte::UInt8 bytepos::Int bitpos::Int - literal_or_length_code::Vector{Vector{Int}} - distance_code::Vector{Vector{Int}} + literal_or_length_code::Vector{Vector{LInt}} + distance_code::Vector{Vector{LInt}} update_input_crc::Bool crc::UInt32 end @@ -81,6 +84,9 @@ function InflateData(source::Vector{UInt8}) fixed_distance_table, false, init_crc()) end +""" + get_input_byte(data::InflateData) -> UInt8 +""" function get_input_byte(data::InflateData) byte = data.bytes[data.bytepos] data.bytepos += 1 @@ -92,34 +98,47 @@ end # This isn't called when reading Gzip header, so no need to # consider updating crc. +""" + get_input_bytes(data::InflateData, n::Int) -> SubArray{UInt8, 1} +""" function get_input_bytes(data::InflateData, n::Int) bytes = @view data.bytes[data.bytepos:(data.bytepos + n - 1)] data.bytepos += n return bytes end +""" + getbit(data::AbstractInflateData) -> Bool +""" function getbit(data::AbstractInflateData) if data.bitpos == 0 - data.current_byte = Int(get_input_byte(data)) - end - b = data.current_byte & 1 - data.bitpos += 1 - if data.bitpos == 8 - data.bitpos = 0 - else - data.current_byte >>= 1 + data.current_byte = get_input_byte(data) end + b = data.current_byte % Bool + data.current_byte >>= 0x1 + data.bitpos = (data.bitpos + 1) & 0x7 return b end +""" + getbits(data::AbstractInflateData, n::Int) -> UInt32 +""" function getbits(data::AbstractInflateData, n::Int) - b = 0 - for i = 0:(n-1) - b |= getbit(data) << i + b = UInt32(0) + for i = 0x00:UInt8(n-1) + b |= UInt32(getbit(data)) << i end return b end +""" + getbitsint(data::AbstractInflateData, n::Int) -> Int +""" +function getbitsint(data::AbstractInflateData, n::Int) + n < 32 || error("too long bit length `n`") + return getbits(data, n) % Int +end + function skip_bits_to_byte_boundary(data::AbstractInflateData) data.bitpos = 0 return @@ -127,12 +146,18 @@ end # It is the responsibility of the caller to make sure that bitpos is # at zero, e.g. by calling skip_bits_to_byte_boundary. +""" + get_aligned_byte(data::AbstractInflateData) -> UInt8 +""" function get_aligned_byte(data::AbstractInflateData) return get_input_byte(data) end +""" + get_value_from_code(data::AbstractInflateData, code) -> LInt +""" function get_value_from_code(data::AbstractInflateData, - code::Vector{Vector{Int}}) + code::Vector{Vector{LInt}}) v = 0 for i = 1:length(code) v = (v << 1) | getbit(data) @@ -144,40 +169,51 @@ function get_value_from_code(data::AbstractInflateData, error("incomplete code table") end +""" + get_literal_or_length(data::AbstractInflateData) -> LInt +""" function get_literal_or_length(data::AbstractInflateData) return get_value_from_code(data, data.literal_or_length_code) end const base_length = [11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227] const extra_length_bits = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5] - +""" + getlength(data::AbstractInflateData, v::Int) -> Int +""" function getlength(data::AbstractInflateData, v::Int) if v <= 264 return v - 254 elseif v <= 284 - return base_length[v - 264] + getbits(data, extra_length_bits[v - 264]) + return base_length[v - 264] + getbitsint(data, extra_length_bits[v - 264]) else return 258 end end +""" + getdist(data::AbstractInflateData) -> Int +""" function getdist(data::AbstractInflateData) - b = get_value_from_code(data, data.distance_code) + b = get_value_from_code(data, data.distance_code) % Int if b <= 3 return b + 1 else extra_bits = fld(b - 2, 2) - return 1 + ((2 + b % 2) << extra_bits) + getbits(data, extra_bits) + return 1 + ((2 + b % 2) << extra_bits) + getbitsint(data, extra_bits) end end +""" + transform_code_lengths_to_code(code_lengths::Vector{Int}) -> Vector{Vector{LInt}} +""" function transform_code_lengths_to_code(code_lengths::Vector{Int}) - code = Vector{Int}[] + code = Vector{LInt}[] for i = 1:length(code_lengths) n = code_lengths[i] if n > 0 while n > length(code) - push!(code, Int[]) + push!(code, LInt[]) end push!(code[n], i - 1) end @@ -188,31 +224,31 @@ end const order = [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15] function read_code_tables(data::AbstractInflateData) - hlit = getbits(data, 5) + 257 - hdist = getbits(data, 5) + 1 - hclen = getbits(data, 4) + 4 + hlit = getbitsint(data, 5) + 257 + hdist = getbitsint(data, 5) + 1 + hclen = getbitsint(data, 4) + 4 code_length_code_lengths = zeros(Int, 19) for i = 1:hclen - code_length_code_lengths[1 + order[i]] = getbits(data, 3) + code_length_code_lengths[1 + order[i]] = getbitsint(data, 3) end code_length_code = transform_code_lengths_to_code(code_length_code_lengths) code_lengths = zeros(Int, hlit + hdist) i = 1 while i <= hlit + hdist - c = get_value_from_code(data, code_length_code) + c = get_value_from_code(data, code_length_code) % Int n = 1 l = 0 if c < 16 l = c elseif c == 16 - n = 3 + getbits(data, 2) + n = 3 + getbitsint(data, 2) l = code_lengths[i-1] elseif c == 17 - n = 3 + getbits(data, 3) + n = 3 + getbitsint(data, 3) else # A code of length 19 can only yield values between 0 and # 18 so we can only get here if c == 18. - n = 11 + getbits(data, 7) + n = 11 + getbitsint(data, 7) end code_lengths[i:(i+n-1)] .= l i += n @@ -225,21 +261,21 @@ function _inflate(data::InflateData) out = UInt8[] final_block = false while !final_block - final_block = getbits(data, 1) == 1 - compression_mode = getbits(data, 2) - if compression_mode == 0 + final_block = getbit(data) + compression_mode = getbits(data, 2) % UInt8 + if compression_mode === 0x0 skip_bits_to_byte_boundary(data) len = getbits(data, 16) nlen = getbits(data, 16) if len ⊻ nlen != 0xffff error("corrupted data") end - append!(out, get_input_bytes(data, len)) + append!(out, get_input_bytes(data, len % Int)) continue - elseif compression_mode == 1 + elseif compression_mode === 0x1 data.literal_or_length_code = fixed_literal_or_length_table data.distance_code = fixed_distance_table - elseif compression_mode == 2 + elseif compression_mode === 0x2 read_code_tables(data) else error("invalid block compression mode 3") @@ -247,12 +283,12 @@ function _inflate(data::InflateData) while true v = get_literal_or_length(data) - if v < 256 - push!(out, UInt8(v)) - elseif v == 256 + if v < 0x100 + push!(out, v % UInt8) + elseif v == 0x100 break else - length = getlength(data, v) + length = getlength(data, v % Int) distance = getdist(data) if length <= distance append!(out, @view out[(end - distance + 1):(end - distance + length)]) @@ -268,10 +304,16 @@ function _inflate(data::InflateData) return out end +""" + init_adler() -> Tuple{Int, Int} +""" function init_adler() return (0, 1) end +""" + update_adler(adler::Tuple{Int, Int}, x::UInt8) -> UInt32 +""" function update_adler(adler::Tuple{Int, Int}, x::UInt8) s2, s1 = adler s1 += x @@ -285,11 +327,18 @@ function update_adler(adler::Tuple{Int, Int}, x::UInt8) return (s2, s1) end -function finish_adler(adler) + +""" + finish_adler(adler::Tuple{Int, Int}, x::UInt8) -> UInt32 +""" +function finish_adler(adler::Tuple{Int, Int}) s2, s1 = adler return (UInt32(s2) << 16) | UInt32(s1) end +""" + compute_adler_checksum(x::Vector{UInt8}) -> UInt32 +""" @inline function compute_adler_checksum(x::Vector{UInt8}) adler = init_adler() for b in x @@ -298,6 +347,18 @@ end return finish_adler(adler) end +""" + read_stored_adler(data::AbstractInflateData) -> UInt32 +""" +function read_stored_adler(data::AbstractInflateData) + skip_bits_to_byte_boundary(data) + stored_adler = UInt32(0) + for i = 1:4 + stored_adler = stored_adler << 0x8 | get_aligned_byte(data) + end + return stored_adler +end + const crc_table = zeros(UInt32, 256) function make_crc_table() @@ -305,30 +366,42 @@ function make_crc_table() c = UInt32(n - 1) for k = 1:8 if (c & 0x00000001) != 0 - c = 0xedb88320 ⊻ (c >> 1) + c = 0xedb88320 ⊻ (c >> 0x1) else - c >>= 1 + c >>= 0x1 end end crc_table[n] = c end end +""" + init_crc() -> UInt32 +""" function init_crc() - if crc_table[1] == 0 - make_crc_table() + if crc_table[1] == 0 # FIXME: `crc_table[1]` is always zero. + make_crc_table() # `crc_table` should be initialized at compile-time. end return 0xffffffff end +""" + update_crc(c::UInt32, x::UInt8) -> UInt32 +""" @inline function update_crc(c::UInt32, x::UInt8) - @inbounds return crc_table[1 + ((c ⊻ x) & 0xff)] ⊻ (c >> 8) + @inbounds return crc_table[1 + ((c % UInt8) ⊻ x)] ⊻ (c >> 0x8) end -function finish_crc(c) +""" + finish_crc(c::UInt32) -> UInt32 +""" +function finish_crc(c::UInt32) return c ⊻ 0xffffffff end +""" + crc(x::Vector{UInt8}) -> UInt32 +""" function crc(x::Vector{UInt8}) c = init_crc() for b in x @@ -337,12 +410,15 @@ function crc(x::Vector{UInt8}) return finish_crc(c) end +""" + read_zero_terminated_data(data::AbstractInflateData) -> Vector{UInt8} +""" function read_zero_terminated_data(data::AbstractInflateData) s = UInt8[] while true c = get_aligned_byte(data) push!(s, c) - if c == 0 + if c === 0x00 break end end @@ -353,19 +429,19 @@ function read_zlib_header(data::AbstractInflateData) CMF = get_aligned_byte(data) FLG = get_aligned_byte(data) CM = CMF & 0x0f - CINFO = CMF >> 4 - FLEVEL = FLG >> 6 - FDICT = (FLG >> 5) & 0x01 - if CM != 8 + CINFO = CMF >> 0x4 + FLEVEL = FLG >> 0x6 + FDICT = (FLG >> 0x5) & 0x01 + if CM !== 0x8 error("unsupported compression method") end - if CINFO > 7 + if CINFO > 0x7 error("invalid LZ77 window size") end - if FDICT != 0 + if FDICT !== 0x0 error("preset dictionary not supported") end - if mod((UInt(CMF) << 8) | FLG, 31) != 0 + if mod((UInt16(CMF) << 0x8) | FLG, 31) != 0 error("header checksum error") end end @@ -374,11 +450,11 @@ function read_gzip_header(data::AbstractInflateData, headers, compute_crc) data.update_input_crc = compute_crc ID1 = get_aligned_byte(data) ID2 = get_aligned_byte(data) - if ID1 != 0x1f || ID2 != 0x8b + if ID1 !== 0x1f || ID2 !== 0x8b error("not gzipped data") end CM = get_aligned_byte(data) - if CM != 8 + if CM !== 0x8 error("unsupported compression method") end FLG = get_aligned_byte(data) @@ -391,7 +467,7 @@ function read_gzip_header(data::AbstractInflateData, headers, compute_crc) headers["os"] = OS end - if (FLG & 0x04) != 0 # FLG.FEXTRA + if (FLG & 0x04) !== 0x0 # FLG.FEXTRA xlen = getbits(data, 16) if headers != nothing headers["fextra"] = zeros(UInt8, xlen) @@ -405,30 +481,30 @@ function read_gzip_header(data::AbstractInflateData, headers, compute_crc) end end - if (FLG & 0x08) != 0 # FLG.FNAME + if (FLG & 0x08) !== 0x0 # FLG.FNAME name = read_zero_terminated_data(data) if headers != nothing headers["fname"] = String(name[1:end-1]) end end - if (FLG & 0x10) != 0 # FLG.FCOMMENT + if (FLG & 0x10) !== 0x0 # FLG.FCOMMENT comment = read_zero_terminated_data(data) if headers != nothing headers["fcomment"] = String(comment[1:end-1]) end end - if (FLG & 0xe0) != 0 + if (FLG & 0xe0) !== 0x0 error("reserved FLG bit set") end data.update_input_crc = false - if (FLG & 0x02) != 0 # FLG.FHCRC - crc16 = getbits(data, 16) + if (FLG & 0x02) !== 0x0 # FLG.FHCRC + crc16 = getbits(data, 16) % UInt16 if compute_crc header_crc = finish_crc(data.crc) - if crc16 != (header_crc & 0xffff) + if crc16 !== (header_crc % UInt16) error("corrupted data, header crc check failed") end end @@ -465,11 +541,7 @@ function inflate_zlib(source::Vector{UInt8}; ignore_checksum = false) out = _inflate(data) - skip_bits_to_byte_boundary(data) - stored_adler = 0 - for i = [24, 16, 8, 0] - stored_adler |= Int(get_aligned_byte(data)) << i - end + stored_adler = read_stored_adler(data) if !ignore_checksum && compute_adler_checksum(out) != stored_adler error("corrupted data, adler checksum error") end @@ -516,7 +588,7 @@ function inflate_gzip(source::Vector{UInt8}; headers = nothing, end """ - inflate_gzip(filename::AbstractString) + inflate_gzip(filename::AbstractString) -> String Convenience wrapper for reading a gzip compressed text file. The result is returned as a string. @@ -536,10 +608,10 @@ mutable struct StreamingInflateData <: AbstractInflateData stream::IO input_buffer::Vector{UInt8} input_buffer_pos::Int - current_byte::Int + current_byte::UInt8 bitpos::Int - literal_or_length_code::Vector{Vector{Int}} - distance_code::Vector{Vector{Int}} + literal_or_length_code::Vector{Vector{LInt}} + distance_code::Vector{Vector{LInt}} output_buffer::Vector{UInt8} write_pos::Int read_pos::Int @@ -559,9 +631,12 @@ function StreamingInflateData(stream::IO) true, 0, -2, false, false, init_crc()) end +""" + get_input_byte(data::StreamingInflateData) -> UInt8 +""" function get_input_byte(data::StreamingInflateData) if data.input_buffer_pos > length(data.input_buffer) - data.input_buffer = read(data.stream, 65536) + data.input_buffer = read(data.stream, 65536) # TODO: Remove the magic number data.input_buffer_pos = 1 end byte = data.input_buffer[data.input_buffer_pos] @@ -574,9 +649,12 @@ end # This isn't called when reading Gzip header, so no need to # consider updating crc. +""" + get_input_bytes(data::StreamingInflateData, n) -> SubArray{UInt8, 1} +""" function get_input_bytes(data::StreamingInflateData, n) if data.input_buffer_pos > length(data.input_buffer) - data.input_buffer = read(data.stream, 65536) + data.input_buffer = read(data.stream, 65536) # TODO: Remove the magic number data.input_buffer_pos = 1 end n = min(n, length(data.input_buffer) - data.input_buffer_pos + 1) @@ -676,11 +754,7 @@ read_trailer(stream::InflateStream) = nothing function read_trailer(stream::InflateZlibStream) computed_adler = finish_adler(stream.adler) - skip_bits_to_byte_boundary(stream.data) - stored_adler = 0 - for i = [24, 16, 8, 0] - stored_adler |= Int(get_aligned_byte(stream.data)) << i - end + stored_adler = read_stored_adler(stream.data) if stream.compute_adler && computed_adler != stored_adler error("corrupted data, adler checksum error") end @@ -690,15 +764,18 @@ function read_trailer(stream::InflateGzipStream) crc = finish_crc(stream.crc) skip_bits_to_byte_boundary(stream.data) crc32 = getbits(stream.data, 32) - if stream.compute_crc && crc32 != crc + if stream.compute_crc && crc32 !== crc error("corrupted data, crc check failed") end isize = getbits(stream.data, 32) - if isize != stream.num_bytes % UInt32 + if isize !== stream.num_bytes % UInt32 error("corrupted data, length check failed") end end +""" + read_output_byte(data::StreamingInflateData) -> UInt8 +""" @inline function read_output_byte(data::StreamingInflateData) @inbounds byte = data.output_buffer[data.read_pos] data.read_pos += 1 @@ -728,6 +805,9 @@ function read_output_bytes!(data::StreamingInflateData, out, i) return n end +""" + read_output_byte(stream::AbstractInflateStream) -> UInt8 +""" function read_output_byte(stream::AbstractInflateStream) byte = read_output_byte(stream.data) getbyte(stream) @@ -808,9 +888,12 @@ function write_to_buffer(stream::InflateGzipStream, x::AbstractVector{UInt8}) stream.num_bytes += length(x) end +""" + getbyte(stream::AbstractInflateStream) +""" function getbyte(stream::AbstractInflateStream) if stream.data.write_pos != stream.data.read_pos - return + return nothing end if stream.data.pending_bytes > 0 @@ -834,16 +917,16 @@ function getbyte(stream::AbstractInflateStream) stream.data.pending_bytes -= n write_to_buffer(stream, @view stream.data.output_buffer[pos:(pos + n - 1)]) end - return + return nothing end if stream.data.waiting_for_new_block if stream.data.reading_final_block - return + return nothing end - stream.data.reading_final_block = getbits(stream.data, 1) == 1 - compression_mode = getbits(stream.data, 2) - if compression_mode == 0 + stream.data.reading_final_block = getbit(stream.data) + compression_mode = getbits(stream.data, 2) % UInt8 + if compression_mode === 0x0 skip_bits_to_byte_boundary(stream.data) len = getbits(stream.data, 16) nlen = getbits(stream.data, 16) @@ -853,11 +936,11 @@ function getbyte(stream::AbstractInflateStream) stream.data.distance = -1 stream.data.pending_bytes = len getbyte(stream) - return - elseif compression_mode == 1 + return nothing + elseif compression_mode === 0x1 stream.data.literal_or_length_code = fixed_literal_or_length_table stream.data.distance_code = fixed_distance_table - elseif compression_mode == 2 + elseif compression_mode === 0x2 read_code_tables(stream.data) else error("invalid block compression mode 3") @@ -866,16 +949,17 @@ function getbyte(stream::AbstractInflateStream) end v = get_literal_or_length(stream.data) - if v < 256 + if v < 0x100 write_to_buffer(stream, UInt8(v)) - elseif v == 256 + elseif v == 0x100 stream.data.waiting_for_new_block = true getbyte(stream) else - stream.data.pending_bytes = getlength(stream.data, v) + stream.data.pending_bytes = getlength(stream.data, v % Int) stream.data.distance = getdist(stream.data) getbyte(stream) end + return nothing end function Base.eof(stream::AbstractInflateStream)