diff --git a/Cargo.toml b/Cargo.toml index d26915d88..69ad7fd8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,8 +17,9 @@ name = "domain" path = "src/lib.rs" [dependencies] -octseq = { version = "0.3.2", default-features = false } -time = { version = "0.3.1", default-features = false } +octseq = { version = "0.3.2", default-features = false } +pin-project-lite = "0.2" +time = { version = "0.3.1", default-features = false } rand = { version = "0.8", optional = true } bytes = { version = "1.0", optional = true, default-features = false } @@ -30,7 +31,11 @@ ring = { version = "0.17", optional = true } serde = { version = "1.0.130", optional = true, features = ["derive"] } siphasher = { version = "1", optional = true } smallvec = { version = "1", optional = true } -tokio = { version = "1.0", optional = true, features = ["io-util", "macros", "net", "time"] } +tokio = { version = "1.33", optional = true, features = ["io-util", "macros", "net", "time", "sync", "rt-multi-thread" ] } +tokio-rustls = { version = "0.24", optional = true, features = [] } + +# XXX Force proc-macro2 to at least 1.0.69 for minimal-version build +proc-macro2 = "1.0.69" [target.'cfg(macos)'.dependencies] # specifying this overrides minimum-version mio's 0.2.69 libc dependency, which allows the build to work @@ -41,24 +46,31 @@ default = ["std", "rand"] bytes = ["dep:bytes", "octseq/bytes"] heapless = ["dep:heapless", "octseq/heapless"] interop = ["bytes", "ring"] -resolv = ["bytes", "futures-util", "smallvec", "std", "tokio", "libc", "rand"] +resolv = ["net", "smallvec", "std", "rand", "unstable-client-transport"] resolv-sync = ["resolv", "tokio/rt"] serde = ["dep:serde", "octseq/serde"] sign = ["std"] smallvec = ["dep:smallvec", "octseq/smallvec"] std = ["bytes?/std", "octseq/std", "time/std"] +net = ["bytes", "futures-util", "std", "tokio", "tokio-rustls"] tsig = ["bytes", "ring", "smallvec"] validate = ["std", "ring"] zonefile = ["bytes", "std"] +# Unstable features +unstable-client-transport = [] + # This feature should include all features that the CI should include for a # test run. Which is everything except interop. -ci-test = ["resolv", "resolv-sync", "sign", "std", "serde", "tsig", "validate", "zonefile"] +ci-test = ["net", "resolv", "resolv-sync", "sign", "std", "serde", "tsig", "validate", "zonefile"] [dev-dependencies] +rustls = { version = "0.21.9" } serde_test = "1.0.130" serde_yaml = "0.9" tokio = { version = "1", features = ["rt-multi-thread", "io-util", "net"] } +tokio-test = "0.4" +webpki-roots = { version = "0.25" } [package.metadata.docs.rs] all-features = true @@ -84,3 +96,6 @@ required-features = ["resolv-sync"] name = "client" required-features = ["std", "rand"] +[[example]] +name = "client-transports" +required-features = ["net"] diff --git a/examples/client-transports.rs b/examples/client-transports.rs new file mode 100644 index 000000000..acda31f91 --- /dev/null +++ b/examples/client-transports.rs @@ -0,0 +1,233 @@ +/// Using the `domain::net::client` module for sending a query. +use domain::base::Dname; +use domain::base::MessageBuilder; +use domain::base::Rtype::Aaaa; +use domain::net::client::dgram; +use domain::net::client::dgram_stream; +use domain::net::client::multi_stream; +use domain::net::client::protocol::{TcpConnect, TlsConnect, UdpConnect}; +use domain::net::client::redundant; +use domain::net::client::request::{RequestMessage, SendRequest}; +use domain::net::client::stream; +use std::net::{IpAddr, SocketAddr}; +use std::str::FromStr; +use std::time::Duration; +use tokio::net::TcpStream; +use tokio::time::timeout; +use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore}; + +#[tokio::main] +async fn main() { + // Create DNS request message. + // + // Transports currently take a `RequestMessage` as their input to be able + // to add options along the way. + // + // In the future, it will also be possible to pass in a message or message + // builder directly as input but for now it needs to be converted into a + // `RequestMessage` manually. + let mut msg = MessageBuilder::new_vec(); + msg.header_mut().set_rd(true); + let mut msg = msg.question(); + msg.push((Dname::vec_from_str("example.com").unwrap(), Aaaa)) + .unwrap(); + let req = RequestMessage::new(msg); + + // Destination for UDP and TCP + let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); + + let mut stream_config = stream::Config::new(); + stream_config.set_response_timeout(Duration::from_millis(100)); + let multi_stream_config = + multi_stream::Config::from(stream_config.clone()); + + // Create a new UDP+TCP transport connection. Pass the destination address + // and port as parameter. + let mut dgram_config = dgram::Config::new(); + dgram_config.set_max_parallel(1); + dgram_config.set_read_timeout(Duration::from_millis(1000)); + dgram_config.set_max_retries(1); + dgram_config.set_udp_payload_size(Some(1400)); + let dgram_stream_config = dgram_stream::Config::from_parts( + dgram_config.clone(), + multi_stream_config.clone(), + ); + let udp_connect = UdpConnect::new(server_addr); + let tcp_connect = TcpConnect::new(server_addr); + let (udptcp_conn, transport) = dgram_stream::Connection::with_config( + udp_connect, + tcp_connect, + dgram_stream_config, + ); + + // Start the run function in a separate task. The run function will + // terminate when all references to the connection have been dropped. + // Make sure that the task does not accidentally get a reference to the + // connection. + tokio::spawn(async move { + transport.run().await; + println!("UDP+TCP run exited"); + }); + + // Send a query message. + let mut request = udptcp_conn.send_request(req.clone()); + + // Get the reply + println!("Wating for UDP+TCP reply"); + let reply = request.get_response().await; + println!("UDP+TCP reply: {:?}", reply); + + // The query may have a reference to the connection. Drop the query + // when it is no longer needed. + drop(request); + + // Create a new TCP connections object. Pass the destination address and + // port as parameter. + let tcp_connect = TcpConnect::new(server_addr); + + // A muli_stream transport connection sets up new TCP connections when + // needed. + let (tcp_conn, transport) = multi_stream::Connection::with_config( + tcp_connect, + multi_stream_config.clone(), + ); + + // Get a future for the run function. The run function receives + // the connection stream as a parameter. + tokio::spawn(async move { + transport.run().await; + println!("multi TCP run exited"); + }); + + // Send a query message. + let mut request = tcp_conn.send_request(req.clone()); + + // Get the reply. A multi_stream connection does not have any timeout. + // Wrap get_result in a timeout. + println!("Wating for multi TCP reply"); + let reply = + timeout(Duration::from_millis(500), request.get_response()).await; + println!("multi TCP reply: {:?}", reply); + + drop(request); + + // Some TLS boiler plate for the root certificates. + let mut root_store = RootCertStore::empty(); + root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map( + |ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }, + )); + + // TLS config + let client_config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + + // Currently the only support TLS connections are the ones that have a + // valid certificate. Use a well known public resolver. + let google_server_addr = + SocketAddr::new(IpAddr::from_str("8.8.8.8").unwrap(), 853); + + // Create a new TLS connections object. We pass the TLS config, the name of + // the remote server and the destination address and port. + let tls_connect = TlsConnect::new( + client_config, + "dns.google".try_into().unwrap(), + google_server_addr, + ); + + // Again create a multi_stream transport connection. + let (tls_conn, transport) = multi_stream::Connection::with_config( + tls_connect, + multi_stream_config, + ); + + // Start the run function. + tokio::spawn(async move { + transport.run().await; + println!("TLS run exited"); + }); + + let mut request = tls_conn.send_request(req.clone()); + println!("Wating for TLS reply"); + let reply = + timeout(Duration::from_millis(500), request.get_response()).await; + println!("TLS reply: {:?}", reply); + + drop(request); + + // Create a transport connection for redundant connections. + let (redun, transp) = redundant::Connection::new(); + + // Start the run function on a separate task. + let run_fut = transp.run(); + tokio::spawn(async move { + run_fut.await; + println!("redundant run terminated"); + }); + + // Add the previously created transports. + redun.add(Box::new(udptcp_conn)).await.unwrap(); + redun.add(Box::new(tcp_conn)).await.unwrap(); + redun.add(Box::new(tls_conn)).await.unwrap(); + + // Start a few queries. + for i in 1..10 { + let mut request = redun.send_request(req.clone()); + let reply = request.get_response().await; + if i == 2 { + println!("redundant connection reply: {:?}", reply); + } + } + + drop(redun); + + // Create a new datagram transport connection. Pass the destination address + // and port as parameter. This transport does not retry over TCP if the + // reply is truncated. This transport does not have a separate run + // function. + let udp_connect = UdpConnect::new(server_addr); + let dgram_conn = + dgram::Connection::with_config(udp_connect, dgram_config); + + // Send a message. + let mut request = dgram_conn.send_request(req.clone()); + // + // Get the reply + let reply = request.get_response().await; + println!("Dgram reply: {:?}", reply); + + // Create a single TCP transport connection. This is usefull for a + // single request or a small burst of requests. + let tcp_conn = match TcpStream::connect(server_addr).await { + Ok(conn) => conn, + Err(err) => { + println!( + "TCP Connection to {} failed: {}, exiting", + server_addr, err + ); + return; + } + }; + + let (tcp, transport) = stream::Connection::new(tcp_conn); + tokio::spawn(async move { + transport.run().await; + println!("single TCP run terminated"); + }); + + // Send a request message. + let mut request = tcp.send_request(req); + + // Get the reply + let reply = request.get_response().await; + println!("TCP reply: {:?}", reply); + + drop(tcp); +} diff --git a/examples/readzone.rs b/examples/readzone.rs index c854eade0..e07cafc89 100644 --- a/examples/readzone.rs +++ b/examples/readzone.rs @@ -15,7 +15,7 @@ fn main() { start.elapsed().unwrap().as_secs_f32() ); let mut i = 0; - while let Some(_) = zone.next_entry().unwrap() { + while zone.next_entry().unwrap().is_some() { i += 1; if i % 100_000_000 == 0 { eprintln!( diff --git a/src/base/message.rs b/src/base/message.rs index 83cd2b80b..79db7997c 100644 --- a/src/base/message.rs +++ b/src/base/message.rs @@ -171,6 +171,18 @@ impl Message { Ok(unsafe { Self::from_octets_unchecked(octets) }) } + /// Creates a message from octets, returning the octets if it fails. + pub fn try_from_octets(octets: Octs) -> Result + where + Octs: AsRef<[u8]>, + { + if Message::check_slice(octets.as_ref()).is_err() { + Err(octets) + } else { + Ok(unsafe { Self::from_octets_unchecked(octets) }) + } + } + /// Creates a message from a bytes value without checking. /// /// # Safety @@ -1194,6 +1206,12 @@ impl From for CopyRecordsError { } } +impl From for CopyRecordsError { + fn from(err: PushError) -> Self { + CopyRecordsError::Push(err) + } +} + //--- Display and Error impl fmt::Display for CopyRecordsError { diff --git a/src/base/message_builder.rs b/src/base/message_builder.rs index 6702bf427..c38ce2ddc 100644 --- a/src/base/message_builder.rs +++ b/src/base/message_builder.rs @@ -300,14 +300,16 @@ impl> MessageBuilder { /// # Conversions /// -impl MessageBuilder { +impl MessageBuilder { /// Converts the message builder into a message builder /// /// This is a no-op. pub fn builder(self) -> MessageBuilder { self } +} +impl MessageBuilder { /// Converts the message builder into a question builder. pub fn question(self) -> QuestionBuilder { QuestionBuilder::new(self) @@ -340,15 +342,14 @@ impl MessageBuilder { pub fn finish(self) -> Target { self.target } +} +impl MessageBuilder { /// Converts the builder into a message. /// /// The method will return a message atop whatever octets sequence the /// builder’s octets builder converts into. - pub fn into_message(self) -> Message<::Octets> - where - Target: FreezeBuilder, - { + pub fn into_message(self) -> Message { unsafe { Message::from_octets_unchecked(self.target.freeze()) } } } @@ -448,6 +449,15 @@ where } } +impl From> for Message +where + Target: FreezeBuilder, +{ + fn from(src: MessageBuilder) -> Self { + src.into_message() + } +} + //--- AsRef // // XXX Should we deref down to target? @@ -554,14 +564,16 @@ impl QuestionBuilder { } } -impl QuestionBuilder { +impl QuestionBuilder { /// Converts the question builder into a question builder. /// /// In other words, doesn’t do anything. pub fn question(self) -> QuestionBuilder { self } +} +impl QuestionBuilder { /// Converts the question builder into an answer builder. pub fn answer(self) -> AnswerBuilder { AnswerBuilder::new(self.builder) @@ -587,15 +599,14 @@ impl QuestionBuilder { pub fn finish(self) -> Target { self.builder.finish() } +} +impl QuestionBuilder { /// Converts the question builder into the final message. /// /// The method will return a message atop whatever octets sequence the /// builder’s octets builder converts into. - pub fn into_message(self) -> Message - where - Target: FreezeBuilder, - { + pub fn into_message(self) -> Message { self.builder.into_message() } } @@ -650,6 +661,15 @@ where } } +impl From> for Message +where + Target: FreezeBuilder, +{ + fn from(src: QuestionBuilder) -> Self { + src.into_message() + } +} + //--- Deref, DerefMut, AsRef, and AsMut impl Deref for QuestionBuilder { @@ -831,15 +851,14 @@ impl AnswerBuilder { pub fn finish(self) -> Target { self.builder.finish() } +} +impl AnswerBuilder { /// Converts the answer builder into the final message. /// /// The method will return a message atop whatever octets sequence the /// builder’s octets builder converts into. - pub fn into_message(self) -> Message - where - Target: FreezeBuilder, - { + pub fn into_message(self) -> Message { self.builder.into_message() } } @@ -894,6 +913,15 @@ where } } +impl From> for Message +where + Target: FreezeBuilder, +{ + fn from(src: AnswerBuilder) -> Self { + src.into_message() + } +} + //--- Deref, DerefMut, AsRef, and AsMut impl Deref for AnswerBuilder { @@ -1055,9 +1083,7 @@ impl AuthorityBuilder { self.rewind(); self.answer } -} -impl AuthorityBuilder { /// Converts the authority builder into an authority builder. /// /// This is identical to the identity function. @@ -1076,15 +1102,14 @@ impl AuthorityBuilder { pub fn finish(self) -> Target { self.answer.finish() } +} +impl AuthorityBuilder { /// Converts the authority builder into the final message. /// /// The method will return a message atop whatever octets sequence the /// builder’s octets builder converts into. - pub fn into_message(self) -> Message - where - Target: FreezeBuilder, - { + pub fn into_message(self) -> Message { self.answer.into_message() } } @@ -1139,6 +1164,15 @@ where } } +impl From> for Message +where + Target: FreezeBuilder, +{ + fn from(src: AuthorityBuilder) -> Self { + src.into_message() + } +} + //--- Deref, DerefMut, AsRef, and AsMut impl Deref for AuthorityBuilder { @@ -1336,9 +1370,7 @@ impl AdditionalBuilder { self.rewind(); self.authority } -} -impl AdditionalBuilder { /// Converts the additional builder into an additional builder. /// /// In other words, does absolutely nothing. @@ -1350,15 +1382,14 @@ impl AdditionalBuilder { pub fn finish(self) -> Target { self.authority.finish() } +} +impl AdditionalBuilder { /// Converts the additional builder into the final message. /// /// The method will return a message atop whatever octets sequence the /// builder’s octets builder converts into. - pub fn into_message(self) -> Message - where - Target: FreezeBuilder, - { + pub fn into_message(self) -> Message { self.authority.into_message() } } @@ -1413,6 +1444,15 @@ where } } +impl From> for Message +where + Target: FreezeBuilder, +{ + fn from(src: AdditionalBuilder) -> Self { + src.into_message() + } +} + //--- Deref, DerefMut, AsRef, and AsMut impl Deref for AdditionalBuilder { diff --git a/src/base/opt/mod.rs b/src/base/opt/mod.rs index cb6a6ce0c..5557ce94e 100644 --- a/src/base/opt/mod.rs +++ b/src/base/opt/mod.rs @@ -40,17 +40,17 @@ opt_types! { //============ Module Content ================================================ use super::header::Header; -use super::iana::{OptRcode, OptionCode, Rtype}; +use super::iana::{Class, OptRcode, OptionCode, Rtype}; use super::name::{Dname, ToDname}; use super::rdata::{ComposeRecordData, ParseRecordData, RecordData}; -use super::record::Record; -use super::wire::{Composer, FormError, ParseError}; +use super::record::{Record, Ttl}; +use super::wire::{Compose, Composer, FormError, ParseError}; use crate::utils::base16; use core::cmp::Ordering; use core::convert::TryInto; use core::marker::PhantomData; use core::{fmt, hash, mem}; -use octseq::builder::{OctetsBuilder, ShortBuf}; +use octseq::builder::{EmptyBuilder, OctetsBuilder, ShortBuf}; use octseq::octets::{Octets, OctetsFrom}; use octseq::parse::Parser; @@ -76,6 +76,15 @@ pub struct Opt { octets: Octs, } +impl Opt { + /// Creates empty OPT record data. + pub fn empty() -> Self { + Self { + octets: Octs::empty(), + } + } +} + impl> Opt { /// Creates OPT record data from an octets sequence. /// @@ -83,7 +92,18 @@ impl> Opt { /// options. It does not check whether the options themselves are valid. pub fn from_octets(octets: Octs) -> Result { Opt::check_slice(octets.as_ref())?; - Ok(Opt { octets }) + Ok(unsafe { Self::from_octets_unchecked(octets) }) + } + + /// Creates OPT record data from octets without checking. + /// + /// # Safety + /// + /// The caller needs to ensure that the slice contains correctly encoded + /// OPT record data. The data of the options themselves does not need to + /// be correct. + unsafe fn from_octets_unchecked(octets: Octs) -> Self { + Self { octets } } /// Parses OPT record data from the beginning of a parser. @@ -128,6 +148,12 @@ impl Opt<[u8]> { } } +impl + ?Sized> Opt { + pub fn for_slice_ref(&self) -> Opt<&[u8]> { + unsafe { Opt::from_octets_unchecked(self.octets.as_ref()) } + } +} + impl + ?Sized> Opt { /// Returns the length of the OPT record data. pub fn len(&self) -> usize { @@ -163,6 +189,44 @@ impl + ?Sized> Opt { } } +impl Opt { + /// Appends a new option to the OPT data. + pub fn push( + &mut self, + option: &Opt, + ) -> Result<(), BuildDataError> { + self.push_raw_option(option.code(), option.compose_len(), |target| { + option.compose_option(target) + }) + } + + /// Appends a raw option to the OPT data. + /// + /// The method will append an option with the given option code. The data + /// of the option will be written via the closure `op`. + pub fn push_raw_option( + &mut self, + code: OptionCode, + option_len: u16, + op: F, + ) -> Result<(), BuildDataError> + where + F: FnOnce(&mut Octs) -> Result<(), Octs::AppendError>, + { + LongOptData::check_len( + self.octets + .as_ref() + .len() + .saturating_add(usize::from(option_len)), + )?; + + code.compose(&mut self.octets)?; + option_len.compose(&mut self.octets)?; + op(&mut self.octets)?; + Ok(()) + } +} + //--- OctetsFrom impl OctetsFrom> for Opt @@ -283,7 +347,8 @@ impl + ?Sized> fmt::Debug for Opt { /// /// The OPT record reappropriates the record header for encoding some /// basic information. This type provides access to this information. It -/// consists of the record header accept for its `rdlen` field. +/// consists of the record header with the exception of the fiinal `rdlen` +/// field. /// /// This is so that `OptBuilder` can safely deref to this type. /// @@ -440,6 +505,23 @@ impl OptRecord { } } + /// Converts the OPT record into a regular record. + pub fn as_record(&self) -> Record<&'static Dname<[u8]>, Opt<&[u8]>> + where + Octs: AsRef<[u8]>, + { + Record::new( + Dname::root_slice(), + Class::Int(self.udp_payload_size), + Ttl::from_secs( + u32::from(self.ext_rcode) << 24 + | u32::from(self.version) << 16 + | u32::from(self.flags), + ), + self.data.for_slice_ref(), + ) + } + /// Returns the UDP payload size. /// /// Through this field a sender of a message can signal the maximum size @@ -451,6 +533,11 @@ impl OptRecord { self.udp_payload_size } + /// Sets the UDP payload size. + pub fn set_udp_payload_size(&mut self, value: u16) { + self.udp_payload_size = value + } + /// Returns the extended rcode. /// /// Some of the bits of the rcode are stored in the regular message @@ -484,6 +571,44 @@ impl OptRecord { } } +impl OptRecord { + /// Appends a new option to the OPT data. + pub fn push( + &mut self, + option: &Opt, + ) -> Result<(), BuildDataError> { + self.data.push(option) + } + + /// Appends a raw option to the OPT data. + /// + /// The method will append an option with the given option code. The data + /// of the option will be written via the closure `op`. + pub fn push_raw_option( + &mut self, + code: OptionCode, + option_len: u16, + op: F, + ) -> Result<(), BuildDataError> + where + F: FnOnce(&mut Octs) -> Result<(), Octs::AppendError>, + { + self.data.push_raw_option(code, option_len, op) + } +} + +impl Default for OptRecord { + fn default() -> Self { + Self { + udp_payload_size: 0, + ext_rcode: 0, + version: 0, + flags: 0, + data: Opt::empty(), + } + } +} + //--- From impl From>> for OptRecord { @@ -521,6 +646,20 @@ impl AsRef> for OptRecord { } } +//--- Debug + +impl> fmt::Debug for OptRecord { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("OptRecord") + .field("udp_payload_size", &self.udp_payload_size) + .field("ext_rcord", &self.ext_rcode) + .field("version", &self.version) + .field("flags", &self.flags) + .field("data", &self.data) + .finish() + } +} + //------------ OptionHeader -------------------------------------------------- /// The header of an OPT option. @@ -859,13 +998,27 @@ impl std::error::Error for LongOptData {} /// An error happened while constructing an SVCB value. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum BuildDataError { - /// The value would exceed the allow length of a value. + /// The value would exceed the allowed length of a value. LongOptData, /// The underlying octets builder ran out of buffer space. ShortBuf, } +impl BuildDataError { + /// Converts the error into a `LongOptData` error for ‘endless’ buffers. + /// + /// # Panics + /// + /// This method will panic if the error is of the `ShortBuf` variant. + pub fn unlimited_buf(self) -> LongOptData { + match self { + Self::LongOptData => LongOptData(()), + Self::ShortBuf => panic!("ShortBuf on unlimited buffer"), + } + } +} + impl From for BuildDataError { fn from(_: LongOptData) -> Self { Self::LongOptData diff --git a/src/lib.rs b/src/lib.rs index 6fe39406e..af99c537e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,9 @@ //! //! Currently, there are the following modules: //! +#![cfg_attr(feature = "net", doc = "* [net]:")] +#![cfg_attr(not(feature = "net"), doc = "* net:")] +//! Sending and receiving DNS message. #![cfg_attr(feature = "resolv", doc = "* [resolv]:")] #![cfg_attr(not(feature = "resolv"), doc = "* resolv:")] //! An asynchronous DNS resolver based on the @@ -48,12 +51,13 @@ //! Finally, the [dep] module contains re-exports of some important //! dependencies to help avoid issues with multiple versions of a crate. //! -//! # Reference of Feature Flags +//! # Reference of feature flags //! -//! The following is the complete list of the feature flags available. +//! The following is the complete list of the feature flags with the +//! exception of unstable features which are described below. //! //! * `bytes`: Enables using the types `Bytes` and `BytesMut` from the -//! [bytes](https://github.com/tokio-rs/bytes) crate as octet sequences. +//! [bytes](https://github.com/tokio-rs/bytes) crate as octet sequences. //! * `chrono`: Adds the [chrono](https://github.com/chronotope/chrono) //! crate as a dependency. This adds support for generating serial numbers //! from time stamps. @@ -104,6 +108,30 @@ #![cfg_attr(feature = "zonefile", doc = " [zonefile]")] #![cfg_attr(not(feature = "zonefile"), doc = " zonefile")] //! module and currently also enables the `bytes` and `std` features. +//! +//! # Unstable features +//! +//! When adding new functionality to the crate, practical experience is +//! necessary to arrive at a good, user friendly design. Unstable features +//! allow adding and rapidly changing new code without having to release +//! versions allowing breaking changes all the time. If you use unstable +//! features, it is best to specify a concrete version as a dependency in +//! `Cargo.toml` using the `=` operator, e.g.: +//! +//! ```text +//! [dependencies] +//! domain = "=0.9.3" +//! ``` +//! +//! Currently, the following unstable features exist: +//! +//! * `unstable-client-transport`: sending and receiving DNS messages from +//! a client perspective; primarily the `net::client` module. +//! +//! Note: Some functionality is currently informally marked as +//! “experimental” since it was introduced before adoption of the concept +//! of unstable features. These will follow proper Semver practice but may +//! significant changes in releases with breakting changes. #![no_std] #![allow(renamed_and_removed_lints)] @@ -121,6 +149,7 @@ extern crate core; pub mod base; pub mod dep; +pub mod net; pub mod rdata; pub mod resolv; pub mod sign; diff --git a/src/net/client/dgram.rs b/src/net/client/dgram.rs new file mode 100644 index 000000000..a6197e639 --- /dev/null +++ b/src/net/client/dgram.rs @@ -0,0 +1,510 @@ +//! A client over datagram protocols. +//! +//! This module implements a DNS client for use with datagram protocols, i.e., +//! message-oriented, connection-less, unreliable network protocols. In +//! practice, this is pretty much exclusively UDP. + +#![warn(missing_docs)] + +// To do: +// - cookies + +use crate::base::Message; +use crate::net::client::protocol::{ + AsyncConnect, AsyncDgramRecv, AsyncDgramRecvEx, AsyncDgramSend, + AsyncDgramSendEx, +}; +use crate::net::client::request::{ + ComposeRequest, Error, GetResponse, SendRequest, +}; +use bytes::Bytes; +use core::{cmp, fmt}; +use octseq::OctetsInto; +use std::boxed::Box; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::{error, io}; +use tokio::sync::Semaphore; +use tokio::time::{timeout_at, Duration, Instant}; + +//------------ Configuration Constants ---------------------------------------- + +/// Configuration limits for the maximum number of parallel requests. +const MAX_PARALLEL: DefMinMax = DefMinMax::new(100, 1, 1000); + +/// Configuration limits for the read timeout. +const READ_TIMEOUT: DefMinMax = DefMinMax::new( + Duration::from_secs(5), + Duration::from_millis(1), + Duration::from_secs(60), +); + +/// Configuration limits for the maximum number of retries. +const MAX_RETRIES: DefMinMax = DefMinMax::new(5, 1, 100); + +/// Default UDP payload size. +const DEF_UDP_PAYLOAD_SIZE: u16 = 1232; + +/// The default receive buffer size. +const DEF_RECV_SIZE: usize = 2000; + +//------------ Config --------------------------------------------------------- + +/// Configuration of a datagram transport. +#[derive(Clone, Debug)] +pub struct Config { + /// Maximum number of parallel requests for a transport connection. + max_parallel: usize, + + /// Read timeout. + read_timeout: Duration, + + /// Maximum number of retries. + max_retries: u8, + + /// EDNS UDP payload size. + /// + /// If this is `None`, no OPT record will be included at all. + udp_payload_size: Option, + + /// Receive buffer size. + recv_size: usize, +} + +impl Config { + /// Creates a new config with default values. + pub fn new() -> Self { + Default::default() + } + + /// Sets the maximum number of parallel requests. + /// + /// Once this many number of requests are currently outstanding, + /// additional requests will wait. + /// + /// If this value is too small or too large, it will be caped. + pub fn set_max_parallel(&mut self, value: usize) { + self.max_parallel = MAX_PARALLEL.limit(value) + } + + /// Returns the maximum number of parallel requests. + pub fn max_parallel(&self) -> usize { + self.max_parallel + } + + /// Sets the read timeout. + /// + /// The read timeout is the maximum amount of time to wait for any + /// response after a request was sent. + /// + /// If this value is too small or too large, it will be caped. + pub fn set_read_timeout(&mut self, value: Duration) { + self.read_timeout = READ_TIMEOUT.limit(value) + } + + /// Returns the read timeout. + pub fn read_timeout(&self) -> Duration { + self.read_timeout + } + + /// Sets the maximum number a request is retried before giving up. + /// + /// If this value is too small or too large, it will be caped. + pub fn set_max_retries(&mut self, value: u8) { + self.max_retries = MAX_RETRIES.limit(value) + } + + /// Returns the maximum number of request retries. + pub fn max_retries(&self) -> u8 { + self.max_retries + } + + /// Sets the requested UDP payload size. + /// + /// This value indicates to the server the maximum size of a UDP packet. + /// For UDP on public networks, this value should be left at the default + /// of 1232 to avoid issues rising from packet fragmentation. See + /// [draft-ietf-dnsop-avoid-fragmentation] for a discussion on these + /// issues and recommendations. + /// + /// On private networks or protocols other than UDP, other values can be + /// used. + /// + /// Setting the UDP payload size to `None` currently results in messages + /// that will not include an OPT record. + /// + /// [draft-ietf-dnsop-avoid-fragmentation]: https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/ + pub fn set_udp_payload_size(&mut self, value: Option) { + self.udp_payload_size = value; + } + + /// Returns the UDP payload size. + pub fn udp_payload_size(&self) -> Option { + self.udp_payload_size + } + + /// Sets the receive buffer size. + /// + /// This is the amount of memory that is allocated for receiving a + /// response. + pub fn set_recv_size(&mut self, size: usize) { + self.recv_size = size + } + + /// Returns the receive buffer size. + pub fn recv_size(&self) -> usize { + self.recv_size + } +} + +impl Default for Config { + fn default() -> Self { + Self { + max_parallel: MAX_PARALLEL.default(), + read_timeout: READ_TIMEOUT.default(), + max_retries: MAX_RETRIES.default(), + udp_payload_size: Some(DEF_UDP_PAYLOAD_SIZE), + recv_size: DEF_RECV_SIZE, + } + } +} + +//------------ Connection ----------------------------------------------------- + +/// A datagram protocol connection. +/// +/// Because it owns the connection’s resources, this type is not `Clone`. +/// However, it is entirely safe to share it by sticking it into e.g. an arc. +#[derive(Debug)] +pub struct Connection { + state: Arc>, +} + +#[derive(Debug)] +struct ConnectionState { + /// User configuration variables. + config: Config, + + /// Connections to datagram sockets. + connect: S, + + /// Semaphore to limit access to UDP sockets. + semaphore: Semaphore, +} + +impl Connection { + /// Create a new datagram transport with default configuration. + pub fn new(connect: S) -> Self { + Self::with_config(connect, Default::default()) + } + + /// Create a new datagram transport with a given configuration. + pub fn with_config(connect: S, config: Config) -> Self { + Self { + state: Arc::new(ConnectionState { + semaphore: Semaphore::new(config.max_parallel), + config, + connect, + }), + } + } +} + +impl Connection +where + S: AsyncConnect, + S::Connection: AsyncDgramRecv + AsyncDgramSend + Unpin, +{ + /// Performs a request. + /// + /// Sends the provided and returns either a response or an error. If there + /// are currently too many active queries, the future will wait until the + /// number has dropped below the limit. + async fn handle_request_impl( + self, + mut request: Req, + ) -> Result, Error> { + // Acquire the semaphore or wait for it. + let _ = self + .state + .semaphore + .acquire() + .await + .expect("semaphore closed"); + + // A place to store the receive buffer for reuse. + let mut reuse_buf = None; + + // Transmit loop. + for _ in 0..self.state.config.max_retries { + let mut sock = self + .state + .connect + .connect() + .await + .map_err(QueryError::connect)?; + + // Set random ID in header + request.header_mut().set_random_id(); + + // Set UDP payload size if necessary. + if let Some(size) = self.state.config.udp_payload_size { + request.set_udp_payload_size(size) + } + + // Create the message and send it out. + let request_msg = request.to_message(); + let dgram = request_msg.as_slice(); + let sent = sock.send(dgram).await.map_err(QueryError::send)?; + if sent != dgram.len() { + return Err(QueryError::short_send().into()); + } + + // Receive loop. It may at most take read_timeout time. + let deadline = Instant::now() + self.state.config.read_timeout; + while deadline > Instant::now() { + let mut buf = reuse_buf.take().unwrap_or_else(|| { + // XXX use uninit'ed mem here. + vec![0; self.state.config.recv_size] + }); + let len = + match timeout_at(deadline, sock.recv(&mut buf)).await { + Ok(Ok(len)) => len, + Ok(Err(err)) => { + // Receiving failed. + return Err(QueryError::receive(err).into()); + } + Err(_) => { + // Timeout. + break; + } + }; + buf.truncate(len); + + // We ignore garbage since there is a timer on this whole + // thing. + let answer = match Message::try_from_octets(buf) { + Ok(answer) => answer, + Err(buf) => { + // Just go back to receiving. + reuse_buf = Some(buf); + continue; + } + }; + + if !request.is_answer(answer.for_slice()) { + // Wrong answer, go back to receiving + reuse_buf = Some(answer.into_octets()); + continue; + } + return Ok(answer.octets_into()); + } + } + Err(QueryError::timeout().into()) + } +} + +//--- Clone + +impl Clone for Connection { + fn clone(&self) -> Self { + Self { + state: self.state.clone(), + } + } +} + +//--- SendRequest + +impl SendRequest for Connection +where + S: AsyncConnect + Clone + Send + Sync + 'static, + S::Connection: + AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, + Req: ComposeRequest + Clone + Send + Sync + 'static, +{ + fn send_request(&self, request_msg: Req) -> Box { + Box::new(Request { + fut: Box::pin(self.clone().handle_request_impl(request_msg)), + }) + } +} + +//------------ Request ------------------------------------------------------ + +/// The state of a DNS request. +pub struct Request { + /// Future that does the actual work of GetResponse. + fut: Pin, Error>> + Send>>, +} + +impl Request { + /// Async function that waits for the future stored in Request to complete. + async fn get_response_impl(&mut self) -> Result, Error> { + (&mut self.fut).await + } +} + +impl fmt::Debug for Request { + fn fmt(&self, _: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + todo!() + } +} + +impl GetResponse for Request { + fn get_response( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + > { + Box::pin(self.get_response_impl()) + } +} + +//------------ DefMinMax ----------------------------------------------------- + +/// The default, minimum, and maximum values for a config variable. +#[derive(Clone, Copy)] +struct DefMinMax { + /// The default value, + def: T, + + /// The minimum value, + min: T, + + /// The maximum value, + max: T, +} + +impl DefMinMax { + /// Creates a new value. + const fn new(def: T, min: T, max: T) -> Self { + Self { def, min, max } + } + + /// Returns the default value. + fn default(self) -> T { + self.def + } + + /// Trims the given value to fit into the minimum/maximum range. + fn limit(self, value: T) -> T + where + T: Ord, + { + cmp::max(self.min, cmp::min(self.max, value)) + } +} + +//============ Errors ======================================================== + +//------------ QueryError ---------------------------------------------------- + +/// A query failed. +#[derive(Debug)] +pub struct QueryError { + /// Which step failed? + kind: QueryErrorKind, + + /// The underlying IO error. + io: std::io::Error, +} + +impl QueryError { + fn new(kind: QueryErrorKind, io: io::Error) -> Self { + Self { kind, io } + } + + fn connect(io: io::Error) -> Self { + Self::new(QueryErrorKind::Connect, io) + } + + fn send(io: io::Error) -> Self { + Self::new(QueryErrorKind::Send, io) + } + + fn short_send() -> Self { + Self::new( + QueryErrorKind::Send, + io::Error::new(io::ErrorKind::Other, "short request sent"), + ) + } + + fn timeout() -> Self { + Self::new( + QueryErrorKind::Timeout, + io::Error::new(io::ErrorKind::TimedOut, "timeout expired"), + ) + } + + fn receive(io: io::Error) -> Self { + Self::new(QueryErrorKind::Receive, io) + } +} + +impl QueryError { + /// Returns information about when the query has failed. + pub fn kind(&self) -> QueryErrorKind { + self.kind + } + + /// Converts the query error into the underlying IO error. + pub fn io_error(self) -> std::io::Error { + self.io + } +} + +impl From for std::io::Error { + fn from(err: QueryError) -> std::io::Error { + err.io + } +} + +impl fmt::Display for QueryError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}: {}", self.kind.error_str(), self.io) + } +} + +impl error::Error for QueryError {} + +//------------ QueryErrorKind ------------------------------------------------ + +/// Which part of processing the query failed? +#[derive(Copy, Clone, Debug)] +pub enum QueryErrorKind { + /// Failed to connect to the remote. + Connect, + + /// Failed to send the request. + Send, + + /// The request has timed out. + Timeout, + + /// Failed to read the response. + Receive, +} + +impl QueryErrorKind { + /// Returns the string to be used when displaying a query error. + fn error_str(self) -> &'static str { + match self { + Self::Connect => "connecting failed", + Self::Send => "sending request failed", + Self::Timeout | Self::Receive => "reading response failed", + } + } +} + +impl fmt::Display for QueryErrorKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match self { + Self::Connect => "connecting failed", + Self::Send => "sending request failed", + Self::Timeout => "request timeout", + Self::Receive => "reading response failed", + }) + } +} diff --git a/src/net/client/dgram_stream.rs b/src/net/client/dgram_stream.rs new file mode 100644 index 000000000..42b007ade --- /dev/null +++ b/src/net/client/dgram_stream.rs @@ -0,0 +1,251 @@ +//! A UDP transport that falls back to TCP if the reply is truncated + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +// To do: +// - handle shutdown + +use crate::base::Message; +use crate::net::client::dgram; +use crate::net::client::multi_stream; +use crate::net::client::protocol::{ + AsyncConnect, AsyncDgramRecv, AsyncDgramSend, +}; +use crate::net::client::request::{ + ComposeRequest, Error, GetResponse, SendRequest, +}; +use bytes::Bytes; +use std::boxed::Box; +use std::fmt::Debug; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +//------------ Config --------------------------------------------------------- + +/// Configuration for an octet_stream transport connection. +#[derive(Clone, Debug, Default)] +pub struct Config { + /// Configuration for the UDP transport. + dgram: dgram::Config, + + /// Configuration for the multi_stream (TCP) transport. + multi_stream: multi_stream::Config, +} + +impl Config { + /// Creates a new config with default values. + pub fn new() -> Self { + Default::default() + } + + /// Creates a new config from the two portions. + pub fn from_parts( + dgram: dgram::Config, + multi_stream: multi_stream::Config, + ) -> Self { + Self { + dgram, + multi_stream, + } + } + + /// Returns the datagram config. + pub fn dgram(&self) -> &dgram::Config { + &self.dgram + } + + /// Returns a mutable reference to the datagram config. + pub fn dgram_mut(&mut self) -> &mut dgram::Config { + &mut self.dgram + } + + /// Sets the datagram config. + pub fn set_dgram(&mut self, dgram: dgram::Config) { + self.dgram = dgram + } + + /// Returns the stream config. + pub fn stream(&self) -> &multi_stream::Config { + &self.multi_stream + } + + /// Returns a mutable reference to the stream config. + pub fn stream_mut(&mut self) -> &mut multi_stream::Config { + &mut self.multi_stream + } + + /// Sets the stream config. + pub fn set_stream(&mut self, stream: multi_stream::Config) { + self.multi_stream = stream + } +} + +//------------ Connection ----------------------------------------------------- + +/// DNS transport connection that first issues a query over a UDP transport and +/// falls back to TCP if the reply is truncated. +#[derive(Clone)] +pub struct Connection { + /// The UDP transport connection. + udp_conn: Arc>, + + /// The TCP transport connection. + tcp_conn: multi_stream::Connection, +} + +impl Connection +where + DgramS: AsyncConnect + Clone + Send + Sync + 'static, + DgramS::Connection: + AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, +{ + /// Creates a new multi-stream transport with default configuration. + pub fn new( + dgram_remote: DgramS, + stream_remote: StreamS, + ) -> (Self, multi_stream::Transport) { + Self::with_config(dgram_remote, stream_remote, Default::default()) + } + + /// Creates a new multi-stream transport. + pub fn with_config( + dgram_remote: DgramS, + stream_remote: StreamS, + config: Config, + ) -> (Self, multi_stream::Transport) { + let udp_conn = + dgram::Connection::with_config(dgram_remote, config.dgram).into(); + let (tcp_conn, transport) = multi_stream::Connection::with_config( + stream_remote, + config.multi_stream, + ); + (Self { udp_conn, tcp_conn }, transport) + } +} + +//--- SendRequest + +impl SendRequest for Connection +where + DgramS: AsyncConnect + Clone + Debug + Send + Sync + 'static, + DgramS::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, + Req: ComposeRequest + Clone + 'static, +{ + fn send_request(&self, request_msg: Req) -> Box { + Box::new(Request::new( + request_msg, + self.udp_conn.clone(), + self.tcp_conn.clone(), + )) + } +} + +//------------ Request -------------------------------------------------------- + +/// Object that contains the current state of a query. +#[derive(Debug)] +pub struct Request { + /// Reqeust message. + request_msg: Req, + + /// UDP transport to be used. + udp_conn: Arc>, + + /// TCP transport to be used. + tcp_conn: multi_stream::Connection, + + /// Current state of the request. + state: QueryState, +} + +/// Status of the query. +#[derive(Debug)] +enum QueryState { + /// Start a request over the UDP transport. + StartUdpRequest, + + /// Get the response from the UDP transport. + GetUdpResponse(Box), + + /// Start a request over the TCP transport. + StartTcpRequest, + + /// Get the response from the TCP transport. + GetTcpResponse(Box), +} + +impl Request +where + S: AsyncConnect + Clone + Send + Sync + 'static, + Req: ComposeRequest + Clone + 'static, +{ + /// Create a new Request object. + /// + /// The initial state is to start with a UDP transport. + fn new( + request_msg: Req, + udp_conn: Arc>, + tcp_conn: multi_stream::Connection, + ) -> Request { + Self { + request_msg, + udp_conn, + tcp_conn, + state: QueryState::StartUdpRequest, + } + } + + /// Get the response of a DNS request. + /// + /// This function is cancel safe. + async fn get_response_impl(&mut self) -> Result, Error> + where + S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, + { + loop { + match &mut self.state { + QueryState::StartUdpRequest => { + let msg = self.request_msg.clone(); + let request = self.udp_conn.send_request(msg); + self.state = QueryState::GetUdpResponse(request); + continue; + } + QueryState::GetUdpResponse(ref mut request) => { + let response = request.get_response().await?; + if response.header().tc() { + self.state = QueryState::StartTcpRequest; + continue; + } + return Ok(response); + } + QueryState::StartTcpRequest => { + let msg = self.request_msg.clone(); + let request = self.tcp_conn.send_request(msg); + self.state = QueryState::GetTcpResponse(request); + continue; + } + QueryState::GetTcpResponse(ref mut query) => { + let response = query.get_response().await?; + return Ok(response); + } + } + } + } +} + +impl GetResponse for Request +where + S: AsyncConnect + Clone + Debug + Send + Sync + 'static, + S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, + Req: ComposeRequest + Clone + 'static, +{ + fn get_response( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + > { + Box::pin(self.get_response_impl()) + } +} diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs new file mode 100644 index 000000000..85e162094 --- /dev/null +++ b/src/net/client/mod.rs @@ -0,0 +1,145 @@ +//! Sending requests and receiving responses. +//! +//! This module provides DNS transport protocols that allow sending a DNS +//! request and receiving the corresponding reply. +//! +//! Sending a request and receiving the reply consists of four steps: +//! 1) Creating a request message, +//! 2) Creating a DNS transport, +//! 3) Sending the request, and +//! 4) Receiving the reply. +//! +//! The first and second step are independent and can happen in any order. +//! The third step uses the resuts of the first and second step. +//! Finally, the fourth step uses the result of the third step. + +//! # Creating a request message +//! +//! The DNS transport protocols expect a request message that implements the +//! [ComposeRequest][request::ComposeRequest] trait. +//! This trait allows transports to add ENDS(0) options, set flags, etc. +//! The [RequestMessage][request::RequestMessage] type implements this trait. +//! The [new][request::RequestMessage::new] method of RequestMessage create +//! a new RequestMessage object based an existing messsage (that implements +//! ```Into>```). +//! +//! For example: +//! ```rust +//! # use domain::base::{Dname, MessageBuilder, Rtype}; +//! # use domain::net::client::request::RequestMessage; +//! let mut msg = MessageBuilder::new_vec(); +//! msg.header_mut().set_rd(true); +//! let mut msg = msg.question(); +//! msg.push( +//! (Dname::vec_from_str("example.com").unwrap(), Rtype::Aaaa) +//! ).unwrap(); +//! let req = RequestMessage::new(msg); +//! ``` + +//! # Creating a DNS transport +//! +//! Creating a DNS transport typically involves creating a configuration +//! object, creating the underlying network connection, creating the +//! DNS transport and running a ```run``` method as a separate task. This +//! is illustrated in the following example: +//! ```rust +//! # use domain::net::client::multi_stream; +//! # use domain::net::client::protocol::TcpConnect; +//! # use domain::net::client::request::SendRequest; +//! # use std::time::Duration; +//! # async fn _test() { +//! # let server_addr = String::from("127.0.0.1:53"); +//! let mut multi_stream_config = multi_stream::Config::default(); +//! multi_stream_config.stream_mut().set_response_timeout( +//! Duration::from_millis(100), +//! ); +//! let tcp_connect = TcpConnect::new(server_addr); +//! let (tcp_conn, transport) = multi_stream::Connection::with_config( +//! tcp_connect, multi_stream_config +//! ); +//! tokio::spawn(transport.run()); +//! # let req = domain::net::client::request::RequestMessage::new( +//! # domain::base::MessageBuilder::new_vec() +//! # ); +//! # let mut request = tcp_conn.send_request(req); +//! # } +//! ``` +//! The currently implemented DNS transports have the following layering. At +//! the lowest layer are [dgram] and [stream]. The dgram transport is used for +//! DNS over UDP, the stream transport is used for DNS over a single TCP or +//! TLS connection. The transport works as long as the connection continuous +//! to exist. +//! The [multi_stream] transport is layered on top of stream, and creates new +//! TCP or TLS connections when old ones terminates. +//! Next, [dgram_stream] combines the dgram transport with the multi_stream +//! transport. This is typically needed because a request over UDP can receive +//! a truncated response, which should be retried over TCP. +//! Finally, the [redundant] transport can select the best transport out of +//! a collection of underlying transports. + +//! # Sending the request +//! +//! A DNS transport implements the [SendRequest][request::SendRequest] trait. +//! This trait provides a single method, +//! [send_request][request::SendRequest::send_request] and returns an object +//! that provides the response. +//! +//! For example: +//! ```no_run +//! # use domain::net::client::request::SendRequest; +//! # async fn _test() { +//! # let (tls_conn, _) = domain::net::client::stream::Connection::new( +//! # domain::net::client::protocol::TcpConnect::new( +//! # String::from("127.0.0.1:53") +//! # ) +//! # ); +//! # let req = domain::net::client::request::RequestMessage::new( +//! # domain::base::MessageBuilder::new_vec() +//! # ); +//! let mut request = tls_conn.send_request(req); +//! # } +//! ``` +//! where ```tls_conn``` is a transport connection for DNS over TLS. + +//! # Receiving the request +//! +//! The [send_request][request::SendRequest::send_request] method returns an +//! object that implements the [GetResponse][request::GetResponse] trait. +//! This trait provides a single method, +//! [get_response][request::GetResponse::get_response], which returns the +//! DNS response message or an error. This method is intended to be +//! cancelation safe. +//! +//! For example: +//! ```no_run +//! # use crate::domain::net::client::request::SendRequest; +//! # async fn _test() { +//! # let (tls_conn, _) = domain::net::client::stream::Connection::new( +//! # domain::net::client::protocol::TcpConnect::new( +//! # String::from("127.0.0.1:53") +//! # ) +//! # ); +//! # let req = domain::net::client::request::RequestMessage::new( +//! # domain::base::MessageBuilder::new_vec() +//! # ); +//! # let mut request = tls_conn.send_request(req); +//! let reply = request.get_response().await; +//! # } +//! ``` + +//! # Example with various transport connections +//! ```no_run +#![doc = include_str!("../../../examples/client-transports.rs")] +//! ``` + +#![cfg(feature = "unstable-client-transport")] +#![cfg_attr(docsrs, doc(cfg(feature = "unstable-client-transport")))] +#![warn(missing_docs)] + +pub mod dgram; +pub mod dgram_stream; +pub mod multi_stream; +pub mod protocol; +pub mod redundant; +pub mod request; +pub mod stream; diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs new file mode 100644 index 000000000..a5690a464 --- /dev/null +++ b/src/net/client/multi_stream.rs @@ -0,0 +1,628 @@ +//! A DNS over multiple octet streams transport + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +// To do: +// - too many connection errors + +use crate::base::Message; +use crate::net::client::protocol::AsyncConnect; +use crate::net::client::request::{ + ComposeRequest, Error, GetResponse, SendRequest, +}; +use crate::net::client::stream; +use bytes::Bytes; +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; +use rand::random; +use std::boxed::Box; +use std::fmt::Debug; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; +use tokio::io; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::{mpsc, oneshot}; +use tokio::time::{sleep_until, Instant}; + +//------------ Constants ----------------------------------------------------- + +/// Capacity of the channel that transports `ChanReq`. +const DEF_CHAN_CAP: usize = 8; + +/// Error messafe when the connection is closed. +const ERR_CONN_CLOSED: &str = "connection closed"; + +//------------ Config --------------------------------------------------------- + +/// Configuration for an multi-stream transport. +#[derive(Clone, Debug, Default)] +pub struct Config { + /// Configuration of the underlying stream transport. + stream: stream::Config, +} + +impl Config { + /// Returns the underlying stream config. + pub fn stream(&self) -> &stream::Config { + &self.stream + } + + /// Returns a mutable reference to the underlying stream config. + pub fn stream_mut(&mut self) -> &mut stream::Config { + &mut self.stream + } +} + +impl From for Config { + fn from(stream: stream::Config) -> Self { + Self { stream } + } +} + +//------------ Connection ----------------------------------------------------- + +/// A connection to a multi-stream transport. +#[derive(Debug)] +pub struct Connection { + /// The sender half of the connection request channel. + sender: mpsc::Sender>, +} + +impl Connection { + /// Creates a new multi-stream transport with default configuration. + pub fn new(remote: Remote) -> (Self, Transport) { + Self::with_config(remote, Default::default()) + } + + /// Creates a new multi-stream transport. + pub fn with_config( + remote: Remote, + config: Config, + ) -> (Self, Transport) { + let (sender, transport) = Transport::new(remote, config); + (Self { sender }, transport) + } +} + +impl Connection { + /// Sends a request and receives a response. + pub async fn request( + &self, + request: Req, + ) -> Result, Error> { + Request::new(self.clone(), request).get_response().await + } + + /// Starts a request. + /// + /// This is the future that is returned by the `SendRequest` impl. + async fn _send_request( + &self, + request: &Req, + ) -> Result, Error> + where + Req: 'static, + { + let gr = Request::new(self.clone(), request.clone()); + Ok(Box::new(gr)) + } + + /// Request a new connection. + async fn new_conn( + &self, + opt_id: Option, + ) -> Result>, Error> { + let (sender, receiver) = oneshot::channel(); + let req = ChanReq { + cmd: ReqCmd::NewConn(opt_id, sender), + }; + self.sender + .send(req) + .await + .map_err(|_| Error::ConnectionClosed)?; + Ok(receiver) + } + + /// Request a shutdown. + pub async fn shutdown(&self) -> Result<(), &'static str> { + let req = ChanReq { + cmd: ReqCmd::Shutdown, + }; + match self.sender.send(req).await { + Err(_) => + // Send error. The receiver is gone, this means that the + // connection is closed. + { + Err(ERR_CONN_CLOSED) + } + Ok(_) => Ok(()), + } + } +} + +//--- Clone + +impl Clone for Connection { + fn clone(&self) -> Self { + Self { + sender: self.sender.clone(), + } + } +} + +//--- SendRequest + +impl SendRequest for Connection +where + Req: ComposeRequest + Clone + 'static, +{ + fn send_request(&self, request: Req) -> Box { + Box::new(Request::new(self.clone(), request)) + } +} + +//------------ Request -------------------------------------------------------- + +/// The connection side of an active request. +#[derive(Debug)] +struct Request { + /// The request message. + /// + /// It is kept so we can compare a response with it. + request_msg: Req, + + /// Current state of the query. + state: QueryState, + + /// The underlying transport. + conn: Connection, + + /// The id of the most recent connection, if any. + conn_id: Option, + + /// Number of retries with delay. + delayed_retry_count: u64, +} + +/// The states of the query state machine. +#[derive(Debug)] +enum QueryState { + /// Request a new connection. + RequestConn, + + /// Receive a new connection from the receiver. + ReceiveConn(oneshot::Receiver>), + + /// Start a query using the given stream transport. + StartQuery(Arc>), + + /// Get the result of the query. + GetResult(stream::Request), + + /// Wait until trying again. + /// + /// The instant represents when the error occurred, the duration how + /// long to wait. + Delay(Instant, Duration), + + /// A response has been received and the query is done. + Done, +} + +/// The response to a connection request. +type ChanResp = Result, Arc>; + +/// The successful response to a connection request. +#[derive(Debug)] +struct ChanRespOk { + /// The id of this connection. + id: u64, + + /// The new stream transport to use for sending a request. + conn: Arc>, +} + +impl Request { + /// Creates a new query. + fn new(conn: Connection, request_msg: Req) -> Self { + Self { + conn, + request_msg, + state: QueryState::RequestConn, + conn_id: None, + delayed_retry_count: 0, + } + } +} + +impl Request { + /// Get the result of a DNS request. + /// + /// This function is cancellation safe. If its future is dropped before + /// it is resolved, you can call it again to get a new future. + pub async fn get_response(&mut self) -> Result, Error> { + loop { + match self.state { + QueryState::RequestConn => { + let rx = match self.conn.new_conn(self.conn_id).await { + Ok(rx) => rx, + Err(err) => { + self.state = QueryState::Done; + return Err(err); + } + }; + self.state = QueryState::ReceiveConn(rx); + } + QueryState::ReceiveConn(ref mut receiver) => { + let res = match receiver.await { + Ok(res) => res, + Err(_) => { + // Assume receive error + self.state = QueryState::Done; + return Err(Error::StreamReceiveError); + } + }; + + // Another Result. This time from executing the request + match res { + Err(_) => { + self.delayed_retry_count += 1; + let retry_time = + retry_time(self.delayed_retry_count); + self.state = + QueryState::Delay(Instant::now(), retry_time); + continue; + } + Ok(ok_res) => { + let id = ok_res.id; + let conn = ok_res.conn; + + self.conn_id = Some(id); + self.state = QueryState::StartQuery(conn); + continue; + } + } + } + QueryState::StartQuery(ref mut conn) => { + self.state = QueryState::GetResult( + conn.get_request(self.request_msg.clone()), + ); + continue; + } + QueryState::GetResult(ref mut query) => { + match query.get_response().await { + Ok(reply) => return Ok(reply), + // XXX This replicates the previous behavior. But + // maybe we should have a whole category of + // fatal errors where retrying doesn’t make any + // sense? + Err(Error::WrongReplyForQuery) => { + return Err(Error::WrongReplyForQuery) + } + Err(_) => { + self.delayed_retry_count += 1; + let retry_time = + retry_time(self.delayed_retry_count); + self.state = + QueryState::Delay(Instant::now(), retry_time); + continue; + } + } + } + QueryState::Delay(instant, duration) => { + sleep_until(instant + duration).await; + self.state = QueryState::RequestConn; + } + QueryState::Done => { + panic!("Already done"); + } + } + } + } +} + +impl GetResponse for Request { + fn get_response( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + > { + Box::pin(Self::get_response(self)) + } +} + +//------------ Transport ------------------------------------------------ + +/// The actual implementation of [Connection]. +#[derive(Debug)] +pub struct Transport { + /// User configuration values. + config: Config, + + /// The remote destination. + stream: Remote, + + /// Underlying stream connection. + conn_state: SingleConnState3, + + /// Current connection id. + conn_id: u64, + + /// Receiver part of the channel. + receiver: mpsc::Receiver>, +} + +#[derive(Debug)] +/// A request to [Connection::run] either for a new stream or to +/// shutdown. +struct ChanReq { + /// A requests consists of a command. + cmd: ReqCmd, +} + +#[derive(Debug)] +/// Commands that can be requested. +enum ReqCmd { + /// Request for a (new) connection. + /// + /// The id of the previous connection (if any) is passed as well as a + /// channel to send the reply. + NewConn(Option, ReplySender), + + /// Shutdown command. + Shutdown, +} + +/// This is the type of sender in [ReqCmd]. +type ReplySender = oneshot::Sender>; + +/// State of the current underlying stream transport. +#[derive(Debug)] +enum SingleConnState3 { + /// No current stream transport. + None, + + /// Current stream transport. + Some(Arc>), + + /// State that deals with an error getting a new octet stream from + /// a connection stream. + Err(ErrorState), +} + +/// State associated with a failed attempt to create a new stream +/// transport. +#[derive(Clone, Debug)] +struct ErrorState { + /// The error we got from the most recent attempt. + error: Arc, + + /// How many times we tried so far. + retries: u64, + + /// When we got an error. + timer: Instant, + + /// Time to wait before trying to create a new connection. + timeout: Duration, +} + +impl Transport { + /// Creates a new transport. + fn new( + stream: Remote, + config: Config, + ) -> (mpsc::Sender>, Self) { + let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP); + ( + sender, + Self { + config, + stream, + conn_state: SingleConnState3::None, + conn_id: 0, + receiver, + }, + ) + } +} + +impl Transport +where + Remote: AsyncConnect, + Remote::Connection: AsyncRead + AsyncWrite, + Req: ComposeRequest, +{ + /// Run the transport machinery. + pub async fn run(mut self) { + let mut curr_cmd: Option> = None; + let mut do_stream = false; + let mut runners = FuturesUnordered::new(); + let mut stream_fut: Pin< + Box< + dyn Future< + Output = Result, + > + Send, + >, + > = Box::pin(stream_nop()); + let mut opt_chan = None; + + loop { + if let Some(req) = curr_cmd { + assert!(!do_stream); + curr_cmd = None; + match req { + ReqCmd::NewConn(opt_id, chan) => { + if let SingleConnState3::Err(error_state) = + &self.conn_state + { + if error_state.timer.elapsed() + < error_state.timeout + { + let resp = + ChanResp::Err(error_state.error.clone()); + + // Ignore errors. We don't care if the receiver + // is gone + _ = chan.send(resp); + continue; + } + + // Try to set up a new connection + } + + // Check if the command has an id greather than the + // current id. + if let Some(id) = opt_id { + if id >= self.conn_id { + // We need a new connection. Remove the + // current one. This is the best place to + // increment conn_id. + self.conn_id += 1; + self.conn_state = SingleConnState3::None; + } + } + // If we still have a connection then we can reply + // immediately. + if let SingleConnState3::Some(conn) = &self.conn_state + { + let resp = ChanResp::Ok(ChanRespOk { + id: self.conn_id, + conn: conn.clone(), + }); + // Ignore errors. We don't care if the receiver + // is gone + _ = chan.send(resp); + } else { + opt_chan = Some(chan); + stream_fut = Box::pin(self.stream.connect()); + do_stream = true; + } + } + ReqCmd::Shutdown => break, + } + } + + if do_stream { + let runners_empty = runners.is_empty(); + + loop { + tokio::select! { + res_conn = stream_fut.as_mut() => { + do_stream = false; + stream_fut = Box::pin(stream_nop()); + + let stream = match res_conn { + Ok(stream) => stream, + Err(error) => { + let error = Arc::new(error); + match self.conn_state { + SingleConnState3::None => + self.conn_state = + SingleConnState3::Err(ErrorState { + error: error.clone(), + retries: 0, + timer: Instant::now(), + timeout: retry_time(0), + }), + SingleConnState3::Some(_) => + panic!("Illegal Some state"), + SingleConnState3::Err(error_state) => { + self.conn_state = + SingleConnState3::Err(ErrorState { + error: + error_state.error.clone(), + retries: error_state.retries+1, + timer: Instant::now(), + timeout: retry_time( + error_state.retries+1), + }); + } + } + + let resp = ChanResp::Err(error); + let loc_opt_chan = opt_chan.take(); + + // Ignore errors. We don't care if the receiver + // is gone + _ = loc_opt_chan.expect("weird, no channel?") + .send(resp); + break; + } + }; + let (conn, tran) = stream::Connection::with_config( + stream, self.config.stream.clone() + ); + let conn = Arc::new(conn); + runners.push(Box::pin(tran.run())); + + let resp = ChanResp::Ok(ChanRespOk { + id: self.conn_id, + conn: conn.clone(), + }); + self.conn_state = SingleConnState3::Some(conn); + + let loc_opt_chan = opt_chan.take(); + + // Ignore errors. We don't care if the receiver + // is gone + _ = loc_opt_chan.expect("weird, no channel?") + .send(resp); + break; + } + _ = runners.next(), if !runners_empty => { + } + } + } + continue; + } + + assert!(curr_cmd.is_none()); + let recv_fut = self.receiver.recv(); + let runners_empty = runners.is_empty(); + tokio::select! { + msg = recv_fut => { + if msg.is_none() { + // All references to the connection object have been + // dropped. Shutdown. + break; + } + curr_cmd = Some(msg.expect("None is checked before").cmd); + } + _ = runners.next(), if !runners_empty => { + } + } + } + + // Avoid new queries + drop(self.receiver); + + // Wait for existing stream runners to terminate + while !runners.is_empty() { + runners.next().await; + } + } +} + +//------------ Utility -------------------------------------------------------- + +/// Compute the retry timeout based on the number of retries so far. +/// +/// The computation is a random value (in microseconds) between zero and +/// two to the power of the number of retries. +fn retry_time(retries: u64) -> Duration { + let to_secs = if retries > 6 { 60 } else { 1 << retries }; + let to_usecs = to_secs * 1000000; + let rnd: f64 = random(); + let to_usecs = to_usecs as f64 * rnd; + Duration::from_micros(to_usecs as u64) +} + +/// Helper function to create an empty future that is compatible with the +/// future returned by a connection stream. +async fn stream_nop() -> Result { + Err(io::Error::new(io::ErrorKind::Other, "nop")) +} diff --git a/src/net/client/protocol.rs b/src/net/client/protocol.rs new file mode 100644 index 000000000..3a5779dec --- /dev/null +++ b/src/net/client/protocol.rs @@ -0,0 +1,311 @@ +//! Underlying transport protocols. + +use core::future::Future; +use core::pin::Pin; +use pin_project_lite::pin_project; +use std::boxed::Box; +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::io::ReadBuf; +use tokio::net::{TcpStream, ToSocketAddrs, UdpSocket}; +use tokio_rustls::client::TlsStream; +use tokio_rustls::rustls::{ClientConfig, ServerName}; +use tokio_rustls::TlsConnector; + +/// How many times do we try a new random port if we get ‘address in use.’ +const RETRY_RANDOM_PORT: usize = 10; + +//------------ AsyncConnect -------------------------------------------------- + +/// Establish a connection asynchronously. +/// +/// +pub trait AsyncConnect { + /// The type of an established connection. + type Connection; + + /// The future establishing the connection. + type Fut: Future> + Send; + + /// Returns a future that establishing a connection. + fn connect(&self) -> Self::Fut; +} + +//------------ TcpConnect -------------------------------------------------- + +/// Create new TCP connections. +#[derive(Clone, Copy, Debug)] +pub struct TcpConnect { + /// Remote address to connect to. + addr: Addr, +} + +impl TcpConnect { + /// Create new TCP connections. + /// + /// addr is the destination address to connect to. + pub fn new(addr: Addr) -> Self { + Self { addr } + } +} + +impl AsyncConnect for TcpConnect +where + Addr: ToSocketAddrs + Clone + Send + 'static, +{ + type Connection = TcpStream; + type Fut = Pin< + Box< + dyn Future> + + Send, + >, + >; + + fn connect(&self) -> Self::Fut { + Box::pin(TcpStream::connect(self.addr.clone())) + } +} + +//------------ TlsConnect ----------------------------------------------------- + +/// Create new TLS connections +#[derive(Clone, Debug)] +pub struct TlsConnect { + /// Configuration for setting up a TLS connection. + client_config: Arc, + + /// Server name for certificate verification. + server_name: ServerName, + + /// Remote address to connect to. + addr: Addr, +} + +impl TlsConnect { + /// Function to create a new TLS connection stream + pub fn new( + client_config: impl Into>, + server_name: ServerName, + addr: Addr, + ) -> Self { + Self { + client_config: client_config.into(), + server_name, + addr, + } + } +} + +impl AsyncConnect for TlsConnect +where + Addr: ToSocketAddrs + Clone + Send + 'static, +{ + type Connection = TlsStream; + type Fut = Pin< + Box< + dyn Future> + + Send, + >, + >; + + fn connect(&self) -> Self::Fut { + let tls_connection = TlsConnector::from(self.client_config.clone()); + let server_name = self.server_name.clone(); + let addr = self.addr.clone(); + Box::pin(async { + let box_connection = Box::new(tls_connection); + let tcp = TcpStream::connect(addr).await?; + box_connection.connect(server_name, tcp).await + }) + } +} + +//------------ UdpConnect -------------------------------------------------- + +/// Create new TCP connections. +#[derive(Clone, Copy, Debug)] +pub struct UdpConnect { + /// Remote address to connect to. + addr: SocketAddr, +} + +impl UdpConnect { + /// Create new UDP connections. + /// + /// addr is the destination address to connect to. + pub fn new(addr: SocketAddr) -> Self { + Self { addr } + } + + /// Bind to a random local UDP port. + async fn bind_and_connect(self) -> Result { + let mut i = 0; + let sock = loop { + let local: SocketAddr = if self.addr.is_ipv4() { + ([0u8; 4], 0).into() + } else { + ([0u16; 8], 0).into() + }; + match UdpSocket::bind(&local).await { + Ok(sock) => break sock, + Err(err) => { + if i == RETRY_RANDOM_PORT { + return Err(err); + } else { + i += 1 + } + } + } + }; + sock.connect(self.addr).await?; + Ok(sock) + } +} + +impl AsyncConnect for UdpConnect { + type Connection = UdpSocket; + type Fut = Pin< + Box< + dyn Future> + + Send, + >, + >; + + fn connect(&self) -> Self::Fut { + Box::pin(self.bind_and_connect()) + } +} + +//------------ AsyncDgramRecv ------------------------------------------------- + +/// Receive a datagram packets asynchronously. +pub trait AsyncDgramRecv { + /// Polled receive. + fn poll_recv( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll>; +} + +impl AsyncDgramRecv for UdpSocket { + fn poll_recv( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + UdpSocket::poll_recv(self, cx, buf) + } +} + +//------------ AsyncDgramRecvEx ----------------------------------------------- + +/// Convenvience trait to turn poll_recv into an asynchronous function. +pub trait AsyncDgramRecvEx: AsyncDgramRecv { + /// Asynchronous receive function. + fn recv<'a>(&'a mut self, buf: &'a mut [u8]) -> DgramRecv<'a, Self> + where + Self: Unpin, + { + DgramRecv { + receiver: self, + buf, + } + } +} + +impl AsyncDgramRecvEx for R {} + +//------------ DgramRecv ----------------------------------------------------- + +pin_project! { + /// Return value of recv. This captures the future for recv. + pub struct DgramRecv<'a, R: ?Sized> { + receiver: &'a R, + buf: &'a mut [u8], + } +} + +impl Future for DgramRecv<'_, R> { + type Output = io::Result; + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let me = self.project(); + let mut buf = ReadBuf::new(me.buf); + match Pin::new(me.receiver).poll_recv(cx, &mut buf) { + Poll::Pending => return Poll::Pending, + Poll::Ready(res) => { + if let Err(err) = res { + return Poll::Ready(Err(err)); + } + } + } + Poll::Ready(Ok(buf.filled().len())) + } +} + +//------------ AsyncDgramSend ------------------------------------------------- + +/// Send a datagram packet asynchronously. +/// +/// +pub trait AsyncDgramSend { + /// Polled send function. + fn poll_send( + &self, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll>; +} + +impl AsyncDgramSend for UdpSocket { + fn poll_send( + &self, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + UdpSocket::poll_send(self, cx, buf) + } +} + +//------------ AsyncDgramSendEx ---------------------------------------------- + +/// Convenience trait that turns poll_send into an asynchronous function. +pub trait AsyncDgramSendEx: AsyncDgramSend { + /// Asynchronous function to send a packet. + fn send<'a>(&'a self, buf: &'a [u8]) -> DgramSend<'a, Self> + where + Self: Unpin, + { + DgramSend { sender: self, buf } + } +} + +impl AsyncDgramSendEx for S {} + +//------------ DgramSend ----------------------------------------------------- + +/// This is the return value of send. It captures the future for send. +pub struct DgramSend<'a, S: ?Sized> { + /// The datagram send object. + sender: &'a S, + + /// The buffer that needs to be sent. + buf: &'a [u8], +} + +impl Future for DgramSend<'_, S> { + type Output = io::Result; + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(self.sender).poll_send(cx, self.buf) + } +} diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs new file mode 100644 index 000000000..d0e9a4d95 --- /dev/null +++ b/src/net/client/redundant.rs @@ -0,0 +1,762 @@ +//! A transport that multiplexes requests over multiple redundant transports. + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +use bytes::Bytes; + +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; + +use octseq::Octets; + +use rand::random; + +use std::boxed::Box; +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; +use std::future::Future; +use std::pin::Pin; +use std::vec::Vec; + +use tokio::sync::{mpsc, oneshot}; +use tokio::time::{sleep_until, Duration, Instant}; + +use crate::base::iana::OptRcode; +use crate::base::Message; +use crate::net::client::request::{Error, GetResponse, SendRequest}; + +/* +Basic algorithm: +- keep track of expected response time for every upstream +- start with the upstream with the lowest expected response time +- set a timer to the expect response time. +- if the timer expires before reply arrives, send the query to the next lowest + and set a timer +- when a reply arrives update the expected response time for the relevant + upstream and for the ones that failed. + +Based on a random number generator: +- pick a different upstream rather then the best but set the timer to the + expected response time of the best. +*/ + +/// Capacity of the channel that transports [ChanReq]. +const DEF_CHAN_CAP: usize = 8; + +/// Time in milliseconds for the initial response time estimate. +const DEFAULT_RT_MS: u64 = 300; + +/// The initial response time estimate for unused connections. +const DEFAULT_RT: Duration = Duration::from_millis(DEFAULT_RT_MS); + +/// Maintain a moving average for the measured response time and the +/// square of that. The window is SMOOTH_N. +const SMOOTH_N: f64 = 8.; + +/// Chance to probe a worse connection. +const PROBE_P: f64 = 0.05; + +/// Avoid sending two requests at the same time. +/// +/// When a worse connection is probed, give it a slight head start. +const PROBE_RT: Duration = Duration::from_millis(1); + +//------------ Config --------------------------------------------------------- + +/// User configuration variables. +#[derive(Clone, Copy, Debug, Default)] +pub struct Config { + /// Defer transport errors. + pub defer_transport_error: bool, + + /// Defer replies that report Refused. + pub defer_refused: bool, + + /// Defer replies that report ServFail. + pub defer_servfail: bool, +} + +//------------ Connection ----------------------------------------------------- + +/// This type represents a transport connection. +#[derive(Debug)] +pub struct Connection { + /// User configuation. + config: Config, + + /// To send a request to the runner. + sender: mpsc::Sender>, +} + +impl Connection { + /// Create a new connection. + pub fn new() -> (Self, Transport) { + Self::with_config(Default::default()) + } + + /// Create a new connection with a given config. + pub fn with_config(config: Config) -> (Self, Transport) { + let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP); + (Self { config, sender }, Transport::new(receiver)) + } + + /// Add a transport connection. + pub async fn add( + &self, + conn: Box + Send + Sync>, + ) -> Result<(), Error> { + let (tx, rx) = oneshot::channel(); + self.sender + .send(ChanReq::Add(AddReq { conn, tx })) + .await + .expect("send should not fail"); + rx.await.expect("receive should not fail") + } + + /// Implementation of the query method. + async fn request_impl( + self, + request_msg: Req, + ) -> Result, Error> { + let (tx, rx) = oneshot::channel(); + self.sender + .send(ChanReq::GetRT(RTReq { tx })) + .await + .expect("send should not fail"); + let conn_rt = rx.await.expect("receive should not fail")?; + Query::new(self.config, request_msg, conn_rt, self.sender.clone()) + .get_response() + .await + } +} + +impl Clone for Connection { + fn clone(&self) -> Self { + Self { + config: self.config, + sender: self.sender.clone(), + } + } +} + +impl SendRequest + for Connection +{ + fn send_request(&self, request_msg: Req) -> Box { + Box::new(Request { + fut: Box::pin(self.clone().request_impl(request_msg)), + }) + } +} + +//------------ Request ------------------------------------------------------- + +/// An active request. +pub struct Request { + /// The underlying future. + fut: Pin, Error>> + Send>>, +} + +impl Request { + /// Async function that waits for the future stored in Query to complete. + async fn get_response_impl(&mut self) -> Result, Error> { + (&mut self.fut).await + } +} + +impl GetResponse for Request { + fn get_response( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + > { + Box::pin(self.get_response_impl()) + } +} + +impl Debug for Request { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + f.debug_struct("Request") + .field("fut", &format_args!("_")) + .finish() + } +} + +//------------ Query -------------------------------------------------------- + +/// This type represents an active query request. +#[derive(Debug)] +pub struct Query { + /// User configuration. + config: Config, + + /// The state of the query + state: QueryState, + + /// The reuqest message + request_msg: Req, + + /// List of connections identifiers and estimated response times. + conn_rt: Vec, + + /// Channel to send requests to the run function. + sender: mpsc::Sender>, + + /// List of futures for outstanding requests. + fut_list: + FuturesUnordered + Send>>>, + + /// Transport error that should be reported if nothing better shows + /// up. + deferred_transport_error: Option, + + /// Reply that should be returned to the user if nothing better shows + /// up. + deferred_reply: Option>, + + /// The result from one of the connectons. + result: Option, Error>>, + + /// Index of the connection that returned a result. + res_index: usize, +} + +/// The various states a query can be in. +#[derive(Debug)] +enum QueryState { + /// The initial state + Init, + + /// Start a request on a specific connection. + Probe(usize), + + /// Report the response time for a specific index in the list. + Report(usize), + + /// Wait for one of the requests to finish. + Wait, +} + +/// The commands that can be sent to the run function. +enum ChanReq { + /// Add a connection + Add(AddReq), + + /// Get the list of estimated response times for all connections + GetRT(RTReq), + + /// Start a query + Query(RequestReq), + + /// Report how long it took to get a response + Report(TimeReport), + + /// Report that a connection failed to provide a timely response + Failure(TimeReport), +} + +impl Debug for ChanReq { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + f.debug_struct("ChanReq").finish() + } +} + +/// Request to add a new connection +struct AddReq { + /// New connection to add + conn: Box + Send + Sync>, + + /// Channel to send the reply to + tx: oneshot::Sender, +} + +/// Reply to an Add request +type AddReply = Result<(), Error>; + +/// Request to give the estimated response times for all connections +struct RTReq /**/ { + /// Channel to send the reply to + tx: oneshot::Sender, +} + +/// Reply to a RT request +type RTReply = Result, Error>; + +/// Request to start a request +struct RequestReq { + /// Identifier of connection + id: u64, + + /// Request message + request_msg: Req, + + /// Channel to send the reply to + tx: oneshot::Sender, +} + +impl Debug for RequestReq { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + f.debug_struct("RequestReq") + .field("id", &self.id) + .field("request_msg", &self.request_msg) + .finish() + } +} + +/// Reply to a request request. +type RequestReply = Result, Error>; + +/// Report the amount of time until success or failure. +#[derive(Debug)] +struct TimeReport { + /// Identifier of the transport connection. + id: u64, + + /// Time spend waiting for a reply. + elapsed: Duration, +} + +/// Connection statistics to compute the estimated response time. +struct ConnStats { + /// Aproximation of the windowed average of response times. + mean: f64, + + /// Aproximation of the windowed average of the square of response times. + mean_sq: f64, +} + +/// Data required to schedule requests and report timing results. +#[derive(Clone, Debug)] +struct ConnRT { + /// Estimated response time. + est_rt: Duration, + + /// Identifier of the connection. + id: u64, + + /// Start of a request using this connection. + start: Option, +} + +/// Result of the futures in fut_list. +type FutListOutput = (usize, Result, Error>); + +impl Query { + /// Create a new query object. + fn new( + config: Config, + request_msg: Req, + mut conn_rt: Vec, + sender: mpsc::Sender>, + ) -> Self { + let conn_rt_len = conn_rt.len(); + conn_rt.sort_unstable_by(conn_rt_cmp); + + // Do we want to probe a less performant upstream? + if conn_rt_len > 1 && random::() < PROBE_P { + let index: usize = 1 + random::() % (conn_rt_len - 1); + conn_rt[index].est_rt = PROBE_RT; + + // Sort again + conn_rt.sort_unstable_by(conn_rt_cmp); + } + + Self { + config, + request_msg, + conn_rt, + sender, + state: QueryState::Init, + fut_list: FuturesUnordered::new(), + deferred_transport_error: None, + deferred_reply: None, + result: None, + res_index: 0, + } + } + + /// Implementation of get_response. + async fn get_response(&mut self) -> Result, Error> { + loop { + match self.state { + QueryState::Init => { + if self.conn_rt.is_empty() { + return Err(Error::NoTransportAvailable); + } + self.state = QueryState::Probe(0); + continue; + } + QueryState::Probe(ind) => { + self.conn_rt[ind].start = Some(Instant::now()); + let fut = start_request( + ind, + self.conn_rt[ind].id, + self.sender.clone(), + self.request_msg.clone(), + ); + self.fut_list.push(Box::pin(fut)); + let timeout = Instant::now() + self.conn_rt[ind].est_rt; + loop { + tokio::select! { + res = self.fut_list.next() => { + let res = res.expect("res should not be empty"); + match res.1 { + Err(ref err) => { + if self.config.defer_transport_error { + if self.deferred_transport_error.is_none() { + self.deferred_transport_error = Some(err.clone()); + } + if res.0 == ind { + // The current upstream finished, + // try the next one, if any. + self.state = + if ind+1 < self.conn_rt.len() { + QueryState::Probe(ind+1) + } + else + { + QueryState::Wait + }; + // Break out of receive loop + break; + } + // Just continue receiving + continue; + } + // Return error to the user. + } + Ok(ref msg) => { + if skip(msg, &self.config) { + if self.deferred_reply.is_none() { + self.deferred_reply = Some(msg.clone()); + } + if res.0 == ind { + // The current upstream finished, + // try the next one, if any. + self.state = + if ind+1 < self.conn_rt.len() { + QueryState::Probe(ind+1) + } + else + { + QueryState::Wait + }; + // Break out of receive loop + break; + } + // Just continue receiving + continue; + } + // Now we have a reply that can be + // returned to the user. + } + } + self.result = Some(res.1); + self.res_index= res.0; + + self.state = QueryState::Report(0); + // Break out of receive loop + break; + } + _ = sleep_until(timeout) => { + // Move to the next Probe state if there + // are more upstreams to try, otherwise + // move to the Wait state. + self.state = + if ind+1 < self.conn_rt.len() { + QueryState::Probe(ind+1) + } + else { + QueryState::Wait + }; + // Break out of receive loop + break; + } + } + } + // Continue with state machine loop + continue; + } + QueryState::Report(ind) => { + if ind >= self.conn_rt.len() + || self.conn_rt[ind].start.is_none() + { + // Nothing more to report. Return result. + let res = self + .result + .take() + .expect("result should not be empty"); + return res; + } + + let start = self.conn_rt[ind] + .start + .expect("start time should not be empty"); + let elapsed = start.elapsed(); + let time_report = TimeReport { + id: self.conn_rt[ind].id, + elapsed, + }; + let report = if ind == self.res_index { + // Succesfull entry + ChanReq::Report(time_report) + } else { + // Failed entry + ChanReq::Failure(time_report) + }; + + // Send could fail but we don't care. + let _ = self.sender.send(report).await; + + self.state = QueryState::Report(ind + 1); + continue; + } + QueryState::Wait => { + loop { + if self.fut_list.is_empty() { + // We have nothing left. There should be a reply or + // an error. Prefer a reply over an error. + if self.deferred_reply.is_some() { + let msg = self + .deferred_reply + .take() + .expect("just checked for Some"); + return Ok(msg); + } + if self.deferred_transport_error.is_some() { + let err = self + .deferred_transport_error + .take() + .expect("just checked for Some"); + return Err(err); + } + panic!("either deferred_reply or deferred_error should be present"); + } + let res = self.fut_list.next().await; + let res = res.expect("res should not be empty"); + match res.1 { + Err(ref err) => { + if self.config.defer_transport_error { + if self.deferred_transport_error.is_none() + { + self.deferred_transport_error = + Some(err.clone()); + } + // Just continue with the next future, or + // finish if fut_list is empty. + continue; + } + // Return error to the user. + } + Ok(ref msg) => { + if skip(msg, &self.config) { + if self.deferred_reply.is_none() { + self.deferred_reply = + Some(msg.clone()); + } + // Just continue with the next future, or + // finish if fut_list is empty. + continue; + } + // Return reply to user. + } + } + self.result = Some(res.1); + self.res_index = res.0; + self.state = QueryState::Report(0); + // Break out of loop to continue with the state machine + break; + } + continue; + } + } + } + } +} + +//------------ Transport ----------------------------------------------------- + +/// Type that actually implements the connection. +#[derive(Debug)] +pub struct Transport { + /// Receive side of the channel used by the runner. + receiver: mpsc::Receiver>, +} + +impl<'a, Req: Clone + Send + Sync + 'static> Transport { + /// Implementation of the new method. + fn new(receiver: mpsc::Receiver>) -> Self { + Self { receiver } + } + + /// Run method. + pub async fn run(mut self) { + let mut next_id: u64 = 10; + let mut conn_stats: Vec = Vec::new(); + let mut conn_rt: Vec = Vec::new(); + let mut conns: Vec + Send + Sync>> = + Vec::new(); + + loop { + let req = match self.receiver.recv().await { + Some(req) => req, + None => break, // All references to connection objects are + // dropped. Shutdown. + }; + match req { + ChanReq::Add(add_req) => { + let id = next_id; + next_id += 1; + conn_stats.push(ConnStats { + mean: (DEFAULT_RT_MS as f64) / 1000., + mean_sq: 0., + }); + conn_rt.push(ConnRT { + id, + est_rt: DEFAULT_RT, + start: None, + }); + conns.push(add_req.conn); + + // Don't care if send fails + let _ = add_req.tx.send(Ok(())); + } + ChanReq::GetRT(rt_req) => { + // Don't care if send fails + let _ = rt_req.tx.send(Ok(conn_rt.clone())); + } + ChanReq::Query(request_req) => { + let opt_ind = + conn_rt.iter().position(|e| e.id == request_req.id); + match opt_ind { + Some(ind) => { + let query = conns[ind] + .send_request(request_req.request_msg); + // Don't care if send fails + let _ = request_req.tx.send(Ok(query)); + } + None => { + // Don't care if send fails + let _ = request_req + .tx + .send(Err(Error::RedundantTransportNotFound)); + } + } + } + ChanReq::Report(time_report) => { + let opt_ind = + conn_rt.iter().position(|e| e.id == time_report.id); + if let Some(ind) = opt_ind { + let elapsed = time_report.elapsed.as_secs_f64(); + conn_stats[ind].mean += + (elapsed - conn_stats[ind].mean) / SMOOTH_N; + let elapsed_sq = elapsed * elapsed; + conn_stats[ind].mean_sq += + (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N; + let mean = conn_stats[ind].mean; + let var = conn_stats[ind].mean_sq - mean * mean; + let std_dev = + if var < 0. { 0. } else { f64::sqrt(var) }; + let est_rt = mean + 3. * std_dev; + conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt); + } + } + ChanReq::Failure(time_report) => { + let opt_ind = + conn_rt.iter().position(|e| e.id == time_report.id); + if let Some(ind) = opt_ind { + let elapsed = time_report.elapsed.as_secs_f64(); + if elapsed < conn_stats[ind].mean { + // Do not update the mean if a + // failure took less time than the + // current mean. + continue; + } + conn_stats[ind].mean += + (elapsed - conn_stats[ind].mean) / SMOOTH_N; + let elapsed_sq = elapsed * elapsed; + conn_stats[ind].mean_sq += + (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N; + let mean = conn_stats[ind].mean; + let var = conn_stats[ind].mean_sq - mean * mean; + let std_dev = + if var < 0. { 0. } else { f64::sqrt(var) }; + let est_rt = mean + 3. * std_dev; + conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt); + } + } + } + } + } +} + +//------------ Utility -------------------------------------------------------- + +/// Async function to send a request and wait for the reply. +/// +/// This gives a single future that we can put in a list. +async fn start_request( + index: usize, + id: u64, + sender: mpsc::Sender>, + request_msg: Req, +) -> (usize, Result, Error>) { + let (tx, rx) = oneshot::channel(); + sender + .send(ChanReq::Query(RequestReq { + id, + request_msg, + tx, + })) + .await + .expect("send is expected to work"); + let mut request = match rx.await.expect("receive is expected to work") { + Err(err) => return (index, Err(err)), + Ok(request) => request, + }; + let reply = request.get_response().await; + + (index, reply) +} + +/// Compare ConnRT elements based on estimated response time. +fn conn_rt_cmp(e1: &ConnRT, e2: &ConnRT) -> Ordering { + e1.est_rt.cmp(&e2.est_rt) +} + +/// Return if this reply should be skipped or not. +fn skip(msg: &Message, config: &Config) -> bool { + // Check if we actually need to check. + if !config.defer_refused && !config.defer_servfail { + return false; + } + + let opt_rcode = get_opt_rcode(msg); + // OptRcode needs PartialEq + if let OptRcode::Refused = opt_rcode { + if config.defer_refused { + return true; + } + } + if let OptRcode::ServFail = opt_rcode { + if config.defer_servfail { + return true; + } + } + + false +} + +/// Get the extended rcode of a message. +fn get_opt_rcode(msg: &Message) -> OptRcode { + let opt = msg.opt(); + match opt { + Some(opt) => opt.rcode(msg.header()), + None => { + // Convert Rcode to OptRcode, this should be part of + // OptRcode + OptRcode::from_int(msg.header().rcode().to_int() as u16) + } + } +} diff --git a/src/net/client/request.rs b/src/net/client/request.rs new file mode 100644 index 000000000..b89bf0ee5 --- /dev/null +++ b/src/net/client/request.rs @@ -0,0 +1,410 @@ +//! Constructing and sending requests. + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +use crate::base::iana::Rcode; +use crate::base::message::CopyRecordsError; +use crate::base::message_builder::{ + AdditionalBuilder, MessageBuilder, PushError, StaticCompressor, +}; +use crate::base::opt::{ComposeOptData, LongOptData, OptRecord}; +use crate::base::wire::Composer; +use crate::base::{Header, Message, ParsedDname, Rtype}; +use crate::rdata::AllRecordData; +use bytes::Bytes; +use octseq::Octets; +use std::boxed::Box; +use std::fmt::Debug; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::vec::Vec; +use std::{error, fmt}; + +//------------ ComposeRequest ------------------------------------------------ + +/// A trait that allows composing a request as a series. +pub trait ComposeRequest: Debug + Send + Sync { + /// Appends the final message to a provided composer. + fn append_message( + &self, + target: &mut Target, + ) -> Result<(), CopyRecordsError>; + + /// Create a message that captures the recorded changes. + fn to_message(&self) -> Message>; + + /// Create a message that captures the recorded changes and convert to + /// a Vec. + fn to_vec(&self) -> Vec; + + /// Return a reference to a mutable Header to record changes to the header. + fn header_mut(&mut self) -> &mut Header; + + /// Set the UDP payload size. + fn set_udp_payload_size(&mut self, value: u16); + + /// Add an EDNS option. + fn add_opt( + &mut self, + opt: &impl ComposeOptData, + ) -> Result<(), LongOptData>; + + /// Returns whether a message is an answer to the request. + fn is_answer(&self, answer: &Message<[u8]>) -> bool; +} + +//------------ SendRequest --------------------------------------------------- + +/// Trait for starting a DNS request based on a request composer. +/// +/// In the future, the return type of request should become an associated type. +/// However, the use of 'dyn Request' in redundant currently prevents that. +pub trait SendRequest { + /// Request function that takes a ComposeRequest type. + fn send_request(&self, request_msg: CR) -> Box; +} + +//------------ GetResponse --------------------------------------------------- + +/// Trait for getting the result of a DNS query. +/// +/// In the future, the return type of get_response should become an associated +/// type. However, too many uses of 'dyn GetResponse' currently prevent that. +pub trait GetResponse: Debug { + /// Get the result of a DNS request. + /// + /// This function is intended to be cancel safe. + fn get_response( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + >; +} + +//------------ RequestMessage ------------------------------------------------ + +/// Object that implements the ComposeRequest trait for a Message object. +#[derive(Clone, Debug)] +pub struct RequestMessage> { + /// Base message. + msg: Message, + + /// New header. + header: Header, + + /// The OPT record to add if required. + opt: Option>>, +} + +impl + Debug + Octets> RequestMessage { + /// Create a new BMB object. + pub fn new(msg: impl Into>) -> Self { + let msg = msg.into(); + let header = msg.header(); + Self { + msg, + header, + opt: None, + } + } + + /// Returns a mutable reference to the OPT record. + /// + /// Adds one if necessary. + fn opt_mut(&mut self) -> &mut OptRecord> { + self.opt.get_or_insert_with(Default::default) + } + + /// Appends the message to a composer. + fn append_message_impl( + &self, + mut target: MessageBuilder, + ) -> Result, CopyRecordsError> { + let source = &self.msg; + + *target.header_mut() = self.header; + + let source = source.question(); + let mut target = target.question(); + for rr in source { + target.push(rr?)?; + } + let mut source = source.answer()?; + let mut target = target.answer(); + for rr in &mut source { + let rr = rr? + .into_record::>>()? + .expect("record expected"); + target.push(rr)?; + } + + let mut source = + source.next_section()?.expect("section should be present"); + let mut target = target.authority(); + for rr in &mut source { + let rr = rr? + .into_record::>>()? + .expect("record expected"); + target.push(rr)?; + } + + let source = + source.next_section()?.expect("section should be present"); + let mut target = target.additional(); + for rr in source { + let rr = rr?; + if rr.rtype() != Rtype::Opt { + let rr = rr + .into_record::>>()? + .expect("record expected"); + target.push(rr)?; + } + } + + if let Some(opt) = self.opt.as_ref() { + target.push(opt.as_record())?; + } + + Ok(target) + } + + /// Create new message based on the changes to the base message. + fn to_message_impl(&self) -> Result>, Error> { + let target = + MessageBuilder::from_target(StaticCompressor::new(Vec::new())) + .expect("Vec is expected to have enough space"); + + let target = self.append_message_impl(target)?; + + // It would be nice to use .builder() here. But that one deletes all + // section. We have to resort to .as_builder() which gives a + // reference and then .clone() + let result = target.as_builder().clone(); + let msg = Message::from_octets(result.finish().into_target()).expect( + "Message should be able to parse output from MessageBuilder", + ); + Ok(msg) + } +} + +impl + Clone + Debug + Octets + Send + Sync + 'static> + ComposeRequest for RequestMessage +{ + fn append_message( + &self, + target: &mut Target, + ) -> Result<(), CopyRecordsError> { + let target = MessageBuilder::from_target(target) + .map_err(|_| CopyRecordsError::Push(PushError::ShortBuf))?; + self.append_message_impl(target)?; + Ok(()) + } + + fn to_vec(&self) -> Vec { + let msg = self.to_message(); + msg.as_octets().clone() + } + + fn to_message(&self) -> Message> { + self.to_message_impl().unwrap() + } + + fn header_mut(&mut self) -> &mut Header { + &mut self.header + } + + fn set_udp_payload_size(&mut self, value: u16) { + self.opt_mut().set_udp_payload_size(value); + } + + fn add_opt( + &mut self, + opt: &impl ComposeOptData, + ) -> Result<(), LongOptData> { + self.opt_mut().push(opt).map_err(|e| e.unlimited_buf()) + } + + fn is_answer(&self, answer: &Message<[u8]>) -> bool { + let answer_header = answer.header(); + let answer_hcounts = answer.header_counts(); + + // First check qr is set and IDs match. + if !answer_header.qr() || answer_header.id() != self.header.id() { + return false; + } + + // If the result is an error, then the question section can be empty. + // In that case we require all other sections to be empty as well. + if answer_header.rcode() != Rcode::NoError + && answer_hcounts.qdcount() == 0 + && answer_hcounts.ancount() == 0 + && answer_hcounts.nscount() == 0 + && answer_hcounts.arcount() == 0 + { + // We can accept this as a valid reply. + return true; + } + + // Now the question section in the reply has to be the same as in the + // query. + if answer_hcounts.qdcount() != self.msg.header_counts().qdcount() { + false + } else { + answer.question() == self.msg.for_slice().question() + } + } +} + +//------------ Error --------------------------------------------------------- + +/// Error type for client transports. +#[derive(Clone, Debug)] +pub enum Error { + /// Connection was already closed. + ConnectionClosed, + + /// The OPT record has become too long. + OptTooLong, + + /// PushError from MessageBuilder. + MessageBuilderPushError, + + /// ParseError from Message. + MessageParseError, + + /// Underlying transport not found in redundant connection + RedundantTransportNotFound, + + /// Octet sequence too short to be a valid DNS message. + ShortMessage, + + /// Message too long for stream transport. + StreamLongMessage, + + /// Stream transport closed because it was idle (for too long). + StreamIdleTimeout, + + /// Error receiving a reply. + // + StreamReceiveError, + + /// Reading from stream gave an error. + StreamReadError(Arc), + + /// Reading from stream took too long. + StreamReadTimeout, + + /// Too many outstand queries on a single stream transport. + StreamTooManyOutstandingQueries, + + /// Writing to a stream gave an error. + StreamWriteError(Arc), + + /// Reading for a stream ended unexpectedly. + StreamUnexpectedEndOfData, + + /// Reply does not match the query. + WrongReplyForQuery, + + /// No transport available to transmit request. + NoTransportAvailable, + + /// An error happened in the datagram transport. + Dgram(Arc), +} + +impl From for Error { + fn from(_: LongOptData) -> Self { + Self::OptTooLong + } +} + +impl From for Error { + fn from(err: super::dgram::QueryError) -> Self { + Self::Dgram(err.into()) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::ConnectionClosed => write!(f, "connection closed"), + Error::OptTooLong => write!(f, "OPT record is too long"), + Error::MessageBuilderPushError => { + write!(f, "PushError from MessageBuilder") + } + Error::MessageParseError => write!(f, "ParseError from Message"), + Error::RedundantTransportNotFound => write!( + f, + "Underlying transport not found in redundant connection" + ), + Error::ShortMessage => { + write!(f, "octet sequence to short to be a valid message") + } + Error::StreamLongMessage => { + write!(f, "message too long for stream transport") + } + Error::StreamIdleTimeout => { + write!(f, "stream was idle for too long") + } + Error::StreamReceiveError => write!(f, "error receiving a reply"), + Error::StreamReadError(_) => { + write!(f, "error reading from stream") + } + Error::StreamReadTimeout => { + write!(f, "timeout reading from stream") + } + Error::StreamTooManyOutstandingQueries => { + write!(f, "too many outstanding queries on stream") + } + Error::StreamWriteError(_) => { + write!(f, "error writing to stream") + } + Error::StreamUnexpectedEndOfData => { + write!(f, "unexpected end of data") + } + Error::WrongReplyForQuery => { + write!(f, "reply does not match query") + } + Error::NoTransportAvailable => { + write!(f, "no transport available") + } + Error::Dgram(err) => fmt::Display::fmt(err, f), + } + } +} + +impl From for Error { + fn from(err: CopyRecordsError) -> Self { + match err { + CopyRecordsError::Parse(_) => Self::MessageParseError, + CopyRecordsError::Push(_) => Self::MessageBuilderPushError, + } + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match self { + Error::ConnectionClosed => None, + Error::OptTooLong => None, + Error::MessageBuilderPushError => None, + Error::MessageParseError => None, + Error::RedundantTransportNotFound => None, + Error::ShortMessage => None, + Error::StreamLongMessage => None, + Error::StreamIdleTimeout => None, + Error::StreamReceiveError => None, + Error::StreamReadError(e) => Some(e), + Error::StreamReadTimeout => None, + Error::StreamTooManyOutstandingQueries => None, + Error::StreamWriteError(e) => Some(e), + Error::StreamUnexpectedEndOfData => None, + Error::WrongReplyForQuery => None, + Error::NoTransportAvailable => None, + Error::Dgram(err) => Some(err), + } + } +} diff --git a/src/net/client/stream.rs b/src/net/client/stream.rs new file mode 100644 index 000000000..3e3b05817 --- /dev/null +++ b/src/net/client/stream.rs @@ -0,0 +1,897 @@ +//! A client transport using a stream socket. + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +// RFC 7766 describes DNS over TCP +// RFC 7828 describes the edns-tcp-keepalive option + +// TODO: +// - errors +// - connect errors? Retry after connection refused? +// - server errors +// - ID out of range +// - ID not in use +// - reply for wrong query +// - timeouts +// - request timeout +// - create new connection after end/failure of previous one + +use crate::base::message::Message; +use crate::base::message_builder::StreamTarget; +use crate::base::opt::{AllOptData, OptRecord, TcpKeepalive}; +use crate::net::client::request::{ + ComposeRequest, Error, GetResponse, SendRequest, +}; +use bytes; +use bytes::{Bytes, BytesMut}; +use core::cmp; +use core::convert::From; +use octseq::Octets; +use std::boxed::Box; +use std::fmt::Debug; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use std::vec::Vec; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::sync::{mpsc, oneshot}; +use tokio::time::sleep; + +//------------ Configuration Constants ---------------------------------------- + +/// Default response timeout. +/// +/// Note: nsd has 120 seconds, unbound has 3 seconds. +const DEF_RESPONSE_TIMEOUT: Duration = Duration::from_secs(19); + +/// Minimum configuration value for the response timeout. +const MIN_RESPONSE_TIMEOUT: Duration = Duration::from_millis(1); + +/// Maximum configuration value for the response timeout. +const MAX_RESPONSE_TIMEOUT: Duration = Duration::from_secs(600); + +/// Capacity of the channel that transports `ChanReq`s. +const DEF_CHAN_CAP: usize = 8; + +/// Capacity of a private channel dispatching responses. +const READ_REPLY_CHAN_CAP: usize = 8; + +//------------ Config --------------------------------------------------------- + +/// Configuration for a stream transport connection. +#[derive(Clone, Debug)] +pub struct Config { + /// Response timeout. + response_timeout: Duration, +} + +impl Config { + /// Creates a new, default config. + pub fn new() -> Self { + Default::default() + } + + /// Returns the response timeout. + /// + /// This is the amount of time to wait on a non-idle connection for a + /// response to an outstanding request. + pub fn response_timeout(&self) -> Duration { + self.response_timeout + } + + /// Sets the response timeout. + /// + /// Excessive values are quietly trimmed. + // + // XXX Maybe that’s wrong and we should rather return an error? + pub fn set_response_timeout(&mut self, timeout: Duration) { + self.response_timeout = cmp::max( + cmp::min(timeout, MAX_RESPONSE_TIMEOUT), + MIN_RESPONSE_TIMEOUT, + ) + } +} + +impl Default for Config { + fn default() -> Self { + Self { + response_timeout: DEF_RESPONSE_TIMEOUT, + } + } +} + +//------------ Connection ----------------------------------------------------- + +/// A connection to a single stream transport. +#[derive(Debug)] +pub struct Connection { + /// The sender half of the request channel. + sender: mpsc::Sender>, +} + +impl Connection { + /// Creates a new stream transport with default configuration. + /// + /// Returns a connection and a future that drives the transport using + /// the provided stream. This future needs to be run while any queries + /// are active. This is most easly achieved by spawning it into a runtime. + /// It terminates when the last connection is dropped. + pub fn new(stream: Stream) -> (Self, Transport) { + Self::with_config(stream, Default::default()) + } + + /// Creates a new stream transport with the given configuration. + /// + /// Returns a connection and a future that drives the transport using + /// the provided stream. This future needs to be run while any queries + /// are active. This is most easly achieved by spawning it into a runtime. + /// It terminates when the last connection is dropped. + pub fn with_config( + stream: Stream, + config: Config, + ) -> (Self, Transport) { + let (sender, transport) = Transport::new(stream, config); + (Self { sender }, transport) + } +} + +impl Connection { + /// Start a DNS request. + /// + /// This function takes a precomposed message as a parameter and + /// returns a [ReqRepl] object wrapped in a [Result]. + async fn handle_request_impl( + self, + msg: Req, + ) -> Result, Error> { + let (sender, receiver) = oneshot::channel(); + let req = ChanReq { sender, msg }; + self.sender.send(req).await.map_err(|_| { + // Send error. The receiver is gone, this means that the + // connection is closed. + Error::ConnectionClosed + })?; + receiver.await.map_err(|_| Error::StreamReceiveError)? + } + + /// Returns a request handler for this connection. + pub fn get_request(&self, request_msg: Req) -> Request { + Request { + fut: Box::pin(self.clone().handle_request_impl(request_msg)), + } + } +} + +impl Clone for Connection { + fn clone(&self) -> Self { + Self { + sender: self.sender.clone(), + } + } +} + +impl SendRequest + for Connection +{ + fn send_request(&self, request_msg: Req) -> Box { + Box::new(self.get_request(request_msg)) + } +} + +//------------ Request ------------------------------------------------------- + +/// An active request. +pub struct Request { + /// The underlying future. + fut: Pin, Error>> + Send>>, +} + +impl Request { + /// Async function that waits for the future stored in Request to complete. + async fn get_response_impl(&mut self) -> Result, Error> { + (&mut self.fut).await + } +} + +impl GetResponse for Request { + fn get_response( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + > { + Box::pin(self.get_response_impl()) + } +} + +impl Debug for Request { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Request") + .field("fut", &format_args!("_")) + .finish() + } +} + +//------------ Transport ----------------------------------------------------- + +/// The underlying machinery of a stream transport. +#[derive(Debug)] +pub struct Transport { + /// The stream socket towards the remove end. + stream: Stream, + + /// Transport configuration. + config: Config, + + /// The receiver half of request channel. + receiver: mpsc::Receiver>, +} + +/// A message from a `Request` to start a new request. +#[derive(Debug)] +struct ChanReq { + /// DNS request message + msg: Req, + + /// Sender to send result back to [Request] + sender: ReplySender, +} + +/// This is the type of sender in [ChanReq]. +type ReplySender = oneshot::Sender; + +/// A message back to `Request` returning a response. +type ChanResp = Result, Error>; + +/// Internal datastructure of [Transport::run] to keep track of +/// the status of the connection. +// The types Status and ConnState are only used in Transport +struct Status { + /// State of the connection. + state: ConnState, + + /// Do we need to include edns-tcp-keepalive in an outogoing request. + /// + /// Typically this is true at the start of the connection and gets + /// cleared when we successfully managed to include the option in a + /// request. + send_keepalive: bool, + + /// Time we are allow to keep the connection open when idle. + /// + /// Initially we assume that the idle timeout is zero. A received + /// edns-tcp-keepalive option may change that. + idle_timeout: Option, +} + +/// Status of the connection. Used in [Status]. +enum ConnState { + /// The connection is in this state from the start and when at least + /// one active DNS request is present. + /// + /// The instant contains the time of the first request or the + /// most recent response that was received. + Active(Option), + + /// This state represent a connection that went idle and has an + /// idle timeout. + /// + /// The instant contains the time the connection went idle. + Idle(Instant), + + /// This state represent an idle connection where either there was no + /// idle timeout or the idle timer expired. + IdleTimeout, + + /// A read error occurred. + ReadError(Error), + + /// It took too long to receive a response. + ReadTimeout, + + /// A write error occurred. + WriteError(Error), +} + +impl Transport { + /// Creates a new transport. + fn new( + stream: Stream, + config: Config, + ) -> (mpsc::Sender>, Self) { + let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP); + ( + sender, + Self { + config, + stream, + receiver, + }, + ) + } +} + +impl Transport +where + Stream: AsyncRead + AsyncWrite, + Req: ComposeRequest, +{ + /// Run the transport machinery. + pub async fn run(mut self) { + let (reply_sender, mut reply_receiver) = + mpsc::channel::>(READ_REPLY_CHAN_CAP); + + let (read_stream, mut write_stream) = tokio::io::split(self.stream); + + let reader_fut = Self::reader(read_stream, reply_sender); + tokio::pin!(reader_fut); + + let mut status = Status { + state: ConnState::Active(None), + idle_timeout: None, + send_keepalive: true, + }; + let mut query_vec = Queries::new(); + + let mut reqmsg: Option> = None; + let mut reqmsg_offset = 0; + + loop { + let opt_timeout = match status.state { + ConnState::Active(opt_instant) => { + if let Some(instant) = opt_instant { + let elapsed = instant.elapsed(); + if elapsed > self.config.response_timeout { + Self::error( + Error::StreamReadTimeout, + &mut query_vec, + ); + status.state = ConnState::ReadTimeout; + break; + } + Some(self.config.response_timeout - elapsed) + } else { + None + } + } + ConnState::Idle(instant) => { + if let Some(timeout) = &status.idle_timeout { + let elapsed = instant.elapsed(); + if elapsed >= *timeout { + // Move to IdleTimeout and end + // the loop + status.state = ConnState::IdleTimeout; + break; + } + Some(*timeout - elapsed) + } else { + panic!("Idle state but no timeout"); + } + } + ConnState::IdleTimeout + | ConnState::ReadError(_) + | ConnState::WriteError(_) => None, // No timers here + ConnState::ReadTimeout => { + panic!("should not be in loop with ReadTimeout"); + } + }; + + // For simplicity, make sure we always have a timeout + let timeout = match opt_timeout { + Some(timeout) => timeout, + None => + // Just use the response timeout + { + self.config.response_timeout + } + }; + + let sleep_fut = sleep(timeout); + let recv_fut = self.receiver.recv(); + + let (do_write, msg) = match &reqmsg { + None => { + let msg: &[u8] = &[]; + (false, msg) + } + Some(msg) => { + let msg: &[u8] = msg; + (true, msg) + } + }; + + tokio::select! { + biased; + res = &mut reader_fut => { + match res { + Ok(_) => + // The reader should not + // terminate without + // error. + panic!("reader terminated"), + Err(error) => { + Self::error(error.clone(), &mut query_vec); + status.state = ConnState::ReadError(error); + // Reader failed. Break + // out of loop and + // shut down + break + } + } + } + opt_answer = reply_receiver.recv() => { + let answer = opt_answer.expect("reader died?"); + // Check for a edns-tcp-keepalive option + let opt_record = answer.opt(); + if let Some(ref opts) = opt_record { + Self::handle_opts(opts, + &mut status); + }; + drop(opt_record); + Self::demux_reply(answer, &mut status, &mut query_vec); + } + res = write_stream.write(&msg[reqmsg_offset..]), + if do_write => { + match res { + Err(error) => { + let error = + Error::StreamWriteError(Arc::new(error)); + Self::error(error.clone(), &mut query_vec); + status.state = + ConnState::WriteError(error); + break; + } + Ok(len) => { + reqmsg_offset += len; + if reqmsg_offset >= msg.len() { + reqmsg = None; + reqmsg_offset = 0; + } + } + } + } + res = recv_fut, if !do_write => { + match res { + Some(req) => { + Self::insert_req( + req, &mut status, &mut reqmsg, &mut query_vec + ) + } + None => { + // All references to the connection object have + // been dropped. Shutdown. + break; + } + } + } + _ = sleep_fut => { + // Timeout expired, just + // continue with the loop + } + + } + + // Check if the connection is idle + match status.state { + ConnState::Active(_) | ConnState::Idle(_) => { + // Keep going + } + ConnState::IdleTimeout => break, + ConnState::ReadError(_) + | ConnState::ReadTimeout + | ConnState::WriteError(_) => { + panic!("Should not be here"); + } + } + } + + // Send FIN + _ = write_stream.shutdown().await; + } + + /// This function reads a DNS message from the connection and sends + /// it to [Transport::run]. + /// + /// Reading has to be done in two steps: first read a two octet value + /// the specifies the length of the message, and then read in a loop the + /// body of the message. + /// + /// This function is not async cancellation safe. + async fn reader( + mut sock: tokio::io::ReadHalf, + sender: mpsc::Sender>, + ) -> Result<(), Error> { + loop { + let read_res = sock.read_u16().await; + let len = match read_res { + Ok(len) => len, + Err(error) => { + return Err(Error::StreamReadError(Arc::new(error))); + } + } as usize; + + let mut buf = BytesMut::with_capacity(len); + + loop { + let curlen = buf.len(); + if curlen >= len { + if curlen > len { + panic!( + "reader: got too much data {curlen}, expetect {len}"); + } + + // We got what we need + break; + } + + let read_res = sock.read_buf(&mut buf).await; + + match read_res { + Ok(readlen) => { + if readlen == 0 { + return Err(Error::StreamUnexpectedEndOfData); + } + } + Err(error) => { + return Err(Error::StreamReadError(Arc::new(error))); + } + }; + + // Check if we are done at the head of the loop + } + + let reply_message = Message::::from_octets(buf.into()); + match reply_message { + Ok(answer) => { + sender + .send(answer) + .await + .expect("can't send reply to run"); + } + Err(_) => { + // The only possible error is short message + return Err(Error::ShortMessage); + } + } + } + } + + /// Reports an error to all outstanding queries. + fn error(error: Error, query_vec: &mut Queries>) { + // Update all requests that are in progress. Don't wait for + // any reply that may be on its way. + for item in query_vec.drain() { + _ = item.sender.send(Err(error.clone())); + } + } + + /// Handles received EDNS options. + /// + /// In particular, it processes the edns-tcp-keepalive option. + fn handle_opts>( + opts: &OptRecord, + status: &mut Status, + ) { + // XXX This handles _all_ keepalive options. I think just using the + // first option as returned by Opt::tcp_keepalive should be good + // enough? -- M. + for option in opts.opt().iter().flatten() { + if let AllOptData::TcpKeepalive(tcpkeepalive) = option { + Self::handle_keepalive(tcpkeepalive, status); + } + } + } + + /// Demultiplexes a response and sends it to the right query. + /// + /// In addition, the status is updated to IdleTimeout or Idle if there + /// are no remaining pending requests. + fn demux_reply( + answer: Message, + status: &mut Status, + query_vec: &mut Queries>, + ) { + // We got an answer, reset the timer + status.state = ConnState::Active(Some(Instant::now())); + + // Get the correct query and send it the reply. + let req = match query_vec.try_remove(answer.header().id()) { + Some(req) => req, + None => { + // No query with this ID. We should + // mark the connection as broken + return; + } + }; + let answer = if req.msg.is_answer(answer.for_slice()) { + Ok(answer) + } else { + Err(Error::WrongReplyForQuery) + }; + _ = req.sender.send(answer); + + if query_vec.is_empty() { + // Clear the activity timer. There is no need to do + // this because state will be set to either IdleTimeout + // or Idle just below. However, it is nicer to keep + // this independent. + status.state = ConnState::Active(None); + + status.state = if status.idle_timeout.is_none() { + // Assume that we can just move to IdleTimeout + // state + ConnState::IdleTimeout + } else { + ConnState::Idle(Instant::now()) + } + } + } + + /// Insert a request in query_vec and return the request to be sent + /// in *reqmsg. + /// + /// First the status is checked, an error is returned if not Active or + /// idle. Addend a edns-tcp-keepalive option if needed. + // Note: maybe reqmsg should be a return value. + fn insert_req( + req: ChanReq, + status: &mut Status, + reqmsg: &mut Option>, + query_vec: &mut Queries>, + ) { + match &status.state { + ConnState::Active(timer) => { + // Set timer if we don't have one already + if timer.is_none() { + status.state = ConnState::Active(Some(Instant::now())); + } + } + ConnState::Idle(_) => { + // Go back to active + status.state = ConnState::Active(Some(Instant::now())); + } + ConnState::IdleTimeout => { + // The connection has been closed. Report error + _ = req.sender.send(Err(Error::StreamIdleTimeout)); + return; + } + ConnState::ReadError(error) => { + _ = req.sender.send(Err(error.clone())); + return; + } + ConnState::ReadTimeout => { + _ = req.sender.send(Err(Error::StreamReadTimeout)); + return; + } + ConnState::WriteError(error) => { + _ = req.sender.send(Err(error.clone())); + return; + } + } + + // Note that insert may fail if there are too many + // outstanding queires. First call insert before checking + // send_keepalive. + let (index, req) = match query_vec.insert(req) { + Ok(res) => res, + Err(req) => { + // Send an appropriate error and return. + _ = req + .sender + .send(Err(Error::StreamTooManyOutstandingQueries)); + return; + } + }; + + // We set the ID to the array index. Defense in depth + // suggests that a random ID is better because it works + // even if TCP sequence numbers could be predicted. However, + // Section 9.3 of RFC 5452 recommends retrying over TCP + // if many spoofed answers arrive over UDP: "TCP, by the + // nature of its use of sequence numbers, is far more + // resilient against forgery by third parties." + + let hdr = req.msg.header_mut(); + hdr.set_id(index); + + if status.send_keepalive + && req.msg.add_opt(&TcpKeepalive::new(None)).is_ok() + { + status.send_keepalive = false; + } + + match Self::convert_query(&req.msg) { + Ok(msg) => { + *reqmsg = Some(msg); + } + Err(err) => { + // Take the sender out again and return the error. + if let Some(req) = query_vec.try_remove(index) { + _ = req.sender.send(Err(err)); + } + } + } + } + + /// Handle a received edns-tcp-keepalive option. + fn handle_keepalive(opt_value: TcpKeepalive, status: &mut Status) { + if let Some(value) = opt_value.timeout() { + let value_dur = Duration::from(value); + status.idle_timeout = Some(value_dur); + } + } + + /// Convert the query message to a vector. + fn convert_query(msg: &Req) -> Result, Error> { + let mut target = StreamTarget::new_vec(); + msg.append_message(&mut target) + .map_err(|_| Error::StreamLongMessage)?; + Ok(target.into_target()) + } +} + +//------------ Queries ------------------------------------------------------- + +/// Mapping outstanding queries to their ID. +/// +/// This is generic over anything rather than our concrete request type for +/// easier testing. +#[derive(Clone, Debug)] +struct Queries { + /// The number of elements in `vec` that are not None. + count: usize, + + /// Index in `vec? where to look for a space for a new query. + curr: usize, + + /// Vector of senders to forward a DNS reply message (or error) to. + vec: Vec>, +} + +impl Queries { + /// Creates a new empty value. + fn new() -> Self { + Self { + count: 0, + curr: 0, + vec: Vec::new(), + } + } + + /// Returns whether there are no more outstanding queries. + fn is_empty(&self) -> bool { + self.count == 0 + } + + /// Inserts the given query. + /// + /// Upon success, returns the index and a mutable reference to the stored + /// query. + /// + /// Upon error, which means the set is full, returns the query. + fn insert(&mut self, req: T) -> Result<(u16, &mut T), T> { + // Fail if there are to many entries already in this vector + // We cannot have more than u16::MAX entries because the + // index needs to fit in an u16. For efficiency we want to + // keep the vector half empty. So we return a failure if + // 2*count > u16::MAX + if 2 * self.count > u16::MAX.into() { + return Err(req); + } + + // If more than half the vec is empty, we try and find the index of + // an empty slot. + let idx = if self.vec.len() >= 2 * self.count { + let mut found = None; + for idx in self.curr..self.vec.len() { + if self.vec[idx].is_none() { + found = Some(idx); + break; + } + } + found + } else { + None + }; + + // If we have an index, we can insert there, otherwise we need to + // append. + let idx = match idx { + Some(idx) => { + self.vec[idx] = Some(req); + idx + } + None => { + let idx = self.vec.len(); + self.vec.push(Some(req)); + idx + } + }; + + self.count += 1; + if idx == self.curr { + self.curr += 1; + } + let req = self.vec[idx].as_mut().expect("no inserted item?"); + let idx = u16::try_from(idx).expect("query vec too large"); + Ok((idx, req)) + } + + /// Tries to remove and return the query at the given index. + /// + /// Returns `None` if there was no query there. + fn try_remove(&mut self, index: u16) -> Option { + let res = self.vec.get_mut(usize::from(index))?.take()?; + self.count = self.count.saturating_sub(1); + self.curr = cmp::min(self.curr, index.into()); + Some(res) + } + + /// Removes all queries and returns an iterator over them. + fn drain(&mut self) -> impl Iterator + '_ { + let res = self.vec.drain(..).flatten(); // Skips all the `None`s. + self.count = 0; + self.curr = 0; + res + } +} + +//============ Tests ========================================================= + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[allow(clippy::needless_range_loop)] + fn queries_insert_remove() { + // Insert items, remove a few, insert a few more. Check that + // everything looks right. + let mut idxs = [None; 20]; + let mut queries = Queries::new(); + + for i in 0..12 { + let (idx, item) = queries.insert(i).unwrap(); + idxs[i] = Some(idx); + assert_eq!(i, *item); + } + assert_eq!(queries.count, 12); + assert_eq!(queries.vec.iter().flatten().count(), 12); + + for i in [1, 2, 3, 4, 7, 9] { + let item = queries.try_remove(idxs[i].unwrap()).unwrap(); + assert_eq!(i, item); + idxs[i] = None; + } + assert_eq!(queries.count, 6); + assert_eq!(queries.vec.iter().flatten().count(), 6); + + for i in 12..20 { + let (idx, item) = queries.insert(i).unwrap(); + idxs[i] = Some(idx); + assert_eq!(i, *item); + } + assert_eq!(queries.count, 14); + assert_eq!(queries.vec.iter().flatten().count(), 14); + + for i in 0..20 { + if let Some(idx) = idxs[i] { + let item = queries.try_remove(idx).unwrap(); + assert_eq!(i, item); + } + } + assert_eq!(queries.count, 0); + assert_eq!(queries.vec.iter().flatten().count(), 0); + } + + #[test] + fn queries_overrun() { + // This is just a quick check that inserting to much stuff doesn’t + // break. + let mut queries = Queries::new(); + for i in 0..usize::from(u16::MAX) * 2 { + let _ = queries.insert(i); + } + } +} diff --git a/src/net/mod.rs b/src/net/mod.rs new file mode 100644 index 000000000..5b7e9435b --- /dev/null +++ b/src/net/mod.rs @@ -0,0 +1,18 @@ +//! Sending and receiving DNS messages. +//! +//! This module provides types, traits, and function for sending and receiving +//! DNS messages. +//! +//! Currently, the module only provides the unstable +#![cfg_attr(feature = "unstable-client-transport", doc = " [`client`]")] +#![cfg_attr(not(feature = "unstable-client-transport"), doc = " `client`")] +//! sub-module intended for sending requests and receiving responses to them. +#![cfg_attr( + not(feature = "unstable-client-transport"), + doc = " The `unstable-client-transport` feature is necessary to enable this module." +)] +//! +#![cfg(feature = "net")] +#![cfg_attr(docsrs, doc(cfg(feature = "net")))] + +pub mod client; diff --git a/src/resolv/stub/conf.rs b/src/resolv/stub/conf.rs index b40be6139..24554e487 100644 --- a/src/resolv/stub/conf.rs +++ b/src/resolv/stub/conf.rs @@ -66,11 +66,6 @@ pub struct ResolvOptions { /// it is supposed to mean. pub primary: bool, - /// Ignore trunactions errors, don’t retry with TCP. - /// - /// This option is implemented by the query. - pub ign_tc: bool, - /// Set the recursion desired bit in queries. /// /// Enabled by default. @@ -186,7 +181,6 @@ impl Default for ResolvOptions { aa_only: false, use_vc: false, primary: false, - ign_tc: false, stay_open: false, use_inet6: false, rotate: false, @@ -556,9 +550,6 @@ impl fmt::Display for ResolvConf { if self.options.primary { options.push("primary".into()) } - if self.options.ign_tc { - options.push("ign-tc".into()) - } if !self.options.recurse { options.push("no-recurse".into()) } diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index c6e5a59ff..2f6fad373 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -14,30 +14,36 @@ use self::conf::{ }; use crate::base::iana::Rcode; use crate::base::message::Message; -use crate::base::message_builder::{ - AdditionalBuilder, MessageBuilder, StreamTarget, -}; +use crate::base::message_builder::{AdditionalBuilder, MessageBuilder}; use crate::base::name::{ToDname, ToRelativeDname}; use crate::base::question::Question; +use crate::net::client::dgram_stream; +use crate::net::client::multi_stream; +use crate::net::client::protocol::{TcpConnect, UdpConnect}; +use crate::net::client::redundant; +use crate::net::client::request::{ + ComposeRequest, Error, RequestMessage, SendRequest, +}; use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs}; use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts}; use crate::resolv::lookup::srv::{lookup_srv, FoundSrvs, SrvError}; use crate::resolv::resolver::{Resolver, SearchNames}; use bytes::Bytes; +use futures_util::stream::{FuturesUnordered, StreamExt}; use octseq::array::Array; use std::boxed::Box; +use std::fmt::Debug; use std::future::Future; -use std::net::{IpAddr, SocketAddr}; +use std::net::IpAddr; use std::pin::Pin; -use std::slice::SliceIndex; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::string::ToString; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::vec::Vec; use std::{io, ops}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpStream, UdpSocket}; #[cfg(feature = "resolv-sync")] use tokio::runtime; +use tokio::sync::Mutex; use tokio::time::timeout; //------------ Sub-modules --------------------------------------------------- @@ -46,9 +52,6 @@ pub mod conf; //------------ Module Configuration ------------------------------------------ -/// How many times do we try a new random port if we get ‘address in use.’ -const RETRY_RANDOM_PORT: usize = 10; - //------------ StubResolver -------------------------------------------------- /// A DNS stub resolver. @@ -70,16 +73,14 @@ const RETRY_RANDOM_PORT: usize = 10; /// [`query()`]: #method.query /// [`run()`]: #method.run /// [`run_with_conf()`]: #method.run_with_conf -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct StubResolver { - /// Preferred servers. - preferred: ServerList, - - /// Streaming servers. - stream: ServerList, + transport: Mutex>>>>, /// Resolver options. options: ResolvOptions, + + servers: Vec, } impl StubResolver { @@ -91,11 +92,10 @@ impl StubResolver { /// Creates a new resolver using the given configuraiton. pub fn from_conf(conf: ResolvConf) -> Self { StubResolver { - preferred: ServerList::from_conf(&conf, |s| { - s.transport.is_preferred() - }), - stream: ServerList::from_conf(&conf, |s| s.transport.is_stream()), + transport: None.into(), options: conf.options, + + servers: conf.servers, } } @@ -118,6 +118,108 @@ impl StubResolver { ) -> Result { Query::new(self)?.run(message).await } + + async fn setup_transport< + CR: Clone + Debug + ComposeRequest + Send + Sync + 'static, + >( + &self, + ) -> Result, Error> { + // Create a redundant transport and fill it with the right transports + let (redun, transp) = redundant::Connection::new(); + + // Start the run function on a separate task. + let redun_run_fut = transp.run(); + + // It would be nice to have just one task. However redun.run() has to + // execute before we can call redun.add(). However, we need to know + // the type of the elements we add to FuturesUnordered. For the moment + // we have two tasks. + tokio::spawn(async move { + redun_run_fut.await; + }); + + let fut_list_tcp = FuturesUnordered::new(); + let fut_list_udp_tcp = FuturesUnordered::new(); + + // Start the tasks with empty base transports. We need redun to be + // running before we can add transports. + + // We have 3 modes of operation: use_vc: only use TCP, ign_tc: only + // UDP no fallback to TCP, and normal with is UDP falling back to TCP. + if self.options.use_vc { + for s in &self.servers { + if let Transport::Tcp = s.transport { + let (conn, tran) = multi_stream::Connection::new( + TcpConnect::new(s.addr), + ); + // Start the run function on a separate task. + let run_fut = tran.run(); + fut_list_tcp.push(async move { + run_fut.await; + }); + redun.add(Box::new(conn)).await?; + } + } + } else { + for s in &self.servers { + if let Transport::Udp = s.transport { + let udp_connect = UdpConnect::new(s.addr); + let tcp_connect = TcpConnect::new(s.addr); + let (conn, tran) = dgram_stream::Connection::new( + udp_connect, + tcp_connect, + ); + // Start the run function on a separate task. + fut_list_udp_tcp.push(async move { + tran.run().await; + }); + redun.add(Box::new(conn)).await?; + } + } + } + + tokio::spawn(async move { + run(fut_list_tcp, fut_list_udp_tcp).await; + }); + + Ok(redun) + } + + async fn get_transport( + &self, + ) -> Result>>, Error> { + let mut opt_transport = self.transport.lock().await; + + match &*opt_transport { + Some(transport) => Ok(transport.clone()), + None => { + let transport = self.setup_transport().await?; + *opt_transport = Some(transport.clone()); + Ok(transport) + } + } + } +} + +async fn run( + mut fut_list_tcp: FuturesUnordered, + mut fut_list_udp_tcp: FuturesUnordered, +) { + loop { + let tcp_empty = fut_list_tcp.is_empty(); + let udp_tcp_empty = fut_list_udp_tcp.is_empty(); + if tcp_empty && udp_tcp_empty { + break; + } + tokio::select! { + _ = fut_list_tcp.next(), if !tcp_empty => { + // Nothing to do + } + _ = fut_list_udp_tcp.next(), if !udp_tcp_empty => { + // Nothing to do + } + } + } } impl StubResolver { @@ -168,10 +270,10 @@ impl StubResolver { /// The only argument is a closure taking a reference to a `StubResolver` /// and returning a future. Whatever that future resolves to will be /// returned. - pub fn run(op: F) -> R::Output + pub fn run(op: F) -> R::Output where - R: Future + Send + 'static, - R::Output: Send + 'static, + R: Future> + Send + 'static, + E: From, F: FnOnce(StubResolver) -> R + Send + 'static, { Self::run_with_conf(ResolvConf::default(), op) @@ -183,17 +285,16 @@ impl StubResolver { /// tailor-making your own resolver. /// /// [`run()`]: #method.run - pub fn run_with_conf(conf: ResolvConf, op: F) -> R::Output + pub fn run_with_conf(conf: ResolvConf, op: F) -> R::Output where - R: Future + Send + 'static, - R::Output: Send + 'static, + R: Future> + Send + 'static, + E: From, F: FnOnce(StubResolver) -> R + Send + 'static, { let resolver = Self::from_conf(conf); let runtime = runtime::Builder::new_current_thread() .enable_all() - .build() - .unwrap(); + .build()?; runtime.block_on(op(resolver)) } } @@ -238,14 +339,7 @@ pub struct Query<'a> { /// The resolver whose configuration we are using. resolver: &'a StubResolver, - /// Are we still in the preferred server list or have gone streaming? - preferred: bool, - - /// The number of attempts, starting with zero. - attempt: usize, - - /// The index in the server list we currently trying. - counter: ServerListCounter, + edns: Arc, /// The preferred error to return. /// @@ -259,23 +353,9 @@ pub struct Query<'a> { impl<'a> Query<'a> { pub fn new(resolver: &'a StubResolver) -> Result { - let (preferred, counter) = - if resolver.options().use_vc || resolver.preferred.is_empty() { - if resolver.stream.is_empty() { - return Err(io::Error::new( - io::ErrorKind::NotFound, - "no servers available", - )); - } - (false, resolver.stream.counter(resolver.options().rotate)) - } else { - (true, resolver.preferred.counter(resolver.options().rotate)) - }; Ok(Query { resolver, - preferred, - attempt: 0, - counter, + edns: Arc::new(AtomicBool::new(true)), error: Err(io::Error::new( io::ErrorKind::TimedOut, "all timed out", @@ -291,26 +371,14 @@ impl<'a> Query<'a> { match self.run_query(&mut message).await { Ok(answer) => { if answer.header().rcode() == Rcode::FormErr - && self.current_server().does_edns() + && self.does_edns() { // FORMERR with EDNS: turn off EDNS and try again. - self.current_server().disable_edns(); + self.disable_edns(); continue; } else if answer.header().rcode() == Rcode::ServFail { // SERVFAIL: go to next server. self.update_error_servfail(answer); - } else if answer.header().tc() - && self.preferred - && !self.resolver.options().ign_tc - { - // Truncated. If we can, switch to stream transports - // and try again. Otherwise return the truncated - // answer. - if self.switch_to_stream() { - continue; - } else { - return Ok(answer); - } } else { // I guess we have an answer ... return Ok(answer); @@ -318,20 +386,16 @@ impl<'a> Query<'a> { } Err(err) => self.update_error(err), } - if !self.next_server() { - return self.error; - } + return self.error; } } fn create_message(question: Question) -> QueryMessage { - let mut message = MessageBuilder::from_target( - StreamTarget::new(Default::default()).unwrap(), - ) - .unwrap(); + let mut message = MessageBuilder::from_target(Default::default()) + .expect("MessageBuilder should not fail"); message.header_mut().set_rd(true); let mut message = message.question(); - message.push(question).unwrap(); + message.push(question).expect("push should not fail"); message.additional() } @@ -339,18 +403,22 @@ impl<'a> Query<'a> { &mut self, message: &mut QueryMessage, ) -> Result { - let server = self.current_server(); - server.prepare_message(message); - server.query(message).await - } + let msg = Message::from_octets(message.as_target().to_vec()) + .expect("Message::from_octets should not fail"); - fn current_server(&self) -> &ServerInfo { - let list = if self.preferred { - &self.resolver.preferred - } else { - &self.resolver.stream - }; - self.counter.info(list) + let request_msg = RequestMessage::new(msg); + + let transport = self.resolver.get_transport().await.map_err(|e| { + io::Error::new(io::ErrorKind::Other, e.to_string()) + })?; + let mut gr_fut = transport.send_request(request_msg); + let reply = + timeout(self.resolver.options.timeout, gr_fut.get_response()) + .await? + .map_err(|e| { + io::Error::new(io::ErrorKind::Other, e.to_string()) + })?; + Ok(Answer { message: reply }) } fn update_error(&mut self, err: io::Error) { @@ -366,41 +434,19 @@ impl<'a> Query<'a> { self.error = Ok(answer) } - fn switch_to_stream(&mut self) -> bool { - if !self.preferred { - // We already did this. - return false; - } - self.preferred = false; - self.attempt = 0; - self.counter = - self.resolver.stream.counter(self.resolver.options().rotate); - true + pub fn does_edns(&self) -> bool { + self.edns.load(Ordering::Relaxed) } - fn next_server(&mut self) -> bool { - if self.counter.next() { - return true; - } - self.attempt += 1; - if self.attempt >= self.resolver.options().attempts { - return false; - } - self.counter = if self.preferred { - self.resolver - .preferred - .counter(self.resolver.options().rotate) - } else { - self.resolver.stream.counter(self.resolver.options().rotate) - }; - true + pub fn disable_edns(&self) { + self.edns.store(false, Ordering::Relaxed); } } //------------ QueryMessage -------------------------------------------------- // XXX This needs to be re-evaluated if we start adding OPTtions to the query. -pub(super) type QueryMessage = AdditionalBuilder>>; +pub(super) type QueryMessage = AdditionalBuilder>; //------------ Answer -------------------------------------------------------- @@ -451,312 +497,6 @@ impl AsRef> for Answer { } } -//------------ ServerInfo ---------------------------------------------------- - -#[derive(Clone, Debug)] -struct ServerInfo { - /// The basic server configuration. - conf: ServerConf, - - /// Whether this server supports EDNS. - /// - /// We start out with assuming it does and unset it if we get a FORMERR. - edns: Arc, -} - -impl ServerInfo { - pub fn does_edns(&self) -> bool { - self.edns.load(Ordering::Relaxed) - } - - pub fn disable_edns(&self) { - self.edns.store(false, Ordering::Relaxed); - } - - pub fn prepare_message(&self, query: &mut QueryMessage) { - query.rewind(); - if self.does_edns() { - query - .opt(|opt| { - opt.set_udp_payload_size(self.conf.udp_payload_size); - Ok(()) - }) - .unwrap(); - } - } - - pub async fn query( - &self, - query: &QueryMessage, - ) -> Result { - let res = match self.conf.transport { - Transport::Udp => { - timeout( - self.conf.request_timeout, - Self::udp_query( - query, - self.conf.addr, - self.conf.recv_size, - ), - ) - .await - } - Transport::Tcp => { - timeout( - self.conf.request_timeout, - Self::tcp_query(query, self.conf.addr), - ) - .await - } - }; - match res { - Ok(Ok(answer)) => Ok(answer), - Ok(Err(err)) => Err(err), - Err(_) => Err(io::Error::new( - io::ErrorKind::TimedOut, - "request timed out", - )), - } - } - - pub async fn tcp_query( - query: &QueryMessage, - addr: SocketAddr, - ) -> Result { - let mut sock = TcpStream::connect(&addr).await?; - sock.write_all(query.as_target().as_stream_slice()).await?; - - // This loop can be infinite because we have a timeout on this whole - // thing, anyway. - loop { - let mut buf = Vec::new(); - let len = sock.read_u16().await? as u64; - AsyncReadExt::take(&mut sock, len) - .read_to_end(&mut buf) - .await?; - if let Ok(answer) = Message::from_octets(buf.into()) { - if answer.is_answer(&query.as_message()) { - return Ok(answer.into()); - } - // else try with the next message. - } else { - return Err(io::Error::new( - io::ErrorKind::Other, - "short buf", - )); - } - } - } - - pub async fn udp_query( - query: &QueryMessage, - addr: SocketAddr, - recv_size: usize, - ) -> Result { - let sock = Self::udp_bind(addr.is_ipv4()).await?; - sock.connect(addr).await?; - let sent = sock.send(query.as_target().as_dgram_slice()).await?; - if sent != query.as_target().as_dgram_slice().len() { - return Err(io::Error::new( - io::ErrorKind::Other, - "short UDP send", - )); - } - loop { - let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here. - let len = sock.recv(&mut buf).await?; - buf.truncate(len); - - // We ignore garbage since there is a timer on this whole thing. - let answer = match Message::from_octets(buf.into()) { - Ok(answer) => answer, - Err(_) => continue, - }; - if !answer.is_answer(&query.as_message()) { - continue; - } - return Ok(answer.into()); - } - } - - async fn udp_bind(v4: bool) -> Result { - let mut i = 0; - loop { - let local: SocketAddr = if v4 { - ([0u8; 4], 0).into() - } else { - ([0u16; 8], 0).into() - }; - match UdpSocket::bind(&local).await { - Ok(sock) => return Ok(sock), - Err(err) => { - if i == RETRY_RANDOM_PORT { - return Err(err); - } else { - i += 1 - } - } - } - } - } -} - -impl From for ServerInfo { - fn from(conf: ServerConf) -> Self { - ServerInfo { - conf, - edns: Arc::new(AtomicBool::new(true)), - } - } -} - -impl<'a> From<&'a ServerConf> for ServerInfo { - fn from(conf: &'a ServerConf) -> Self { - conf.clone().into() - } -} - -//------------ ServerList ---------------------------------------------------- - -#[derive(Clone, Debug)] -struct ServerList { - /// The actual list of servers. - servers: Vec, - - /// Where to start accessing the list. - /// - /// In rotate mode, this value will always keep growing and will have to - /// be used modulo `servers`’s length. - /// - /// When it eventually wraps around the end of usize’s range, there will - /// be a jump in rotation. Since that will happen only oh-so-often, we - /// accept that in favour of simpler code. - start: Arc, -} - -impl ServerList { - pub fn from_conf(conf: &ResolvConf, filter: F) -> Self - where - F: Fn(&ServerConf) -> bool, - { - ServerList { - servers: { - conf.servers - .iter() - .filter(|f| filter(f)) - .map(Into::into) - .collect() - }, - start: Arc::new(AtomicUsize::new(0)), - } - } - - pub fn is_empty(&self) -> bool { - self.servers.is_empty() - } - - pub fn counter(&self, rotate: bool) -> ServerListCounter { - let res = ServerListCounter::new(self); - if rotate { - self.rotate() - } - res - } - - pub fn iter(&self) -> ServerListIter { - ServerListIter::new(self) - } - - pub fn rotate(&self) { - self.start.fetch_add(1, Ordering::SeqCst); - } -} - -impl<'a> IntoIterator for &'a ServerList { - type Item = &'a ServerInfo; - type IntoIter = ServerListIter<'a>; - - fn into_iter(self) -> Self::IntoIter { - self.iter() - } -} - -impl> ops::Index for ServerList { - type Output = >::Output; - - fn index(&self, index: I) -> &>::Output { - self.servers.index(index) - } -} - -//------------ ServerListCounter --------------------------------------------- - -#[derive(Clone, Debug)] -struct ServerListCounter { - cur: usize, - end: usize, -} - -impl ServerListCounter { - fn new(list: &ServerList) -> Self { - if list.servers.is_empty() { - return ServerListCounter { cur: 0, end: 0 }; - } - - // We modulo the start value here to prevent hick-ups towards the - // end of usize’s range. - let start = list.start.load(Ordering::Relaxed) % list.servers.len(); - ServerListCounter { - cur: start, - end: start + list.servers.len(), - } - } - - #[allow(clippy::should_implement_trait)] - pub fn next(&mut self) -> bool { - let next = self.cur + 1; - if next < self.end { - self.cur = next; - true - } else { - false - } - } - - pub fn info<'a>(&self, list: &'a ServerList) -> &'a ServerInfo { - &list[self.cur % list.servers.len()] - } -} - -//------------ ServerListIter ------------------------------------------------ - -#[derive(Clone, Debug)] -struct ServerListIter<'a> { - servers: &'a ServerList, - counter: ServerListCounter, -} - -impl<'a> ServerListIter<'a> { - fn new(list: &'a ServerList) -> Self { - ServerListIter { - servers: list, - counter: ServerListCounter::new(list), - } - } -} - -impl<'a> Iterator for ServerListIter<'a> { - type Item = &'a ServerInfo; - - fn next(&mut self) -> Option { - if self.counter.next() { - Some(self.counter.info(self.servers)) - } else { - None - } - } -} - //------------ SearchIter ---------------------------------------------------- #[derive(Clone, Debug)] diff --git a/test-data/basic.rpl b/test-data/basic.rpl new file mode 100644 index 000000000..72f453fe0 --- /dev/null +++ b/test-data/basic.rpl @@ -0,0 +1,161 @@ +do-ip6: no + +; config options +; target-fetch-policy: "3 2 1 0 0" +; name: "." + stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET. +CONFIG_END + +SCENARIO_BEGIN Test iterator with NS falsely declaring referral answer as authoritative. + +; K.ROOT-SERVERS.NET. +RANGE_BEGIN 0 100 + ADDRESS 193.0.14.129 +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +. IN NS +SECTION ANSWER +. IN NS K.ROOT-SERVERS.NET. +SECTION ADDITIONAL +K.ROOT-SERVERS.NET. IN A 193.0.14.129 +ENTRY_END + +; net. +ENTRY_BEGIN +MATCH opcode qname +ADJUST copy_id copy_query +REPLY QR NOERROR +SECTION QUESTION +net. IN NS +SECTION AUTHORITY +. IN SOA . . 0 0 0 0 0 +ENTRY_END + +; root-servers.net. +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +root-servers.net. IN NS +SECTION ANSWER +root-servers.net. IN NS k.root-servers.net. +SECTION ADDITIONAL +k.root-servers.net. IN A 193.0.14.129 +ENTRY_END + +ENTRY_BEGIN +MATCH opcode qname +ADJUST copy_id copy_query +REPLY QR NOERROR +SECTION QUESTION +root-servers.net. IN A +SECTION AUTHORITY +root-servers.net. IN SOA . . 0 0 0 0 0 +ENTRY_END + +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +k.root-servers.net. IN A +SECTION ANSWER +k.root-servers.net. IN A 193.0.14.129 +SECTION ADDITIONAL +ENTRY_END + +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +k.root-servers.net. IN AAAA +SECTION AUTHORITY +root-servers.net. IN SOA . . 0 0 0 0 0 +ENTRY_END + +; gtld-servers.net. +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +gtld-servers.net. IN NS +SECTION ANSWER +gtld-servers.net. IN NS a.gtld-servers.net. +SECTION ADDITIONAL +a.gtld-servers.net. IN A 192.5.6.30 +ENTRY_END + +ENTRY_BEGIN +MATCH opcode qname +ADJUST copy_id copy_query +REPLY QR NOERROR +SECTION QUESTION +gtld-servers.net. IN A +SECTION AUTHORITY +gtld-servers.net. IN SOA . . 0 0 0 0 0 +ENTRY_END + +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +a.gtld-servers.net. IN A +SECTION ANSWER +a.gtld-servers.net. IN A 192.5.6.30 +SECTION ADDITIONAL +ENTRY_END + +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +a.gtld-servers.net. IN AAAA +SECTION AUTHORITY +gtld-servers.net. IN SOA . . 0 0 0 0 0 +ENTRY_END + +RANGE_END + +; a.gtld-servers.net. +RANGE_BEGIN 0 100 + ADDRESS 192.5.6.30 + +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id copy_query +REPLY QR RD NOERROR +SECTION QUESTION +example.com. IN A +SECTION ANSWER +example.com. IN A 93.184.216.34 +ENTRY_END + +RANGE_END + +STEP 1 QUERY +ENTRY_BEGIN +REPLY RD +SECTION QUESTION +example.com. IN A +ENTRY_END + +; recursion happens here. +STEP 10 CHECK_ANSWER +ENTRY_BEGIN +MATCH all +REPLY QR RD RA NOERROR +SECTION QUESTION +example.com. IN A +SECTION ANSWER +example.com. IN A 93.184.216.34 +ENTRY_END + +SCENARIO_END diff --git a/tests/net-client.rs b/tests/net-client.rs new file mode 100644 index 000000000..10a7ab467 --- /dev/null +++ b/tests/net-client.rs @@ -0,0 +1,148 @@ +#![cfg(feature = "net")] +mod net; + +use crate::net::deckard::client::do_client; +use crate::net::deckard::client::CurrStepValue; +use crate::net::deckard::connect::Connect; +use crate::net::deckard::connection::Connection; +use crate::net::deckard::dgram::Dgram; +use crate::net::deckard::parse_deckard::parse_file; +use domain::net::client::dgram; +use domain::net::client::dgram_stream; +use domain::net::client::multi_stream; +use domain::net::client::redundant; +use domain::net::client::stream; +use std::fs::File; +use std::net::IpAddr; +use std::net::SocketAddr; +use std::str::FromStr; +use std::sync::Arc; +use tokio::net::TcpStream; + +const TEST_FILE: &str = "test-data/basic.rpl"; + +#[test] +fn dgram() { + tokio_test::block_on(async { + let file = File::open(TEST_FILE).unwrap(); + let deckard = parse_file(file); + + let step_value = Arc::new(CurrStepValue::new()); + let conn = Dgram::new(deckard.clone(), step_value.clone()); + let octstr = dgram::Connection::new(conn); + + do_client(&deckard, octstr, &step_value).await; + }); +} + +#[test] +fn single() { + tokio_test::block_on(async { + let file = File::open(TEST_FILE).unwrap(); + let deckard = parse_file(file); + + let step_value = Arc::new(CurrStepValue::new()); + let conn = Connection::new(deckard.clone(), step_value.clone()); + let (octstr, transport) = stream::Connection::new(conn); + tokio::spawn(async move { + transport.run().await; + }); + + do_client(&deckard, octstr, &step_value).await; + }); +} + +#[test] +fn multi() { + tokio_test::block_on(async { + let file = File::open(TEST_FILE).unwrap(); + let deckard = parse_file(file); + + let step_value = Arc::new(CurrStepValue::new()); + let multi_conn = Connect::new(deckard.clone(), step_value.clone()); + let (ms, ms_tran) = multi_stream::Connection::new(multi_conn); + tokio::spawn(async move { + ms_tran.run().await; + println!("multi conn run terminated"); + }); + + do_client(&deckard, ms.clone(), &step_value).await; + }); +} + +#[test] +fn dgram_stream() { + tokio_test::block_on(async { + let file = File::open(TEST_FILE).unwrap(); + let deckard = parse_file(file); + + let step_value = Arc::new(CurrStepValue::new()); + let conn = Dgram::new(deckard.clone(), step_value.clone()); + let multi_conn = Connect::new(deckard.clone(), step_value.clone()); + let (ds, tran) = dgram_stream::Connection::new(conn, multi_conn); + tokio::spawn(async move { + tran.run().await; + println!("dgram_stream conn run terminated"); + }); + + do_client(&deckard, ds, &step_value).await; + }); +} + +#[test] +fn redundant() { + tokio_test::block_on(async { + let file = File::open(TEST_FILE).unwrap(); + let deckard = parse_file(file); + + let step_value = Arc::new(CurrStepValue::new()); + let multi_conn = Connect::new(deckard.clone(), step_value.clone()); + let (ms, ms_tran) = multi_stream::Connection::new(multi_conn); + tokio::spawn(async move { + ms_tran.run().await; + println!("multi conn run terminated"); + }); + + // Redundant add previous connection. + let (redun, transp) = redundant::Connection::new(); + let run_fut = transp.run(); + tokio::spawn(async move { + run_fut.await; + println!("redundant conn run terminated"); + }); + redun.add(Box::new(ms.clone())).await.unwrap(); + + do_client(&deckard, redun, &step_value).await; + }); +} + +#[test] +#[ignore] +// Connect directly to the internet. Disabled by default. +fn tcp() { + tokio_test::block_on(async { + let file = File::open(TEST_FILE).unwrap(); + let deckard = parse_file(file); + + let server_addr = + SocketAddr::new(IpAddr::from_str("9.9.9.9").unwrap(), 53); + + let tcp_conn = match TcpStream::connect(server_addr).await { + Ok(conn) => conn, + Err(err) => { + println!( + "TCP Connection to {server_addr} failed: {err}, exiting" + ); + return; + } + }; + + let (tcp, transport) = stream::Connection::new(tcp_conn); + tokio::spawn(async move { + transport.run().await; + println!("single TCP run terminated"); + }); + + do_client(&deckard, tcp, &CurrStepValue::new()).await; + }); +} diff --git a/tests/net/deckard/client.rs b/tests/net/deckard/client.rs new file mode 100644 index 000000000..5929c9d7e --- /dev/null +++ b/tests/net/deckard/client.rs @@ -0,0 +1,90 @@ +use crate::net::deckard::matches::match_msg; +use crate::net::deckard::parse_deckard::{Deckard, Entry, Reply, StepType}; +use crate::net::deckard::parse_query; +use bytes::Bytes; + +use domain::base::{Message, MessageBuilder}; +use domain::net::client::request::{RequestMessage, SendRequest}; +use std::sync::Mutex; + +pub async fn do_client>>>( + deckard: &Deckard, + request: R, + step_value: &CurrStepValue, +) { + let mut resp: Option> = None; + + // Assume steps are in order. Maybe we need to define that. + for step in &deckard.scenario.steps { + step_value.set(step.step_value); + match step.step_type { + StepType::Query => { + let reqmsg = entry2reqmsg(step.entry.as_ref().unwrap()); + let mut req = request.send_request(reqmsg); + resp = Some(req.get_response().await.unwrap()); + } + StepType::CheckAnswer => { + let answer = resp.take().unwrap(); + if !match_msg(step.entry.as_ref().unwrap(), &answer, true) { + panic!("reply failed"); + } + } + StepType::TimePasses + | StepType::Traffic + | StepType::CheckTempfile + | StepType::Assign => todo!(), + } + } + println!("Done"); +} + +fn entry2reqmsg(entry: &Entry) -> RequestMessage> { + let sections = entry.sections.as_ref().unwrap(); + let mut msg = MessageBuilder::new_vec().question(); + for q in §ions.question { + let question = match q { + parse_query::Entry::QueryRecord(question) => question, + _ => todo!(), + }; + msg.push(question).unwrap(); + } + let msg = msg.answer(); + for _a in §ions.answer { + todo!(); + } + let msg = msg.authority(); + for _a in §ions.authority { + todo!(); + } + let mut msg = msg.additional(); + for _a in §ions.additional { + todo!(); + } + let reply: Reply = match &entry.reply { + Some(reply) => reply.clone(), + None => Default::default(), + }; + if reply.rd { + msg.header_mut().set_rd(true); + } + let msg = msg.into_message(); + RequestMessage::new(msg) +} + +#[derive(Debug)] +pub struct CurrStepValue { + v: Mutex, +} + +impl CurrStepValue { + pub fn new() -> Self { + Self { v: 0.into() } + } + fn set(&self, v: u64) { + let mut self_v = self.v.lock().unwrap(); + *self_v = v; + } + pub fn get(&self) -> u64 { + *(self.v.lock().unwrap()) + } +} diff --git a/tests/net/deckard/connect.rs b/tests/net/deckard/connect.rs new file mode 100644 index 000000000..287710290 --- /dev/null +++ b/tests/net/deckard/connect.rs @@ -0,0 +1,34 @@ +use crate::net::deckard::client::CurrStepValue; +use crate::net::deckard::connection::Connection; +use crate::net::deckard::parse_deckard::Deckard; +use domain::net::client::protocol::AsyncConnect; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +pub struct Connect { + deckard: Deckard, + step_value: Arc, +} + +impl Connect { + pub fn new(deckard: Deckard, step_value: Arc) -> Connect { + Self { + deckard, + step_value, + } + } +} + +impl AsyncConnect for Connect { + type Connection = Connection; + type Fut = Pin< + Box> + Send>, + >; + + fn connect(&self) -> Self::Fut { + let deckard = self.deckard.clone(); + let step_value = self.step_value.clone(); + Box::pin(async move { Ok(Connection::new(deckard, step_value)) }) + } +} diff --git a/tests/net/deckard/connection.rs b/tests/net/deckard/connection.rs new file mode 100644 index 000000000..d106a5395 --- /dev/null +++ b/tests/net/deckard/connection.rs @@ -0,0 +1,110 @@ +use crate::net::deckard::client::CurrStepValue; +use crate::net::deckard::parse_deckard::Deckard; +use crate::net::deckard::server::do_server; +use domain::base::Message; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Context; +use std::task::Poll; +use std::task::Waker; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::io::ReadBuf; + +#[derive(Debug)] +pub struct Connection { + deckard: Deckard, + step_value: Arc, + waker: Option, + reply: Option>>, + send_body: bool, + + tmpbuf: Vec, +} + +impl Connection { + pub fn new( + deckard: Deckard, + step_value: Arc, + ) -> Connection { + Self { + deckard, + step_value, + waker: None, + reply: None, + send_body: false, + tmpbuf: Vec::new(), + } + } +} + +impl AsyncRead for Connection { + fn poll_read( + mut self: Pin<&mut Self>, + context: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.reply.is_some() { + let slice = self.reply.as_ref().unwrap().as_slice(); + let len = slice.len(); + if self.send_body { + buf.put_slice(slice); + self.reply = None; + return Poll::Ready(Ok(())); + } else { + buf.put_slice(&(len as u16).to_be_bytes()); + self.send_body = true; + return Poll::Ready(Ok(())); + } + } + self.reply = None; + self.send_body = false; + self.waker = Some(context.waker().clone()); + Poll::Pending + } +} + +impl AsyncWrite for Connection { + fn poll_write( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.tmpbuf.push(buf[0]); + let buflen = self.tmpbuf.len(); + if buflen < 2 { + return Poll::Ready(Ok(1)); + } + let mut len_str: [u8; 2] = [0; 2]; + len_str.copy_from_slice(&self.tmpbuf[0..2]); + let len = u16::from_be_bytes(len_str) as usize; + if buflen != 2 + len { + return Poll::Ready(Ok(1)); + } + let msg = Message::from_octets(self.tmpbuf[2..].to_vec()).unwrap(); + self.tmpbuf = Vec::new(); + let opt_reply = do_server(&msg, &self.deckard, &self.step_value); + if opt_reply.is_some() { + // Do we need to support more than one reply? + self.reply = opt_reply; + let opt_waker = self.waker.take(); + if let Some(waker) = opt_waker { + waker.wake(); + } + } + Poll::Ready(Ok(1)) + } + fn poll_flush( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + todo!() + } + fn poll_shutdown( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + // Do we need to do anything here? + Poll::Ready(Ok(())) + } +} diff --git a/tests/net/deckard/dgram.rs b/tests/net/deckard/dgram.rs new file mode 100644 index 000000000..6fd0eb33f --- /dev/null +++ b/tests/net/deckard/dgram.rs @@ -0,0 +1,108 @@ +//! Provide server-side of datagram protocols + +use crate::net::deckard::client::CurrStepValue; +use crate::net::deckard::parse_deckard::Deckard; +use crate::net::deckard::server::do_server; +use domain::base::Message; +use domain::net::client::protocol::{ + AsyncConnect, AsyncDgramRecv, AsyncDgramSend, +}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::Mutex as SyncMutex; +use std::task::{Context, Poll, Waker}; +use tokio::io::ReadBuf; + +#[derive(Clone, Debug)] +pub struct Dgram { + deckard: Deckard, + step_value: Arc, +} + +impl Dgram { + pub fn new(deckard: Deckard, step_value: Arc) -> Self { + Self { + deckard, + step_value, + } + } +} + +impl AsyncConnect for Dgram { + type Connection = DgramConnection; + type Fut = Pin< + Box< + dyn Future> + + Send, + >, + >; + fn connect(&self) -> Self::Fut { + let deckard = self.deckard.clone(); + let step_value = self.step_value.clone(); + Box::pin(async move { Ok(DgramConnection::new(deckard, step_value)) }) + } +} + +pub struct DgramConnection { + deckard: Deckard, + step_value: Arc, + + reply: SyncMutex>>>, + waker: SyncMutex>, +} + +impl DgramConnection { + fn new(deckard: Deckard, step_value: Arc) -> Self { + Self { + deckard, + step_value, + reply: SyncMutex::new(None), + waker: SyncMutex::new(None), + } + } +} +impl AsyncDgramRecv for DgramConnection { + fn poll_recv( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let mut reply = self.reply.lock().unwrap(); + if (*reply).is_some() { + let slice = (*reply).as_ref().unwrap().as_slice(); + buf.put_slice(slice); + *reply = None; + return Poll::Ready(Ok(())); + } + *reply = None; + let mut waker = self.waker.lock().unwrap(); + *waker = Some(cx.waker().clone()); + Poll::Pending + } +} + +impl AsyncDgramSend for DgramConnection { + fn poll_send( + &self, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let msg = Message::from_octets(buf).unwrap(); + let opt_reply = do_server(&msg, &self.deckard, &self.step_value); + let len = buf.len(); + if opt_reply.is_some() { + // Do we need to support more than one reply? + let mut reply = self.reply.lock().unwrap(); + *reply = opt_reply; + drop(reply); + let mut waker = self.waker.lock().unwrap(); + let opt_waker = (*waker).take(); + drop(waker); + if let Some(waker) = opt_waker { + waker.wake(); + } + } + Poll::Ready(Ok(len)) + } +} diff --git a/tests/net/deckard/matches.rs b/tests/net/deckard/matches.rs new file mode 100644 index 000000000..1da791495 --- /dev/null +++ b/tests/net/deckard/matches.rs @@ -0,0 +1,305 @@ +use crate::net::deckard::parse_deckard::{Entry, Matches, Reply}; +use crate::net::deckard::parse_query; +use domain::base::iana::Opcode; +use domain::base::iana::OptRcode; +use domain::base::iana::Rtype; +use domain::base::Message; +use domain::base::ParsedDname; +use domain::base::QuestionSection; +use domain::base::RecordSection; +use domain::dep::octseq::Octets; +use domain::rdata::ZoneRecordData; +use domain::zonefile::inplace::Entry as ZonefileEntry; +//use std::fmt::Debug; + +pub fn match_msg<'a, Octs: AsRef<[u8]> + Clone + Octets + 'a>( + entry: &Entry, + msg: &'a Message, + verbose: bool, +) -> bool +where + ::Range<'a>: Clone, +{ + let sections = entry.sections.as_ref().unwrap(); + + let mut matches: Matches = match &entry.matches { + Some(matches) => matches.clone(), + None => Default::default(), + }; + + let reply: Reply = match &entry.reply { + Some(reply) => reply.clone(), + None => Default::default(), + }; + + if matches.all { + matches.opcode = true; + matches.qtype = true; + matches.qname = true; + matches.flags = true; + matches.rcode = true; + matches.answer = true; + matches.authority = true; + matches.additional = true; + } + + if matches.question { + matches.qtype = true; + matches.qname = true; + } + + if matches.additional { + let mut arcount = msg.header_counts().arcount(); + if msg.opt().is_some() { + arcount -= 1; + } + if !match_section( + sections.additional.clone(), + msg.additional().unwrap(), + arcount, + verbose, + ) { + if verbose { + println!("match_msg: additional section does not match"); + } + return false; + } + } + if matches.answer + && !match_section( + sections.answer.clone(), + msg.answer().unwrap(), + msg.header_counts().ancount(), + verbose, + ) + { + if verbose { + todo!(); + } + return false; + } + if matches.authority + && !match_section( + sections.authority.clone(), + msg.authority().unwrap(), + msg.header_counts().nscount(), + verbose, + ) + { + if verbose { + todo!(); + } + return false; + } + if matches.fl_do { + todo!(); + } + if matches.flags { + let header = msg.header(); + if reply.qr != header.qr() { + if verbose { + todo!(); + } + return false; + } + if reply.aa != header.aa() { + if verbose { + println!( + "match_msg: AA does not match, got {}, expected {}", + header.aa(), + reply.aa + ); + } + return false; + } + if reply.tc != header.tc() { + if verbose { + todo!(); + } + return false; + } + if reply.rd != header.rd() { + if verbose { + println!( + "match_msg: RD does not match, got {}, expected {}", + header.aa(), + reply.aa + ); + } + return false; + } + if reply.ad != header.ad() { + if verbose { + todo!(); + } + return false; + } + if reply.cd != header.cd() { + if verbose { + todo!(); + } + return false; + } + } + if matches.opcode { + // Not clear what that means. JUst check if it is Query + if msg.header().opcode() != Opcode::Query { + if verbose { + todo!(); + } + return false; + } + } + if (matches.qname || matches.qtype) + && !match_question( + sections.question.clone(), + msg.question(), + matches.qname, + matches.qtype, + ) + { + if verbose { + println!("match_msg: question section does not match"); + } + return false; + } + if matches.rcode { + let msg_rcode = + get_opt_rcode(&Message::from_octets(msg.as_slice()).unwrap()); + if reply.noerror { + if let OptRcode::NoError = msg_rcode { + // Okay + } else { + if verbose { + todo!(); + } + return false; + } + } else { + println!("reply {reply:?}"); + panic!("no rcode to match?"); + } + } + if matches.subdomain { + todo!() + } + if matches.tcp { + todo!() + } + if matches.ttl { + todo!() + } + if matches.udp { + todo!() + } + + // All checks passed! + true +} + +fn match_section< + 'a, + Octs: Clone + Octets = Octs2> + 'a, + Octs2: AsRef<[u8]> + Clone, +>( + mut match_section: Vec, + msg_section: RecordSection<'a, Octs>, + msg_count: u16, + verbose: bool, +) -> bool { + if match_section.len() != msg_count.into() { + if verbose { + todo!(); + } + return false; + } + 'outer: for msg_rr in msg_section { + let msg_rr = msg_rr.unwrap(); + if msg_rr.rtype() == Rtype::Opt { + continue; + } + for (index, mat_rr) in match_section.iter().enumerate() { + // Remove outer Record + let mat_rr = if let ZonefileEntry::Record(record) = mat_rr { + record + } else { + panic!("include not expected"); + }; + if msg_rr.owner() != mat_rr.owner() { + continue; + } + if msg_rr.class() != mat_rr.class() { + continue; + } + if msg_rr.rtype() != mat_rr.rtype() { + continue; + } + let msg_rdata = msg_rr + .clone() + .into_record::>>() + .unwrap() + .unwrap(); + if msg_rdata.data() != mat_rr.data() { + continue; + } + + // Found one. Delete this entry + match_section.swap_remove(index); + continue 'outer; + } + // Nothing matches + if verbose { + println!( + "no match for record {} {} {}", + msg_rr.owner(), + msg_rr.class(), + msg_rr.rtype() + ); + } + return false; + } + // All entries in the reply were matched. + true +} + +fn match_question( + match_section: Vec, + msg_section: QuestionSection<'_, Octs>, + match_qname: bool, + match_qtype: bool, +) -> bool { + if match_section.is_empty() { + // Nothing to match. + return true; + } + for msg_rr in msg_section { + let msg_rr = msg_rr.unwrap(); + let mat_rr = if let parse_query::Entry::QueryRecord(record) = + &match_section[0] + { + record + } else { + panic!("include not expected"); + }; + if match_qname && msg_rr.qname() != mat_rr.qname() { + return false; + } + if match_qtype && msg_rr.qtype() != mat_rr.qtype() { + return false; + } + } + // All entries in the reply were matched. + true +} + +fn get_opt_rcode(msg: &Message) -> OptRcode { + let opt = msg.opt(); + match opt { + Some(opt) => opt.rcode(msg.header()), + None => { + // Convert Rcode to OptRcode, this should be part of + // OptRcode + OptRcode::from_int(msg.header().rcode().to_int() as u16) + } + } +} diff --git a/tests/net/deckard/mod.rs b/tests/net/deckard/mod.rs new file mode 100644 index 000000000..c61857ce9 --- /dev/null +++ b/tests/net/deckard/mod.rs @@ -0,0 +1,8 @@ +pub mod client; +pub mod connect; +pub mod connection; +pub mod dgram; +mod matches; +pub mod parse_deckard; +mod parse_query; +mod server; diff --git a/tests/net/deckard/parse_deckard.rs b/tests/net/deckard/parse_deckard.rs new file mode 100644 index 000000000..b7fe2fb54 --- /dev/null +++ b/tests/net/deckard/parse_deckard.rs @@ -0,0 +1,609 @@ +use std::default::Default; +use std::fmt::Debug; +use std::io::{self, BufRead, Read}; +use std::net::IpAddr; + +use crate::net::deckard::parse_query; +use crate::net::deckard::parse_query::Zonefile as QueryZonefile; +use domain::zonefile::inplace::Entry as ZonefileEntry; +use domain::zonefile::inplace::Zonefile; + +const CONFIG_END: &str = "CONFIG_END"; +const SCENARIO_BEGIN: &str = "SCENARIO_BEGIN"; +const SCENARIO_END: &str = "SCENARIO_END"; +const RANGE_BEGIN: &str = "RANGE_BEGIN"; +const RANGE_END: &str = "RANGE_END"; +const ADDRESS: &str = "ADDRESS"; +const ENTRY_BEGIN: &str = "ENTRY_BEGIN"; +const ENTRY_END: &str = "ENTRY_END"; +const MATCH: &str = "MATCH"; +const ADJUST: &str = "ADJUST"; +const REPLY: &str = "REPLY"; +const SECTION: &str = "SECTION"; +const QUESTION: &str = "QUESTION"; +const ANSWER: &str = "ANSWER"; +const AUTHORITY: &str = "AUTHORITY"; +const ADDITIONAL: &str = "ADDITIONAL"; +const STEP: &str = "STEP"; +const STEP_TYPE_QUERY: &str = "QUERY"; +const STEP_TYPE_CHECK_ANSWER: &str = "CHECK_ANSWER"; +const STEP_TYPE_TIME_PASSES: &str = "TIME_PASSES"; +const STEP_TYPE_TRAFFIC: &str = "TRAFFIC"; +const STEP_TYPE_CHECK_TEMPFILE: &str = "CHECK_TEMPFILE"; +const STEP_TYPE_ASSIGN: &str = "ASSIGN"; + +enum Section { + Question, + Answer, + Authority, + Additional, +} + +#[derive(Clone, Debug)] +pub enum StepType { + Query, + CheckAnswer, + TimePasses, + Traffic, + CheckTempfile, + Assign, +} + +#[derive(Clone, Debug, Default)] +pub struct Config { + lines: Vec, +} + +#[derive(Clone, Debug)] +pub struct Deckard { + pub config: Config, + pub scenario: Scenario, +} + +pub fn parse_file(file: F) -> Deckard { + let mut lines = io::BufReader::new(file).lines(); + Deckard { + config: parse_config(&mut lines), + scenario: parse_scenario(&mut lines), + } +} + +fn parse_config>>( + l: &mut Lines, +) -> Config { + let mut config: Config = Default::default(); + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + if clean_line == CONFIG_END { + break; + } + config.lines.push(clean_line.to_string()); + } + config +} + +#[derive(Clone, Debug, Default)] +pub struct Scenario { + pub ranges: Vec, + pub steps: Vec, +} + +pub fn parse_scenario< + Lines: Iterator>, +>( + l: &mut Lines, +) -> Scenario { + let mut scenario: Scenario = Default::default(); + // Find SCENARIO_BEGIN + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == SCENARIO_BEGIN { + break; + } + println!("parse_scenario: garbage line {clean_line:?}"); + panic!("bad line"); + } + + // Find RANGE_BEGIN, STEP, or SCENARIO_END + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == RANGE_BEGIN { + scenario.ranges.push(parse_range(tokens, l)); + continue; + } + if token == STEP { + scenario.steps.push(parse_step(tokens, l)); + continue; + } + if token == SCENARIO_END { + break; + } + todo!(); + } + scenario +} + +#[derive(Clone, Debug, Default)] +pub struct Range { + pub start_value: u64, + pub end_value: u64, + addr: Option, + pub entry: Vec, +} + +fn parse_range>>( + mut tokens: LineTokens<'_>, + l: &mut Lines, +) -> Range { + let mut range: Range = Range { + start_value: tokens.next().unwrap().parse::().unwrap(), + end_value: tokens.next().unwrap().parse::().unwrap(), + ..Default::default() + }; + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == ADDRESS { + let addr_str = tokens.next().unwrap(); + range.addr = Some(addr_str.parse().unwrap()); + continue; + } + if token == ENTRY_BEGIN { + range.entry.push(parse_entry(l)); + continue; + } + if token == RANGE_END { + break; + } + todo!(); + } + //println!("parse_range: {:?}", range); + range +} + +#[derive(Clone, Debug)] +pub struct Step { + pub step_value: u64, + pub step_type: StepType, + pub entry: Option, +} + +fn parse_step>>( + mut tokens: LineTokens<'_>, + l: &mut Lines, +) -> Step { + let step_value = tokens.next().unwrap().parse::().unwrap(); + let step_type_str = tokens.next().unwrap(); + let step_type = if step_type_str == STEP_TYPE_QUERY { + StepType::Query + } else if step_type_str == STEP_TYPE_CHECK_ANSWER { + StepType::CheckAnswer + } else if step_type_str == STEP_TYPE_TIME_PASSES { + StepType::TimePasses + } else if step_type_str == STEP_TYPE_TRAFFIC { + StepType::Traffic + } else if step_type_str == STEP_TYPE_CHECK_TEMPFILE { + StepType::CheckTempfile + } else if step_type_str == STEP_TYPE_ASSIGN { + StepType::Assign + } else { + todo!(); + }; + let mut step = Step { + step_value, + step_type, + entry: None, + }; + + match step.step_type { + StepType::Query => (), // Continue with entry + StepType::CheckAnswer => (), // Continue with entry + StepType::TimePasses => { + println!("parse_step: should handle TIME_PASSES"); + return step; + } + StepType::Traffic => { + println!("parse_step: should handle TRAFFIC"); + return step; + } + StepType::CheckTempfile => { + println!("parse_step: should handle CHECK_TEMPFILE"); + return step; + } + StepType::Assign => { + println!("parse_step: should handle ASSIGN"); + return step; + } + } + + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == ENTRY_BEGIN { + step.entry = Some(parse_entry(l)); + //println!("parse_step: {:?}", step); + return step; + } + todo!(); + } +} + +#[derive(Clone, Debug, Default)] +pub struct Entry { + pub matches: Option, + pub adjust: Option, + pub reply: Option, + pub sections: Option, +} + +fn parse_entry>>( + l: &mut Lines, +) -> Entry { + let mut entry = Entry { + matches: None, + adjust: None, + reply: None, + sections: None, + }; + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == MATCH { + entry.matches = Some(parse_match(tokens)); + continue; + } + if token == ADJUST { + entry.adjust = Some(parse_adjust(tokens)); + continue; + } + if token == REPLY { + entry.reply = Some(parse_reply(tokens)); + continue; + } + if token == SECTION { + let (sections, line) = parse_section(tokens, l); + //println!("parse_entry: sections {:?}", sections); + entry.sections = Some(sections); + let clean_line = get_clean_line(line.as_ref()); + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == ENTRY_END { + break; + } + todo!(); + } + if token == ENTRY_END { + break; + } + todo!(); + } + entry +} + +#[derive(Clone, Debug)] +pub struct Sections { + pub question: Vec, + pub answer: Vec, + pub authority: Vec, + pub additional: Vec, +} + +fn parse_section>>( + mut tokens: LineTokens<'_>, + l: &mut Lines, +) -> (Sections, String) { + let mut sections = Sections { + question: Vec::new(), + answer: Vec::new(), + authority: Vec::new(), + additional: Vec::new(), + }; + let next = tokens.next().unwrap(); + let mut section = if next == QUESTION { + Section::Question + } else { + panic!("Bad section {next}"); + }; + // Should extract which section + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == SECTION { + let next = tokens.next().unwrap(); + section = if next == QUESTION { + Section::Question + } else if next == ANSWER { + Section::Answer + } else if next == AUTHORITY { + Section::Authority + } else if next == ADDITIONAL { + Section::Additional + } else { + panic!("Bad section {next}"); + }; + continue; + } + if token == ENTRY_END { + return (sections, line); + } + + match section { + Section::Question => { + let mut zonefile = QueryZonefile::new(); + zonefile.extend_from_slice(clean_line.as_ref()); + zonefile.extend_from_slice(b"\n"); + let e = zonefile.next_entry().unwrap(); + sections.question.push(e.unwrap()); + } + Section::Answer | Section::Authority | Section::Additional => { + let mut zonefile = Zonefile::new(); + zonefile.extend_from_slice(b"$ORIGIN .\n"); + zonefile.extend_from_slice(b"ignore 3600 in ns ignore\n"); + zonefile.extend_from_slice(clean_line.as_ref()); + zonefile.extend_from_slice(b"\n"); + let _e = zonefile.next_entry().unwrap(); + let e = zonefile.next_entry().unwrap(); + + let e = e.unwrap(); + match section { + Section::Question => panic!("should not be here"), + Section::Answer => sections.answer.push(e), + Section::Authority => sections.authority.push(e), + Section::Additional => sections.additional.push(e), + } + } + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct Matches { + pub additional: bool, + pub all: bool, + pub answer: bool, + pub authority: bool, + pub fl_do: bool, + pub flags: bool, + pub opcode: bool, + pub qname: bool, + pub qtype: bool, + pub question: bool, + pub rcode: bool, + pub subdomain: bool, + pub tcp: bool, + pub ttl: bool, + pub udp: bool, +} + +fn parse_match(mut tokens: LineTokens<'_>) -> Matches { + let mut matches: Matches = Default::default(); + + loop { + let token = match tokens.next() { + None => return matches, + Some(token) => token, + }; + + if token == "all" { + matches.all = true; + } else if token == "DO" { + matches.fl_do = true; + } else if token == "opcode" { + matches.opcode = true; + } else if token == "qname" { + matches.qname = true; + } else if token == "question" { + matches.question = true; + } else if token == "qtype" { + matches.qtype = true; + } else if token == "subdomain" { + matches.subdomain = true; + } else if token == "TCP" { + matches.tcp = true; + } else if token == "ttl" { + matches.ttl = true; + } else if token == "UDP" { + matches.tcp = true; + } else { + println!("should handle match {token:?}"); + todo!(); + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct Adjust { + pub copy_id: bool, + pub copy_query: bool, +} + +fn parse_adjust(mut tokens: LineTokens<'_>) -> Adjust { + let mut adjust: Adjust = Default::default(); + + loop { + let token = match tokens.next() { + None => return adjust, + Some(token) => token, + }; + + if token == "copy_id" { + adjust.copy_id = true; + } else if token == "copy_query" { + adjust.copy_query = true; + } else { + println!("should handle adjust {token:?}"); + todo!(); + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct Reply { + pub aa: bool, + pub ad: bool, + pub cd: bool, + pub fl_do: bool, + pub formerr: bool, + pub noerror: bool, + pub nxdomain: bool, + pub qr: bool, + pub ra: bool, + pub rd: bool, + pub refused: bool, + pub servfail: bool, + pub tc: bool, + pub yxdomain: bool, +} + +fn parse_reply(mut tokens: LineTokens<'_>) -> Reply { + let mut reply: Reply = Default::default(); + + loop { + let token = match tokens.next() { + None => return reply, + Some(token) => token, + }; + + if token == "AA" { + reply.aa = true; + } else if token == "AD" { + reply.ad = true; + } else if token == "CD" { + reply.cd = true; + } else if token == "DO" { + reply.fl_do = true; + } else if token == "FORMERR" { + reply.formerr = true; + } else if token == "NOERROR" { + reply.noerror = true; + } else if token == "NXDOMAIN" { + reply.nxdomain = true; + } else if token == "QR" { + reply.qr = true; + } else if token == "RA" { + reply.ra = true; + } else if token == "RD" { + reply.rd = true; + } else if token == "REFUSED" { + reply.refused = true; + } else if token == "SERVFAIL" { + reply.servfail = true; + } else if token == "TC" { + reply.tc = true; + } else if token == "YXDOMAIN" { + reply.yxdomain = true; + } else { + println!("should handle reply {token:?}"); + todo!(); + } + } +} + +fn get_clean_line(line: &str) -> Option<&str> { + //println!("get clean line for {:?}", line); + let opt_comment = line.find(';'); + let line = if let Some(index) = opt_comment { + &line[0..index] + } else { + line + }; + let trimmed = line.trim(); + + //println!("line after trim() {:?}", trimmed); + + if trimmed.is_empty() { + None + } else { + Some(trimmed) + } +} + +struct LineTokens<'a> { + str: &'a str, + curr_index: usize, +} + +impl<'a> LineTokens<'a> { + fn new(str: &'a str) -> Self { + Self { str, curr_index: 0 } + } +} + +impl<'a> Iterator for LineTokens<'a> { + type Item = &'a str; + fn next(&mut self) -> Option { + let cur_str = &self.str[self.curr_index..]; + + if cur_str.is_empty() { + return None; + } + + // Assume cur_str starts with a token + for (index, char) in cur_str.char_indices() { + if !char.is_whitespace() { + continue; + } + let start_index = self.curr_index; + let end_index = start_index + index; + + let space_str = &self.str[end_index..]; + + for (index, char) in space_str.char_indices() { + if char.is_whitespace() { + continue; + } + + self.curr_index = end_index + index; + return Some(&self.str[start_index..end_index]); + } + + todo!(); + } + self.curr_index = self.str.len(); + Some(cur_str) + } +} diff --git a/tests/net/deckard/parse_query.rs b/tests/net/deckard/parse_query.rs new file mode 100644 index 000000000..6bf22e4f2 --- /dev/null +++ b/tests/net/deckard/parse_query.rs @@ -0,0 +1,1501 @@ +//! A zonefile scanner keeping data in place. +//! +//! The zonefile scanner provided by this module reads the entire zonefile +//! into memory and tries as much as possible to modify re-use this memory +//! when scanning data. It uses the `Bytes` family of types for safely +//! storing, manipulating, and returning the data and thus requires the +//! `bytes` feature to be enabled. +//! +//! This may or may not be a good strategy. It was primarily implemented to +//! see that the [`Scan`] trait is powerful enough to build such an +//! implementation. +// #![cfg(feature = "bytes")] +// #![cfg_attr(docsrs, doc(cfg(feature = "bytes")))] + +use bytes::buf::UninitSlice; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use core::str::FromStr; +use core::{fmt, str}; +use domain::base::charstr::CharStr; +use domain::base::iana::{Class, Rtype}; +use domain::base::name::{Chain, Dname, RelativeDname, ToDname}; +use domain::base::scan::{ + BadSymbol, ConvertSymbols, EntrySymbol, Scan, Scanner, ScannerError, + Symbol, SymbolOctetsError, +}; +use domain::base::Question; +use domain::base::Ttl; +use domain::dep::octseq::str::Str; + +//------------ Type Aliases -------------------------------------------------- + +/// The type used for scanned domain names. +pub type ScannedDname = Chain, Dname>; + +/// The type used for scanned records. + +pub type ScannedQueryRecord = Question; + +/// The type used for scanned strings. +pub type ScannedString = Str; + +//------------ Zonefile ------------------------------------------------------ + +/// A zonefile to be scanned. +/// +/// A value of this types holds data to be scanned in memory and allows +/// fetching entries by acting as an iterator. +/// +/// The type implements the `bytes::BufMut` trait for appending data directly +/// into the memory buffer. The function [`load`][Self::load] can be used to +/// create a value directly from a reader. +/// +/// Once data has been added, you can simply iterate over the value to +/// get entries. The [`next_entry`][Self::next_entry] method provides an +/// alternative with a more question mark friendly signature. +#[derive(Clone, Debug)] +pub struct Zonefile { + /// This is where we keep the data of the next entry. + buf: SourceBuf, + + /// The current origin. + origin: Option>, + + /// The last owner. + last_owner: Option, + + /// The last TTL. + last_ttl: Option, + + /// The last class. + last_class: Option, +} + +impl Zonefile { + /// Creates a new, empty value. + pub fn new() -> Self { + Self::with_buf(SourceBuf::with_empty_buf(BytesMut::new())) + } + + /// Creates a new, empty value with the given capacity. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_buf(SourceBuf::with_empty_buf(BytesMut::with_capacity( + capacity + 1, + ))) + } + + /// Creates a new value using the given buffer. + fn with_buf(buf: SourceBuf) -> Self { + Zonefile { + buf, + origin: Some(Dname::root_bytes()), + last_owner: None, + last_ttl: Some(Ttl::ZERO), + last_class: None, + } + } +} + +impl Default for Zonefile { + fn default() -> Self { + Self::new() + } +} + +impl<'a> From<&'a str> for Zonefile { + fn from(src: &'a str) -> Self { + Self::from(src.as_bytes()) + } +} + +impl<'a> From<&'a [u8]> for Zonefile { + fn from(src: &'a [u8]) -> Self { + let mut res = Self::with_capacity(src.len() + 1); + res.extend_from_slice(src); + res + } +} + +impl Zonefile { + /// Appends the given slice to the end of the buffer. + pub fn extend_from_slice(&mut self, slice: &[u8]) { + self.buf.buf.extend_from_slice(slice) + } +} + +unsafe impl BufMut for Zonefile { + fn remaining_mut(&self) -> usize { + self.buf.buf.remaining_mut() + } + + unsafe fn advance_mut(&mut self, cnt: usize) { + self.buf.buf.advance_mut(cnt); + } + + fn chunk_mut(&mut self) -> &mut UninitSlice { + self.buf.buf.chunk_mut() + } +} + +impl Zonefile { + /// Returns the next entry in the zonefile. + /// + /// Returns `Ok(None)` if the end of the file has been reached. Returns + /// an error if scanning the next entry failed. + /// + /// This method is identical to the `next` method of the iterator + /// implementation but has the return type transposed for easier use + /// with the question mark operator. + pub fn next_entry(&mut self) -> Result, Error> { + loop { + match EntryScanner::new(self)?.scan_entry()? { + ScannedEntry::Entry(entry) => return Ok(Some(entry)), + ScannedEntry::Origin(origin) => self.origin = Some(origin), + ScannedEntry::Ttl(ttl) => self.last_ttl = Some(ttl), + ScannedEntry::Empty => {} + ScannedEntry::Eof => return Ok(None), + } + } + } + + /// Returns the origin name of the zonefile. + fn get_origin(&self) -> Result, EntryError> { + self.origin + .as_ref() + .cloned() + .ok_or_else(EntryError::missing_origin) + } +} + +impl Iterator for Zonefile { + type Item = Result; + + fn next(&mut self) -> Option { + self.next_entry().transpose() + } +} + +//------------ Entry --------------------------------------------------------- + +/// An entry of a zonefile. +#[derive(Clone, Debug)] +pub enum Entry { + /// A DNS record. + QueryRecord(ScannedQueryRecord), + + /// An include directive. + /// + /// When this entry is encountered, the referenced file should be scanned + /// next. If `origin` is given, this file should be scanned with it as the + /// initial origin name, + Include { + /// The path to the file to be included. + path: ScannedString, + + /// The initial origin name of the included file, if provided. + origin: Option>, + }, +} + +//------------ ScannedEntry -------------------------------------------------- + +/// A raw scanned entry of a zonefile. +/// +/// This includes all the entry types that we can handle internally and don’t +/// have to bubble up to the user. +#[derive(Clone, Debug)] +#[allow(clippy::large_enum_variant)] +enum ScannedEntry { + /// An entry that should be handed to the user. + Entry(Entry), + + /// An `$ORIGIN` directive changing the origin name. + Origin(Dname), + + /// A `$TTL` directive changing the default TTL if it isn’t given. + Ttl(Ttl), + + /// An empty entry. + Empty, + + /// The end of file was reached. + Eof, +} + +//------------ EntryScanner -------------------------------------------------- + +/// The entry scanner for a zonefile. +/// +/// A value of this type is created for each entry. It implements the +/// [`Scanner`] interface. +#[derive(Debug)] +struct EntryScanner<'a> { + /// The zonefile we are working on. + zonefile: &'a mut Zonefile, +} + +impl<'a> EntryScanner<'a> { + /// Creates a new entry scanner using the given zonefile. + fn new(zonefile: &'a mut Zonefile) -> Result { + Ok(EntryScanner { zonefile }) + } + + /// Scans a single entry from the zone file. + fn scan_entry(&mut self) -> Result { + self._scan_entry() + .map_err(|err| self.zonefile.buf.error(err)) + } + + /// Scans a single entry from the zone file. + /// + /// This is identical to `scan_entry` but with a more convenient error + /// type. + fn _scan_entry(&mut self) -> Result { + self.zonefile.buf.next_item()?; + match self.zonefile.buf.cat { + ItemCat::None => Ok(ScannedEntry::Eof), + ItemCat::LineFeed => Ok(ScannedEntry::Empty), + ItemCat::Unquoted | ItemCat::Quoted => { + if self.zonefile.buf.has_space { + // Indented entry: a record with the last owner as the + // owner. + self.scan_owner_record( + match self.zonefile.last_owner.as_ref() { + Some(owner) => owner.clone(), + None => { + return Err(EntryError::missing_last_owner()) + } + }, + false, + ) + } else if self.zonefile.buf.peek_symbol() + == Some(Symbol::Char('$')) + { + self.scan_control() + } else if self.zonefile.buf.skip_at_token()? { + self.scan_at_record() + } else { + self.scan_record() + } + } + } + } + + /// Scans a regular record. + fn scan_record(&mut self) -> Result { + let owner = ScannedDname::scan(self)?; + self.scan_owner_record(owner, true) + } + + /// Scans a regular record with an owner name of `@`. + fn scan_at_record(&mut self) -> Result { + let owner = RelativeDname::empty_bytes() + .chain(match self.zonefile.origin.as_ref().cloned() { + Some(origin) => origin, + None => return Err(EntryError::missing_origin()), + }) + .unwrap(); // Chaining an empty name will always work. + self.scan_owner_record(owner, true) + } + + /// Scans a regular record with an explicit owner name. + fn scan_owner_record( + &mut self, + owner: ScannedDname, + new_owner: bool, + ) -> Result { + let (class, qtype) = self.scan_qcr()?; + + if new_owner { + self.zonefile.last_owner = Some(owner.clone()); + } + + let class = match class { + Some(class) => { + self.zonefile.last_class = Some(class); + class + } + None => match self.zonefile.last_class { + Some(class) => class, + None => return Err(EntryError::missing_last_class()), + }, + }; + + self.zonefile.buf.require_line_feed()?; + + Ok(ScannedEntry::Entry(Entry::QueryRecord(Question::new( + owner, qtype, class, + )))) + } + + /// Scans the class, and type portions of a query record. + fn scan_qcr(&mut self) -> Result<(Option, Rtype), EntryError> { + // Possible options are: + // + // [] [] + // [] [] + + enum Ctr { + Class(Class), + Qtype(Rtype), + } + + let first = self.scan_ascii_str(|s| { + if let Ok(qtype) = Rtype::from_str(s) { + Ok(Ctr::Qtype(qtype)) + } else if let Ok(class) = Class::from_str(s) { + Ok(Ctr::Class(class)) + } else { + Err(EntryError::expected_qtype()) + } + })?; + + match first { + Ctr::Class(class) => { + // We have a class. Now there may be a qtype. + let qtype = self.scan_ascii_str(|s| { + if let Ok(qtype) = Rtype::from_str(s) { + Ok(qtype) + } else { + Err(EntryError::expected_qtype()) + } + })?; + + Ok((Some(class), qtype)) + } + Ctr::Qtype(qtype) => Ok((None, qtype)), + } + } + + /// Scans a control directive. + fn scan_control(&mut self) -> Result { + let ctrl = self.scan_string()?; + if ctrl.eq_ignore_ascii_case("$ORIGIN") { + let origin = self.scan_dname()?.to_dname().unwrap(); + self.zonefile.buf.require_line_feed()?; + Ok(ScannedEntry::Origin(origin)) + } else if ctrl.eq_ignore_ascii_case("$INCLUDE") { + let path = self.scan_string()?; + let origin = if !self.zonefile.buf.is_line_feed() { + Some(self.scan_dname()?.to_dname().unwrap()) + } else { + None + }; + self.zonefile.buf.require_line_feed()?; + Ok(ScannedEntry::Entry(Entry::Include { path, origin })) + } else if ctrl.eq_ignore_ascii_case("$TTL") { + let ttl = u32::scan(self)?; + self.zonefile.buf.require_line_feed()?; + Ok(ScannedEntry::Ttl(Ttl::from_secs(ttl))) + } else { + Err(EntryError::unknown_control()) + } + } +} + +impl<'a> Scanner for EntryScanner<'a> { + type Octets = Bytes; + type OctetsBuilder = BytesMut; + type Dname = ScannedDname; + type Error = EntryError; + + fn has_space(&self) -> bool { + self.zonefile.buf.has_space + } + + fn continues(&mut self) -> bool { + !matches!(self.zonefile.buf.cat, ItemCat::None | ItemCat::LineFeed) + } + + fn scan_symbols(&mut self, mut op: F) -> Result<(), Self::Error> + where + F: FnMut(Symbol) -> Result<(), Self::Error>, + { + self.zonefile.buf.require_token()?; + while let Some(sym) = self.zonefile.buf.next_symbol()? { + op(sym)?; + } + self.zonefile.buf.next_item() + } + + fn scan_entry_symbols(&mut self, mut op: F) -> Result<(), Self::Error> + where + F: FnMut(EntrySymbol) -> Result<(), Self::Error>, + { + loop { + self.zonefile.buf.require_token()?; + while let Some(sym) = self.zonefile.buf.next_symbol()? { + op(sym.into())?; + } + op(EntrySymbol::EndOfToken)?; + self.zonefile.buf.next_item()?; + if self.zonefile.buf.is_line_feed() { + break; + } + } + Ok(()) + } + + fn convert_token>( + &mut self, + mut convert: C, + ) -> Result { + let mut write = 0; + let mut builder = None; + self.convert_one_token(&mut convert, &mut write, &mut builder)?; + if let Some(data) = convert.process_tail()? { + self.append_data(data, &mut write, &mut builder); + } + match builder { + Some(builder) => Ok(builder.freeze()), + None => Ok(self.zonefile.buf.split_to(write).freeze()), + } + } + + fn convert_entry>( + &mut self, + mut convert: C, + ) -> Result { + let mut write = 0; + let mut builder = None; + loop { + self.convert_one_token(&mut convert, &mut write, &mut builder)?; + if self.zonefile.buf.is_line_feed() { + break; + } + } + if let Some(data) = convert.process_tail()? { + self.append_data(data, &mut write, &mut builder); + } + match builder { + Some(builder) => Ok(builder.freeze()), + None => Ok(self.zonefile.buf.split_to(write).freeze()), + } + } + + fn scan_octets(&mut self) -> Result { + self.zonefile.buf.require_token()?; + + // The result will never be longer than the encoded form, so we can + // trim off everything to the left already. + self.zonefile.buf.trim_to(self.zonefile.buf.start); + + // Skip over symbols that don’t need converting at the beginning. + while self.zonefile.buf.next_ascii_symbol()?.is_some() {} + + // If we aren’t done yet, we have escaped characters to replace. + let mut write = self.zonefile.buf.start; + while let Some(sym) = self.zonefile.buf.next_symbol()? { + self.zonefile.buf.buf[write] = sym.into_octet()?; + write += 1; + } + + // Done. `write` marks the end. + self.zonefile.buf.next_item()?; + Ok(self.zonefile.buf.split_to(write).freeze()) + } + + fn scan_ascii_str(&mut self, op: F) -> Result + where + F: FnOnce(&str) -> Result, + { + self.zonefile.buf.require_token()?; + + // The result will never be longer than the encoded form, so we can + // trim off everything to the left already. + self.zonefile.buf.trim_to(self.zonefile.buf.start); + let mut write = 0; + + // Skip over symbols that don’t need converting at the beginning. + while self.zonefile.buf.next_ascii_symbol()?.is_some() { + write += 1; + } + + // If we not reached the end of the token, we have escaped characters + // to replace. + if !matches!(self.zonefile.buf.cat, ItemCat::None) { + while let Some(sym) = self.zonefile.buf.next_symbol()? { + self.zonefile.buf.buf[write] = sym.into_ascii()?; + write += 1; + } + } + + // Done. `write` marks the end. Process via op and return. + let res = op(unsafe { + str::from_utf8_unchecked(&self.zonefile.buf.buf[..write]) + })?; + self.zonefile.buf.next_item()?; + Ok(res) + } + + fn scan_dname(&mut self) -> Result { + // Because the labels in a domain name have their content preceeded + // by the length octet, an unescaped domain name can be almost as is + // if we have one extra octet to the left. Luckily, we always do + // (SourceBuf makes sure of it). + self.zonefile.buf.require_token()?; + + // Let’s prepare everything. We cut off the bits we don’t need with + // the result that the buffer’s start will be 1 and we set `write` + // to be 0, i.e., the start of the buffer. This also means that write + // will contain the length of the domain name assembled so far, so we + // can easily check if it has gotten too long. + assert!(self.zonefile.buf.start > 0, "missing token prefix space"); + self.zonefile.buf.trim_to(self.zonefile.buf.start - 1); + let mut write = 0; + + // Now convert label by label. + loop { + let start = write; + match self.convert_label(&mut write)? { + None => { + // End of token right after a dot, so this is an absolute + // name. Unless we have not done anything yet, then we + // have an empty domain name which is just the origin. + self.zonefile.buf.next_item()?; + if start == 0 { + return RelativeDname::empty_bytes() + .chain(self.zonefile.get_origin()?) + .map_err(|_| EntryError::bad_dname()); + } else { + return unsafe { + RelativeDname::from_octets_unchecked( + self.zonefile.buf.split_to(write).freeze(), + ) + .chain(Dname::root()) + .map_err(|_| EntryError::bad_dname()) + }; + } + } + Some(true) => { + // Last symbol was a dot: check length and continue. + if write > 254 { + return Err(EntryError::bad_dname()); + } + } + Some(false) => { + // Reached end of token. This means we have a relative + // dname. + self.zonefile.buf.next_item()?; + return unsafe { + RelativeDname::from_octets_unchecked( + self.zonefile.buf.split_to(write).freeze(), + ) + .chain(self.zonefile.get_origin()?) + .map_err(|_| EntryError::bad_dname()) + }; + } + } + } + } + + fn scan_charstr(&mut self) -> Result, Self::Error> { + self.scan_octets().and_then(|octets| { + CharStr::from_octets(octets) + .map_err(|_| EntryError::bad_charstr()) + }) + } + + fn scan_string(&mut self) -> Result, Self::Error> { + self.zonefile.buf.require_token()?; + + // The result will never be longer than the encoded form, so we can + // trim off everything to the left already. + self.zonefile.buf.trim_to(self.zonefile.buf.start); + + // Skip over symbols that don’t need converting at the beginning. + while self.zonefile.buf.next_char_symbol()?.is_some() {} + + // If we aren’t done yet, we have escaped characters to replace. + let mut write = self.zonefile.buf.start; + while let Some(sym) = self.zonefile.buf.next_symbol()? { + write += sym + .into_char()? + .encode_utf8( + &mut self.zonefile.buf.buf + [write..self.zonefile.buf.start], + ) + .len(); + } + + // Done. `write` marks the end. + self.zonefile.buf.next_item()?; + Ok(unsafe { + Str::from_utf8_unchecked( + self.zonefile.buf.split_to(write).freeze(), + ) + }) + } + + fn scan_charstr_entry(&mut self) -> Result { + // Because char-strings are never longer than their representation + // format, we can definitely do this in place. Specifically, we move + // the content around in such a way that by the end we have the result + // in the space of buf before buf.start. + + // Reminder: char-string are one length byte followed by that many + // content bytes. We use the byte just before self.read as the length + // byte of the first char-string. This way, if there is only one and + // it isn’t escaped, we don’t need to move anything at all. + + // Let’s prepare everything. We cut off the bits we don’t need with + // the result that the buffer’s start will be 1 and we set `write` + // to be 0, i.e., the start of the buffer. This also means that write + // will contain the length of the domain name assembled so far, so we + // can easily check if it has gotten too long. + assert!(self.zonefile.buf.start > 0, "missing token prefix space"); + self.zonefile.buf.trim_to(self.zonefile.buf.start - 1); + let mut write = 0; + + // Now convert token by token. + loop { + self.convert_charstr(&mut write)?; + if self.zonefile.buf.is_line_feed() { + break; + } + } + + Ok(self.zonefile.buf.split_to(write).freeze()) + } + + fn scan_opt_unknown_marker(&mut self) -> Result { + self.zonefile.buf.skip_unknown_marker() + } + + fn octets_builder(&mut self) -> Result { + Ok(BytesMut::new()) + } +} + +impl<'a> EntryScanner<'a> { + /// Converts a single token using a token converter. + fn convert_one_token< + S: From, + C: ConvertSymbols, + >( + &mut self, + convert: &mut C, + write: &mut usize, + builder: &mut Option, + ) -> Result<(), EntryError> { + self.zonefile.buf.require_token()?; + while let Some(sym) = self.zonefile.buf.next_symbol()? { + if let Some(data) = convert.process_symbol(sym.into())? { + self.append_data(data, write, builder); + } + } + self.zonefile.buf.next_item() + } + + /// Appends output data. + /// + /// If the data fits into the portion of the buffer before the current + /// read positiion, puts it there. Otherwise creates a new builder. If + /// it created a new builder or if one was passed in via `builder`, + /// appends the data to that. + fn append_data( + &mut self, + data: &[u8], + write: &mut usize, + builder: &mut Option, + ) { + if let Some(builder) = builder.as_mut() { + builder.extend_from_slice(data); + return; + } + + let new_write = *write + data.len(); + if new_write > self.zonefile.buf.start { + let mut new_builder = BytesMut::with_capacity(new_write); + new_builder.extend_from_slice(&self.zonefile.buf.buf[..*write]); + new_builder.extend_from_slice(data); + *builder = Some(new_builder); + } else { + self.zonefile.buf.buf[*write..new_write].copy_from_slice(data); + *write = new_write; + } + } + + /// Converts a single label of a domain name. + /// + /// The next symbol of the buffer should be the first symbol of the + /// label’s content. The method reads symbols from the buffer and + /// constructs a single label complete with length octets starting at + /// `write`. + /// + /// If it reaches the end of the token before making a label, returns + /// `None`. Otherwise returns whether it encountered a dot at the end of + /// the label. I.e., `Some(true)` means a dot was read as the last symbol + /// and `Some(false)` means the end of token was encountered right after + /// the label. + fn convert_label( + &mut self, + write: &mut usize, + ) -> Result, EntryError> { + let start = *write; + *write += 1; + let latest = *write + 64; // If write goes here, the label is too long + if *write == self.zonefile.buf.start { + // Reading and writing position is equal, so we don’t need to + // convert char symbols. Read char symbols until the end of label + // or an escape sequence. + loop { + match self.zonefile.buf.next_ascii_symbol()? { + Some(b'.') => { + // We found an unescaped dot, ie., end of label. + // Update the length octet and return. + self.zonefile.buf.buf[start] = + (*write - start - 1) as u8; + return Ok(Some(true)); + } + Some(_) => { + // A char symbol. Just increase the write index. + *write += 1; + if *write >= latest { + return Err(EntryError::bad_dname()); + } + } + None => { + // Either we got an escape sequence or we reached the + // end of the token. Break out of the loop and decide + // below. + break; + } + } + } + } + + // Now we need to process the label with potential escape sequences. + loop { + match self.zonefile.buf.next_symbol()? { + None => { + // We reached the end of the token. + if *write > start + 1 { + self.zonefile.buf.buf[start] = + (*write - start - 1) as u8; + return Ok(Some(false)); + } else { + return Ok(None); + } + } + Some(Symbol::Char('.')) => { + // We found an unescaped dot, ie., end of label. + // Update the length octet and return. + self.zonefile.buf.buf[start] = (*write - start - 1) as u8; + return Ok(Some(true)); + } + Some(sym) => { + // Any other symbol: Decode it and proceed to the next + // route. + self.zonefile.buf.buf[*write] = sym.into_octet()?; + *write += 1; + if *write >= latest { + return Err(EntryError::bad_dname()); + } + } + } + } + } + + /// Converts a character string. + fn convert_charstr( + &mut self, + write: &mut usize, + ) -> Result<(), EntryError> { + let start = *write; + *write += 1; + let latest = *write + 255; // If write goes here, charstr is too long + if *write == self.zonefile.buf.start { + // Reading and writing position is equal, so we don’t need to + // convert char symbols. Read char symbols until the end of label + // or an escape sequence. + while self.zonefile.buf.next_ascii_symbol()?.is_some() { + *write += 1; + if *write >= latest { + return Err(EntryError::bad_charstr()); + } + } + } + + // Now we need to process the charstr with potential escape sequences. + loop { + match self.zonefile.buf.next_symbol()? { + None => { + self.zonefile.buf.next_item()?; + self.zonefile.buf.buf[start] = (*write - start - 1) as u8; + return Ok(()); + } + Some(sym) => { + self.zonefile.buf.buf[*write] = sym.into_octet()?; + *write += 1; + if *write >= latest { + return Err(EntryError::bad_charstr()); + } + } + } + } + } +} + +//------------ SourceBuf ----------------------------------------------------- + +/// The buffer to read data from and also into if possible. +#[derive(Clone, Debug)] +struct SourceBuf { + /// The underlying ‘real’ buffer. + /// + /// This buffer contains the data we still need to process. This contains + /// the white space and other octets just before the start of the next + /// token as well since that can be used as extra space for in-place + /// manipulations. + buf: BytesMut, + + /// Where in `buf` is the next symbol to read. + start: usize, + + /// The category of the current item. + cat: ItemCat, + + /// Is the token preceeded by white space? + has_space: bool, + + /// How many unclosed opening parentheses did we see at `start`? + parens: usize, + + /// The line number of the current line. + line_num: usize, + + /// The position of the first character of the current line. + /// + /// This may be negative if we cut off bits of the current line. + line_start: isize, +} + +impl SourceBuf { + /// Create a new empty buffer. + /// + /// Assumes that `buf` is empty. Adds a single byte to the buffer which + /// we would need for parsing if the first token is a domain name. + fn with_empty_buf(mut buf: BytesMut) -> Self { + buf.put_u8(0); + SourceBuf { + buf, + start: 1, + cat: ItemCat::None, + has_space: false, + parens: 0, + line_num: 1, + line_start: 1, + } + } + + /// Enriches an entry error with position information. + fn error(&self, err: EntryError) -> Error { + Error { + err, + line: self.line_num, + col: ((self.start as isize) + 1 - self.line_start) as usize, + } + } + + /// Checks whether the current item is a token. + fn require_token(&self) -> Result<(), EntryError> { + match self.cat { + ItemCat::None => Err(EntryError::short_buf()), + ItemCat::LineFeed => Err(EntryError::end_of_entry()), + ItemCat::Quoted | ItemCat::Unquoted => Ok(()), + } + } + + /// Returns whether the current item is a line feed. + fn is_line_feed(&self) -> bool { + matches!(self.cat, ItemCat::LineFeed) + } + + /// Requires that we have reached a line feed. + fn require_line_feed(&self) -> Result<(), EntryError> { + if self.is_line_feed() { + Ok(()) + } else { + Err(EntryError::trailing_tokens()) + } + } + + /// Returns the next symbol but doesn’t advance the buffer. + /// + /// Returns `None` if the current item is a line feed or end-of-file + /// or if we have reached the end of token or if it is not a valid symbol. + fn peek_symbol(&self) -> Option { + match self.cat { + ItemCat::None | ItemCat::LineFeed => None, + ItemCat::Unquoted => { + let sym = + match Symbol::from_slice_index(&self.buf, self.start) { + Ok(Some((sym, _))) => sym, + Ok(None) | Err(_) => return None, + }; + + if sym.is_word_char() { + Some(sym) + } else { + None + } + } + ItemCat::Quoted => { + let sym = + match Symbol::from_slice_index(&self.buf, self.start) { + Ok(Some((sym, _))) => sym, + Ok(None) | Err(_) => return None, + }; + + if sym == Symbol::Char('"') { + None + } else { + Some(sym) + } + } + } + } + + /// Skips over the current token if it contains only an `@` symbol. + /// + /// Returns whether it did skip the token. + fn skip_at_token(&mut self) -> Result { + if self.peek_symbol() != Some(Symbol::Char('@')) { + return Ok(false); + } + + let (sym, sym_end) = + match Symbol::from_slice_index(&self.buf, self.start + 1) { + Ok(Some((sym, sym_end))) => (sym, sym_end), + Ok(None) => return Err(EntryError::short_buf()), + Err(err) => return Err(EntryError::bad_symbol(err)), + }; + + match self.cat { + ItemCat::None | ItemCat::LineFeed => unreachable!(), + ItemCat::Unquoted => { + if !sym.is_word_char() { + self.start += 1; + self.cat = ItemCat::None; + self.next_item()?; + Ok(true) + } else { + Ok(false) + } + } + ItemCat::Quoted => { + if sym == Symbol::Char('"') { + self.start = sym_end; + self.cat = ItemCat::None; + self.next_item()?; + Ok(true) + } else { + Ok(false) + } + } + } + } + + /// Skips over the unknown marker token. + /// + /// Returns whether it didskip the token. + fn skip_unknown_marker(&mut self) -> Result { + if !matches!(self.cat, ItemCat::Unquoted) { + return Ok(false); + } + + let (sym, sym_end) = + match Symbol::from_slice_index(&self.buf, self.start) { + Ok(Some(some)) => some, + _ => return Ok(false), + }; + + if sym != Symbol::SimpleEscape(b'#') { + return Ok(false); + } + + let (sym, sym_end) = + match Symbol::from_slice_index(&self.buf, sym_end) { + Ok(Some(some)) => some, + _ => return Ok(false), + }; + if sym.is_word_char() { + return Ok(false); + } + + self.start = sym_end; + self.cat = ItemCat::None; + self.next_item()?; + Ok(true) + } + + /// Returns the next symbol of the current token. + /// + /// Returns `None` if the current item is a line feed or end-of-file + /// or if we have reached the end of token. + /// + /// If it returns `Some(_)`, advances `self.start` to the start of the + /// next symbol. + fn next_symbol(&mut self) -> Result, EntryError> { + self._next_symbol(|sym| Ok(Some(sym))) + } + + /// Returns the next symbol if it is an unescaped ASCII symbol. + /// + /// Returns `None` if the symbol is escaped or not a printable ASCII + /// character or `self.next_symbol` would return `None`. + /// + /// If it returns `Some(_)`, advances `self.start` to the start of the + /// next symbol. + #[allow(clippy::manual_range_contains)] // Hard disagree. + fn next_ascii_symbol(&mut self) -> Result, EntryError> { + if matches!(self.cat, ItemCat::None | ItemCat::LineFeed) { + return Ok(None); + } + + let ch = match self.buf.get(self.start) { + Some(ch) => *ch, + None => return Ok(None), + }; + + match self.cat { + ItemCat::Unquoted => { + if ch < 0x21 + || ch > 0x7F + || ch == b'"' + || ch == b'(' + || ch == b')' + || ch == b';' + || ch == b'\\' + { + return Ok(None); + } + } + ItemCat::Quoted => { + if ch == b'"' { + self.start += 1; + self.cat = ItemCat::None; + return Ok(None); + } else if ch < 0x21 || ch > 0x7F || ch == b'\\' { + return Ok(None); + } + } + _ => unreachable!(), + } + self.start += 1; + Ok(Some(ch)) + } + + /// Returns the next symbol if it is unescaped. + /// + /// Returns `None` if the symbol is escaped or `self.next_symbol` would + /// return `None`. + /// + /// If it returns `Some(_)`, advances `self.start` to the start of the + /// next symbol. + fn next_char_symbol(&mut self) -> Result, EntryError> { + self._next_symbol(|sym| { + if let Symbol::Char(ch) = sym { + Ok(Some(ch)) + } else { + Ok(None) + } + }) + } + + /// Internal helper for `next_symbol` and friends. + /// + /// This only exists so we don’t have to copy and paste the fiddely part + /// of the logic. It behaves like `next_symbol` but provides an option + /// for the called to decide whether they want the symbol or not. + #[inline] + fn _next_symbol(&mut self, want: F) -> Result, EntryError> + where + F: Fn(Symbol) -> Result, EntryError>, + { + match self.cat { + ItemCat::None | ItemCat::LineFeed => Ok(None), + ItemCat::Unquoted => { + let (sym, sym_end) = + match Symbol::from_slice_index(&self.buf, self.start) { + Ok(Some((sym, sym_end))) => (sym, sym_end), + Ok(None) => return Err(EntryError::short_buf()), + Err(err) => return Err(EntryError::bad_symbol(err)), + }; + + if !sym.is_word_char() { + self.cat = ItemCat::None; + Ok(None) + } else { + match want(sym)? { + Some(some) => { + self.start = sym_end; + Ok(Some(some)) + } + None => Ok(None), + } + } + } + ItemCat::Quoted => { + let (sym, sym_end) = + match Symbol::from_slice_index(&self.buf, self.start) { + Ok(Some((sym, sym_end))) => (sym, sym_end), + Ok(None) => return Err(EntryError::short_buf()), + Err(err) => return Err(EntryError::bad_symbol(err)), + }; + + let res = match want(sym)? { + Some(some) => some, + None => return Ok(None), + }; + + if sym == Symbol::Char('"') { + self.start = sym_end; + self.cat = ItemCat::None; + Ok(None) + } else { + self.start = sym_end; + if sym == Symbol::Char('\n') { + self.line_num += 1; + self.line_start = self.start as isize; + } + Ok(Some(res)) + } + } + } + } + + /// Prepares the next item. + /// + /// # Panics + /// + /// This method must only ever by called if the current item is + /// not a token or if the current token has been read all the way to the + /// end. The latter is true if [`Self::next_symbol`] has returned + /// `Ok(None)` at least once. + /// + /// If the current item is a token and has not been read all the way to + /// the end, the method will panic to maintain consistency of the data. + fn next_item(&mut self) -> Result<(), EntryError> { + assert!( + matches!(self.cat, ItemCat::None | ItemCat::LineFeed), + "token not completely read ({:?} at {}:{})", + self.cat, + self.line_num, + ((self.start as isize) + 1 - self.line_start) as usize, + ); + + self.has_space = false; + + loop { + let ch = match self.buf.get(self.start) { + Some(&ch) => ch, + None => { + self.cat = ItemCat::None; + return Ok(()); + } + }; + + // Skip and mark actual white space. + if matches!(ch, b' ' | b'\t' | b'\r') { + self.has_space = true; + self.start += 1; + } + // CR: ignore for compatibility with Windows-style line endings. + else if ch == b'\r' { + self.start += 1; + } + // Opening parenthesis: increase group level. + else if ch == b'(' { + self.parens += 1; + self.start += 1; + } + // Closing parenthesis: decrease group level or error out. + else if ch == b')' { + if self.parens > 0 { + self.parens -= 1; + self.start += 1; + } else { + return Err(EntryError::unbalanced_parens()); + } + } + // Semicolon: comment -- skip to line end. + else if ch == b';' { + self.start += 1; + while let Some(true) = + self.buf.get(self.start).map(|ch| *ch != b'\n') + { + self.start += 1; + } + // Next iteration deals with the LF. + } + // Line end: skip over it. Ignore if we are inside a paren group. + else if ch == b'\n' { + self.start += 1; + self.line_num += 1; + self.line_start = self.start as isize; + if self.parens == 0 { + self.cat = ItemCat::LineFeed; + break; + } + } + // Double quote: quoted token + else if ch == b'"' { + self.start += 1; + self.cat = ItemCat::Quoted; + break; + } + // Else: unquoted token + else { + self.cat = ItemCat::Unquoted; + break; + } + } + Ok(()) + } + + /// Splits off the beginning of the buffer up to the given index. + /// + /// # Panics + /// + /// The method panics if `at` is greater than `self.start`. + fn split_to(&mut self, at: usize) -> BytesMut { + assert!(at <= self.start); + let res = self.buf.split_to(at); + self.start -= at; + self.line_start -= at as isize; + res + } + + /// Splits off the beginning of the buffer but doesn’t return it. + /// + /// # Panics + /// + /// The method panics if `at` is greater than `self.start`. + fn trim_to(&mut self, at: usize) { + assert!(at <= self.start); + self.buf.advance(at); + self.start -= at; + self.line_start -= at as isize; + } +} + +//------------ ItemCat ------------------------------------------------------- + +/// The category of the current item in a source buffer. +#[allow(dead_code)] // XXX +#[derive(Clone, Copy, Debug)] +enum ItemCat { + /// We don’t currently have an item. + /// + /// This is used to indicate that we have reached the end of a token or + /// that we have reached the end of the buffer. + // + // XXX: We might need a separate category for EOF. But let’s see if we + // can get away with mixing this up, first. + None, + + /// An unquoted normal token. + /// + /// This is a token that did not start with a double quote and will end + /// at the next white space. + Unquoted, + + /// A quoted normal token. + /// + /// This is a token that did start with a double quote and will end at + /// the next unescaped double quote. + /// + /// Note that the start position of the buffer indicates the first + /// character that is part of the content, i.e., the position right after + /// the opening double quote. + Quoted, + + /// A line feed. + /// + /// This is an empty token. The start position is right after the actual + /// line feed. + LineFeed, +} + +//------------ EntryError ---------------------------------------------------- + +/// An error returned by the entry scanner. +#[derive(Debug)] +struct EntryError(&'static str); + +impl EntryError { + fn bad_symbol(_err: SymbolOctetsError) -> Self { + EntryError("bad symbol") + } + + fn bad_charstr() -> Self { + EntryError("bad charstr") + } + + fn bad_dname() -> Self { + EntryError("bad dname") + } + + fn unbalanced_parens() -> Self { + EntryError("unbalanced parens") + } + + fn missing_last_owner() -> Self { + EntryError("missing last owner") + } + + fn missing_last_class() -> Self { + EntryError("missing last class") + } + + fn missing_origin() -> Self { + EntryError("missing origin") + } + + fn expected_qtype() -> Self { + EntryError("expected qtype") + } + + fn unknown_control() -> Self { + EntryError("unknown control") + } +} + +impl ScannerError for EntryError { + fn custom(msg: &'static str) -> Self { + EntryError(msg) + } + + fn end_of_entry() -> Self { + Self("unexpected end of entry") + } + + fn short_buf() -> Self { + Self("short buffer") + } + + fn trailing_tokens() -> Self { + Self("trailing tokens") + } +} + +impl From for EntryError { + fn from(_: SymbolOctetsError) -> Self { + EntryError("symbol octets error") + } +} + +impl From for EntryError { + fn from(_: BadSymbol) -> Self { + EntryError("bad symbol") + } +} + +impl fmt::Display for EntryError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(self.0.as_ref()) + } +} + +//#[cfg(feature = "std")] +impl std::error::Error for EntryError {} + +//------------ Error --------------------------------------------------------- + +#[derive(Debug)] +pub struct Error { + err: EntryError, + line: usize, + col: usize, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}:{}: {}", self.line, self.col, self.err) + } +} + +//#[cfg(feature = "std")] +impl std::error::Error for Error {} + +//============ Tests ========================================================= + +/* +#[cfg(test)] +#[cfg(feature = "std")] +mod test { + use super::*; + use std::vec::Vec; + + fn with_entry(s: &str, op: impl FnOnce(EntryScanner)) { + let mut zone = Zonefile::with_capacity(s.len()); + zone.extend_from_slice(s.as_bytes()); + let entry = EntryScanner::new(&mut zone).unwrap(); + entry.zonefile.buf.next_item().unwrap(); + op(entry) + } + + #[test] + fn scan_symbols() { + fn test(zone: &str, tok: impl AsRef<[u8]>) { + with_entry(zone, |mut entry| { + let mut tok = tok.as_ref(); + entry + .scan_symbols(|sym| { + let sym = sym.into_octet().unwrap(); + assert_eq!(sym, tok[0]); + tok = &tok[1..]; + Ok(()) + }) + .unwrap(); + }); + } + + test(" unquoted\n", b"unquoted"); + test(" unquoted ", b"unquoted"); + test("unquoted ", b"unquoted"); + test("unqu\\oted ", b"unquoted"); + test("unqu\\111ted ", b"unquoted"); + test(" \"quoted\"\n", b"quoted"); + test(" \"quoted\" ", b"quoted"); + test("\"quoted\" ", b"quoted"); + } + + #[derive(serde::Deserialize)] + #[allow(clippy::type_complexity)] + struct TestCase { + origin: Dname, + zonefile: std::string::String, + result: Vec, ZoneRecordData>>>, + } + + impl TestCase { + fn test(yaml: &str) { + let case = serde_yaml::from_str::(yaml).unwrap(); + let mut input = case.zonefile.as_bytes(); + let mut zone = Zonefile::load(&mut input).unwrap(); + zone.set_origin(case.origin); + let mut result = case.result.as_slice(); + while let Some(entry) = zone.next_entry().unwrap() { + match entry { + Entry::Record(record) => { + let (first, tail) = result.split_first().unwrap(); + assert_eq!(first, &record); + result = tail; + } + _ => panic!(), + } + } + } + } + + #[test] + fn test_data() { + TestCase::test(include_str!("../../test-data/zonefiles/basic.yaml")); + TestCase::test(include_str!("../../test-data/zonefiles/escape.yaml")); + TestCase::test(include_str!("../../test-data/zonefiles/unknown.yaml")); + } +} +*/ diff --git a/tests/net/deckard/server.rs b/tests/net/deckard/server.rs new file mode 100644 index 000000000..f8eecf760 --- /dev/null +++ b/tests/net/deckard/server.rs @@ -0,0 +1,133 @@ +use crate::net::deckard::client::CurrStepValue; +use crate::net::deckard::matches::match_msg; +use crate::net::deckard::parse_deckard; +use crate::net::deckard::parse_deckard::{Adjust, Deckard, Reply}; +use crate::net::deckard::parse_query; +use domain::base::iana::rcode::Rcode; +use domain::base::{Message, MessageBuilder}; +use domain::dep::octseq::Octets; +use domain::zonefile::inplace::Entry as ZonefileEntry; + +pub fn do_server<'a, Oct: Clone + Octets + 'a>( + msg: &'a Message, + deckard: &Deckard, + step_value: &CurrStepValue, +) -> Option>> +where + ::Range<'a>: Clone, +{ + let ranges = &deckard.scenario.ranges; + let step = step_value.get(); + for range in ranges { + if step < range.start_value || step > range.end_value { + continue; + } + for entry in &range.entry { + if !match_msg(entry, msg, false) { + continue; + } + let reply = do_adjust(entry, msg); + return Some(reply); + } + } + todo!(); +} + +fn do_adjust( + entry: &parse_deckard::Entry, + reqmsg: &Message, +) -> Message> { + let sections = entry.sections.as_ref().unwrap(); + let adjust: Adjust = match &entry.adjust { + Some(adjust) => adjust.clone(), + None => Default::default(), + }; + let mut msg = MessageBuilder::new_vec().question(); + if adjust.copy_query { + for q in reqmsg.question() { + msg.push(q.unwrap()).unwrap(); + } + } else { + for q in §ions.question { + let question = match q { + parse_query::Entry::QueryRecord(question) => question, + _ => todo!(), + }; + msg.push(question).unwrap(); + } + } + let mut msg = msg.answer(); + for a in §ions.answer { + let rec = if let ZonefileEntry::Record(record) = a { + record + } else { + panic!("include not expected") + }; + msg.push(rec).unwrap(); + } + let mut msg = msg.authority(); + for a in §ions.authority { + let rec = if let ZonefileEntry::Record(record) = a { + record + } else { + panic!("include not expected") + }; + msg.push(rec).unwrap(); + } + let mut msg = msg.additional(); + for _a in §ions.additional { + todo!(); + } + let reply: Reply = match &entry.reply { + Some(reply) => reply.clone(), + None => Default::default(), + }; + if reply.aa { + msg.header_mut().set_aa(true); + } + if reply.ad { + todo!() + } + if reply.cd { + todo!() + } + if reply.fl_do { + todo!() + } + if reply.formerr { + todo!() + } + if reply.noerror { + msg.header_mut().set_rcode(Rcode::NoError); + } + if reply.nxdomain { + todo!() + } + if reply.qr { + msg.header_mut().set_qr(true); + } + if reply.ra { + todo!() + } + if reply.rd { + msg.header_mut().set_rd(true); + } + if reply.refused { + todo!() + } + if reply.servfail { + todo!() + } + if reply.tc { + todo!() + } + if reply.yxdomain { + todo!() + } + if adjust.copy_id { + msg.header_mut().set_id(reqmsg.header().id()); + } else { + todo!(); + } + msg.into_message() +} diff --git a/tests/net/mod.rs b/tests/net/mod.rs new file mode 100644 index 000000000..4e7b62367 --- /dev/null +++ b/tests/net/mod.rs @@ -0,0 +1 @@ +pub mod deckard;