From a6fef8fd8f3ee5fb58a19b367fd4bd173627f3b0 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Mon, 18 Sep 2023 14:25:34 +0200 Subject: [PATCH] Fix the Base 64 symbol converter. (#212) This PR fixes util::base64::SymbolConverter to also include the final group in the output if there is padding. --- src/rdata/dnssec.rs | 4 +- src/utils/base64.rs | 102 ++++++++++++++++++++++++++++++++------------ 2 files changed, 76 insertions(+), 30 deletions(-) diff --git a/src/rdata/dnssec.rs b/src/rdata/dnssec.rs index 1f3e8aa48..c237ab4e6 100644 --- a/src/rdata/dnssec.rs +++ b/src/rdata/dnssec.rs @@ -2437,10 +2437,10 @@ mod test { #[test] fn dnskey_compose_parse_scan() { - let rdata = Dnskey::new(10, 11, SecAlg::RsaSha1, b"key").unwrap(); + let rdata = Dnskey::new(10, 11, SecAlg::RsaSha1, b"key0").unwrap(); test_rdlen(&rdata); test_compose_parse(&rdata, |parser| Dnskey::parse(parser)); - test_scan(&["10", "11", "RSASHA1", "a2V5"], Dnskey::scan, &rdata); + test_scan(&["10", "11", "RSASHA1", "a2V5MA=="], Dnskey::scan, &rdata); } //--- Rrsig diff --git a/src/utils/base64.rs b/src/utils/base64.rs index 97829ed81..a080959c6 100644 --- a/src/utils/base64.rs +++ b/src/utils/base64.rs @@ -343,7 +343,7 @@ impl SymbolConverter { &mut self, ch: char, ) -> Result, Error> { - if self.next == 0xF0 { + if self.next == EOF_MARKER { return Err(Error::custom("trailing Base 64 data")); } @@ -352,7 +352,7 @@ impl SymbolConverter { if self.next < 2 { return Err(Error::custom("illegal Base 64 data")); } - 0x80 // Acts as a marker later on. + PAD_MARKER // Acts as a marker later on. } else { if ch > (127 as char) { return Err(Error::custom("illegal Base 64 data")); @@ -368,19 +368,29 @@ impl SymbolConverter { if self.next == 4 { self.output[0] = self.input[0] << 2 | self.input[1] >> 4; - if self.input[2] != 0x80 { - self.output[1] = self.input[1] << 4 | self.input[2] >> 2; - } - if self.input[3] != 0x80 { - if self.input[2] == 0x80 { - return Err(Error::custom("trailing Base 64 data")); + + if self.input[2] == PAD_MARKER { + // The second to last character is padding. The last one + // needs to be, too. + if self.input[3] == PAD_MARKER { + self.next = EOF_MARKER; + Ok(Some(&self.output[..1])) + } else { + Err(Error::custom("illegal Base 64 data")) } - self.output[2] = (self.input[2] << 6) | self.input[3]; - self.next = 0 } else { - self.next = 0xF0 + self.output[1] = self.input[1] << 4 | self.input[2] >> 2; + + if self.input[3] == PAD_MARKER { + // The last characters is padding. + self.next = EOF_MARKER; + Ok(Some(&self.output[..2])) + } else { + self.output[2] = (self.input[2] << 6) | self.input[3]; + self.next = 0; + Ok(Some(&self.output)) + } } - Ok(Some(&self.output)) } else { Ok(None) } @@ -500,26 +510,39 @@ const ENCODE_ALPHABET: [char; 64] = [ /// The padding character const PAD: char = '='; +/// The marker for padding. +const PAD_MARKER: u8 = 0x80; + +/// The marker for complete data. +const EOF_MARKER: usize = 0xF0; + //============ Test ========================================================== #[cfg(test)] mod test { + use super::*; + + #[allow(dead_code)] + const HAPPY_CASES: &[(&[u8], &str)] = &[ + (b"", ""), + (b"f", "Zg=="), + (b"fo", "Zm8="), + (b"foo", "Zm9v"), + (b"foob", "Zm9vYg=="), + (b"fooba", "Zm9vYmE="), + (b"foobar", "Zm9vYmFy"), + ]; + #[cfg(feature = "std")] #[test] fn decode_str() { - use super::DecodeError; - fn decode(s: &str) -> Result, DecodeError> { super::decode(s) } - assert_eq!(&decode("").unwrap(), b""); - assert_eq!(&decode("Zg==").unwrap(), b"f"); - assert_eq!(&decode("Zm8=").unwrap(), b"fo"); - assert_eq!(&decode("Zm9v").unwrap(), b"foo"); - assert_eq!(&decode("Zm9vYg==").unwrap(), b"foob"); - assert_eq!(&decode("Zm9vYmE=").unwrap(), b"fooba"); - assert_eq!(&decode("Zm9vYmFy").unwrap(), b"foobar"); + for (bin, text) in HAPPY_CASES { + assert_eq!(&decode(text).unwrap(), bin, "decode {}", text) + } assert_eq!(decode("FPucA").unwrap_err(), DecodeError::ShortInput); assert_eq!( @@ -537,6 +560,33 @@ mod test { ); } + #[cfg(feature = "std")] + #[test] + fn symbol_converter() { + use crate::base::scan::Symbols; + use std::vec::Vec; + + fn decode(s: &str) -> Result, std::io::Error> { + let mut convert = SymbolConverter::new(); + let convert: &mut dyn ConvertSymbols<_, std::io::Error> = + &mut convert; + let mut res = Vec::new(); + for sym in Symbols::new(s.chars()) { + if let Some(octs) = convert.process_symbol(sym)? { + res.extend_from_slice(octs); + } + } + if let Some(octs) = convert.process_tail()? { + res.extend_from_slice(octs); + } + Ok(res) + } + + for (bin, text) in HAPPY_CASES { + assert_eq!(&decode(text).unwrap(), bin, "convert {}", text) + } + } + #[test] #[cfg(feature = "std")] fn display_bytes() { @@ -548,12 +598,8 @@ mod test { out } - assert_eq!(fmt(b""), ""); - assert_eq!(fmt(b"f"), "Zg=="); - assert_eq!(fmt(b"fo"), "Zm8="); - assert_eq!(fmt(b"foo"), "Zm9v"); - assert_eq!(fmt(b"foob"), "Zm9vYg=="); - assert_eq!(fmt(b"fooba"), "Zm9vYmE="); - assert_eq!(fmt(b"foobar"), "Zm9vYmFy"); + for (bin, text) in HAPPY_CASES { + assert_eq!(&fmt(bin), text, "fmt {}", text); + } } }