diff --git a/Cargo.toml b/Cargo.toml index bf96f58..c0bc62e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ edition = "2021" [dependencies] arraydeque = "0.5.1" -encoding = "0.2" +encoding_rs = "0.8.33" hashlink = "0.8" [dev-dependencies] diff --git a/src/yaml.rs b/src/yaml.rs index e310795..d092d22 100644 --- a/src/yaml.rs +++ b/src/yaml.rs @@ -4,6 +4,7 @@ use std::{collections::BTreeMap, convert::TryFrom, mem, ops::Index}; +use encoding_rs::{CoderResult, Decoder, Encoding}; use hashlink::LinkedHashMap; use crate::parser::{Event, MarkedEventReceiver, Parser, Tag}; @@ -238,11 +239,22 @@ impl YamlLoader { } } +/// The behavior [`YamlDecoder`] must have when an decoding error occurs. +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum YAMLDecodingTrap { + /// Ignore the offending bytes, remove them from the output. + Ignore, + /// Error out. + Strict, + /// Replace them with the Unicode REPLACEMENT CHARACTER. + Replace, +} + /// `YamlDecoder` is a `YamlLoader` builder that allows you to supply your own encoding error trap. /// For example, to read a YAML file while ignoring Unicode decoding errors you can set the /// `encoding_trap` to `encoding::DecoderTrap::Ignore`. /// ```rust -/// use yaml_rust2::yaml::YamlDecoder; +/// use yaml_rust2::yaml::{YamlDecoder, YAMLDecodingTrap}; /// /// let string = b"--- /// a\xa9: 1 @@ -250,13 +262,13 @@ impl YamlLoader { /// c: [1, 2] /// "; /// let out = YamlDecoder::read(string as &[u8]) -/// .encoding_trap(encoding::DecoderTrap::Ignore) +/// .encoding_trap(YAMLDecodingTrap::Ignore) /// .decode() /// .unwrap(); /// ``` pub struct YamlDecoder { source: T, - trap: encoding::types::DecoderTrap, + trap: YAMLDecodingTrap, } impl YamlDecoder { @@ -264,12 +276,12 @@ impl YamlDecoder { pub fn read(source: T) -> YamlDecoder { YamlDecoder { source, - trap: encoding::DecoderTrap::Strict, + trap: YAMLDecodingTrap::Strict, } } /// Set the behavior of the decoder when the encoding is invalid. - pub fn encoding_trap(&mut self, trap: encoding::types::DecoderTrap) -> &mut Self { + pub fn encoding_trap(&mut self, trap: YAMLDecodingTrap) -> &mut Self { self.trap = trap; self } @@ -282,14 +294,61 @@ impl YamlDecoder { let mut buffer = Vec::new(); self.source.read_to_end(&mut buffer)?; - // Decodes the input buffer using either UTF-8, UTF-16LE or UTF-16BE depending on the BOM codepoint. - // If the buffer doesn't start with a BOM codepoint, it will use a fallback encoding obtained by - // detect_utf16_endianness. - let (res, _) = - encoding::types::decode(&buffer, self.trap, detect_utf16_endianness(&buffer)); - let s = res.map_err(LoadError::Decode)?; - YamlLoader::load_from_str(&s).map_err(LoadError::Scan) + // Check if the `encoding` library can detect encoding from the BOM, otherwise use + // `detect_utf16_endianness`. + let (encoding, _) = + Encoding::for_bom(&buffer).unwrap_or_else(|| (detect_utf16_endianness(&buffer), 2)); + let mut decoder = encoding.new_decoder(); + let mut output = String::new(); + + // Decode the input buffer. We abort upon encountering an unknown character if and only if + // the decoding is set to `Strict`. + let has_replacements = decode_loop( + &buffer, + &mut output, + &mut decoder, + self.trap != YAMLDecodingTrap::Strict, + )?; + + // Remove REPLACEMENT characters if asked to + if self.trap == YAMLDecodingTrap::Ignore && has_replacements { + // One day we'll be able to use [`std::string::String::remove_matches`]. + output = output.replace('\u{FFFD}', ""); + } + + YamlLoader::load_from_str(&output).map_err(LoadError::Scan) + } +} + +/// Perform a loop of [`Decoder::decode_to_string`], reallocating `output` if needed. +/// +/// # Returns +/// This function returns whether any character was replaced by a Unicode REPLACEMENT CHARACTER in +/// the output. +fn decode_loop( + input: &[u8], + output: &mut String, + decoder: &mut Decoder, + allow_replacements: bool, +) -> Result { + output.reserve(input.len()); + let mut total_bytes_read = 0; + let mut replacements = false; + + while let (CoderResult::OutputFull, bytes_read, replacement) = + decoder.decode_to_string(&input[total_bytes_read..], output, true) + { + if replacement && !allow_replacements { + return Err(LoadError::Decode(std::borrow::Cow::Borrowed( + "Invalid UTF character encountered", + ))); + } + replacements |= replacement; + total_bytes_read += bytes_read; + // Arbitrary grow the output's capacity. + output.reserve(input.len() / 10); } + Ok(replacements) } /// The encoding crate knows how to tell apart UTF-8 from UTF-16LE and utf-16BE, when the @@ -301,15 +360,15 @@ impl YamlDecoder { /// This allows the encoding to be deduced by the pattern of null (#x00) characters. // /// See spec at -fn detect_utf16_endianness(b: &[u8]) -> encoding::types::EncodingRef { +fn detect_utf16_endianness(b: &[u8]) -> &'static Encoding { if b.len() > 1 && (b[0] != b[1]) { if b[0] == 0 { - return encoding::all::UTF_16BE; + return encoding_rs::UTF_16BE; } else if b[1] == 0 { - return encoding::all::UTF_16LE; + return encoding_rs::UTF_16LE; } } - encoding::all::UTF_8 + encoding_rs::UTF_8 } macro_rules! define_as ( @@ -550,7 +609,7 @@ impl Iterator for YamlIter { #[cfg(test)] mod test { - use super::{Yaml, YamlDecoder}; + use super::{YAMLDecodingTrap, Yaml, YamlDecoder}; #[test] fn test_read_bom() { @@ -623,7 +682,7 @@ b: 2.2 c: [1, 2] "; let out = YamlDecoder::read(s as &[u8]) - .encoding_trap(encoding::DecoderTrap::Ignore) + .encoding_trap(YAMLDecodingTrap::Ignore) .decode() .unwrap(); let doc = &out[0];