Skip to content

Commit

Permalink
Switch from encoding to encoding_rs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ethiraric committed Mar 23, 2024
1 parent eaa4fc3 commit fc07ee3
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ edition = "2021"

[dependencies]
arraydeque = "0.5.1"
encoding = "0.2"
encoding_rs = "0.8.33"
hashlink = "0.8"

[dev-dependencies]
Expand Down
95 changes: 77 additions & 18 deletions src/yaml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

use std::{collections::BTreeMap, convert::TryFrom, mem, ops::Index};

use encoding_rs::{CoderResult, Decoder, Encoding};
use hashlink::LinkedHashMap;

use crate::parser::{Event, MarkedEventReceiver, Parser, Tag};
Expand Down Expand Up @@ -238,38 +239,49 @@ impl YamlLoader {
}
}

/// The behavior [`YamlDecoder`] must have when an decoding error occurs.
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum YAMLDecodingTrap {
/// Ignore the offending bytes, remove them from the output.
Ignore,
/// Error out.
Strict,
/// Replace them with the Unicode REPLACEMENT CHARACTER.
Replace,
}

/// `YamlDecoder` is a `YamlLoader` builder that allows you to supply your own encoding error trap.
/// For example, to read a YAML file while ignoring Unicode decoding errors you can set the
/// `encoding_trap` to `encoding::DecoderTrap::Ignore`.
/// ```rust
/// use yaml_rust2::yaml::YamlDecoder;
/// use yaml_rust2::yaml::{YamlDecoder, YAMLDecodingTrap};
///
/// let string = b"---
/// a\xa9: 1
/// b: 2.2
/// c: [1, 2]
/// ";
/// let out = YamlDecoder::read(string as &[u8])
/// .encoding_trap(encoding::DecoderTrap::Ignore)
/// .encoding_trap(YAMLDecodingTrap::Ignore)
/// .decode()
/// .unwrap();
/// ```
pub struct YamlDecoder<T: std::io::Read> {
source: T,
trap: encoding::types::DecoderTrap,
trap: YAMLDecodingTrap,
}

impl<T: std::io::Read> YamlDecoder<T> {
/// Create a `YamlDecoder` decoding the given source.
pub fn read(source: T) -> YamlDecoder<T> {
YamlDecoder {
source,
trap: encoding::DecoderTrap::Strict,
trap: YAMLDecodingTrap::Strict,
}
}

/// Set the behavior of the decoder when the encoding is invalid.
pub fn encoding_trap(&mut self, trap: encoding::types::DecoderTrap) -> &mut Self {
pub fn encoding_trap(&mut self, trap: YAMLDecodingTrap) -> &mut Self {
self.trap = trap;
self
}
Expand All @@ -282,14 +294,61 @@ impl<T: std::io::Read> YamlDecoder<T> {
let mut buffer = Vec::new();
self.source.read_to_end(&mut buffer)?;

// Decodes the input buffer using either UTF-8, UTF-16LE or UTF-16BE depending on the BOM codepoint.
// If the buffer doesn't start with a BOM codepoint, it will use a fallback encoding obtained by
// detect_utf16_endianness.
let (res, _) =
encoding::types::decode(&buffer, self.trap, detect_utf16_endianness(&buffer));
let s = res.map_err(LoadError::Decode)?;
YamlLoader::load_from_str(&s).map_err(LoadError::Scan)
// Check if the `encoding` library can detect encoding from the BOM, otherwise use
// `detect_utf16_endianness`.
let (encoding, _) =
Encoding::for_bom(&buffer).unwrap_or_else(|| (detect_utf16_endianness(&buffer), 2));
let mut decoder = encoding.new_decoder();
let mut output = String::new();

// Decode the input buffer. We abort upon encountering an unknown character if and only if
// the decoding is set to `Strict`.
let has_replacements = decode_loop(
&buffer,
&mut output,
&mut decoder,
self.trap != YAMLDecodingTrap::Strict,
)?;

// Remove REPLACEMENT characters if asked to
if self.trap == YAMLDecodingTrap::Ignore && has_replacements {
// One day we'll be able to use [`std::string::String::remove_matches`].
output = output.replace('\u{FFFD}', "");
}

YamlLoader::load_from_str(&output).map_err(LoadError::Scan)
}
}

/// Perform a loop of [`Decoder::decode_to_string`], reallocating `output` if needed.
///
/// # Returns
/// This function returns whether any character was replaced by a Unicode REPLACEMENT CHARACTER in
/// the output.
fn decode_loop(
input: &[u8],
output: &mut String,
decoder: &mut Decoder,
allow_replacements: bool,
) -> Result<bool, LoadError> {
output.reserve(input.len());
let mut total_bytes_read = 0;
let mut replacements = false;

while let (CoderResult::OutputFull, bytes_read, replacement) =
decoder.decode_to_string(&input[total_bytes_read..], output, true)
{
if replacement && !allow_replacements {
return Err(LoadError::Decode(std::borrow::Cow::Borrowed(
"Invalid UTF character encountered",
)));
}
replacements |= replacement;
total_bytes_read += bytes_read;
// Arbitrary grow the output's capacity.
output.reserve(input.len() / 10);
}
Ok(replacements)
}

/// The encoding crate knows how to tell apart UTF-8 from UTF-16LE and utf-16BE, when the
Expand All @@ -301,15 +360,15 @@ impl<T: std::io::Read> YamlDecoder<T> {
/// This allows the encoding to be deduced by the pattern of null (#x00) characters.
//
/// See spec at <https://yaml.org/spec/1.2/spec.html#id2771184>
fn detect_utf16_endianness(b: &[u8]) -> encoding::types::EncodingRef {
fn detect_utf16_endianness(b: &[u8]) -> &'static Encoding {
if b.len() > 1 && (b[0] != b[1]) {
if b[0] == 0 {
return encoding::all::UTF_16BE;
return encoding_rs::UTF_16BE;
} else if b[1] == 0 {
return encoding::all::UTF_16LE;
return encoding_rs::UTF_16LE;
}
}
encoding::all::UTF_8
encoding_rs::UTF_8
}

macro_rules! define_as (
Expand Down Expand Up @@ -550,7 +609,7 @@ impl Iterator for YamlIter {

#[cfg(test)]
mod test {
use super::{Yaml, YamlDecoder};
use super::{YAMLDecodingTrap, Yaml, YamlDecoder};

#[test]
fn test_read_bom() {
Expand Down Expand Up @@ -623,7 +682,7 @@ b: 2.2
c: [1, 2]
";
let out = YamlDecoder::read(s as &[u8])
.encoding_trap(encoding::DecoderTrap::Ignore)
.encoding_trap(YAMLDecodingTrap::Ignore)
.decode()
.unwrap();
let doc = &out[0];
Expand Down

0 comments on commit fc07ee3

Please sign in to comment.