From 72da252cc91c8b780959b8f77f1f762d519b72cc Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 19 Dec 2022 16:38:51 +0100 Subject: [PATCH 001/124] First cut at standalone TCP transport --- Cargo.toml | 3 +- src/lib.rs | 1 + src/transport/mod.rs | 5 + src/transport/tcp.rs | 261 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 269 insertions(+), 1 deletion(-) create mode 100644 src/transport/mod.rs create mode 100644 src/transport/tcp.rs diff --git a/Cargo.toml b/Cargo.toml index 34ba716e2..e2955fa48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ ring = { version = "0.17", optional = true } serde = { version = "1.0.130", optional = true, features = ["derive"] } siphasher = { version = "1", optional = true } smallvec = { version = "1", optional = true } -tokio = { version = "1.0", optional = true, features = ["io-util", "macros", "net", "time"] } +tokio = { version = "1.0", optional = true, features = ["io-util", "macros", "net", "time", "sync" ] } [target.'cfg(macos)'.dependencies] # specifying this overrides minimum-version mio's 0.2.69 libc dependency, which allows the build to work @@ -47,6 +47,7 @@ serde = ["dep:serde", "octseq/serde"] sign = ["std"] smallvec = ["dep:smallvec", "octseq/smallvec"] std = [] +transport = ["bytes", "tokio"] tsig = ["bytes", "ring", "smallvec"] validate = ["std", "ring"] zonefile = ["bytes", "std"] diff --git a/src/lib.rs b/src/lib.rs index 6fe39406e..c08354145 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,6 +125,7 @@ pub mod rdata; pub mod resolv; pub mod sign; pub mod test; +pub mod transport; pub mod tsig; pub mod utils; pub mod validate; diff --git a/src/transport/mod.rs b/src/transport/mod.rs new file mode 100644 index 000000000..b56d82637 --- /dev/null +++ b/src/transport/mod.rs @@ -0,0 +1,5 @@ +//! DNS transport protocols +#![cfg(feature = "transport")] +#![cfg_attr(docsrs, doc(cfg(feature = "transport")))] + +pub mod tcp; diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs new file mode 100644 index 000000000..e2a9de8bc --- /dev/null +++ b/src/transport/tcp.rs @@ -0,0 +1,261 @@ +//! A DNS over TCP transport + +use std::sync::Arc; +use std::sync::Mutex as Std_mutex; +use std::vec::Vec; +use std::collections::VecDeque; +use bytes::{Bytes, BytesMut}; + +use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::base::octets; + +use tokio::io; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpStream, ToSocketAddrs}; +use tokio::net::tcp::{ReadHalf, WriteHalf}; +use tokio::sync::Notify; + +struct SingleQuery { + reply: Option, &'static str>>, + complete: Arc, +} + +struct Queries { + count: usize, + vec: Vec>, +} + +pub struct TcpConnection { + stream: Std_mutex, + + // Should deal with keepalive + + /* Vector with outstanding queries */ + query_vec: Std_mutex, + + /* Vector with outstanding requests that need to be transmitted */ + tx_queue: Std_mutex>>, + + worker_notify: Notify, +} + +// impl<'a, Octets: octets::OctetsBuilder + AsRef<[u8]> + AsMut<[u8]> + Clone + 'a> TcpConnection<'a> { +impl TcpConnection { + pub async fn connect(addr: A) -> + io::Result { + let tcp = TcpStream::connect(addr).await?; + Ok(Self { + stream: Std_mutex::new(tcp), + query_vec: Std_mutex::new(Queries { + count: 0, + vec: Vec::new() + }), + tx_queue: Std_mutex::new(VecDeque::new()), + worker_notify: Notify::new(), + }) + } + + fn insert_answer(&self, answer: Message) { + let ind16 = answer.header().id(); + let index: usize = ind16.into(); + + println!("Got ID {}", ind16); + + println!("Before query_vec.lock()"); + let mut query_vec = self.query_vec.lock().unwrap(); + + let vec_len = query_vec.vec.len(); + if index >= vec_len { + // Index is out of bouds. We should mark + // the TCP connection as broken + return; + } + + // Do we have a query with this ID? + match &mut query_vec.vec[index] { + None => { + // No query with this ID. We should + // mark the TCP connection as broken + return; + } + Some(query) => { + match &query.reply { + None => { + query.reply = + Some(Ok( + answer)); + query.complete. + notify_one(); + return; + } + _ => { + // Already got a + // result. + return; + } + } + } + } + } + + async fn reader(&self, sock: &mut ReadHalf<'_>) { + loop { + let len = sock.read_u16().await.unwrap() as usize; + + let mut buf = BytesMut::with_capacity(len); + + let reslen = sock.read_buf(&mut buf).await.unwrap(); + + let reply_message = Message::::from_octets(buf.into()); + if let Ok(answer) = reply_message { + self.insert_answer(answer); + + // else try with the next message. + } else { + panic!("Read error"); + //return Err(io::Error::new( + // io::ErrorKind::Other, + // "short buf", + //)); + } + } + } + + async fn writer(&self, sock: &mut WriteHalf<'_>) { + loop { + let mut tx_queue = self.tx_queue.lock().unwrap(); + let head = tx_queue.pop_front(); + drop(tx_queue); + match head { + Some(vec) => { + sock.write_all(&vec).await; + () + } + None => + break, + } + } + } + + pub async fn worker(&self) -> Option<()> { + let mut stream = self.stream.lock().unwrap(); + let (mut read_stream, mut write_stream) = stream.split(); + + let reader_fut = self.reader(&mut read_stream); + tokio::pin!(reader_fut); + + loop { + println!("in loop"); + let writer_fut = self.writer(&mut write_stream); + + tokio::select! { + read = &mut reader_fut => { + panic!("reader terminated"); + } + write = writer_fut => { + // The writer is done. Wait + // for a notify + () + } + } + + println!("Waiting for work"); + let notify_fut = self.worker_notify.notified(); + + tokio::select! { + read = &mut reader_fut => { + panic!("reader terminated"); + } + notify = notify_fut => { + // Got notified, start writing + println!("Got work"); + () + } + } + } + } + + // Insert a message in the query vector. Return the index + fn insert(&self) + -> usize { + let q = Some(SingleQuery { + reply: None, + complete: Arc::new(Notify::new()), + }); + let mut query_vec = self.query_vec.lock().unwrap(); + let vec_len = query_vec.vec.len(); + if vec_len < 2*(query_vec.count+1) { + // Just append + query_vec.vec.push(q); + query_vec.count = query_vec.count + 1; + let index = query_vec.vec.len()-1; + return index; + } + panic!("Sould insert"); + 0 + } + + fn queue_query>(&self, + msg: &MessageBuilder>>) { + + let query_vec = self.query_vec.lock().unwrap(); + let vec = msg.as_target().as_target().as_stream_slice(); + + // Store a close of the request. That makes life easier + // and requests tend to be small + let mut tx_queue = self.tx_queue.lock().unwrap(); + // self.tx_queue.push_back(vec.to_vec()); + tx_queue.push_back(vec.to_vec()); + } + + pub async fn query + + AsMut<[u8]>>(&self, + query_msg: &mut MessageBuilder>>) -> Result, &'static str> { + let index = self.insert(); + let ind16: u16 = index.try_into().unwrap(); + + let hdr = query_msg.header_mut(); + hdr.set_id(ind16); + + self.queue_query(query_msg); + + // Now kick the worker to transmit the query + self.worker_notify.notify_one(); + + // Wait for reply + println!("Waiting for reply"); + let mut query_vec = self.query_vec.lock().unwrap(); + let local_notify = query_vec.vec[index].as_mut().unwrap(). + complete.clone(); + drop(query_vec); + local_notify.notified().await; + println!("Got reply"); + + // Get the lock again to take a look + let mut query_vec = self.query_vec.lock().unwrap(); + let opt_q = query_vec.vec[index].take(); + query_vec.count = query_vec.count - 1; + drop(query_vec); + + if let Some(q) = opt_q + { + if let Some(result) = q.reply + { + if let Ok(answer) = &result + { + if !answer.is_answer(&query_msg. + as_message()) { + // Wrong answer, try again? + panic!("wring answer"); + } + } + return result; + } + panic!("inconsistent state"); + } + + panic!("inconsistent state"); + } + +} From a9895eefc1724087ecd09997712d4ee79de8878a Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 2 Jan 2023 14:04:36 +0100 Subject: [PATCH 002/124] Added TODOs and a comment about predictable IDs --- src/transport/tcp.rs | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index e2a9de8bc..1ad0ac018 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -1,5 +1,21 @@ //! A DNS over TCP transport +// TODO: +// - errors +// - read errors +// - write errors +// - connect errors? Retry after connection refused? +// - server errors +// - ID out of range +// - ID not in use +// - reply for wrong query +// - separate Query object +// - timeouts +// - idle timeout +// - channel timeout +// - request timeout +// - create new TCP connection after end/failure of previous one + use std::sync::Arc; use std::sync::Mutex as Std_mutex; use std::vec::Vec; @@ -201,10 +217,9 @@ impl TcpConnection { let query_vec = self.query_vec.lock().unwrap(); let vec = msg.as_target().as_target().as_stream_slice(); - // Store a close of the request. That makes life easier + // Store a clone of the request. That makes life easier // and requests tend to be small let mut tx_queue = self.tx_queue.lock().unwrap(); - // self.tx_queue.push_back(vec.to_vec()); tx_queue.push_back(vec.to_vec()); } @@ -215,6 +230,10 @@ impl TcpConnection { let index = self.insert(); let ind16: u16 = index.try_into().unwrap(); + // We set the ID to the array index. Wouter recommends + // against this. He argues that defense in depth suggests + // that a random ID is better because it works even if + // TCP sequence numbers could be predicted. let hdr = query_msg.header_mut(); hdr.set_id(ind16); @@ -247,7 +266,7 @@ impl TcpConnection { if !answer.is_answer(&query_msg. as_message()) { // Wrong answer, try again? - panic!("wring answer"); + panic!("wrong answer"); } } return result; From 2c02b389be7d97b353e710f1cdb45b1878262a5e Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 2 Jan 2023 14:12:55 +0100 Subject: [PATCH 003/124] Improved text about random IDs --- src/transport/tcp.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index 1ad0ac018..41508c69c 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -230,10 +230,13 @@ impl TcpConnection { let index = self.insert(); let ind16: u16 = index.try_into().unwrap(); - // We set the ID to the array index. Wouter recommends - // against this. He argues that defense in depth suggests - // that a random ID is better because it works even if - // TCP sequence numbers could be predicted. + // We set the ID to the array index. Defense in depth + // suggests that a random ID is better because it works + // even if TCP sequence numbers could be predicted. However, + // Section 9.3 of RFC 5452 recommends retrying over TCP + // if many spoofed answers arrive over UDP: "TCP, by the + // nature of its use of sequence numbers, is far more + // resilient against forgery by third parties." let hdr = query_msg.header_mut(); hdr.set_id(ind16); From 99fe4fcd3204c9967877f475d72e24d264221079 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 20 Feb 2023 13:41:01 +0100 Subject: [PATCH 004/124] Separate Query object to capture the lifetime of the query. --- src/transport/tcp.rs | 237 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 208 insertions(+), 29 deletions(-) diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index 41508c69c..3c3c12896 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -9,7 +9,6 @@ // - ID out of range // - ID not in use // - reply for wrong query -// - separate Query object // - timeouts // - idle timeout // - channel timeout @@ -23,7 +22,7 @@ use std::collections::VecDeque; use bytes::{Bytes, BytesMut}; use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; -use crate::base::octets; +use octseq::{Octets, OctetsBuilder}; use tokio::io; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -31,8 +30,14 @@ use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::sync::Notify; +enum SingleQueryState { + Busy, + Done(Result, &'static str>), + Canceled, +} + struct SingleQuery { - reply: Option, &'static str>>, + state: SingleQueryState, complete: Arc, } @@ -41,7 +46,7 @@ struct Queries { vec: Vec>, } -pub struct TcpConnection { +struct InnerTcpConnection { stream: Std_mutex, // Should deal with keepalive @@ -55,10 +60,24 @@ pub struct TcpConnection { worker_notify: Notify, } -// impl<'a, Octets: octets::OctetsBuilder + AsRef<[u8]> + AsMut<[u8]> + Clone + 'a> TcpConnection<'a> { -impl TcpConnection { +pub struct TcpConnection { + inner: Arc, +} + +enum QueryState { + Busy(usize), // index + Done, +} + +pub struct Query { + transport: Arc, + query_msg: Message>, + state: QueryState, +} + +impl InnerTcpConnection { pub async fn connect(addr: A) -> - io::Result { + io::Result { let tcp = TcpStream::connect(addr).await?; Ok(Self { stream: Std_mutex::new(tcp), @@ -75,9 +94,6 @@ impl TcpConnection { let ind16 = answer.header().id(); let index: usize = ind16.into(); - println!("Got ID {}", ind16); - - println!("Before query_vec.lock()"); let mut query_vec = self.query_vec.lock().unwrap(); let vec_len = query_vec.vec.len(); @@ -95,18 +111,31 @@ impl TcpConnection { return; } Some(query) => { - match &query.reply { - None => { - query.reply = - Some(Ok( + match &query.state { + SingleQueryState::Busy => { + query.state = + SingleQueryState:: + Done(Ok( answer)); query.complete. notify_one(); return; } - _ => { + SingleQueryState::Canceled => { + //`The query has been + // canceled already + // Clean up. + let _ = query_vec. + vec[index]. + take(); + query_vec.count = + query_vec. + count - 1; + return; + } + SingleQueryState::Done(_) => { // Already got a - // result. + // result. return; } } @@ -161,7 +190,6 @@ impl TcpConnection { tokio::pin!(reader_fut); loop { - println!("in loop"); let writer_fut = self.writer(&mut write_stream); tokio::select! { @@ -175,7 +203,6 @@ impl TcpConnection { } } - println!("Waiting for work"); let notify_fut = self.worker_notify.notified(); tokio::select! { @@ -184,7 +211,6 @@ impl TcpConnection { } notify = notify_fut => { // Got notified, start writing - println!("Got work"); () } } @@ -195,7 +221,7 @@ impl TcpConnection { fn insert(&self) -> usize { let q = Some(SingleQuery { - reply: None, + state: SingleQueryState::Busy, complete: Arc::new(Notify::new()), }); let mut query_vec = self.query_vec.lock().unwrap(); @@ -211,10 +237,9 @@ impl TcpConnection { 0 } - fn queue_query>(&self, - msg: &MessageBuilder>>) { + fn queue_query + AsRef<[u8]>> + (&self, msg: &MessageBuilder>>) { - let query_vec = self.query_vec.lock().unwrap(); let vec = msg.as_target().as_target().as_stream_slice(); // Store a clone of the request. That makes life easier @@ -223,10 +248,9 @@ impl TcpConnection { tx_queue.push_back(vec.to_vec()); } - pub async fn query + - AsMut<[u8]>>(&self, - query_msg: &mut MessageBuilder>>) -> Result, &'static str> { + pub async fn query + AsMut<[u8]>> + (&self, query_msg: &mut MessageBuilder>>) + -> Result, &'static str> { let index = self.insert(); let ind16: u16 = index.try_into().unwrap(); @@ -246,7 +270,6 @@ impl TcpConnection { self.worker_notify.notify_one(); // Wait for reply - println!("Waiting for reply"); let mut query_vec = self.query_vec.lock().unwrap(); let local_notify = query_vec.vec[index].as_mut().unwrap(). complete.clone(); @@ -262,7 +285,7 @@ impl TcpConnection { if let Some(q) = opt_q { - if let Some(result) = q.reply + if let SingleQueryState::Done(result) = q.state { if let Ok(answer) = &result { @@ -280,4 +303,160 @@ impl TcpConnection { panic!("inconsistent state"); } + pub fn query2 + AsRef<[u8]>> + (&self, + query_msg: &mut MessageBuilder>>) -> usize { + let index = self.insert(); + let ind16: u16 = index.try_into().unwrap(); + + // We set the ID to the array index. Defense in depth + // suggests that a random ID is better because it works + // even if TCP sequence numbers could be predicted. However, + // Section 9.3 of RFC 5452 recommends retrying over TCP + // if many spoofed answers arrive over UDP: "TCP, by the + // nature of its use of sequence numbers, is far more + // resilient against forgery by third parties." + let hdr = query_msg.header_mut(); + hdr.set_id(ind16); + + self.queue_query(query_msg); + + // Now kick the worker to transmit the query + self.worker_notify.notify_one(); + + index + } + + pub async fn get_result(&self, query_msg: &Message, + index: usize) -> Result, &'static str> { + // Wait for reply + let mut query_vec = self.query_vec.lock().unwrap(); + let local_notify = query_vec.vec[index].as_mut().unwrap(). + complete.clone(); + drop(query_vec); + local_notify.notified().await; + + // Get the lock again to take a look + let mut query_vec = self.query_vec.lock().unwrap(); + let opt_q = query_vec.vec[index].take(); + query_vec.count = query_vec.count - 1; + drop(query_vec); + + if let Some(q) = opt_q + { + if let SingleQueryState::Done(result) = q.state + { + if let Ok(answer) = &result + { + if !answer.is_answer(query_msg) { + // Wrong answer, try again? + panic!("wrong answer"); + } + } + return result; + } + panic!("inconsistent state"); + } + + panic!("inconsistent state"); + } + + fn cancel(&self, index: usize) { + let mut query_vec = self.query_vec.lock().unwrap(); + + match &mut query_vec.vec[index] { + None => { + panic!("Cancel called, but nothing to cancel"); + } + Some(query) => { + match &query.state { + SingleQueryState::Busy => { + query.state = + SingleQueryState::Canceled; + return; + } + SingleQueryState::Canceled => { + panic!("Already canceled"); + } + SingleQueryState::Done(_) => { + // Remove the result + let _ = query_vec. + vec[index].take(); + query_vec.count = + query_vec.count - 1; + drop(query_vec); + } + } + } + } + } +} + +impl TcpConnection { + pub async fn connect(addr: A) -> + io::Result { + let tcpconnection = InnerTcpConnection::connect(addr).await?; + Ok(Self { inner: Arc::new(tcpconnection) }) + } + pub async fn worker(&self) -> Option<()> { + self.inner.worker().await + } + pub async fn query + AsRef<[u8]>> + (&self, query_msg: &mut MessageBuilder>>) + -> Result, &'static str> { + self.inner.query(query_msg).await + } + pub fn query2 + AsRef<[u8]>> + (&self, query_msg: &mut MessageBuilder>>) + -> Result { + let index = self.inner.query2(query_msg); + let msg = &query_msg.as_message(); + Ok(Query::new(self, msg, index)) + } +} + + +impl Query { + fn new(transport: &TcpConnection, + query_msg: &Message, + index: usize) -> Query { + let msg_ref: &[u8] = query_msg.as_ref(); + let vec = msg_ref.to_vec(); + let msg = Message::from_octets(vec).unwrap(); + Self { + transport: transport.inner.clone(), + query_msg: msg, + state: QueryState::Busy(index) } + } + pub async fn get_result(&mut self) -> + Result, &'static str> { + // Just the result of get_result on tranport. We should record + // that we got an answer to avoid asking again + match self.state { + QueryState::Busy(index) => { + let result = self.transport.get_result( + &self.query_msg, index).await; + self.state = QueryState::Done; + result + } + QueryState::Done => { + panic!("Already done"); + } + } + } +} + +impl Drop for Query { + fn drop(&mut self) { + match self.state { + QueryState::Busy(index) => { + self.transport.cancel(index); + } + QueryState::Done => { + // Done, nothing to cancel + } + } + } } From eeace82f1e4b5d1d0872b77435be36b7fbfb5c5b Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 27 Feb 2023 14:45:13 +0100 Subject: [PATCH 005/124] Support for idle timeouts and edns-tcp-keepalive --- src/transport/tcp.rs | 433 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 344 insertions(+), 89 deletions(-) diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index 3c3c12896..16d441003 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -1,5 +1,8 @@ //! A DNS over TCP transport +// RFC 7766 describes DNS over TCP +// RFC 7828 describes the edns-tcp-keepalive option + // TODO: // - errors // - read errors @@ -10,18 +13,20 @@ // - ID not in use // - reply for wrong query // - timeouts -// - idle timeout // - channel timeout // - request timeout // - create new TCP connection after end/failure of previous one +use std::collections::VecDeque; use std::sync::Arc; use std::sync::Mutex as Std_mutex; +use std::time::{Duration, Instant}; use std::vec::Vec; -use std::collections::VecDeque; use bytes::{Bytes, BytesMut}; -use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::base::{Message, MessageBuilder, opt::{AllOptData, OptRecord, + TcpKeepalive}, StaticCompressor, StreamTarget}; +use crate::base::wire::Composer; use octseq::{Octets, OctetsBuilder}; use tokio::io; @@ -29,6 +34,12 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::sync::Notify; +use tokio::time::sleep; + +const ERR_IDLE_TIMEOUT: &str = "idle connection was closed"; + +// From RFC 7828. This should go somewhere with the option parsing +const EDNS_TCP_KEEPALIE_TO_MS: u64 = 100; enum SingleQueryState { Busy, @@ -42,14 +53,42 @@ struct SingleQuery { } struct Queries { + // Number of queries in the vector. The count of element that are + // not None count: usize, + + // Number of queries that are still waiting for an answer + busy: usize, + + // Index in the vector where to look for a space for a new query + curr: usize, + vec: Vec>, } +enum ConnState { + Active, + Idle(Instant), + IdleTimeout, +} + +struct Keepalive { + state: ConnState, + + // For edns-tcp-keepalive, we have a boolean the specifies if we + // need to send one (typically at the start of the connection). + // Initially we assume that the idle timeout is zero. A received + // edns-tcp-keepalive option may change that. What the best way to + // specify time in Rust? Currently we specify it in milliseconds. + send_keepalive: bool, + idle_timeout: Option, +} + struct InnerTcpConnection { stream: Std_mutex, - // Should deal with keepalive + /* keepalive */ + keepalive: Std_mutex, /* Vector with outstanding queries */ query_vec: Std_mutex, @@ -81,8 +120,15 @@ impl InnerTcpConnection { let tcp = TcpStream::connect(addr).await?; Ok(Self { stream: Std_mutex::new(tcp), + keepalive: Std_mutex::new(Keepalive { + state: ConnState::Active, + send_keepalive: true, + idle_timeout: None, + }), query_vec: Std_mutex::new(Queries { count: 0, + busy: 0, + curr: 0, vec: Vec::new() }), tx_queue: Std_mutex::new(VecDeque::new()), @@ -91,56 +137,97 @@ impl InnerTcpConnection { } fn insert_answer(&self, answer: Message) { - let ind16 = answer.header().id(); - let index: usize = ind16.into(); + let ind16 = answer.header().id(); + let index: usize = ind16.into(); - let mut query_vec = self.query_vec.lock().unwrap(); + let mut query_vec = self.query_vec.lock().unwrap(); - let vec_len = query_vec.vec.len(); - if index >= vec_len { - // Index is out of bouds. We should mark - // the TCP connection as broken + let vec_len = query_vec.vec.len(); + if index >= vec_len { + // Index is out of bouds. We should mark + // the TCP connection as broken + return; + } + + // Do we have a query with this ID? + match &mut query_vec.vec[index] { + None => { + // No query with this ID. We should + // mark the TCP connection as broken return; } - - // Do we have a query with this ID? - match &mut query_vec.vec[index] { - None => { - // No query with this ID. We should - // mark the TCP connection as broken - return; - } - Some(query) => { - match &query.state { - SingleQueryState::Busy => { - query.state = - SingleQueryState:: - Done(Ok( - answer)); - query.complete. - notify_one(); - return; - } - SingleQueryState::Canceled => { - //`The query has been - // canceled already - // Clean up. - let _ = query_vec. - vec[index]. - take(); - query_vec.count = - query_vec. - count - 1; - return; - } - SingleQueryState::Done(_) => { - // Already got a - // result. - return; - } + Some(query) => { + match &query.state { + SingleQueryState::Busy => { + query.state = + SingleQueryState:: + Done(Ok( + answer)); + query.complete. + notify_one(); + } + SingleQueryState::Canceled => { + //`The query has been + // canceled already + // Clean up. + let _ = query_vec. + vec[index]. + take(); + query_vec.count = + query_vec. + count - 1; + } + SingleQueryState::Done(_) => { + // Already got a + // result. + return; } } } + } + query_vec.busy = query_vec.busy-1; + if query_vec.busy == 0 { + let mut keepalive = self.keepalive.lock().unwrap(); + if keepalive.idle_timeout == None { + // Assume that we can just move to IdleTimeout + // state + keepalive.state = ConnState::IdleTimeout; + + // Notify the worker. Then the worker can + // close the tcp connection + self.worker_notify.notify_one(); + } + else { + keepalive.state = + ConnState::Idle(Instant::now()); + + // Notify the worker. The worker waits for + // the timeout to expire + self.worker_notify.notify_one(); + } + } + } + + fn handle_keepalive(&self, opt_value: TcpKeepalive) { + if let Some(value) = opt_value.timeout() { + let mut keepalive = self.keepalive.lock().unwrap(); + keepalive.idle_timeout = + Some(Duration::from_millis(u64::from(value) * + EDNS_TCP_KEEPALIE_TO_MS)); + } + } + + fn handle_opts> + (&self, opts: &OptRecord) { + for option in opts.iter() { + let opt = option.unwrap(); + match opt { + AllOptData::TcpKeepalive(tcpkeepalive) => { + self.handle_keepalive(tcpkeepalive); + } + _ => {} + } + } } async fn reader(&self, sock: &mut ReadHalf<'_>) { @@ -153,6 +240,11 @@ impl InnerTcpConnection { let reply_message = Message::::from_octets(buf.into()); if let Ok(answer) = reply_message { + // Check for a edns-tcp-keepalive option + let opt_record = answer.opt(); + if let Some(ref opts) = opt_record { + self.handle_opts(opts); + }; self.insert_answer(answer); // else try with the next message. @@ -192,6 +284,7 @@ impl InnerTcpConnection { loop { let writer_fut = self.writer(&mut write_stream); + println!("worker: before writer"); tokio::select! { read = &mut reader_fut => { panic!("reader terminated"); @@ -205,16 +298,99 @@ impl InnerTcpConnection { let notify_fut = self.worker_notify.notified(); - tokio::select! { - read = &mut reader_fut => { - panic!("reader terminated"); + println!("worker: before reader"); + let mut opt_timeout: Option = None; + let mut keepalive = self.keepalive.lock().unwrap(); + if let ConnState::Idle(instant) = keepalive.state { + if let Some(timeout) = keepalive.idle_timeout { + let elapsed = instant.elapsed(); + if elapsed >= timeout { + // Move to IdleTimeout and end + // the loop + keepalive.state = + ConnState::IdleTimeout; + break; + } + opt_timeout = Some(timeout - elapsed); } - notify = notify_fut => { - // Got notified, start writing - () + else { + panic!("Idle state but no timeout"); + } + } + drop(keepalive); + + + if let Some(timeout) = opt_timeout { + let sleep_fut = sleep(timeout); + + println!("sleeping for {:?}", timeout); + tokio::select! { + read = &mut reader_fut => { + panic!("reader terminated"); + } + _ = notify_fut => { + // Got notified, start writing + () + } + _ = sleep_fut => { + // Idle timeout expired, just + // continue with the loop + () + } + } + } else { + tokio::select! { + read = &mut reader_fut => { + panic!("reader terminated"); + } + _ = notify_fut => { + // Got notified, start writing + () + } + } + } + + println!("worker: got notified"); + + // Check if the connection is idle + let keepalive = self.keepalive.lock().unwrap(); + match keepalive.state { + ConnState::Active | ConnState::Idle(_) => { + // Keep going + } + ConnState::IdleTimeout => { + break } } + drop(keepalive); + } + + // Send FIN to server + println!("worker: sending FIN"); + write_stream.shutdown().await; + + // Stay around until the last query result is collected + loop { + println!("worker: checking query count"); + let query_vec = self.query_vec.lock().unwrap(); + if query_vec.count == 0 { + // We are done + break; + } + drop(query_vec); + + println!("waiting for last query to end"); + self.worker_notify.notified().await; } + None + } + + fn insert_at(query_vec: &mut Queries, index: usize, + q: Option) { + query_vec.vec[index] = q; + query_vec.count = query_vec.count + 1; + query_vec.busy = query_vec.busy + 1; + query_vec.curr = index + 1; } // Insert a message in the query vector. Return the index @@ -230,11 +406,44 @@ impl InnerTcpConnection { // Just append query_vec.vec.push(q); query_vec.count = query_vec.count + 1; + query_vec.busy = query_vec.busy + 1; let index = query_vec.vec.len()-1; return index; } - panic!("Sould insert"); - 0 + let loc_curr = query_vec.curr; + + for index in loc_curr..vec_len { + match query_vec.vec[index] { + Some(_) => { + // Already in use, just continue + () + } + None => { + Self::insert_at(&mut query_vec, + index, q); + return index; + } + } + } + + // Nothing until the end of the vector. Try for the entire + // vector + for index in 0..vec_len { + match query_vec.vec[index] { + Some(_) => { + // Already in use, just continue + () + } + None => { + Self::insert_at(&mut query_vec, + index, q); + return index; + } + } + } + + // Still nothing, that is not good + panic!("insert failed"); } fn queue_query + AsRef<[u8]>> @@ -303,10 +512,39 @@ impl InnerTcpConnection { panic!("inconsistent state"); } - pub fn query2 + AsRef<[u8]>> + pub fn query2 + AsRef<[u8]> + + Composer + Clone> (&self, - query_msg: &mut MessageBuilder>>) -> usize { + query_msg: &mut MessageBuilder>> + ) -> Result { + + // Check the state of the connection, fail if the connection is in + // IdleTimeout. If the connection is Idle, move it back to Active + // Also check for the need to send a keepalive + let mut keepalive = self.keepalive.lock().unwrap(); + match keepalive.state { + ConnState::Active => { + // Nothing to do + () + } + ConnState::Idle(_) => { + // Go back to active + keepalive.state = ConnState::Active; + () + } + ConnState::IdleTimeout => { + // The connection has been closed. Report error + return Err(ERR_IDLE_TIMEOUT); + } + } + + let mut do_keepalive = false; + if keepalive.send_keepalive { + do_keepalive = true; + keepalive.send_keepalive = false; + } + drop(keepalive); + let index = self.insert(); let ind16: u16 = index.try_into().unwrap(); @@ -320,48 +558,64 @@ impl InnerTcpConnection { let hdr = query_msg.header_mut(); hdr.set_id(ind16); - self.queue_query(query_msg); + if do_keepalive { + let mut msgadd = query_msg.clone().additional(); - // Now kick the worker to transmit the query - self.worker_notify.notify_one(); + // send an empty keepalive option + msgadd.opt(|opt| { + opt.tcp_keepalive(None) + }); + self.queue_query(&msgadd); + } else { + self.queue_query(query_msg); + } - index - } - pub async fn get_result(&self, query_msg: &Message, - index: usize) -> Result, &'static str> { - // Wait for reply - let mut query_vec = self.query_vec.lock().unwrap(); - let local_notify = query_vec.vec[index].as_mut().unwrap(). - complete.clone(); - drop(query_vec); - local_notify.notified().await; + // Now kick the worker to transmit the query + self.worker_notify.notify_one(); - // Get the lock again to take a look - let mut query_vec = self.query_vec.lock().unwrap(); - let opt_q = query_vec.vec[index].take(); - query_vec.count = query_vec.count - 1; - drop(query_vec); + Ok(index) + } - if let Some(q) = opt_q - { - if let SingleQueryState::Done(result) = q.state + pub async fn get_result(&self, query_msg: &Message, + index: usize) -> Result, &'static str> { + // Wait for reply + let mut query_vec = self.query_vec.lock().unwrap(); + let local_notify = query_vec.vec[index].as_mut().unwrap(). + complete.clone(); + drop(query_vec); + local_notify.notified().await; + + // Get the lock again to take a look + let mut query_vec = self.query_vec.lock().unwrap(); + let opt_q = query_vec.vec[index].take(); + query_vec.count = query_vec.count - 1; + println!("get_result: query count is now {}", query_vec.count); + if query_vec.count == 0 { + // The worker may be waiting for this + self.worker_notify.notify_one(); + } + drop(query_vec); + + if let Some(q) = opt_q { - if let Ok(answer) = &result + if let SingleQueryState::Done(result) = q.state { - if !answer.is_answer(query_msg) { - // Wrong answer, try again? - panic!("wrong answer"); + if let Ok(answer) = &result + { + if !answer.is_answer(query_msg) { + // Wrong answer, try again? + panic!("wrong answer"); + } } + return result; } - return result; + panic!("inconsistent state"); } + panic!("inconsistent state"); } - panic!("inconsistent state"); - } - fn cancel(&self, index: usize) { let mut query_vec = self.query_vec.lock().unwrap(); @@ -407,11 +661,12 @@ impl TcpConnection { -> Result, &'static str> { self.inner.query(query_msg).await } - pub fn query2 + AsRef<[u8]>> + pub fn query2 + AsRef<[u8]> + + Composer + Clone> (&self, query_msg: &mut MessageBuilder>>) -> Result { - let index = self.inner.query2(query_msg); + let index = self.inner.query2(query_msg)?; let msg = &query_msg.as_message(); Ok(Query::new(self, msg, index)) } From d7baa6410d2a2533a5e6f1044621088268a9fded Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Tue, 7 Mar 2023 16:53:58 +0100 Subject: [PATCH 006/124] Handle tcp read and write errors --- src/transport/tcp.rs | 226 +++++++++++++++++++++++++++++++++---------- 1 file changed, 175 insertions(+), 51 deletions(-) diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index 16d441003..8638a7ba0 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -37,13 +37,15 @@ use tokio::sync::Notify; use tokio::time::sleep; const ERR_IDLE_TIMEOUT: &str = "idle connection was closed"; +const ERR_READ_ERROR: &str = "read error"; +const ERR_WRITE_ERROR: &str = "write error"; // From RFC 7828. This should go somewhere with the option parsing const EDNS_TCP_KEEPALIE_TO_MS: u64 = 100; enum SingleQueryState { Busy, - Done(Result, &'static str>), + Done(Result, Arc>), Canceled, } @@ -70,6 +72,8 @@ enum ConnState { Active, Idle(Instant), IdleTimeout, + ReadError, + WriteError, } struct Keepalive { @@ -230,48 +234,111 @@ impl InnerTcpConnection { } } - async fn reader(&self, sock: &mut ReadHalf<'_>) { + async fn reader(&self, sock: &mut ReadHalf<'_>) -> Result<(), &str> { loop { - let len = sock.read_u16().await.unwrap() as usize; + let read_res = sock.read_u16().await; + let len = match read_res { + Ok(len) => len, + Err(error) => { + self.tcp_error(error); + let mut keepalive = self.keepalive.lock().unwrap(); + keepalive.state = ConnState::ReadError; + return Err(ERR_READ_ERROR); + } + } as usize; let mut buf = BytesMut::with_capacity(len); - let reslen = sock.read_buf(&mut buf).await.unwrap(); + let read_res = sock.read_buf(&mut buf).await; + match read_res { + Ok(_) => (), // We don't need the result + Err(error) => { + self.tcp_error(error); + let mut keepalive = self.keepalive.lock().unwrap(); + keepalive.state = ConnState::ReadError; + return Err(ERR_READ_ERROR); + } + }; let reply_message = Message::::from_octets(buf.into()); - if let Ok(answer) = reply_message { - // Check for a edns-tcp-keepalive option - let opt_record = answer.opt(); - if let Some(ref opts) = opt_record { - self.handle_opts(opts); - }; - self.insert_answer(answer); - - // else try with the next message. - } else { - panic!("Read error"); - //return Err(io::Error::new( - // io::ErrorKind::Other, - // "short buf", - //)); + + match reply_message { + Ok(answer) => { + // Check for a edns-tcp-keepalive option + let opt_record = answer.opt(); + if let Some(ref opts) = opt_record { + self.handle_opts(opts); + }; + self.insert_answer(answer); + } + Err(_) => { + // The only possible error is short message + let error = io::Error::new(io::ErrorKind::Other, + "short buf"); + self.tcp_error(error); + let mut keepalive = self.keepalive.lock().unwrap(); + keepalive.state = ConnState::ReadError; + return Err(ERR_READ_ERROR); + } } } } - async fn writer(&self, sock: &mut WriteHalf<'_>) { + fn tcp_error(&self, error: std::io::Error) { + // Update all requests that are in progress. Don't wait for + // any reply that may be on its way. + let arc_error = Arc::new(error); + let mut query_vec = self.query_vec.lock().unwrap(); + for query in &mut query_vec.vec { + match query { + None => { + continue; + } + Some(q) => { + match q.state { + SingleQueryState::Busy => { + q.state = + SingleQueryState:: + Done(Err( + arc_error + .clone())); + q.complete. + notify_one(); + } + SingleQueryState::Done(_) | + SingleQueryState::Canceled => + // Nothing to do + () + } + } + } + } + } + + async fn writer(&self, sock: &mut WriteHalf<'_>) -> + Result<(), &'static str> { loop { let mut tx_queue = self.tx_queue.lock().unwrap(); let head = tx_queue.pop_front(); drop(tx_queue); match head { Some(vec) => { - sock.write_all(&vec).await; + let res = sock.write_all(&vec).await; + if let Err(error) = res { + self.tcp_error(error); + let mut keepalive = + self.keepalive.lock().unwrap(); + keepalive.state = + ConnState::WriteError; + return Err(ERR_WRITE_ERROR); + } () } None => break, } } + Ok(()) } pub async fn worker(&self) -> Option<()> { @@ -284,21 +351,38 @@ impl InnerTcpConnection { loop { let writer_fut = self.writer(&mut write_stream); - println!("worker: before writer"); tokio::select! { - read = &mut reader_fut => { - panic!("reader terminated"); + res = &mut reader_fut => { + match res { + Ok(_) => + // The reader should not + // terminate without + // error. + panic!("reader terminated"), + Err(_) => + // Reader failed. Break + // out of loop and + // shut down + break + } } - write = writer_fut => { - // The writer is done. Wait - // for a notify - () + res = writer_fut => { + match res { + Ok(_) => + // The writer is done. + // Wait for a notify. + (), + Err(_) => + // Writer failed. Break + // out of loop and + // shut down + break + } } } let notify_fut = self.worker_notify.notified(); - println!("worker: before reader"); let mut opt_timeout: Option = None; let mut keepalive = self.keepalive.lock().unwrap(); if let ConnState::Idle(instant) = keepalive.state { @@ -323,10 +407,20 @@ impl InnerTcpConnection { if let Some(timeout) = opt_timeout { let sleep_fut = sleep(timeout); - println!("sleeping for {:?}", timeout); tokio::select! { - read = &mut reader_fut => { - panic!("reader terminated"); + res = &mut reader_fut => { + match res { + Ok(_) => + // The reader should not + // terminate without + // error. + panic!("reader terminated"), + Err(_) => + // Reader failed. Break + // out of loop and + // shut down + break + } } _ = notify_fut => { // Got notified, start writing @@ -340,8 +434,19 @@ impl InnerTcpConnection { } } else { tokio::select! { - read = &mut reader_fut => { - panic!("reader terminated"); + res = &mut reader_fut => { + match res { + Ok(_) => + // The reader should not + // terminate without + // error. + panic!("reader terminated"), + Err(_) => + // Reader failed. Break + // out of loop and + // shut down + break + } } _ = notify_fut => { // Got notified, start writing @@ -350,8 +455,6 @@ impl InnerTcpConnection { } } - println!("worker: got notified"); - // Check if the connection is idle let keepalive = self.keepalive.lock().unwrap(); match keepalive.state { @@ -361,17 +464,19 @@ impl InnerTcpConnection { ConnState::IdleTimeout => { break } + ConnState::ReadError | + ConnState::WriteError => { + panic!("Should not be here"); + } } drop(keepalive); } - // Send FIN to server - println!("worker: sending FIN"); - write_stream.shutdown().await; + // Send FIN to server. Ignore any errors. + let _ = write_stream.shutdown().await; // Stay around until the last query result is collected loop { - println!("worker: checking query count"); let query_vec = self.query_vec.lock().unwrap(); if query_vec.count == 0 { // We are done @@ -379,7 +484,6 @@ impl InnerTcpConnection { } drop(query_vec); - println!("waiting for last query to end"); self.worker_notify.notified().await; } None @@ -459,7 +563,7 @@ impl InnerTcpConnection { pub async fn query + AsMut<[u8]>> (&self, query_msg: &mut MessageBuilder>>) - -> Result, &'static str> { + -> Result, Arc> { let index = self.insert(); let ind16: u16 = index.try_into().unwrap(); @@ -484,7 +588,6 @@ impl InnerTcpConnection { complete.clone(); drop(query_vec); local_notify.notified().await; - println!("Got reply"); // Get the lock again to take a look let mut query_vec = self.query_vec.lock().unwrap(); @@ -536,6 +639,12 @@ impl InnerTcpConnection { // The connection has been closed. Report error return Err(ERR_IDLE_TIMEOUT); } + ConnState::ReadError => { + return Err(ERR_READ_ERROR); + } + ConnState::WriteError => { + return Err(ERR_WRITE_ERROR); + } } let mut do_keepalive = false; @@ -562,23 +671,39 @@ impl InnerTcpConnection { let mut msgadd = query_msg.clone().additional(); // send an empty keepalive option - msgadd.opt(|opt| { + let res = msgadd.opt(|opt| { opt.tcp_keepalive(None) }); - self.queue_query(&msgadd); + match res { + Ok(_) => + self.queue_query(&msgadd), + Err(_) => { + // Adding keepalive option + // failed. Send the original + // request and turn the + // send_keepalive flag back on + let mut keepalive = + self.keepalive.lock() + .unwrap(); + keepalive.send_keepalive = + true; + drop(keepalive); + self.queue_query(query_msg); + } + } } else { self.queue_query(query_msg); } - // Now kick the worker to transmit the query self.worker_notify.notify_one(); Ok(index) } - pub async fn get_result(&self, query_msg: &Message, - index: usize) -> Result, &'static str> { + pub async fn get_result(&self, + query_msg: &Message, index: usize) -> + Result, Arc> { // Wait for reply let mut query_vec = self.query_vec.lock().unwrap(); let local_notify = query_vec.vec[index].as_mut().unwrap(). @@ -590,7 +715,6 @@ impl InnerTcpConnection { let mut query_vec = self.query_vec.lock().unwrap(); let opt_q = query_vec.vec[index].take(); query_vec.count = query_vec.count - 1; - println!("get_result: query count is now {}", query_vec.count); if query_vec.count == 0 { // The worker may be waiting for this self.worker_notify.notify_one(); @@ -658,7 +782,7 @@ impl TcpConnection { } pub async fn query + AsRef<[u8]>> (&self, query_msg: &mut MessageBuilder>>) - -> Result, &'static str> { + -> Result, Arc> { self.inner.query(query_msg).await } pub fn query2 + AsRef<[u8]> + @@ -686,7 +810,7 @@ impl Query { state: QueryState::Busy(index) } } pub async fn get_result(&mut self) -> - Result, &'static str> { + Result, Arc> { // Just the result of get_result on tranport. We should record // that we got an answer to avoid asking again match self.state { From 544933a3b1b1067321fa5d291fb48779a8645310 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Tue, 7 Mar 2023 16:56:21 +0100 Subject: [PATCH 007/124] A bit of cleanup --- src/transport/tcp.rs | 65 ++------------------------------------------ 1 file changed, 3 insertions(+), 62 deletions(-) diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index 8638a7ba0..29219a9e3 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -561,61 +561,7 @@ impl InnerTcpConnection { tx_queue.push_back(vec.to_vec()); } - pub async fn query + AsMut<[u8]>> - (&self, query_msg: &mut MessageBuilder>>) - -> Result, Arc> { - let index = self.insert(); - let ind16: u16 = index.try_into().unwrap(); - - // We set the ID to the array index. Defense in depth - // suggests that a random ID is better because it works - // even if TCP sequence numbers could be predicted. However, - // Section 9.3 of RFC 5452 recommends retrying over TCP - // if many spoofed answers arrive over UDP: "TCP, by the - // nature of its use of sequence numbers, is far more - // resilient against forgery by third parties." - let hdr = query_msg.header_mut(); - hdr.set_id(ind16); - - self.queue_query(query_msg); - - // Now kick the worker to transmit the query - self.worker_notify.notify_one(); - - // Wait for reply - let mut query_vec = self.query_vec.lock().unwrap(); - let local_notify = query_vec.vec[index].as_mut().unwrap(). - complete.clone(); - drop(query_vec); - local_notify.notified().await; - - // Get the lock again to take a look - let mut query_vec = self.query_vec.lock().unwrap(); - let opt_q = query_vec.vec[index].take(); - query_vec.count = query_vec.count - 1; - drop(query_vec); - - if let Some(q) = opt_q - { - if let SingleQueryState::Done(result) = q.state - { - if let Ok(answer) = &result - { - if !answer.is_answer(&query_msg. - as_message()) { - // Wrong answer, try again? - panic!("wrong answer"); - } - } - return result; - } - panic!("inconsistent state"); - } - - panic!("inconsistent state"); - } - - pub fn query2 + AsRef<[u8]> + + pub fn query + AsRef<[u8]> + Composer + Clone> (&self, query_msg: &mut MessageBuilder>> @@ -780,17 +726,12 @@ impl TcpConnection { pub async fn worker(&self) -> Option<()> { self.inner.worker().await } - pub async fn query + AsRef<[u8]>> - (&self, query_msg: &mut MessageBuilder>>) - -> Result, Arc> { - self.inner.query(query_msg).await - } - pub fn query2 + AsRef<[u8]> + + pub fn query + AsRef<[u8]> + Composer + Clone> (&self, query_msg: &mut MessageBuilder>>) -> Result { - let index = self.inner.query2(query_msg)?; + let index = self.inner.query(query_msg)?; let msg = &query_msg.as_message(); Ok(Query::new(self, msg, index)) } From cf76e5765f92c3f73bcc37a8c1b2aab02551e341 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 13 Mar 2023 17:08:02 +0100 Subject: [PATCH 008/124] Timeouts and a bit of cleanup --- src/transport/tcp.rs | 210 ++++++++++++++++++++++++++++--------------- 1 file changed, 137 insertions(+), 73 deletions(-) diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index 29219a9e3..9f91ac5a3 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -5,19 +5,18 @@ // TODO: // - errors -// - read errors -// - write errors // - connect errors? Retry after connection refused? // - server errors // - ID out of range // - ID not in use // - reply for wrong query // - timeouts -// - channel timeout // - request timeout +// - limit number of outstanding queries to 32K // - create new TCP connection after end/failure of previous one use std::collections::VecDeque; +use std::ops::DerefMut; use std::sync::Arc; use std::sync::Mutex as Std_mutex; use std::time::{Duration, Instant}; @@ -38,11 +37,21 @@ use tokio::time::sleep; const ERR_IDLE_TIMEOUT: &str = "idle connection was closed"; const ERR_READ_ERROR: &str = "read error"; +const ERR_READ_TIMEOUT: &str = "read timeout"; const ERR_WRITE_ERROR: &str = "write error"; // From RFC 7828. This should go somewhere with the option parsing const EDNS_TCP_KEEPALIE_TO_MS: u64 = 100; +// Implement a simple response timer to see if the connection and the server +// are alive. Set the timer when the connection goes from idle to busy. +// Reset the timer each time a reply arrives. Cancel the timer when the +// connection goes back to idle. When the time expires, mark all outstanding +// queries as timed out and shutdown the connection. +// +// Note: nsd has 120 seconds, unbound has 3 seconds. +const RESPONSE_TIMEOUT_S: u64 = 19; + enum SingleQueryState { Busy, Done(Result, Arc>), @@ -69,14 +78,15 @@ struct Queries { } enum ConnState { - Active, + Active(Option), Idle(Instant), IdleTimeout, ReadError, + ReadTimeout, WriteError, } -struct Keepalive { +struct Status { state: ConnState, // For edns-tcp-keepalive, we have a boolean the specifies if we @@ -91,8 +101,8 @@ struct Keepalive { struct InnerTcpConnection { stream: Std_mutex, - /* keepalive */ - keepalive: Std_mutex, + /* status */ + status: Std_mutex, /* Vector with outstanding queries */ query_vec: Std_mutex, @@ -124,8 +134,8 @@ impl InnerTcpConnection { let tcp = TcpStream::connect(addr).await?; Ok(Self { stream: Std_mutex::new(tcp), - keepalive: Std_mutex::new(Keepalive { - state: ConnState::Active, + status: Std_mutex::new(Status { + state: ConnState::Active(None), send_keepalive: true, idle_timeout: None, }), @@ -140,7 +150,32 @@ impl InnerTcpConnection { }) } + // Take a query out of query_vec and decrement the query count + fn take_query(&self, index: usize) -> Option + { + let mut query_vec = self.query_vec.lock().unwrap(); + self.vec_take_query(query_vec.deref_mut(), index) + } + + // Very similar to take_query, but sometime the caller already has + // a lock on the mutex + fn vec_take_query(&self, query_vec: &mut Queries, index: usize) -> + Option{ + let query = query_vec.vec[index].take(); + query_vec.count = query_vec.count - 1; + if query_vec.count == 0 { + // The worker may be waiting for this + self.worker_notify.notify_one(); + } + query + } + fn insert_answer(&self, answer: Message) { + // We got an answer, reset the timer + let mut status = self.status.lock().unwrap(); + status.state = ConnState::Active(Some(Instant::now())); + drop(status); + let ind16 = answer.header().id(); let index: usize = ind16.into(); @@ -174,12 +209,9 @@ impl InnerTcpConnection { //`The query has been // canceled already // Clean up. - let _ = query_vec. - vec[index]. - take(); - query_vec.count = - query_vec. - count - 1; + let _ = self.vec_take_query( + query_vec.deref_mut(), + index); } SingleQueryState::Done(_) => { // Already got a @@ -191,18 +223,25 @@ impl InnerTcpConnection { } query_vec.busy = query_vec.busy-1; if query_vec.busy == 0 { - let mut keepalive = self.keepalive.lock().unwrap(); - if keepalive.idle_timeout == None { + let mut status = self.status.lock().unwrap(); + + // Clear the activity timer. There is no need to do + // this because state will be set to either IdleTimeout + // or Idle just below. However, it is nicer to keep + // this indenpendent. + status.state = ConnState::Active(None); + + if status.idle_timeout == None { // Assume that we can just move to IdleTimeout // state - keepalive.state = ConnState::IdleTimeout; + status.state = ConnState::IdleTimeout; // Notify the worker. Then the worker can // close the tcp connection self.worker_notify.notify_one(); } else { - keepalive.state = + status.state = ConnState::Idle(Instant::now()); // Notify the worker. The worker waits for @@ -214,8 +253,8 @@ impl InnerTcpConnection { fn handle_keepalive(&self, opt_value: TcpKeepalive) { if let Some(value) = opt_value.timeout() { - let mut keepalive = self.keepalive.lock().unwrap(); - keepalive.idle_timeout = + let mut status = self.status.lock().unwrap(); + status.idle_timeout = Some(Duration::from_millis(u64::from(value) * EDNS_TCP_KEEPALIE_TO_MS)); } @@ -241,8 +280,8 @@ impl InnerTcpConnection { Ok(len) => len, Err(error) => { self.tcp_error(error); - let mut keepalive = self.keepalive.lock().unwrap(); - keepalive.state = ConnState::ReadError; + let mut status = self.status.lock().unwrap(); + status.state = ConnState::ReadError; return Err(ERR_READ_ERROR); } } as usize; @@ -254,8 +293,8 @@ impl InnerTcpConnection { Ok(_) => (), // We don't need the result Err(error) => { self.tcp_error(error); - let mut keepalive = self.keepalive.lock().unwrap(); - keepalive.state = ConnState::ReadError; + let mut status = self.status.lock().unwrap(); + status.state = ConnState::ReadError; return Err(ERR_READ_ERROR); } }; @@ -276,8 +315,8 @@ impl InnerTcpConnection { let error = io::Error::new(io::ErrorKind::Other, "short buf"); self.tcp_error(error); - let mut keepalive = self.keepalive.lock().unwrap(); - keepalive.state = ConnState::ReadError; + let mut status = self.status.lock().unwrap(); + status.state = ConnState::ReadError; return Err(ERR_READ_ERROR); } } @@ -326,9 +365,9 @@ impl InnerTcpConnection { let res = sock.write_all(&vec).await; if let Err(error) = res { self.tcp_error(error); - let mut keepalive = - self.keepalive.lock().unwrap(); - keepalive.state = + let mut status = + self.status.lock().unwrap(); + status.state = ConnState::WriteError; return Err(ERR_WRITE_ERROR); } @@ -384,14 +423,31 @@ impl InnerTcpConnection { let notify_fut = self.worker_notify.notified(); let mut opt_timeout: Option = None; - let mut keepalive = self.keepalive.lock().unwrap(); - if let ConnState::Idle(instant) = keepalive.state { - if let Some(timeout) = keepalive.idle_timeout { + let mut status = self.status.lock().unwrap(); + match status.state { + ConnState::Active(opt_instant) => { + if let Some(instant) = opt_instant { + let timeout = Duration::from_secs( + RESPONSE_TIMEOUT_S); + let elapsed = instant.elapsed(); + if elapsed > timeout { + let error = io::Error::new( + io::ErrorKind::Other, + "read timeout"); + self.tcp_error(error); + status.state = ConnState::ReadTimeout; + break; + } + opt_timeout = Some(timeout - elapsed); + } + } + ConnState::Idle(instant) => { + if let Some(timeout) = status.idle_timeout { let elapsed = instant.elapsed(); if elapsed >= timeout { // Move to IdleTimeout and end // the loop - keepalive.state = + status.state = ConnState::IdleTimeout; break; } @@ -400,8 +456,15 @@ impl InnerTcpConnection { else { panic!("Idle state but no timeout"); } + } + ConnState::IdleTimeout | + ConnState::ReadError | + ConnState::WriteError => + (), // No timers here + ConnState::ReadTimeout => panic!( + "should not be in loop with ReadTimeout") } - drop(keepalive); + drop(status); if let Some(timeout) = opt_timeout { @@ -456,20 +519,21 @@ impl InnerTcpConnection { } // Check if the connection is idle - let keepalive = self.keepalive.lock().unwrap(); - match keepalive.state { - ConnState::Active | ConnState::Idle(_) => { + let status = self.status.lock().unwrap(); + match status.state { + ConnState::Active(_) | ConnState::Idle(_) => { // Keep going } ConnState::IdleTimeout => { break } ConnState::ReadError | + ConnState::ReadTimeout | ConnState::WriteError => { panic!("Should not be here"); } } - drop(keepalive); + drop(status); } // Send FIN to server. Ignore any errors. @@ -567,18 +631,23 @@ impl InnerTcpConnection { query_msg: &mut MessageBuilder>> ) -> Result { - // Check the state of the connection, fail if the connection is in - // IdleTimeout. If the connection is Idle, move it back to Active - // Also check for the need to send a keepalive - let mut keepalive = self.keepalive.lock().unwrap(); - match keepalive.state { - ConnState::Active => { - // Nothing to do + // Check the state of the connection, fail if the connection + // is in IdleTimeout. If the connection is Idle, move it + // back to Active. Also check for the need to send a keepalive + let mut status = self.status.lock().unwrap(); + match status.state { + ConnState::Active(timer) => { + // Set timer if we don't have one already + if timer == None { + status.state = ConnState::Active(Some( + Instant::now())); + } () } ConnState::Idle(_) => { // Go back to active - keepalive.state = ConnState::Active; + status.state = ConnState::Active(Some( + Instant::now())); () } ConnState::IdleTimeout => { @@ -588,17 +657,20 @@ impl InnerTcpConnection { ConnState::ReadError => { return Err(ERR_READ_ERROR); } + ConnState::ReadTimeout => { + return Err(ERR_READ_TIMEOUT); + } ConnState::WriteError => { return Err(ERR_WRITE_ERROR); } } let mut do_keepalive = false; - if keepalive.send_keepalive { + if status.send_keepalive { do_keepalive = true; - keepalive.send_keepalive = false; + status.send_keepalive = false; } - drop(keepalive); + drop(status); let index = self.insert(); let ind16: u16 = index.try_into().unwrap(); @@ -628,12 +700,12 @@ impl InnerTcpConnection { // failed. Send the original // request and turn the // send_keepalive flag back on - let mut keepalive = - self.keepalive.lock() + let mut status = + self.status.lock() .unwrap(); - keepalive.send_keepalive = + status.send_keepalive = true; - drop(keepalive); + drop(status); self.queue_query(query_msg); } } @@ -652,21 +724,13 @@ impl InnerTcpConnection { Result, Arc> { // Wait for reply let mut query_vec = self.query_vec.lock().unwrap(); - let local_notify = query_vec.vec[index].as_mut().unwrap(). - complete.clone(); + let local_notify = query_vec.vec[index].as_mut(). + unwrap().complete.clone(); drop(query_vec); local_notify.notified().await; - // Get the lock again to take a look - let mut query_vec = self.query_vec.lock().unwrap(); - let opt_q = query_vec.vec[index].take(); - query_vec.count = query_vec.count - 1; - if query_vec.count == 0 { - // The worker may be waiting for this - self.worker_notify.notify_one(); - } - drop(query_vec); - + // take a look + let opt_q = self.take_query(index); if let Some(q) = opt_q { if let SingleQueryState::Done(result) = q.state @@ -697,7 +761,8 @@ impl InnerTcpConnection { match &query.state { SingleQueryState::Busy => { query.state = - SingleQueryState::Canceled; + SingleQueryState:: + Canceled; return; } SingleQueryState::Canceled => { @@ -705,11 +770,9 @@ impl InnerTcpConnection { } SingleQueryState::Done(_) => { // Remove the result - let _ = query_vec. - vec[index].take(); - query_vec.count = - query_vec.count - 1; - drop(query_vec); + let _ = self.vec_take_query( + query_vec.deref_mut(), + index); } } } @@ -748,7 +811,8 @@ impl Query { Self { transport: transport.inner.clone(), query_msg: msg, - state: QueryState::Busy(index) } + state: QueryState::Busy(index) + } } pub async fn get_result(&mut self) -> Result, Arc> { From 297d5eb75fa2e31919fd88f6b493f9d02abdae4f Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Wed, 15 Mar 2023 16:00:17 +0100 Subject: [PATCH 009/124] Limit the number of concurrent queries to 32K and some small changes --- src/transport/tcp.rs | 88 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 70 insertions(+), 18 deletions(-) diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index 9f91ac5a3..abd6ae717 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -39,6 +39,7 @@ const ERR_IDLE_TIMEOUT: &str = "idle connection was closed"; const ERR_READ_ERROR: &str = "read error"; const ERR_READ_TIMEOUT: &str = "read timeout"; const ERR_WRITE_ERROR: &str = "write error"; +const ERR_TOO_MANY_QUERIES: &str = "too many outstanding queries"; // From RFC 7828. This should go somewhere with the option parsing const EDNS_TCP_KEEPALIE_TO_MS: u64 = 100; @@ -273,6 +274,7 @@ impl InnerTcpConnection { } } + // This function is not async cancellation safe async fn reader(&self, sock: &mut ReadHalf<'_>) -> Result<(), &str> { loop { let read_res = sock.read_u16().await; @@ -288,16 +290,43 @@ impl InnerTcpConnection { let mut buf = BytesMut::with_capacity(len); - let read_res = sock.read_buf(&mut buf).await; - match read_res { - Ok(_) => (), // We don't need the result - Err(error) => { - self.tcp_error(error); - let mut status = self.status.lock().unwrap(); - status.state = ConnState::ReadError; - return Err(ERR_READ_ERROR); + loop { + let curlen = buf.len(); + if curlen >= len { + if curlen > len { + panic!( + "reader: got too much data {curlen}, expetect {len}"); + } + + // We got what we need + break; } - }; + + let read_res = sock.read_buf(&mut buf).await; + + match read_res { + Ok(readlen) => { + if readlen == 0 { + let error = io::Error::new( + io::ErrorKind::Other, + "unexpected end of data"); + self.tcp_error(error); + let mut status = self.status.lock(). + unwrap(); + status.state = ConnState::ReadError; + return Err(ERR_READ_ERROR); + } + } + Err(error) => { + self.tcp_error(error); + let mut status = self.status.lock().unwrap(); + status.state = ConnState::ReadError; + return Err(ERR_READ_ERROR); + } + }; + + // Check if we are done at the head of the loop + } let reply_message = Message::::from_octets(buf.into()); @@ -380,6 +409,8 @@ impl InnerTcpConnection { Ok(()) } + // This function is not async cancellation safe because it calls + // reader which is not async cancellation safe pub async fn worker(&self) -> Option<()> { let mut stream = self.stream.lock().unwrap(); let (mut read_stream, mut write_stream) = stream.split(); @@ -563,20 +594,34 @@ impl InnerTcpConnection { // Insert a message in the query vector. Return the index fn insert(&self) - -> usize { + -> Result { let q = Some(SingleQuery { state: SingleQueryState::Busy, complete: Arc::new(Notify::new()), }); let mut query_vec = self.query_vec.lock().unwrap(); + + // Fail if there are to many entries already in this vector + // We cannot have more than u16::MAX entries because the + // index needs to fit in an u16. For efficiency we want to + // keep the vector half empty. So we return a failure if + // 2*count > u16::MAX + if 2*query_vec.count > u16::MAX.into() { + return Err(ERR_TOO_MANY_QUERIES); + } + let vec_len = query_vec.vec.len(); - if vec_len < 2*(query_vec.count+1) { + + // Append if the amount of empty space in the vector is less + // than half. But limit vec_len to u16::MAX + if vec_len < 2*(query_vec.count+1) && vec_len < + u16::MAX.into() { // Just append query_vec.vec.push(q); query_vec.count = query_vec.count + 1; query_vec.busy = query_vec.busy + 1; let index = query_vec.vec.len()-1; - return index; + return Ok(index); } let loc_curr = query_vec.curr; @@ -589,7 +634,7 @@ impl InnerTcpConnection { None => { Self::insert_at(&mut query_vec, index, q); - return index; + return Ok(index); } } } @@ -605,7 +650,7 @@ impl InnerTcpConnection { None => { Self::insert_at(&mut query_vec, index, q); - return index; + return Ok(index); } } } @@ -665,6 +710,11 @@ impl InnerTcpConnection { } } + // Note that insert may fail if there are too many + // outstanding queires. First call insert before checking + // send_keepalive. + let index = self.insert()?; + let mut do_keepalive = false; if status.send_keepalive { do_keepalive = true; @@ -672,7 +722,6 @@ impl InnerTcpConnection { } drop(status); - let index = self.insert(); let ind16: u16 = index.try_into().unwrap(); // We set the ID to the array index. Defense in depth @@ -737,9 +786,12 @@ impl InnerTcpConnection { { if let Ok(answer) = &result { - if !answer.is_answer(query_msg) { - // Wrong answer, try again? - panic!("wrong answer"); + if !answer.is_answer( + query_msg) { + return Err(Arc::new( + io::Error::new( + io::ErrorKind::Other, + "wrong answer"))); } } return result; From 1500676dd256e7f372e785b718f4538fe33940ce Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 17 Mar 2023 15:16:48 +0100 Subject: [PATCH 010/124] Rewrote tokio::select! loop --- src/transport/tcp.rs | 206 ++++++++++++++++++++++--------------------- 1 file changed, 105 insertions(+), 101 deletions(-) diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index abd6ae717..31dff9a84 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -97,6 +97,7 @@ struct Status { // specify time in Rust? Currently we specify it in milliseconds. send_keepalive: bool, idle_timeout: Option, + do_shutdown: bool, } struct InnerTcpConnection { @@ -112,6 +113,7 @@ struct InnerTcpConnection { tx_queue: Std_mutex>>, worker_notify: Notify, + writer_notify: Notify, } pub struct TcpConnection { @@ -139,6 +141,7 @@ impl InnerTcpConnection { state: ConnState::Active(None), send_keepalive: true, idle_timeout: None, + do_shutdown: false, }), query_vec: Std_mutex::new(Queries { count: 0, @@ -148,6 +151,7 @@ impl InnerTcpConnection { }), tx_queue: Std_mutex::new(VecDeque::new()), worker_notify: Notify::new(), + writer_notify: Notify::new(), }) } @@ -383,76 +387,63 @@ impl InnerTcpConnection { } } + // This function is not async cancellation safe async fn writer(&self, sock: &mut WriteHalf<'_>) -> Result<(), &'static str> { loop { - let mut tx_queue = self.tx_queue.lock().unwrap(); - let head = tx_queue.pop_front(); - drop(tx_queue); - match head { - Some(vec) => { - let res = sock.write_all(&vec).await; - if let Err(error) = res { - self.tcp_error(error); - let mut status = - self.status.lock().unwrap(); - status.state = - ConnState::WriteError; - return Err(ERR_WRITE_ERROR); + loop { + // Check if we need to shutdown + let status = self.status.lock().unwrap(); + let do_shutdown = status.do_shutdown; + drop(status); + + if do_shutdown { + // Ignore errors + _ = sock.shutdown().await; + + // Do we need to clear do_shutdown? + break; + } + + let mut tx_queue = self.tx_queue.lock(). + unwrap(); + let head = tx_queue.pop_front(); + drop(tx_queue); + match head { + Some(vec) => { + let res = sock.write_all(&vec).await; + if let Err(error) = res { + self.tcp_error(error); + let mut status = + self.status.lock(). + unwrap(); + status.state = + ConnState::WriteError; + return Err(ERR_WRITE_ERROR); + } + () + } + None => + break, } - () - } - None => - break, } + + self.writer_notify.notified().await; } - Ok(()) } // This function is not async cancellation safe because it calls - // reader which is not async cancellation safe + // reader and writer which are not async cancellation safe pub async fn worker(&self) -> Option<()> { let mut stream = self.stream.lock().unwrap(); let (mut read_stream, mut write_stream) = stream.split(); let reader_fut = self.reader(&mut read_stream); tokio::pin!(reader_fut); + let writer_fut = self.writer(&mut write_stream); + tokio::pin!(writer_fut); loop { - let writer_fut = self.writer(&mut write_stream); - - tokio::select! { - res = &mut reader_fut => { - match res { - Ok(_) => - // The reader should not - // terminate without - // error. - panic!("reader terminated"), - Err(_) => - // Reader failed. Break - // out of loop and - // shut down - break - } - } - res = writer_fut => { - match res { - Ok(_) => - // The writer is done. - // Wait for a notify. - (), - Err(_) => - // Writer failed. Break - // out of loop and - // shut down - break - } - } - } - - let notify_fut = self.worker_notify.notified(); - let mut opt_timeout: Option = None; let mut status = self.status.lock().unwrap(); match status.state { @@ -498,55 +489,58 @@ impl InnerTcpConnection { drop(status); - if let Some(timeout) = opt_timeout { - let sleep_fut = sleep(timeout); + // For simplicity, make sure we always have a timeout + let timeout = match opt_timeout { + Some(timeout) => timeout, + None => + // Just use the response timeout + Duration::from_secs(RESPONSE_TIMEOUT_S) + }; - tokio::select! { - res = &mut reader_fut => { - match res { - Ok(_) => - // The reader should not - // terminate without - // error. - panic!("reader terminated"), - Err(_) => - // Reader failed. Break - // out of loop and - // shut down - break - } - } - _ = notify_fut => { - // Got notified, start writing - () - } - _ = sleep_fut => { - // Idle timeout expired, just - // continue with the loop - () - } + let sleep_fut = sleep(timeout); + let notify_fut = self.worker_notify.notified(); + + tokio::select! { + res = &mut reader_fut => { + match res { + Ok(_) => + // The reader should not + // terminate without + // error. + panic!("reader terminated"), + Err(_) => + // Reader failed. Break + // out of loop and + // shut down + break + } } - } else { - tokio::select! { - res = &mut reader_fut => { - match res { + res = &mut writer_fut => { + match res { Ok(_) => - // The reader should not - // terminate without - // error. - panic!("reader terminated"), - Err(_) => - // Reader failed. Break - // out of loop and - // shut down - break - } - } - _ = notify_fut => { - // Got notified, start writing - () + // The writer should not + // terminate without + // error. + panic!("reader terminated"), + Err(_) => + // Writer failed. Break + // out of loop and + // shut down + break } } + + _ = sleep_fut => { + // Timeout expired, just + // continue with the loop + () + } + _ = notify_fut => { + // Got notifies, go through the loop + // to see what changed. + () + } + } // Check if the connection is idle @@ -567,8 +561,18 @@ impl InnerTcpConnection { drop(status); } - // Send FIN to server. Ignore any errors. - let _ = write_stream.shutdown().await; + // We can't see a FIN directly because the writer_fut owns + // write_stream. + let mut status = self.status.lock().unwrap(); + status.do_shutdown = true; + drop(status); + + // Kick writer + self.writer_notify.notify_one(); + + // Wait for writer to terminate. Ignore the result. We may + // want a timer here + _ = writer_fut.await; // Stay around until the last query result is collected loop { @@ -762,8 +766,8 @@ impl InnerTcpConnection { self.queue_query(query_msg); } - // Now kick the worker to transmit the query - self.worker_notify.notify_one(); + // Now kick the writer to transmit the query + self.writer_notify.notify_one(); Ok(index) } From 24ba9a6041f71bd2b3f06359f19ae9b0fdedc0f1 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 30 Mar 2023 12:27:44 +0200 Subject: [PATCH 011/124] Rename transport to net. --- Cargo.toml | 2 +- src/lib.rs | 2 +- src/net/client/mod.rs | 5 +++++ src/{transport => net/client}/tcp.rs | 0 src/net/mod.rs | 5 +++++ src/transport/mod.rs | 5 ----- 6 files changed, 12 insertions(+), 7 deletions(-) create mode 100644 src/net/client/mod.rs rename src/{transport => net/client}/tcp.rs (100%) create mode 100644 src/net/mod.rs delete mode 100644 src/transport/mod.rs diff --git a/Cargo.toml b/Cargo.toml index e2955fa48..331076df1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ serde = ["dep:serde", "octseq/serde"] sign = ["std"] smallvec = ["dep:smallvec", "octseq/smallvec"] std = [] -transport = ["bytes", "tokio"] +net = ["bytes", "tokio"] tsig = ["bytes", "ring", "smallvec"] validate = ["std", "ring"] zonefile = ["bytes", "std"] diff --git a/src/lib.rs b/src/lib.rs index c08354145..d13e66273 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,7 +125,7 @@ pub mod rdata; pub mod resolv; pub mod sign; pub mod test; -pub mod transport; +pub mod net; pub mod tsig; pub mod utils; pub mod validate; diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs new file mode 100644 index 000000000..1dd6e4be8 --- /dev/null +++ b/src/net/client/mod.rs @@ -0,0 +1,5 @@ +//! DNS transport protocols +#![cfg(feature = "net")] +#![cfg_attr(docsrs, doc(cfg(feature = "net")))] + +pub mod tcp; diff --git a/src/transport/tcp.rs b/src/net/client/tcp.rs similarity index 100% rename from src/transport/tcp.rs rename to src/net/client/tcp.rs diff --git a/src/net/mod.rs b/src/net/mod.rs new file mode 100644 index 000000000..ebc56aacd --- /dev/null +++ b/src/net/mod.rs @@ -0,0 +1,5 @@ +//! DNS transport protocols +#![cfg(feature = "net")] +#![cfg_attr(docsrs, doc(cfg(feature = "net")))] + +pub mod client; diff --git a/src/transport/mod.rs b/src/transport/mod.rs deleted file mode 100644 index b56d82637..000000000 --- a/src/transport/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -//! DNS transport protocols -#![cfg(feature = "transport")] -#![cfg_attr(docsrs, doc(cfg(feature = "transport")))] - -pub mod tcp; From 1b4244c7d98fc8d203b3bdb1b106495393605e9e Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 17 Apr 2023 09:19:17 +0200 Subject: [PATCH 012/124] Renamed tcp.rs to tcp_mutex.ts and fixed clippy warnings. --- src/net/client/mod.rs | 2 +- src/net/client/{tcp.rs => tcp_mutex.rs} | 188 +++++++++++++----------- 2 files changed, 105 insertions(+), 85 deletions(-) rename src/net/client/{tcp.rs => tcp_mutex.rs} (87%) diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 1dd6e4be8..c4588996e 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -2,4 +2,4 @@ #![cfg(feature = "net")] #![cfg_attr(docsrs, doc(cfg(feature = "net")))] -pub mod tcp; +pub mod tcp_mutex; diff --git a/src/net/client/tcp.rs b/src/net/client/tcp_mutex.rs similarity index 87% rename from src/net/client/tcp.rs rename to src/net/client/tcp_mutex.rs index 31dff9a84..64971de22 100644 --- a/src/net/client/tcp.rs +++ b/src/net/client/tcp_mutex.rs @@ -101,7 +101,7 @@ struct Status { } struct InnerTcpConnection { - stream: Std_mutex, + stream: Std_mutex>, /* status */ status: Std_mutex, @@ -136,7 +136,7 @@ impl InnerTcpConnection { io::Result { let tcp = TcpStream::connect(addr).await?; Ok(Self { - stream: Std_mutex::new(tcp), + stream: Std_mutex::new(Some(tcp)), status: Std_mutex::new(Status { state: ConnState::Active(None), send_keepalive: true, @@ -167,7 +167,7 @@ impl InnerTcpConnection { fn vec_take_query(&self, query_vec: &mut Queries, index: usize) -> Option{ let query = query_vec.vec[index].take(); - query_vec.count = query_vec.count - 1; + query_vec.count -= 1; if query_vec.count == 0 { // The worker may be waiting for this self.worker_notify.notify_one(); @@ -226,7 +226,7 @@ impl InnerTcpConnection { } } } - query_vec.busy = query_vec.busy-1; + query_vec.busy -= 1; if query_vec.busy == 0 { let mut status = self.status.lock().unwrap(); @@ -236,7 +236,7 @@ impl InnerTcpConnection { // this indenpendent. status.state = ConnState::Active(None); - if status.idle_timeout == None { + if status.idle_timeout.is_none() { // Assume that we can just move to IdleTimeout // state status.state = ConnState::IdleTimeout; @@ -269,11 +269,8 @@ impl InnerTcpConnection { (&self, opts: &OptRecord) { for option in opts.iter() { let opt = option.unwrap(); - match opt { - AllOptData::TcpKeepalive(tcpkeepalive) => { - self.handle_keepalive(tcpkeepalive); - } - _ => {} + if let AllOptData::TcpKeepalive(tcpkeepalive) = opt { + self.handle_keepalive(tcpkeepalive); } } } @@ -393,9 +390,14 @@ impl InnerTcpConnection { loop { loop { // Check if we need to shutdown - let status = self.status.lock().unwrap(); - let do_shutdown = status.do_shutdown; - drop(status); + let do_shutdown = { + // Extra block to satisfy clippy + // await_holding_lock + let status = self.status.lock() + .unwrap(); + status.do_shutdown + // drop(status); + }; if do_shutdown { // Ignore errors @@ -405,10 +407,14 @@ impl InnerTcpConnection { break; } - let mut tx_queue = self.tx_queue.lock(). - unwrap(); - let head = tx_queue.pop_front(); - drop(tx_queue); + let head = { + // Extra block to satisfy clippy + // await_holding_lock + let mut tx_queue = self.tx_queue + .lock().unwrap(); + tx_queue.pop_front() + // drop(tx_queue); + }; match head { Some(vec) => { let res = sock.write_all(&vec).await; @@ -421,7 +427,6 @@ impl InnerTcpConnection { ConnState::WriteError; return Err(ERR_WRITE_ERROR); } - () } None => break, @@ -435,7 +440,13 @@ impl InnerTcpConnection { // This function is not async cancellation safe because it calls // reader and writer which are not async cancellation safe pub async fn worker(&self) -> Option<()> { - let mut stream = self.stream.lock().unwrap(); + let mut stream = { + // Extra block to satisfy clippy + // await_holding_lock + let mut opt_stream = self.stream.lock().unwrap(); + opt_stream.take().unwrap() + // drop(opt_stream); + }; let (mut read_stream, mut write_stream) = stream.split(); let reader_fut = self.reader(&mut read_stream); @@ -444,50 +455,53 @@ impl InnerTcpConnection { tokio::pin!(writer_fut); loop { - let mut opt_timeout: Option = None; - let mut status = self.status.lock().unwrap(); - match status.state { - ConnState::Active(opt_instant) => { - if let Some(instant) = opt_instant { - let timeout = Duration::from_secs( - RESPONSE_TIMEOUT_S); - let elapsed = instant.elapsed(); - if elapsed > timeout { - let error = io::Error::new( - io::ErrorKind::Other, - "read timeout"); - self.tcp_error(error); - status.state = ConnState::ReadTimeout; - break; - } - opt_timeout = Some(timeout - elapsed); - } - } - ConnState::Idle(instant) => { - if let Some(timeout) = status.idle_timeout { - let elapsed = instant.elapsed(); - if elapsed >= timeout { - // Move to IdleTimeout and end - // the loop - status.state = - ConnState::IdleTimeout; + let opt_timeout: Option = { + // Extra block to satisfy clippy + // await_holding_lock + let mut status = self.status.lock().unwrap(); + match status.state { + ConnState::Active(opt_instant) => { + if let Some(instant) = opt_instant { + let timeout = Duration::from_secs( + RESPONSE_TIMEOUT_S); + let elapsed = instant.elapsed(); + if elapsed > timeout { + let error = io::Error::new( + io::ErrorKind::Other, + "read timeout"); + self.tcp_error(error); + status.state = ConnState::ReadTimeout; break; + } + Some(timeout - elapsed) } - opt_timeout = Some(timeout - elapsed); - } - else { - panic!("Idle state but no timeout"); + else { None } + } + ConnState::Idle(instant) => { + if let Some(timeout) = status.idle_timeout { + let elapsed = instant.elapsed(); + if elapsed >= timeout { + // Move to IdleTimeout and end + // the loop + status.state = + ConnState::IdleTimeout; + break; + } + Some(timeout - elapsed) + } + else { + panic!("Idle state but no timeout"); + } + } + ConnState::IdleTimeout | + ConnState::ReadError | + ConnState::WriteError => + None, // No timers here + ConnState::ReadTimeout => panic!( + "should not be in loop with ReadTimeout") } - } - ConnState::IdleTimeout | - ConnState::ReadError | - ConnState::WriteError => - (), // No timers here - ConnState::ReadTimeout => panic!( - "should not be in loop with ReadTimeout") - } - drop(status); - + // drop(status); + }; // For simplicity, make sure we always have a timeout let timeout = match opt_timeout { @@ -533,12 +547,10 @@ impl InnerTcpConnection { _ = sleep_fut => { // Timeout expired, just // continue with the loop - () } _ = notify_fut => { // Got notifies, go through the loop // to see what changed. - () } } @@ -563,9 +575,13 @@ impl InnerTcpConnection { // We can't see a FIN directly because the writer_fut owns // write_stream. - let mut status = self.status.lock().unwrap(); - status.do_shutdown = true; - drop(status); + { + // Extra block to satisfy clippy + // await_holding_lock + let mut status = self.status.lock().unwrap(); + status.do_shutdown = true; + // drop(status); + }; // Kick writer self.writer_notify.notify_one(); @@ -576,12 +592,16 @@ impl InnerTcpConnection { // Stay around until the last query result is collected loop { - let query_vec = self.query_vec.lock().unwrap(); - if query_vec.count == 0 { - // We are done - break; + { + // Extra block to satisfy clippy + // await_holding_lock + let query_vec = self.query_vec.lock().unwrap(); + if query_vec.count == 0 { + // We are done + break; + } + // drop(query_vec); } - drop(query_vec); self.worker_notify.notified().await; } @@ -591,8 +611,8 @@ impl InnerTcpConnection { fn insert_at(query_vec: &mut Queries, index: usize, q: Option) { query_vec.vec[index] = q; - query_vec.count = query_vec.count + 1; - query_vec.busy = query_vec.busy + 1; + query_vec.count += 1; + query_vec.busy += 1; query_vec.curr = index + 1; } @@ -622,8 +642,8 @@ impl InnerTcpConnection { u16::MAX.into() { // Just append query_vec.vec.push(q); - query_vec.count = query_vec.count + 1; - query_vec.busy = query_vec.busy + 1; + query_vec.count += 1; + query_vec.busy += 1; let index = query_vec.vec.len()-1; return Ok(index); } @@ -633,7 +653,6 @@ impl InnerTcpConnection { match query_vec.vec[index] { Some(_) => { // Already in use, just continue - () } None => { Self::insert_at(&mut query_vec, @@ -649,7 +668,6 @@ impl InnerTcpConnection { match query_vec.vec[index] { Some(_) => { // Already in use, just continue - () } None => { Self::insert_at(&mut query_vec, @@ -687,17 +705,15 @@ impl InnerTcpConnection { match status.state { ConnState::Active(timer) => { // Set timer if we don't have one already - if timer == None { + if timer.is_none() { status.state = ConnState::Active(Some( Instant::now())); } - () } ConnState::Idle(_) => { // Go back to active status.state = ConnState::Active(Some( Instant::now())); - () } ConnState::IdleTimeout => { // The connection has been closed. Report error @@ -776,10 +792,15 @@ impl InnerTcpConnection { query_msg: &Message, index: usize) -> Result, Arc> { // Wait for reply - let mut query_vec = self.query_vec.lock().unwrap(); - let local_notify = query_vec.vec[index].as_mut(). - unwrap().complete.clone(); - drop(query_vec); + let local_notify = { + // Extra block to satisfy clippy + // await_holding_lock + let mut query_vec = self.query_vec.lock() + .unwrap(); + query_vec.vec[index].as_mut(). + unwrap().complete.clone() + // drop(query_vec); + }; local_notify.notified().await; // take a look @@ -819,7 +840,6 @@ impl InnerTcpConnection { query.state = SingleQueryState:: Canceled; - return; } SingleQueryState::Canceled => { panic!("Already canceled"); From 6d004b6a0b242ff023605a30013e8fc2429bd9a8 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 17 Apr 2023 09:21:05 +0200 Subject: [PATCH 013/124] A channel based implementation of a TCP transport. --- Cargo.toml | 2 +- src/net/client/mod.rs | 1 + src/net/client/tcp_channel.rs | 840 ++++++++++++++++++++++++++++++++++ 3 files changed, 842 insertions(+), 1 deletion(-) create mode 100644 src/net/client/tcp_channel.rs diff --git a/Cargo.toml b/Cargo.toml index 331076df1..120da3aa8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ serde = ["dep:serde", "octseq/serde"] sign = ["std"] smallvec = ["dep:smallvec", "octseq/smallvec"] std = [] -net = ["bytes", "tokio"] +net = ["bytes", "futures", "tokio"] tsig = ["bytes", "ring", "smallvec"] validate = ["std", "ring"] zonefile = ["bytes", "std"] diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index c4588996e..18e3d282c 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -3,3 +3,4 @@ #![cfg_attr(docsrs, doc(cfg(feature = "net")))] pub mod tcp_mutex; +pub mod tcp_channel; diff --git a/src/net/client/tcp_channel.rs b/src/net/client/tcp_channel.rs new file mode 100644 index 000000000..5a849e28c --- /dev/null +++ b/src/net/client/tcp_channel.rs @@ -0,0 +1,840 @@ +//! A DNS over TCP transport + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +// RFC 7766 describes DNS over TCP +// RFC 7828 describes the edns-tcp-keepalive option + +// TODO: +// - errors +// - connect errors? Retry after connection refused? +// - server errors +// - ID out of range +// - ID not in use +// - reply for wrong query +// - timeouts +// - request timeout +// - create new TCP connection after end/failure of previous one + +use core::convert::From; +use std::fmt::Debug; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use std::vec::Vec; +use futures::lock::Mutex as Futures_mutex; +use bytes; +use bytes::{Bytes, BytesMut}; + +use crate::base::{Message, MessageBuilder, opt::{AllOptData, OptRecord, + TcpKeepalive}, StaticCompressor, StreamTarget}; +use crate::base::wire::Composer; +use octseq::{Octets, OctetsBuilder}; + +use tokio::io; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpStream, ToSocketAddrs}; +use tokio::net::tcp::ReadHalf; +use tokio::sync::{mpsc, oneshot}; +use tokio::time::sleep; + +/// Error returned when too many queries are currently active. +const ERR_TOO_MANY_QUERIES: &str = "too many outstanding queries"; + +/// Constant from RFC 7828. How to convert the value on the +/// edns-tcp-keepalive option to milliseconds. +// This should go somewhere with the option parsing +const EDNS_TCP_KEEPALIE_TO_MS: u64 = 100; + +/// Time to wait on a non-idle TCP connection for the other side to send +/// a response on any outstanding query. +// Implement a simple response timer to see if the connection and the server +// are alive. Set the timer when the connection goes from idle to busy. +// Reset the timer each time a reply arrives. Cancel the timer when the +// connection goes back to idle. When the time expires, mark all outstanding +// queries as timed out and shutdown the connection. +// +// Note: nsd has 120 seconds, unbound has 3 seconds. +const RESPONSE_TIMEOUT: Duration = Duration::from_secs(19); + +/// Capacity of the channel that transports [ChanReq]. +const DEF_CHAN_CAP: usize = 8; + +/// Capacity of a private channel between [InnerTcpConnection::reader] and +/// [InnerTcpConnection::run]. +const READ_REPLY_CHAN_CAP: usize = 8; + +/// Error reported when the TCP connection is closed and +/// [InnerTcpConnection::run] terminated. +const ERR_CONN_CLOSED: &str = "connection closed"; + +/// This is the type of sender in [ChanReq]. +type ReplySender = oneshot::Sender; + +#[derive(Debug)] +/// A request from [Query] to [TcpConnection::run] to start a DNS request. +struct ChanReq { + /// DNS request message + msg: MessageBuilder>>, + + /// Sender to send result back to [Query] + sender: ReplySender, +} + +#[derive(Debug)] +/// a response to a [ChanReq]. +struct Response { + /// The 2 octet id that went into the outgoing DNS request. + /// + /// This id is needed to match the response with the query. + id: u16, + + /// The DNS reply message. + reply: Message, +} +/// Response to the DNS request sent by [InnerTcpConnection::run] to [Query]. +type ChanResp = Result>; + +/// The actual implementation of [TcpConnection]. +struct InnerTcpConnection { + /// TCP connection protected by a mutex to allow read/write access by + /// [InnerTcpConnection::run]. + stream: Futures_mutex, + + /// [InnerTcpConnection::sender] and [InnerTcpConnection::receiver] are + /// part of a single channel. + /// + /// Used by [Query] to send requests to [InnerTcpConnection::run]. + sender: mpsc::Sender>, + + /// receiver part of the channel. + /// + /// Protected by a mutex to allow read/write access by + /// [InnerTcpConnection::run]. + /// The Option is to allow [InnerTcpConnection::run] to signal that the + /// TCP connection is closed. + receiver: Futures_mutex>>>, +} + +/// Internal datastructure of [InnerTcpConnection::run] to keep track of +/// outstanding DNS requests. +struct Queries { + /// The number of elements in [Queries::vec] that are not None. + count: usize, + + /// Index in the [Queries::vec] where to look for a space for a new query. + curr: usize, + + /// Vector of senders to forward a DNS reply message (or error) to. + vec: Vec>, +} + +#[derive(Clone)] +/// A single DNS over TCP connection. +pub struct TcpConnection { + /// Reference counted [InnerTcpConnection]. + inner: Arc>, +} + +/// Status of a query. Used in [Query]. +enum QueryState { + /// A request is in progress. + /// + /// The receiver for receiving the response is part of this state. + Busy(oneshot::Receiver), + + /// The response has been received and the query is done. + Done, +} + +/// This struct represent an active DNS query. +pub struct Query { + /// Request message. + /// + /// The reply message is compared with the request message to see if + /// it matches the query. + query_msg: Message>, + + /// Current state of the query. + state: QueryState, +} + +/// Internal datastructure of [InnerTcpConnection::run] to keep track of +/// the status of the TCP connection. +// The types Status and ConnState are only used in InnerTcpConnection +struct Status { + /// State of the TCP connection. + state: ConnState, + + /// Boolean if we need to include an edns-tcp-keepalive option in an + /// outogoing request. + /// + /// Typically send_keepalive is true at the start of the connection. + /// it gets cleared when we successfully managed to include the option + /// in a request. + send_keepalive: bool, + + /// Time we are allow to keep the TCP connection open when idle. + /// + /// Initially we assume that the idle timeout is zero. A received + /// edns-tcp-keepalive option may change that. + idle_timeout: Option, +} +/// Status of the TCP connection. Used in [Status]. +enum ConnState { + /// The connection is in this state from the start and when at least + /// one active DNS request is present. + /// + /// The instant contains the time of the first request or the + /// most recent response that was received. + Active(Option), + + /// This state represent a TCP connection that went idle and has an + /// idle timeout. + /// + /// The instant contains the time the connection went idle. + Idle(Instant), + + /// This state represent an idle connection where either there was no + /// idle timeout or the idle timer expired. + IdleTimeout, + + /// A read error occurred. + ReadError, + + /// It took too long to receive a (or another) response. + ReadTimeout, + + /// A write error occurred. + WriteError, +} + +/// A DNS message received to [InnerTcpConnection::reader] and sent to +/// [InnerTcpConnection::run]. +// This type could be local to InnerTcpConnection, but I don't know how +type ReaderChanReply = Message; + +impl + Clone + Composer + Debug + OctetsBuilder> + InnerTcpConnection { + + /// Constructor for [InnerTcpConnection]. + /// + /// This is the implementation of [TcpConnection::connect]. + pub async fn connect(addr: A) -> + io::Result> { + let tcp = TcpStream::connect(addr).await?; + let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); + Ok(Self { + stream: Futures_mutex::new(tcp), + sender: tx, + receiver: Futures_mutex::new(Some(rx)), + }) + } + + /// Main execution function for [InnerTcpConnection]. + /// + /// This function Gets called by [TcpConnection::run]. + /// This function is not async cancellation safe + pub async fn run(&self) -> Option<()> { + let mut stream = self.stream.lock().await; + let (mut read_stream, mut write_stream) = stream.split(); + + let (reply_sender, mut reply_receiver) = + mpsc::channel::(READ_REPLY_CHAN_CAP); + + let reader_fut = Self::reader(&mut read_stream, reply_sender); + tokio::pin!(reader_fut); + + let mut receiver = { + let mut locked_opt_receiver = self.receiver.lock().await; + let opt_receiver = locked_opt_receiver.take(); + opt_receiver.expect("no receiver present?") + }; + + let mut status = Status { + state: ConnState::Active(None), + idle_timeout: None, + send_keepalive: true, + }; + let mut query_vec = Queries { + count: 0, + curr: 0, + vec: Vec::new() + }; + + let mut reqmsg: Option> = None; + + loop { + let opt_timeout = match status.state { + ConnState::Active(opt_instant) => { + if let Some(instant) = opt_instant { + let elapsed = instant.elapsed(); + if elapsed > RESPONSE_TIMEOUT { + let error = io::Error::new( + io::ErrorKind::Other, "read timeout"); + Self::tcp_error(error, &mut query_vec); + status.state = ConnState::ReadTimeout; + break; + } + Some(RESPONSE_TIMEOUT - elapsed) + } + else { None } + } + ConnState::Idle(instant) => { + if let Some(timeout) = &status.idle_timeout { + let elapsed = instant.elapsed(); + if elapsed >= *timeout { + // Move to IdleTimeout and end + // the loop + status.state = ConnState::IdleTimeout; + break; + } + Some(*timeout - elapsed) + } + else { + panic!("Idle state but no timeout"); + } + } + ConnState::IdleTimeout | + ConnState::ReadError | + ConnState::WriteError => + None, // No timers here + ConnState::ReadTimeout => panic!( + "should not be in loop with ReadTimeout") + }; + + // For simplicity, make sure we always have a timeout + let timeout = match opt_timeout { + Some(timeout) => timeout, + None => RESPONSE_TIMEOUT, + // Just use the response timeout + }; + + let sleep_fut = sleep(timeout); + let recv_fut = receiver.recv(); + + let (do_write, msg) = match &reqmsg { + None => { + let msg: &[u8] = &[]; + (false, msg) + } + Some(msg) => { + let msg: &[u8] = msg; + (true, msg) + } + }; + + tokio::select! { + biased; + res = &mut reader_fut => { + match res { + Ok(_) => + // The reader should not + // terminate without + // error. + panic!("reader terminated"), + Err(error) => { + Self::tcp_error(error, + &mut query_vec); + status.state = + ConnState::ReadError; + // Reader failed. Break + // out of loop and + // shut down + break + } + } + } + opt_answer = reply_receiver.recv() => { + let answer = opt_answer.expect("reader died?"); + // Check for a edns-tcp-keepalive option + let opt_record = answer.opt(); + if let Some(ref opts) = opt_record { + Self::handle_opts(opts, + &mut status); + }; + drop(opt_record); + Self::demux_reply(answer, + &mut status, &mut query_vec); + } + res = write_stream.write_all(msg), + if do_write => { + if let Err(error) = res { + Self::tcp_error(error, + &mut query_vec); + status.state = + ConnState::WriteError; + break; + } + else { + reqmsg = None; + } + } + res = recv_fut, if !do_write => { + match res { + Some(req) => + self.insert_req(req, &mut status, + &mut reqmsg, &mut query_vec), + None => panic!("recv failed"), + } + } + _ = sleep_fut => { + // Timeout expired, just + // continue with the loop + } + + } + + // Check if the connection is idle + match status.state { + ConnState::Active(_) | ConnState::Idle(_) => { + // Keep going + } + ConnState::IdleTimeout => { + break + } + ConnState::ReadError | + ConnState::ReadTimeout | + ConnState::WriteError => { + panic!("Should not be here"); + } + } + } + + // Send FIN + _ = write_stream.shutdown().await; + + None + } + + /// This function sends a DNS request to [InnerTcpConnection::run]. + pub async fn query + (&self, sender: oneshot::Sender, + query_msg: &mut MessageBuilder>> + ) -> Result<(), &'static str> { + + // We should figure out how to get query_msg. + let msg_clone = query_msg.clone(); + + let req = ChanReq { sender, msg: msg_clone }; + match self.sender.send(req).await { + Err(_) => + // Send error. The receiver is gone, this means that the + // connection is closed. + Err(ERR_CONN_CLOSED), + Ok(_) => Ok(()) + } + } + + /// This function reads a DNS message from the TCP connection and sends + /// it to [InnerTcpConnection::run]. + /// + /// Reading has to be done in two steps: first read a two octet value + /// the specifies the length of the message, and then read in a loop the + /// body of the message. + /// + /// This function is not async cancellation safe. + async fn reader(sock: &mut ReadHalf<'_>, + sender: mpsc::Sender) + -> Result<(), std::io::Error> { + loop { + let read_res = sock.read_u16().await; + let len = match read_res { + Ok(len) => len, + Err(error) => { + return Err(error); + } + } as usize; + + let mut buf = BytesMut::with_capacity(len); + + loop { + let curlen = buf.len(); + if curlen >= len { + if curlen > len { + panic!( + "reader: got too much data {curlen}, expetect {len}"); + } + + // We got what we need + break; + } + + let read_res = sock.read_buf(&mut buf).await; + + match read_res { + Ok(readlen) => { + if readlen == 0 { + let error = io::Error::new( + io::ErrorKind::Other, + "unexpected end of data"); + return Err(error); + } + } + Err(error) => { + return Err(error); + } + }; + + // Check if we are done at the head of the loop + } + + let reply_message = Message::::from_octets(buf.into()); + match reply_message { + Ok(answer) => { + sender.send(answer).await.expect("can't send reply to run"); + } + Err(_) => { + // The only possible error is short message + let error = io::Error::new(io::ErrorKind::Other, + "short buf"); + return Err(error); + } + } + } + } + + /// An error occured, report the error to all outstanding [Query] objects. + fn tcp_error(error: std::io::Error, query_vec: &mut Queries) { + // Update all requests that are in progress. Don't wait for + // any reply that may be on its way. + let arc_error = Arc::new(error); + for index in 0..query_vec.vec.len() { + if !query_vec.vec[index].is_none() { + let sender = Self::take_query(query_vec, index) + .expect("we tested is_none before"); + _ = sender.send(Err(arc_error.clone())); + } + } + } + + /// Handle received EDNS options, in particular the edns-tcp-keepalive + /// option. + fn handle_opts> + (opts: &OptRecord, status: &mut Status) { + for option in opts.iter() { + if let Ok(AllOptData::TcpKeepalive(tcpkeepalive)) = option { + Self::handle_keepalive(tcpkeepalive, status); + } + } + } + + /// Demultiplex a DNS reply and send it to the right [Query] object. + /// + /// In addition, the status is updated to IdleTimeout or Idle if there + /// are no remaining pending requests. + fn demux_reply(answer: Message, status: &mut Status, + query_vec: &mut Queries) { + // We got an answer, reset the timer + status.state = ConnState::Active(Some(Instant::now())); + + let ind16 = answer.header().id(); + let index: usize = ind16.into(); + + let vec_len = query_vec.vec.len(); + if index >= vec_len { + // Index is out of bouds. We should mark + // the TCP connection as broken + return; + } + + // Do we have a query with this ID? + match &mut query_vec.vec[index] { + None => { + // No query with this ID. We should + // mark the TCP connection as broken + return; + } + Some(_) => { + let sender = Self::take_query(query_vec, index).unwrap(); + let ind16: u16 = index.try_into().unwrap(); + let reply = Response { + reply: answer, + id: ind16, + }; + _ = sender.send(Ok(reply)); + } + } + if query_vec.count == 0 { + // Clear the activity timer. There is no need to do + // this because state will be set to either IdleTimeout + // or Idle just below. However, it is nicer to keep + // this independent. + status.state = ConnState::Active(None); + + status.state = if status.idle_timeout.is_none() { + // Assume that we can just move to IdleTimeout + // state + ConnState::IdleTimeout + } + else { + ConnState::Idle(Instant::now()) + } + } + } + + /// Insert a request in query_vec and return the request to be sent + /// in *reqmsg. + /// + /// First the status is checked, an error is returned if not Active or + /// idle. Addend a edns-tcp-keepalive option if needed. + // Note: maybe reqmsg should be a return value. + fn insert_req(&self, mut req: ChanReq, status: &mut Status, + reqmsg: &mut Option>, query_vec: &mut Queries) { + match status.state { + ConnState::Active(timer) => { + // Set timer if we don't have one already + if timer.is_none() { + status.state = ConnState::Active(Some(Instant::now())); + } + } + ConnState::Idle(_) => { + // Go back to active + status.state = ConnState::Active(Some(Instant::now())); + } + ConnState::IdleTimeout => { + // The connection has been closed. Report error + _ = req.sender.send(Err(Arc::new(io::Error::new( + io::ErrorKind::Other, "idle timeout")))); + return; + } + ConnState::ReadError => { + _ = req.sender.send(Err(Arc::new(io::Error::new( + io::ErrorKind::Other, "read error")))); + return; + } + ConnState::ReadTimeout => { + _ = req.sender.send(Err(Arc::new(io::Error::new( + io::ErrorKind::Other, "read timeout")))); + return; + } + ConnState::WriteError => { + _ = req.sender.send(Err(Arc::new(io::Error::new( + io::ErrorKind::Other, "write error")))); + return; + } + } + + // Note that insert may fail if there are too many + // outstanding queires. First call insert before checking + // send_keepalive. + // XXX + let index = Self::insert(req.sender, query_vec).unwrap(); + + let ind16: u16 = index.try_into().unwrap(); + + // We set the ID to the array index. Defense in depth + // suggests that a random ID is better because it works + // even if TCP sequence numbers could be predicted. However, + // Section 9.3 of RFC 5452 recommends retrying over TCP + // if many spoofed answers arrive over UDP: "TCP, by the + // nature of its use of sequence numbers, is far more + // resilient against forgery by third parties." + let hdr = req.msg.header_mut(); + hdr.set_id(ind16); + + if status.send_keepalive { + let mut msgadd = req.msg.clone().additional(); + + // send an empty keepalive option + let res = msgadd.opt(|opt| { opt.tcp_keepalive(None) }); + match res { + Ok(_) => { + Self::convert_query(&msgadd, reqmsg); + status.send_keepalive = false; + } + Err(_) => { + // Adding keepalive option + // failed. Send the original + // request. + Self::convert_query(&req.msg, reqmsg); + } + } + } else { + Self::convert_query(&req.msg, reqmsg); + } + } + + /// Take an element out of query_vec. + fn take_query(query_vec: &mut Queries, index: usize) + -> Option { + let query = query_vec.vec[index].take(); + query_vec.count -= 1; + query + } + + /// Handle a received edns-tcp-keepalive option. + fn handle_keepalive(opt_value: TcpKeepalive, status: &mut Status) { + if let Some(value) = opt_value.timeout() { + status.idle_timeout = + Some(Duration::from_millis(u64::from(value) * + EDNS_TCP_KEEPALIE_TO_MS)); + } + } + + /// Convert the query message to a vector. + // This function should return the vector instead of storing it + // through a reference. + fn convert_query + AsRef<[u8]>> + (msg: &MessageBuilder>>, + reqmsg: &mut Option>) { + + let vec = msg.as_target().as_target().as_stream_slice(); + + // Store a clone of the request. That makes life easier + // and requests tend to be small + *reqmsg = Some(vec.to_vec()); + } + + /// Insert a sender (for the reply) in the query_vec and return the index. + fn insert(sender: oneshot::Sender, + query_vec: &mut Queries) -> Result { + let q = Some(sender); + + // Fail if there are to many entries already in this vector + // We cannot have more than u16::MAX entries because the + // index needs to fit in an u16. For efficiency we want to + // keep the vector half empty. So we return a failure if + // 2*count > u16::MAX + if 2*query_vec.count > u16::MAX.into() { + return Err(ERR_TOO_MANY_QUERIES); + } + + let vec_len = query_vec.vec.len(); + + // Append if the amount of empty space in the vector is less + // than half. But limit vec_len to u16::MAX + if vec_len < 2*(query_vec.count+1) && vec_len < + u16::MAX.into() { + // Just append + query_vec.vec.push(q); + query_vec.count += 1; + let index = query_vec.vec.len()-1; + return Ok(index); + } + let loc_curr = query_vec.curr; + + for index in loc_curr..vec_len { + if query_vec.vec[index].is_none() { + Self::insert_at(query_vec, index, q); + return Ok(index); + } + } + + // Nothing until the end of the vector. Try for the entire + // vector + for index in 0..vec_len { + if query_vec.vec[index].is_none() { + Self::insert_at(query_vec, index, q); + return Ok(index); + } + } + + // Still nothing, that is not good + panic!("insert failed"); + } + + /// Insert a sender at a specific position in query_vec and update + /// the statistics. + fn insert_at(query_vec: &mut Queries, index: usize, + q: Option) { + query_vec.vec[index] = q; + query_vec.count += 1; + query_vec.curr = index + 1; + } +} + +impl + AsRef<[u8]> + Clone + Composer + Debug + + OctetsBuilder> TcpConnection { + /// Constructor for [TcpConnection]. + /// + /// Takes an address ([ToSocketAddrs]) and + /// returns a [TcpConnection] wrapped in a [Result](io::Result). + pub async fn connect(addr: A) -> + io::Result> { + let tcpconnection = InnerTcpConnection::connect(addr).await?; + Ok(Self { inner: Arc::new(tcpconnection) }) + } + + /// Main execution function for [TcpConnection]. + /// + /// This function has to run in the background or together with + /// any calls to [query](Self::query) or [Query::get_result]. + pub async fn run(&self) -> Option<()> { + self.inner.run().await + } + + /// Start a DNS request. + /// + /// This function takes a precomposed message as a parameter and + /// returns a [Query] object wrapped in a [Result]. + pub async fn query + (&self, + query_msg: &mut MessageBuilder>>) + -> Result { + let (tx, rx) = oneshot::channel(); + self.inner.query(tx, query_msg).await?; + let msg = &query_msg.as_message(); + Ok(Query::new(msg, rx)) + } +} + + +impl Query { + /// Constructor for [Query], takes a DNS query and a receiver for the + /// reply. + fn new + (query_msg: &Message, receiver: oneshot::Receiver) + -> Query { + let msg_ref: &[u8] = query_msg.as_ref(); + let vec = msg_ref.to_vec(); + let msg = Message::from_octets(vec).unwrap(); + Self { + query_msg: msg, + state: QueryState::Busy(receiver), + } + } + + /// Get the result of a DNS query. + /// + /// This function returns the reply to a DNS query wrapped in a + /// [Result]. + pub async fn get_result(&mut self) -> + Result, Arc> { + match self.state { + QueryState::Busy(ref mut receiver) => { + let res = receiver.await; + self.state = QueryState::Done; + if res.is_err() { + // Assume receive error + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "receive error"))); + } + let res = res.unwrap(); + + // clippy seems to be wrong here. Replacing + // the following with 'res?;' doesn't work + #[allow(clippy::question_mark)] + if let Err(err) = res { + return Err(err); + } + + let resp = res.unwrap(); + let msg = resp.reply; + + let hdr = self.query_msg.header_mut(); + hdr.set_id(resp.id); + + if !msg.is_answer(&self.query_msg) { + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, "wrong answer"))); + } + Ok(msg) + } + QueryState::Done => { + panic!("Already done"); + } + } + } +} From 52aa7b0a2b903c9fac90d6f29a19156f671ed154 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 17 Apr 2023 10:24:45 +0200 Subject: [PATCH 014/124] Fix some clippy warnings --- src/net/client/tcp_channel.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/net/client/tcp_channel.rs b/src/net/client/tcp_channel.rs index 5a849e28c..625d31e65 100644 --- a/src/net/client/tcp_channel.rs +++ b/src/net/client/tcp_channel.rs @@ -500,7 +500,7 @@ impl + Clone + Composer + Debug + OctetsBuilder> // any reply that may be on its way. let arc_error = Arc::new(error); for index in 0..query_vec.vec.len() { - if !query_vec.vec[index].is_none() { + if query_vec.vec[index].is_some() { let sender = Self::take_query(query_vec, index) .expect("we tested is_none before"); _ = sender.send(Err(arc_error.clone())); @@ -512,8 +512,8 @@ impl + Clone + Composer + Debug + OctetsBuilder> /// option. fn handle_opts> (opts: &OptRecord, status: &mut Status) { - for option in opts.iter() { - if let Ok(AllOptData::TcpKeepalive(tcpkeepalive)) = option { + for option in opts.iter().flatten() { + if let AllOptData::TcpKeepalive(tcpkeepalive) = option { Self::handle_keepalive(tcpkeepalive, status); } } From 2ea565eeb1f32030c26f1683057791bbaaef711d Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 17 Apr 2023 16:53:11 +0200 Subject: [PATCH 015/124] Fix fmt errors. --- src/lib.rs | 2 +- src/net/client/mod.rs | 2 +- src/net/client/tcp_channel.rs | 1126 +++++++++++----------- src/net/client/tcp_mutex.rs | 1639 ++++++++++++++++----------------- 4 files changed, 1402 insertions(+), 1367 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d13e66273..8184d2d4e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -121,11 +121,11 @@ extern crate core; pub mod base; pub mod dep; +pub mod net; pub mod rdata; pub mod resolv; pub mod sign; pub mod test; -pub mod net; pub mod tsig; pub mod utils; pub mod validate; diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 18e3d282c..f24db07dc 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -2,5 +2,5 @@ #![cfg(feature = "net")] #![cfg_attr(docsrs, doc(cfg(feature = "net")))] -pub mod tcp_mutex; pub mod tcp_channel; +pub mod tcp_mutex; diff --git a/src/net/client/tcp_channel.rs b/src/net/client/tcp_channel.rs index 625d31e65..a51bc3139 100644 --- a/src/net/client/tcp_channel.rs +++ b/src/net/client/tcp_channel.rs @@ -17,31 +17,33 @@ // - request timeout // - create new TCP connection after end/failure of previous one +use bytes; +use bytes::{Bytes, BytesMut}; use core::convert::From; +use futures::lock::Mutex as Futures_mutex; use std::fmt::Debug; use std::sync::Arc; use std::time::{Duration, Instant}; use std::vec::Vec; -use futures::lock::Mutex as Futures_mutex; -use bytes; -use bytes::{Bytes, BytesMut}; -use crate::base::{Message, MessageBuilder, opt::{AllOptData, OptRecord, - TcpKeepalive}, StaticCompressor, StreamTarget}; use crate::base::wire::Composer; +use crate::base::{ + opt::{AllOptData, OptRecord, TcpKeepalive}, + Message, MessageBuilder, StaticCompressor, StreamTarget, +}; use octseq::{Octets, OctetsBuilder}; use tokio::io; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::net::tcp::ReadHalf; +use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::sync::{mpsc, oneshot}; use tokio::time::sleep; /// Error returned when too many queries are currently active. const ERR_TOO_MANY_QUERIES: &str = "too many outstanding queries"; -/// Constant from RFC 7828. How to convert the value on the +/// Constant from RFC 7828. How to convert the value on the /// edns-tcp-keepalive option to milliseconds. // This should go somewhere with the option parsing const EDNS_TCP_KEEPALIE_TO_MS: u64 = 100; @@ -116,7 +118,7 @@ struct InnerTcpConnection { receiver: Futures_mutex>>>, } -/// Internal datastructure of [InnerTcpConnection::run] to keep track of +/// Internal datastructure of [InnerTcpConnection::run] to keep track of /// outstanding DNS requests. struct Queries { /// The number of elements in [Queries::vec] that are not None. @@ -159,7 +161,7 @@ pub struct Query { state: QueryState, } -/// Internal datastructure of [InnerTcpConnection::run] to keep track of +/// Internal datastructure of [InnerTcpConnection::run] to keep track of /// the status of the TCP connection. // The types Status and ConnState are only used in InnerTcpConnection struct Status { @@ -177,7 +179,7 @@ struct Status { /// Time we are allow to keep the TCP connection open when idle. /// /// Initially we assume that the idle timeout is zero. A received - /// edns-tcp-keepalive option may change that. + /// edns-tcp-keepalive option may change that. idle_timeout: Option, } /// Status of the TCP connection. Used in [Status]. @@ -185,7 +187,7 @@ enum ConnState { /// The connection is in this state from the start and when at least /// one active DNS request is present. /// - /// The instant contains the time of the first request or the + /// The instant contains the time of the first request or the /// most recent response that was received. Active(Option), @@ -215,215 +217,224 @@ enum ConnState { type ReaderChanReply = Message; impl + Clone + Composer + Debug + OctetsBuilder> - InnerTcpConnection { - + InnerTcpConnection +{ /// Constructor for [InnerTcpConnection]. /// /// This is the implementation of [TcpConnection::connect]. - pub async fn connect(addr: A) -> - io::Result> { - let tcp = TcpStream::connect(addr).await?; - let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); - Ok(Self { - stream: Futures_mutex::new(tcp), - sender: tx, - receiver: Futures_mutex::new(Some(rx)), - }) + pub async fn connect( + addr: A, + ) -> io::Result> { + let tcp = TcpStream::connect(addr).await?; + let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); + Ok(Self { + stream: Futures_mutex::new(tcp), + sender: tx, + receiver: Futures_mutex::new(Some(rx)), + }) } /// Main execution function for [InnerTcpConnection]. /// /// This function Gets called by [TcpConnection::run]. - /// This function is not async cancellation safe + /// This function is not async cancellation safe pub async fn run(&self) -> Option<()> { - let mut stream = self.stream.lock().await; - let (mut read_stream, mut write_stream) = stream.split(); - - let (reply_sender, mut reply_receiver) = - mpsc::channel::(READ_REPLY_CHAN_CAP); - - let reader_fut = Self::reader(&mut read_stream, reply_sender); - tokio::pin!(reader_fut); - - let mut receiver = { - let mut locked_opt_receiver = self.receiver.lock().await; - let opt_receiver = locked_opt_receiver.take(); - opt_receiver.expect("no receiver present?") - }; - - let mut status = Status { - state: ConnState::Active(None), - idle_timeout: None, - send_keepalive: true, - }; - let mut query_vec = Queries { - count: 0, - curr: 0, - vec: Vec::new() - }; - - let mut reqmsg: Option> = None; - - loop { - let opt_timeout = match status.state { - ConnState::Active(opt_instant) => { - if let Some(instant) = opt_instant { - let elapsed = instant.elapsed(); - if elapsed > RESPONSE_TIMEOUT { - let error = io::Error::new( - io::ErrorKind::Other, "read timeout"); - Self::tcp_error(error, &mut query_vec); - status.state = ConnState::ReadTimeout; - break; - } - Some(RESPONSE_TIMEOUT - elapsed) - } - else { None } - } - ConnState::Idle(instant) => { - if let Some(timeout) = &status.idle_timeout { - let elapsed = instant.elapsed(); - if elapsed >= *timeout { - // Move to IdleTimeout and end - // the loop - status.state = ConnState::IdleTimeout; - break; - } - Some(*timeout - elapsed) - } - else { - panic!("Idle state but no timeout"); - } - } - ConnState::IdleTimeout | - ConnState::ReadError | - ConnState::WriteError => - None, // No timers here - ConnState::ReadTimeout => panic!( - "should not be in loop with ReadTimeout") - }; - - // For simplicity, make sure we always have a timeout - let timeout = match opt_timeout { - Some(timeout) => timeout, - None => RESPONSE_TIMEOUT, - // Just use the response timeout - }; - - let sleep_fut = sleep(timeout); - let recv_fut = receiver.recv(); - - let (do_write, msg) = match &reqmsg { - None => { - let msg: &[u8] = &[]; - (false, msg) - } - Some(msg) => { - let msg: &[u8] = msg; - (true, msg) - } - }; - - tokio::select! { - biased; - res = &mut reader_fut => { - match res { - Ok(_) => - // The reader should not - // terminate without - // error. - panic!("reader terminated"), - Err(error) => { - Self::tcp_error(error, - &mut query_vec); - status.state = - ConnState::ReadError; - // Reader failed. Break - // out of loop and - // shut down - break - } - } - } - opt_answer = reply_receiver.recv() => { - let answer = opt_answer.expect("reader died?"); - // Check for a edns-tcp-keepalive option - let opt_record = answer.opt(); - if let Some(ref opts) = opt_record { - Self::handle_opts(opts, - &mut status); - }; - drop(opt_record); - Self::demux_reply(answer, - &mut status, &mut query_vec); - } - res = write_stream.write_all(msg), - if do_write => { - if let Err(error) = res { - Self::tcp_error(error, - &mut query_vec); - status.state = - ConnState::WriteError; - break; - } - else { - reqmsg = None; - } - } - res = recv_fut, if !do_write => { - match res { - Some(req) => - self.insert_req(req, &mut status, - &mut reqmsg, &mut query_vec), - None => panic!("recv failed"), - } - } - _ = sleep_fut => { - // Timeout expired, just - // continue with the loop - } - - } - - // Check if the connection is idle - match status.state { - ConnState::Active(_) | ConnState::Idle(_) => { - // Keep going - } - ConnState::IdleTimeout => { - break - } - ConnState::ReadError | - ConnState::ReadTimeout | - ConnState::WriteError => { - panic!("Should not be here"); - } - } - } - - // Send FIN - _ = write_stream.shutdown().await; - - None + let mut stream = self.stream.lock().await; + let (mut read_stream, mut write_stream) = stream.split(); + + let (reply_sender, mut reply_receiver) = + mpsc::channel::(READ_REPLY_CHAN_CAP); + + let reader_fut = Self::reader(&mut read_stream, reply_sender); + tokio::pin!(reader_fut); + + let mut receiver = { + let mut locked_opt_receiver = self.receiver.lock().await; + let opt_receiver = locked_opt_receiver.take(); + opt_receiver.expect("no receiver present?") + }; + + let mut status = Status { + state: ConnState::Active(None), + idle_timeout: None, + send_keepalive: true, + }; + let mut query_vec = Queries { + count: 0, + curr: 0, + vec: Vec::new(), + }; + + let mut reqmsg: Option> = None; + + loop { + let opt_timeout = match status.state { + ConnState::Active(opt_instant) => { + if let Some(instant) = opt_instant { + let elapsed = instant.elapsed(); + if elapsed > RESPONSE_TIMEOUT { + let error = io::Error::new( + io::ErrorKind::Other, + "read timeout", + ); + Self::tcp_error(error, &mut query_vec); + status.state = ConnState::ReadTimeout; + break; + } + Some(RESPONSE_TIMEOUT - elapsed) + } else { + None + } + } + ConnState::Idle(instant) => { + if let Some(timeout) = &status.idle_timeout { + let elapsed = instant.elapsed(); + if elapsed >= *timeout { + // Move to IdleTimeout and end + // the loop + status.state = ConnState::IdleTimeout; + break; + } + Some(*timeout - elapsed) + } else { + panic!("Idle state but no timeout"); + } + } + ConnState::IdleTimeout + | ConnState::ReadError + | ConnState::WriteError => None, // No timers here + ConnState::ReadTimeout => { + panic!("should not be in loop with ReadTimeout"); + } + }; + + // For simplicity, make sure we always have a timeout + let timeout = match opt_timeout { + Some(timeout) => timeout, + None => + // Just use the response timeout + { + RESPONSE_TIMEOUT + } + }; + + let sleep_fut = sleep(timeout); + let recv_fut = receiver.recv(); + + let (do_write, msg) = match &reqmsg { + None => { + let msg: &[u8] = &[]; + (false, msg) + } + Some(msg) => { + let msg: &[u8] = msg; + (true, msg) + } + }; + + tokio::select! { + biased; + res = &mut reader_fut => { + match res { + Ok(_) => + // The reader should not + // terminate without + // error. + panic!("reader terminated"), + Err(error) => { + Self::tcp_error(error, + &mut query_vec); + status.state = + ConnState::ReadError; + // Reader failed. Break + // out of loop and + // shut down + break + } + } + } + opt_answer = reply_receiver.recv() => { + let answer = opt_answer.expect("reader died?"); + // Check for a edns-tcp-keepalive option + let opt_record = answer.opt(); + if let Some(ref opts) = opt_record { + Self::handle_opts(opts, + &mut status); + }; + drop(opt_record); + Self::demux_reply(answer, + &mut status, &mut query_vec); + } + res = write_stream.write_all(msg), + if do_write => { + if let Err(error) = res { + Self::tcp_error(error, + &mut query_vec); + status.state = + ConnState::WriteError; + break; + } + else { + reqmsg = None; + } + } + res = recv_fut, if !do_write => { + match res { + Some(req) => + self.insert_req(req, &mut status, + &mut reqmsg, &mut query_vec), + None => panic!("recv failed"), + } + } + _ = sleep_fut => { + // Timeout expired, just + // continue with the loop + } + + } + + // Check if the connection is idle + match status.state { + ConnState::Active(_) | ConnState::Idle(_) => { + // Keep going + } + ConnState::IdleTimeout => break, + ConnState::ReadError + | ConnState::ReadTimeout + | ConnState::WriteError => { + panic!("Should not be here"); + } + } + } + + // Send FIN + _ = write_stream.shutdown().await; + + None } /// This function sends a DNS request to [InnerTcpConnection::run]. - pub async fn query - (&self, sender: oneshot::Sender, - query_msg: &mut MessageBuilder>> - ) -> Result<(), &'static str> { - - // We should figure out how to get query_msg. - let msg_clone = query_msg.clone(); - - let req = ChanReq { sender, msg: msg_clone }; - match self.sender.send(req).await { - Err(_) => - // Send error. The receiver is gone, this means that the - // connection is closed. - Err(ERR_CONN_CLOSED), - Ok(_) => Ok(()) - } + pub async fn query( + &self, + sender: oneshot::Sender, + query_msg: &mut MessageBuilder>>, + ) -> Result<(), &'static str> { + // We should figure out how to get query_msg. + let msg_clone = query_msg.clone(); + + let req = ChanReq { + sender, + msg: msg_clone, + }; + match self.sender.send(req).await { + Err(_) => + // Send error. The receiver is gone, this means that the + // connection is closed. + { + Err(ERR_CONN_CLOSED) + } + Ok(_) => Ok(()), + } } /// This function reads a DNS message from the TCP connection and sends @@ -431,146 +442,155 @@ impl + Clone + Composer + Debug + OctetsBuilder> /// /// Reading has to be done in two steps: first read a two octet value /// the specifies the length of the message, and then read in a loop the - /// body of the message. + /// body of the message. /// /// This function is not async cancellation safe. - async fn reader(sock: &mut ReadHalf<'_>, - sender: mpsc::Sender) - -> Result<(), std::io::Error> { - loop { - let read_res = sock.read_u16().await; - let len = match read_res { - Ok(len) => len, - Err(error) => { - return Err(error); - } - } as usize; - - let mut buf = BytesMut::with_capacity(len); - - loop { - let curlen = buf.len(); - if curlen >= len { - if curlen > len { - panic!( - "reader: got too much data {curlen}, expetect {len}"); - } - - // We got what we need - break; - } - - let read_res = sock.read_buf(&mut buf).await; - - match read_res { - Ok(readlen) => { - if readlen == 0 { - let error = io::Error::new( - io::ErrorKind::Other, - "unexpected end of data"); - return Err(error); - } - } - Err(error) => { - return Err(error); - } - }; - - // Check if we are done at the head of the loop - } - - let reply_message = Message::::from_octets(buf.into()); - match reply_message { - Ok(answer) => { - sender.send(answer).await.expect("can't send reply to run"); - } - Err(_) => { - // The only possible error is short message - let error = io::Error::new(io::ErrorKind::Other, - "short buf"); - return Err(error); - } - } - } + async fn reader( + sock: &mut ReadHalf<'_>, + sender: mpsc::Sender, + ) -> Result<(), std::io::Error> { + loop { + let read_res = sock.read_u16().await; + let len = match read_res { + Ok(len) => len, + Err(error) => { + return Err(error); + } + } as usize; + + let mut buf = BytesMut::with_capacity(len); + + loop { + let curlen = buf.len(); + if curlen >= len { + if curlen > len { + panic!( + "reader: got too much data {curlen}, expetect {len}"); + } + + // We got what we need + break; + } + + let read_res = sock.read_buf(&mut buf).await; + + match read_res { + Ok(readlen) => { + if readlen == 0 { + let error = io::Error::new( + io::ErrorKind::Other, + "unexpected end of data", + ); + return Err(error); + } + } + Err(error) => { + return Err(error); + } + }; + + // Check if we are done at the head of the loop + } + + let reply_message = Message::::from_octets(buf.into()); + match reply_message { + Ok(answer) => { + sender + .send(answer) + .await + .expect("can't send reply to run"); + } + Err(_) => { + // The only possible error is short message + let error = + io::Error::new(io::ErrorKind::Other, "short buf"); + return Err(error); + } + } + } } /// An error occured, report the error to all outstanding [Query] objects. fn tcp_error(error: std::io::Error, query_vec: &mut Queries) { - // Update all requests that are in progress. Don't wait for - // any reply that may be on its way. - let arc_error = Arc::new(error); - for index in 0..query_vec.vec.len() { - if query_vec.vec[index].is_some() { - let sender = Self::take_query(query_vec, index) - .expect("we tested is_none before"); - _ = sender.send(Err(arc_error.clone())); - } - } + // Update all requests that are in progress. Don't wait for + // any reply that may be on its way. + let arc_error = Arc::new(error); + for index in 0..query_vec.vec.len() { + if query_vec.vec[index].is_some() { + let sender = Self::take_query(query_vec, index) + .expect("we tested is_none before"); + _ = sender.send(Err(arc_error.clone())); + } + } } /// Handle received EDNS options, in particular the edns-tcp-keepalive /// option. - fn handle_opts> - (opts: &OptRecord, status: &mut Status) { - for option in opts.iter().flatten() { - if let AllOptData::TcpKeepalive(tcpkeepalive) = option { - Self::handle_keepalive(tcpkeepalive, status); - } - } + fn handle_opts>( + opts: &OptRecord, + status: &mut Status, + ) { + for option in opts.iter().flatten() { + if let AllOptData::TcpKeepalive(tcpkeepalive) = option { + Self::handle_keepalive(tcpkeepalive, status); + } + } } /// Demultiplex a DNS reply and send it to the right [Query] object. /// /// In addition, the status is updated to IdleTimeout or Idle if there /// are no remaining pending requests. - fn demux_reply(answer: Message, status: &mut Status, - query_vec: &mut Queries) { - // We got an answer, reset the timer - status.state = ConnState::Active(Some(Instant::now())); - - let ind16 = answer.header().id(); - let index: usize = ind16.into(); - - let vec_len = query_vec.vec.len(); - if index >= vec_len { - // Index is out of bouds. We should mark - // the TCP connection as broken - return; - } - - // Do we have a query with this ID? - match &mut query_vec.vec[index] { - None => { - // No query with this ID. We should - // mark the TCP connection as broken - return; - } - Some(_) => { - let sender = Self::take_query(query_vec, index).unwrap(); - let ind16: u16 = index.try_into().unwrap(); - let reply = Response { - reply: answer, - id: ind16, - }; - _ = sender.send(Ok(reply)); - } - } - if query_vec.count == 0 { - // Clear the activity timer. There is no need to do - // this because state will be set to either IdleTimeout - // or Idle just below. However, it is nicer to keep - // this independent. - status.state = ConnState::Active(None); - - status.state = if status.idle_timeout.is_none() { - // Assume that we can just move to IdleTimeout - // state - ConnState::IdleTimeout - } - else { - ConnState::Idle(Instant::now()) - } - } + fn demux_reply( + answer: Message, + status: &mut Status, + query_vec: &mut Queries, + ) { + // We got an answer, reset the timer + status.state = ConnState::Active(Some(Instant::now())); + + let ind16 = answer.header().id(); + let index: usize = ind16.into(); + + let vec_len = query_vec.vec.len(); + if index >= vec_len { + // Index is out of bouds. We should mark + // the TCP connection as broken + return; + } + + // Do we have a query with this ID? + match &mut query_vec.vec[index] { + None => { + // No query with this ID. We should + // mark the TCP connection as broken + return; + } + Some(_) => { + let sender = Self::take_query(query_vec, index).unwrap(); + let ind16: u16 = index.try_into().unwrap(); + let reply = Response { + reply: answer, + id: ind16, + }; + _ = sender.send(Ok(reply)); + } + } + if query_vec.count == 0 { + // Clear the activity timer. There is no need to do + // this because state will be set to either IdleTimeout + // or Idle just below. However, it is nicer to keep + // this independent. + status.state = ConnState::Active(None); + + status.state = if status.idle_timeout.is_none() { + // Assume that we can just move to IdleTimeout + // state + ConnState::IdleTimeout + } else { + ConnState::Idle(Instant::now()) + } + } } /// Insert a request in query_vec and return the request to be sent @@ -579,181 +599,205 @@ impl + Clone + Composer + Debug + OctetsBuilder> /// First the status is checked, an error is returned if not Active or /// idle. Addend a edns-tcp-keepalive option if needed. // Note: maybe reqmsg should be a return value. - fn insert_req(&self, mut req: ChanReq, status: &mut Status, - reqmsg: &mut Option>, query_vec: &mut Queries) { - match status.state { - ConnState::Active(timer) => { - // Set timer if we don't have one already - if timer.is_none() { - status.state = ConnState::Active(Some(Instant::now())); - } - } - ConnState::Idle(_) => { - // Go back to active - status.state = ConnState::Active(Some(Instant::now())); - } - ConnState::IdleTimeout => { - // The connection has been closed. Report error - _ = req.sender.send(Err(Arc::new(io::Error::new( - io::ErrorKind::Other, "idle timeout")))); - return; - } - ConnState::ReadError => { - _ = req.sender.send(Err(Arc::new(io::Error::new( - io::ErrorKind::Other, "read error")))); - return; - } - ConnState::ReadTimeout => { - _ = req.sender.send(Err(Arc::new(io::Error::new( - io::ErrorKind::Other, "read timeout")))); - return; - } - ConnState::WriteError => { - _ = req.sender.send(Err(Arc::new(io::Error::new( - io::ErrorKind::Other, "write error")))); - return; - } - } - - // Note that insert may fail if there are too many - // outstanding queires. First call insert before checking - // send_keepalive. - // XXX - let index = Self::insert(req.sender, query_vec).unwrap(); - - let ind16: u16 = index.try_into().unwrap(); - - // We set the ID to the array index. Defense in depth - // suggests that a random ID is better because it works - // even if TCP sequence numbers could be predicted. However, - // Section 9.3 of RFC 5452 recommends retrying over TCP - // if many spoofed answers arrive over UDP: "TCP, by the - // nature of its use of sequence numbers, is far more - // resilient against forgery by third parties." - let hdr = req.msg.header_mut(); - hdr.set_id(ind16); - - if status.send_keepalive { - let mut msgadd = req.msg.clone().additional(); - - // send an empty keepalive option - let res = msgadd.opt(|opt| { opt.tcp_keepalive(None) }); - match res { - Ok(_) => { - Self::convert_query(&msgadd, reqmsg); - status.send_keepalive = false; - } - Err(_) => { - // Adding keepalive option - // failed. Send the original - // request. - Self::convert_query(&req.msg, reqmsg); - } - } - } else { - Self::convert_query(&req.msg, reqmsg); - } + fn insert_req( + &self, + mut req: ChanReq, + status: &mut Status, + reqmsg: &mut Option>, + query_vec: &mut Queries, + ) { + match status.state { + ConnState::Active(timer) => { + // Set timer if we don't have one already + if timer.is_none() { + status.state = ConnState::Active(Some(Instant::now())); + } + } + ConnState::Idle(_) => { + // Go back to active + status.state = ConnState::Active(Some(Instant::now())); + } + ConnState::IdleTimeout => { + // The connection has been closed. Report error + _ = req.sender.send(Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "idle timeout", + )))); + return; + } + ConnState::ReadError => { + _ = req.sender.send(Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "read error", + )))); + return; + } + ConnState::ReadTimeout => { + _ = req.sender.send(Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "read timeout", + )))); + return; + } + ConnState::WriteError => { + _ = req.sender.send(Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "write error", + )))); + return; + } + } + + // Note that insert may fail if there are too many + // outstanding queires. First call insert before checking + // send_keepalive. + // XXX + let index = Self::insert(req.sender, query_vec).unwrap(); + + let ind16: u16 = index.try_into().unwrap(); + + // We set the ID to the array index. Defense in depth + // suggests that a random ID is better because it works + // even if TCP sequence numbers could be predicted. However, + // Section 9.3 of RFC 5452 recommends retrying over TCP + // if many spoofed answers arrive over UDP: "TCP, by the + // nature of its use of sequence numbers, is far more + // resilient against forgery by third parties." + let hdr = req.msg.header_mut(); + hdr.set_id(ind16); + + if status.send_keepalive { + let mut msgadd = req.msg.clone().additional(); + + // send an empty keepalive option + let res = msgadd.opt(|opt| opt.tcp_keepalive(None)); + match res { + Ok(_) => { + Self::convert_query(&msgadd, reqmsg); + status.send_keepalive = false; + } + Err(_) => { + // Adding keepalive option + // failed. Send the original + // request. + Self::convert_query(&req.msg, reqmsg); + } + } + } else { + Self::convert_query(&req.msg, reqmsg); + } } /// Take an element out of query_vec. - fn take_query(query_vec: &mut Queries, index: usize) - -> Option { - let query = query_vec.vec[index].take(); - query_vec.count -= 1; - query + fn take_query( + query_vec: &mut Queries, + index: usize, + ) -> Option { + let query = query_vec.vec[index].take(); + query_vec.count -= 1; + query } /// Handle a received edns-tcp-keepalive option. fn handle_keepalive(opt_value: TcpKeepalive, status: &mut Status) { - if let Some(value) = opt_value.timeout() { - status.idle_timeout = - Some(Duration::from_millis(u64::from(value) * - EDNS_TCP_KEEPALIE_TO_MS)); - } + if let Some(value) = opt_value.timeout() { + status.idle_timeout = Some(Duration::from_millis( + u64::from(value) * EDNS_TCP_KEEPALIE_TO_MS, + )); + } } /// Convert the query message to a vector. - // This function should return the vector instead of storing it + // This function should return the vector instead of storing it // through a reference. - fn convert_query + AsRef<[u8]>> - (msg: &MessageBuilder>>, - reqmsg: &mut Option>) { - - let vec = msg.as_target().as_target().as_stream_slice(); - - // Store a clone of the request. That makes life easier - // and requests tend to be small - *reqmsg = Some(vec.to_vec()); + fn convert_query + AsRef<[u8]>>( + msg: &MessageBuilder>>, + reqmsg: &mut Option>, + ) { + let vec = msg.as_target().as_target().as_stream_slice(); + + // Store a clone of the request. That makes life easier + // and requests tend to be small + *reqmsg = Some(vec.to_vec()); } /// Insert a sender (for the reply) in the query_vec and return the index. - fn insert(sender: oneshot::Sender, - query_vec: &mut Queries) -> Result { - let q = Some(sender); - - // Fail if there are to many entries already in this vector - // We cannot have more than u16::MAX entries because the - // index needs to fit in an u16. For efficiency we want to - // keep the vector half empty. So we return a failure if - // 2*count > u16::MAX - if 2*query_vec.count > u16::MAX.into() { - return Err(ERR_TOO_MANY_QUERIES); - } - - let vec_len = query_vec.vec.len(); - - // Append if the amount of empty space in the vector is less - // than half. But limit vec_len to u16::MAX - if vec_len < 2*(query_vec.count+1) && vec_len < - u16::MAX.into() { - // Just append - query_vec.vec.push(q); - query_vec.count += 1; - let index = query_vec.vec.len()-1; - return Ok(index); - } - let loc_curr = query_vec.curr; - - for index in loc_curr..vec_len { - if query_vec.vec[index].is_none() { - Self::insert_at(query_vec, index, q); - return Ok(index); - } - } - - // Nothing until the end of the vector. Try for the entire - // vector - for index in 0..vec_len { - if query_vec.vec[index].is_none() { - Self::insert_at(query_vec, index, q); - return Ok(index); - } - } - - // Still nothing, that is not good - panic!("insert failed"); + fn insert( + sender: oneshot::Sender, + query_vec: &mut Queries, + ) -> Result { + let q = Some(sender); + + // Fail if there are to many entries already in this vector + // We cannot have more than u16::MAX entries because the + // index needs to fit in an u16. For efficiency we want to + // keep the vector half empty. So we return a failure if + // 2*count > u16::MAX + if 2 * query_vec.count > u16::MAX.into() { + return Err(ERR_TOO_MANY_QUERIES); + } + + let vec_len = query_vec.vec.len(); + + // Append if the amount of empty space in the vector is less + // than half. But limit vec_len to u16::MAX + if vec_len < 2 * (query_vec.count + 1) && vec_len < u16::MAX.into() { + // Just append + query_vec.vec.push(q); + query_vec.count += 1; + let index = query_vec.vec.len() - 1; + return Ok(index); + } + let loc_curr = query_vec.curr; + + for index in loc_curr..vec_len { + if query_vec.vec[index].is_none() { + Self::insert_at(query_vec, index, q); + return Ok(index); + } + } + + // Nothing until the end of the vector. Try for the entire + // vector + for index in 0..vec_len { + if query_vec.vec[index].is_none() { + Self::insert_at(query_vec, index, q); + return Ok(index); + } + } + + // Still nothing, that is not good + panic!("insert failed"); } /// Insert a sender at a specific position in query_vec and update /// the statistics. - fn insert_at(query_vec: &mut Queries, index: usize, - q: Option) { - query_vec.vec[index] = q; - query_vec.count += 1; - query_vec.curr = index + 1; + fn insert_at( + query_vec: &mut Queries, + index: usize, + q: Option, + ) { + query_vec.vec[index] = q; + query_vec.count += 1; + query_vec.curr = index + 1; } } -impl + AsRef<[u8]> + Clone + Composer + Debug + - OctetsBuilder> TcpConnection { +impl< + Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Composer + Debug + OctetsBuilder, + > TcpConnection +{ /// Constructor for [TcpConnection]. /// /// Takes an address ([ToSocketAddrs]) and /// returns a [TcpConnection] wrapped in a [Result](io::Result). - pub async fn connect(addr: A) -> - io::Result> { - let tcpconnection = InnerTcpConnection::connect(addr).await?; - Ok(Self { inner: Arc::new(tcpconnection) }) + pub async fn connect( + addr: A, + ) -> io::Result> { + let tcpconnection = InnerTcpConnection::connect(addr).await?; + Ok(Self { + inner: Arc::new(tcpconnection), + }) } /// Main execution function for [TcpConnection]. @@ -761,80 +805,84 @@ impl + AsRef<[u8]> + Clone + Composer + Debug + /// This function has to run in the background or together with /// any calls to [query](Self::query) or [Query::get_result]. pub async fn run(&self) -> Option<()> { - self.inner.run().await + self.inner.run().await } /// Start a DNS request. /// /// This function takes a precomposed message as a parameter and /// returns a [Query] object wrapped in a [Result]. - pub async fn query - (&self, - query_msg: &mut MessageBuilder>>) - -> Result { - let (tx, rx) = oneshot::channel(); - self.inner.query(tx, query_msg).await?; - let msg = &query_msg.as_message(); - Ok(Query::new(msg, rx)) + pub async fn query( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result { + let (tx, rx) = oneshot::channel(); + self.inner.query(tx, query_msg).await?; + let msg = &query_msg.as_message(); + Ok(Query::new(msg, rx)) } } - impl Query { /// Constructor for [Query], takes a DNS query and a receiver for the /// reply. - fn new - (query_msg: &Message, receiver: oneshot::Receiver) - -> Query { - let msg_ref: &[u8] = query_msg.as_ref(); - let vec = msg_ref.to_vec(); - let msg = Message::from_octets(vec).unwrap(); - Self { - query_msg: msg, - state: QueryState::Busy(receiver), - } + fn new( + query_msg: &Message, + receiver: oneshot::Receiver, + ) -> Query { + let msg_ref: &[u8] = query_msg.as_ref(); + let vec = msg_ref.to_vec(); + let msg = Message::from_octets(vec).unwrap(); + Self { + query_msg: msg, + state: QueryState::Busy(receiver), + } } /// Get the result of a DNS query. /// /// This function returns the reply to a DNS query wrapped in a /// [Result]. - pub async fn get_result(&mut self) -> - Result, Arc> { - match self.state { - QueryState::Busy(ref mut receiver) => { - let res = receiver.await; - self.state = QueryState::Done; - if res.is_err() { - // Assume receive error - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "receive error"))); - } - let res = res.unwrap(); - - // clippy seems to be wrong here. Replacing - // the following with 'res?;' doesn't work - #[allow(clippy::question_mark)] - if let Err(err) = res { - return Err(err); - } - - let resp = res.unwrap(); - let msg = resp.reply; - - let hdr = self.query_msg.header_mut(); - hdr.set_id(resp.id); - - if !msg.is_answer(&self.query_msg) { - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, "wrong answer"))); - } - Ok(msg) - } - QueryState::Done => { - panic!("Already done"); - } - } + pub async fn get_result( + &mut self, + ) -> Result, Arc> { + match self.state { + QueryState::Busy(ref mut receiver) => { + let res = receiver.await; + self.state = QueryState::Done; + if res.is_err() { + // Assume receive error + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "receive error", + ))); + } + let res = res.unwrap(); + + // clippy seems to be wrong here. Replacing + // the following with 'res?;' doesn't work + #[allow(clippy::question_mark)] + if let Err(err) = res { + return Err(err); + } + + let resp = res.unwrap(); + let msg = resp.reply; + + let hdr = self.query_msg.header_mut(); + hdr.set_id(resp.id); + + if !msg.is_answer(&self.query_msg) { + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "wrong answer", + ))); + } + Ok(msg) + } + QueryState::Done => { + panic!("Already done"); + } + } } } diff --git a/src/net/client/tcp_mutex.rs b/src/net/client/tcp_mutex.rs index 64971de22..226e4ae58 100644 --- a/src/net/client/tcp_mutex.rs +++ b/src/net/client/tcp_mutex.rs @@ -15,23 +15,25 @@ // - limit number of outstanding queries to 32K // - create new TCP connection after end/failure of previous one +use bytes::{Bytes, BytesMut}; use std::collections::VecDeque; use std::ops::DerefMut; use std::sync::Arc; use std::sync::Mutex as Std_mutex; use std::time::{Duration, Instant}; use std::vec::Vec; -use bytes::{Bytes, BytesMut}; -use crate::base::{Message, MessageBuilder, opt::{AllOptData, OptRecord, - TcpKeepalive}, StaticCompressor, StreamTarget}; use crate::base::wire::Composer; +use crate::base::{ + opt::{AllOptData, OptRecord, TcpKeepalive}, + Message, MessageBuilder, StaticCompressor, StreamTarget, +}; use octseq::{Octets, OctetsBuilder}; use tokio::io; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::net::tcp::{ReadHalf, WriteHalf}; +use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::sync::Notify; use tokio::time::sleep; @@ -54,869 +56,854 @@ const EDNS_TCP_KEEPALIE_TO_MS: u64 = 100; const RESPONSE_TIMEOUT_S: u64 = 19; enum SingleQueryState { - Busy, - Done(Result, Arc>), - Canceled, + Busy, + Done(Result, Arc>), + Canceled, } struct SingleQuery { - state: SingleQueryState, - complete: Arc, + state: SingleQueryState, + complete: Arc, } struct Queries { - // Number of queries in the vector. The count of element that are - // not None - count: usize, + // Number of queries in the vector. The count of element that are + // not None + count: usize, - // Number of queries that are still waiting for an answer - busy: usize, + // Number of queries that are still waiting for an answer + busy: usize, - // Index in the vector where to look for a space for a new query - curr: usize, + // Index in the vector where to look for a space for a new query + curr: usize, - vec: Vec>, + vec: Vec>, } enum ConnState { - Active(Option), - Idle(Instant), - IdleTimeout, - ReadError, - ReadTimeout, - WriteError, + Active(Option), + Idle(Instant), + IdleTimeout, + ReadError, + ReadTimeout, + WriteError, } struct Status { - state: ConnState, - - // For edns-tcp-keepalive, we have a boolean the specifies if we - // need to send one (typically at the start of the connection). - // Initially we assume that the idle timeout is zero. A received - // edns-tcp-keepalive option may change that. What the best way to - // specify time in Rust? Currently we specify it in milliseconds. - send_keepalive: bool, - idle_timeout: Option, - do_shutdown: bool, + state: ConnState, + + // For edns-tcp-keepalive, we have a boolean the specifies if we + // need to send one (typically at the start of the connection). + // Initially we assume that the idle timeout is zero. A received + // edns-tcp-keepalive option may change that. What the best way to + // specify time in Rust? Currently we specify it in milliseconds. + send_keepalive: bool, + idle_timeout: Option, + do_shutdown: bool, } struct InnerTcpConnection { - stream: Std_mutex>, + stream: Std_mutex>, - /* status */ - status: Std_mutex, + /* status */ + status: Std_mutex, - /* Vector with outstanding queries */ - query_vec: Std_mutex, + /* Vector with outstanding queries */ + query_vec: Std_mutex, - /* Vector with outstanding requests that need to be transmitted */ - tx_queue: Std_mutex>>, + /* Vector with outstanding requests that need to be transmitted */ + tx_queue: Std_mutex>>, - worker_notify: Notify, - writer_notify: Notify, + worker_notify: Notify, + writer_notify: Notify, } pub struct TcpConnection { - inner: Arc, + inner: Arc, } enum QueryState { - Busy(usize), // index - Done, + Busy(usize), // index + Done, } pub struct Query { - transport: Arc, - query_msg: Message>, - state: QueryState, + transport: Arc, + query_msg: Message>, + state: QueryState, } impl InnerTcpConnection { - pub async fn connect(addr: A) -> - io::Result { - let tcp = TcpStream::connect(addr).await?; - Ok(Self { - stream: Std_mutex::new(Some(tcp)), - status: Std_mutex::new(Status { - state: ConnState::Active(None), - send_keepalive: true, - idle_timeout: None, - do_shutdown: false, - }), - query_vec: Std_mutex::new(Queries { - count: 0, - busy: 0, - curr: 0, - vec: Vec::new() - }), - tx_queue: Std_mutex::new(VecDeque::new()), - worker_notify: Notify::new(), - writer_notify: Notify::new(), - }) - } - - // Take a query out of query_vec and decrement the query count - fn take_query(&self, index: usize) -> Option - { - let mut query_vec = self.query_vec.lock().unwrap(); - self.vec_take_query(query_vec.deref_mut(), index) - } - - // Very similar to take_query, but sometime the caller already has - // a lock on the mutex - fn vec_take_query(&self, query_vec: &mut Queries, index: usize) -> - Option{ - let query = query_vec.vec[index].take(); - query_vec.count -= 1; - if query_vec.count == 0 { - // The worker may be waiting for this - self.worker_notify.notify_one(); - } - query - } - - fn insert_answer(&self, answer: Message) { - // We got an answer, reset the timer - let mut status = self.status.lock().unwrap(); - status.state = ConnState::Active(Some(Instant::now())); - drop(status); - - let ind16 = answer.header().id(); - let index: usize = ind16.into(); - - let mut query_vec = self.query_vec.lock().unwrap(); - - let vec_len = query_vec.vec.len(); - if index >= vec_len { - // Index is out of bouds. We should mark - // the TCP connection as broken - return; - } - - // Do we have a query with this ID? - match &mut query_vec.vec[index] { - None => { - // No query with this ID. We should - // mark the TCP connection as broken - return; - } - Some(query) => { - match &query.state { - SingleQueryState::Busy => { - query.state = - SingleQueryState:: - Done(Ok( - answer)); - query.complete. - notify_one(); - } - SingleQueryState::Canceled => { - //`The query has been - // canceled already - // Clean up. - let _ = self.vec_take_query( - query_vec.deref_mut(), - index); - } - SingleQueryState::Done(_) => { - // Already got a - // result. - return; - } - } - } - } - query_vec.busy -= 1; - if query_vec.busy == 0 { - let mut status = self.status.lock().unwrap(); - - // Clear the activity timer. There is no need to do - // this because state will be set to either IdleTimeout - // or Idle just below. However, it is nicer to keep - // this indenpendent. - status.state = ConnState::Active(None); - - if status.idle_timeout.is_none() { - // Assume that we can just move to IdleTimeout - // state - status.state = ConnState::IdleTimeout; - - // Notify the worker. Then the worker can - // close the tcp connection - self.worker_notify.notify_one(); - } - else { - status.state = - ConnState::Idle(Instant::now()); - - // Notify the worker. The worker waits for - // the timeout to expire - self.worker_notify.notify_one(); - } - } - } - - fn handle_keepalive(&self, opt_value: TcpKeepalive) { - if let Some(value) = opt_value.timeout() { - let mut status = self.status.lock().unwrap(); - status.idle_timeout = - Some(Duration::from_millis(u64::from(value) * - EDNS_TCP_KEEPALIE_TO_MS)); - } - } - - fn handle_opts> - (&self, opts: &OptRecord) { - for option in opts.iter() { - let opt = option.unwrap(); - if let AllOptData::TcpKeepalive(tcpkeepalive) = opt { - self.handle_keepalive(tcpkeepalive); - } - } - } - - // This function is not async cancellation safe - async fn reader(&self, sock: &mut ReadHalf<'_>) -> Result<(), &str> { - loop { - let read_res = sock.read_u16().await; - let len = match read_res { - Ok(len) => len, - Err(error) => { - self.tcp_error(error); - let mut status = self.status.lock().unwrap(); - status.state = ConnState::ReadError; - return Err(ERR_READ_ERROR); - } - } as usize; - - let mut buf = BytesMut::with_capacity(len); - - loop { - let curlen = buf.len(); - if curlen >= len { - if curlen > len { - panic!( - "reader: got too much data {curlen}, expetect {len}"); - } - - // We got what we need - break; - } - - let read_res = sock.read_buf(&mut buf).await; - - match read_res { - Ok(readlen) => { - if readlen == 0 { - let error = io::Error::new( - io::ErrorKind::Other, - "unexpected end of data"); - self.tcp_error(error); - let mut status = self.status.lock(). - unwrap(); - status.state = ConnState::ReadError; - return Err(ERR_READ_ERROR); - } - } - Err(error) => { - self.tcp_error(error); - let mut status = self.status.lock().unwrap(); - status.state = ConnState::ReadError; - return Err(ERR_READ_ERROR); - } - }; - - // Check if we are done at the head of the loop - } - - let reply_message = Message::::from_octets(buf.into()); - - match reply_message { - Ok(answer) => { - // Check for a edns-tcp-keepalive option - let opt_record = answer.opt(); - if let Some(ref opts) = opt_record { - self.handle_opts(opts); - }; - self.insert_answer(answer); - } - Err(_) => { - // The only possible error is short message - let error = io::Error::new(io::ErrorKind::Other, - "short buf"); - self.tcp_error(error); - let mut status = self.status.lock().unwrap(); - status.state = ConnState::ReadError; - return Err(ERR_READ_ERROR); - } - } - } - } - - fn tcp_error(&self, error: std::io::Error) { - // Update all requests that are in progress. Don't wait for - // any reply that may be on its way. - let arc_error = Arc::new(error); - let mut query_vec = self.query_vec.lock().unwrap(); - for query in &mut query_vec.vec { - match query { - None => { - continue; - } - Some(q) => { - match q.state { - SingleQueryState::Busy => { - q.state = - SingleQueryState:: - Done(Err( - arc_error - .clone())); - q.complete. - notify_one(); - } - SingleQueryState::Done(_) | - SingleQueryState::Canceled => - // Nothing to do - () - } - } - } - } - } - - // This function is not async cancellation safe - async fn writer(&self, sock: &mut WriteHalf<'_>) -> - Result<(), &'static str> { - loop { - loop { - // Check if we need to shutdown - let do_shutdown = { - // Extra block to satisfy clippy - // await_holding_lock - let status = self.status.lock() - .unwrap(); - status.do_shutdown - // drop(status); - }; - - if do_shutdown { - // Ignore errors - _ = sock.shutdown().await; - - // Do we need to clear do_shutdown? - break; - } - - let head = { - // Extra block to satisfy clippy - // await_holding_lock - let mut tx_queue = self.tx_queue - .lock().unwrap(); - tx_queue.pop_front() - // drop(tx_queue); - }; - match head { - Some(vec) => { - let res = sock.write_all(&vec).await; - if let Err(error) = res { - self.tcp_error(error); - let mut status = - self.status.lock(). - unwrap(); - status.state = - ConnState::WriteError; - return Err(ERR_WRITE_ERROR); - } - } - None => - break, - } - } - - self.writer_notify.notified().await; - } - } - - // This function is not async cancellation safe because it calls - // reader and writer which are not async cancellation safe - pub async fn worker(&self) -> Option<()> { - let mut stream = { - // Extra block to satisfy clippy - // await_holding_lock - let mut opt_stream = self.stream.lock().unwrap(); - opt_stream.take().unwrap() - // drop(opt_stream); - }; - let (mut read_stream, mut write_stream) = stream.split(); - - let reader_fut = self.reader(&mut read_stream); - tokio::pin!(reader_fut); - let writer_fut = self.writer(&mut write_stream); - tokio::pin!(writer_fut); - - loop { - let opt_timeout: Option = { - // Extra block to satisfy clippy - // await_holding_lock - let mut status = self.status.lock().unwrap(); - match status.state { - ConnState::Active(opt_instant) => { - if let Some(instant) = opt_instant { - let timeout = Duration::from_secs( - RESPONSE_TIMEOUT_S); - let elapsed = instant.elapsed(); - if elapsed > timeout { - let error = io::Error::new( - io::ErrorKind::Other, - "read timeout"); - self.tcp_error(error); - status.state = ConnState::ReadTimeout; - break; - } - Some(timeout - elapsed) - } - else { None } - } - ConnState::Idle(instant) => { - if let Some(timeout) = status.idle_timeout { - let elapsed = instant.elapsed(); - if elapsed >= timeout { - // Move to IdleTimeout and end - // the loop - status.state = - ConnState::IdleTimeout; - break; - } - Some(timeout - elapsed) - } - else { - panic!("Idle state but no timeout"); - } - } - ConnState::IdleTimeout | - ConnState::ReadError | - ConnState::WriteError => - None, // No timers here - ConnState::ReadTimeout => panic!( - "should not be in loop with ReadTimeout") - } - // drop(status); - }; - - // For simplicity, make sure we always have a timeout - let timeout = match opt_timeout { - Some(timeout) => timeout, - None => - // Just use the response timeout - Duration::from_secs(RESPONSE_TIMEOUT_S) - }; - - let sleep_fut = sleep(timeout); - let notify_fut = self.worker_notify.notified(); - - tokio::select! { - res = &mut reader_fut => { - match res { - Ok(_) => - // The reader should not - // terminate without - // error. - panic!("reader terminated"), - Err(_) => - // Reader failed. Break - // out of loop and - // shut down - break - } - } - res = &mut writer_fut => { - match res { - Ok(_) => - // The writer should not - // terminate without - // error. - panic!("reader terminated"), - Err(_) => - // Writer failed. Break - // out of loop and - // shut down - break - } - } - - _ = sleep_fut => { - // Timeout expired, just - // continue with the loop - } - _ = notify_fut => { - // Got notifies, go through the loop - // to see what changed. - } - - } - - // Check if the connection is idle - let status = self.status.lock().unwrap(); - match status.state { - ConnState::Active(_) | ConnState::Idle(_) => { - // Keep going - } - ConnState::IdleTimeout => { - break - } - ConnState::ReadError | - ConnState::ReadTimeout | - ConnState::WriteError => { - panic!("Should not be here"); - } - } - drop(status); - } - - // We can't see a FIN directly because the writer_fut owns - // write_stream. - { - // Extra block to satisfy clippy - // await_holding_lock - let mut status = self.status.lock().unwrap(); - status.do_shutdown = true; - // drop(status); - }; - - // Kick writer - self.writer_notify.notify_one(); - - // Wait for writer to terminate. Ignore the result. We may - // want a timer here - _ = writer_fut.await; - - // Stay around until the last query result is collected - loop { - { - // Extra block to satisfy clippy - // await_holding_lock - let query_vec = self.query_vec.lock().unwrap(); - if query_vec.count == 0 { - // We are done - break; - } - // drop(query_vec); - } - - self.worker_notify.notified().await; - } - None - } - - fn insert_at(query_vec: &mut Queries, index: usize, - q: Option) { - query_vec.vec[index] = q; - query_vec.count += 1; - query_vec.busy += 1; - query_vec.curr = index + 1; - } - - // Insert a message in the query vector. Return the index - fn insert(&self) - -> Result { - let q = Some(SingleQuery { - state: SingleQueryState::Busy, - complete: Arc::new(Notify::new()), - }); - let mut query_vec = self.query_vec.lock().unwrap(); - - // Fail if there are to many entries already in this vector - // We cannot have more than u16::MAX entries because the - // index needs to fit in an u16. For efficiency we want to - // keep the vector half empty. So we return a failure if - // 2*count > u16::MAX - if 2*query_vec.count > u16::MAX.into() { - return Err(ERR_TOO_MANY_QUERIES); - } - - let vec_len = query_vec.vec.len(); - - // Append if the amount of empty space in the vector is less - // than half. But limit vec_len to u16::MAX - if vec_len < 2*(query_vec.count+1) && vec_len < - u16::MAX.into() { - // Just append - query_vec.vec.push(q); - query_vec.count += 1; - query_vec.busy += 1; - let index = query_vec.vec.len()-1; - return Ok(index); - } - let loc_curr = query_vec.curr; - - for index in loc_curr..vec_len { - match query_vec.vec[index] { - Some(_) => { - // Already in use, just continue - } - None => { - Self::insert_at(&mut query_vec, - index, q); - return Ok(index); - } - } - } - - // Nothing until the end of the vector. Try for the entire - // vector - for index in 0..vec_len { - match query_vec.vec[index] { - Some(_) => { - // Already in use, just continue - } - None => { - Self::insert_at(&mut query_vec, - index, q); - return Ok(index); - } - } - } - - // Still nothing, that is not good - panic!("insert failed"); - } - - fn queue_query + AsRef<[u8]>> - (&self, msg: &MessageBuilder>>) { - - let vec = msg.as_target().as_target().as_stream_slice(); - - // Store a clone of the request. That makes life easier - // and requests tend to be small - let mut tx_queue = self.tx_queue.lock().unwrap(); - tx_queue.push_back(vec.to_vec()); - } - - pub fn query + AsRef<[u8]> + - Composer + Clone> - (&self, - query_msg: &mut MessageBuilder>> - ) -> Result { - - // Check the state of the connection, fail if the connection - // is in IdleTimeout. If the connection is Idle, move it - // back to Active. Also check for the need to send a keepalive - let mut status = self.status.lock().unwrap(); - match status.state { - ConnState::Active(timer) => { - // Set timer if we don't have one already - if timer.is_none() { - status.state = ConnState::Active(Some( - Instant::now())); - } - } - ConnState::Idle(_) => { - // Go back to active - status.state = ConnState::Active(Some( - Instant::now())); - } - ConnState::IdleTimeout => { - // The connection has been closed. Report error - return Err(ERR_IDLE_TIMEOUT); - } - ConnState::ReadError => { - return Err(ERR_READ_ERROR); - } - ConnState::ReadTimeout => { - return Err(ERR_READ_TIMEOUT); - } - ConnState::WriteError => { - return Err(ERR_WRITE_ERROR); - } - } - - // Note that insert may fail if there are too many - // outstanding queires. First call insert before checking - // send_keepalive. - let index = self.insert()?; - - let mut do_keepalive = false; - if status.send_keepalive { - do_keepalive = true; - status.send_keepalive = false; - } - drop(status); - - let ind16: u16 = index.try_into().unwrap(); - - // We set the ID to the array index. Defense in depth - // suggests that a random ID is better because it works - // even if TCP sequence numbers could be predicted. However, - // Section 9.3 of RFC 5452 recommends retrying over TCP - // if many spoofed answers arrive over UDP: "TCP, by the - // nature of its use of sequence numbers, is far more - // resilient against forgery by third parties." - let hdr = query_msg.header_mut(); - hdr.set_id(ind16); - - if do_keepalive { - let mut msgadd = query_msg.clone().additional(); - - // send an empty keepalive option - let res = msgadd.opt(|opt| { - opt.tcp_keepalive(None) - }); - match res { - Ok(_) => - self.queue_query(&msgadd), - Err(_) => { - // Adding keepalive option - // failed. Send the original - // request and turn the - // send_keepalive flag back on - let mut status = - self.status.lock() - .unwrap(); - status.send_keepalive = - true; - drop(status); - self.queue_query(query_msg); - } - } - } else { - self.queue_query(query_msg); - } - - // Now kick the writer to transmit the query - self.writer_notify.notify_one(); - - Ok(index) - } - - pub async fn get_result(&self, - query_msg: &Message, index: usize) -> - Result, Arc> { - // Wait for reply - let local_notify = { - // Extra block to satisfy clippy - // await_holding_lock - let mut query_vec = self.query_vec.lock() - .unwrap(); - query_vec.vec[index].as_mut(). - unwrap().complete.clone() - // drop(query_vec); - }; - local_notify.notified().await; - - // take a look - let opt_q = self.take_query(index); - if let Some(q) = opt_q - { - if let SingleQueryState::Done(result) = q.state - { - if let Ok(answer) = &result - { - if !answer.is_answer( - query_msg) { - return Err(Arc::new( - io::Error::new( - io::ErrorKind::Other, - "wrong answer"))); - } - } - return result; - } - panic!("inconsistent state"); - } - - panic!("inconsistent state"); - } - - fn cancel(&self, index: usize) { - let mut query_vec = self.query_vec.lock().unwrap(); - - match &mut query_vec.vec[index] { - None => { - panic!("Cancel called, but nothing to cancel"); - } - Some(query) => { - match &query.state { - SingleQueryState::Busy => { - query.state = - SingleQueryState:: - Canceled; - } - SingleQueryState::Canceled => { - panic!("Already canceled"); - } - SingleQueryState::Done(_) => { - // Remove the result - let _ = self.vec_take_query( - query_vec.deref_mut(), - index); - } - } - } - } - } + pub async fn connect( + addr: A, + ) -> io::Result { + let tcp = TcpStream::connect(addr).await?; + Ok(Self { + stream: Std_mutex::new(Some(tcp)), + status: Std_mutex::new(Status { + state: ConnState::Active(None), + send_keepalive: true, + idle_timeout: None, + do_shutdown: false, + }), + query_vec: Std_mutex::new(Queries { + count: 0, + busy: 0, + curr: 0, + vec: Vec::new(), + }), + tx_queue: Std_mutex::new(VecDeque::new()), + worker_notify: Notify::new(), + writer_notify: Notify::new(), + }) + } + + // Take a query out of query_vec and decrement the query count + fn take_query(&self, index: usize) -> Option { + let mut query_vec = self.query_vec.lock().unwrap(); + self.vec_take_query(query_vec.deref_mut(), index) + } + + // Very similar to take_query, but sometime the caller already has + // a lock on the mutex + fn vec_take_query( + &self, + query_vec: &mut Queries, + index: usize, + ) -> Option { + let query = query_vec.vec[index].take(); + query_vec.count -= 1; + if query_vec.count == 0 { + // The worker may be waiting for this + self.worker_notify.notify_one(); + } + query + } + + fn insert_answer(&self, answer: Message) { + // We got an answer, reset the timer + let mut status = self.status.lock().unwrap(); + status.state = ConnState::Active(Some(Instant::now())); + drop(status); + + let ind16 = answer.header().id(); + let index: usize = ind16.into(); + + let mut query_vec = self.query_vec.lock().unwrap(); + + let vec_len = query_vec.vec.len(); + if index >= vec_len { + // Index is out of bouds. We should mark + // the TCP connection as broken + return; + } + + // Do we have a query with this ID? + match &mut query_vec.vec[index] { + None => { + // No query with this ID. We should + // mark the TCP connection as broken + return; + } + Some(query) => { + match &query.state { + SingleQueryState::Busy => { + query.state = SingleQueryState::Done(Ok(answer)); + query.complete.notify_one(); + } + SingleQueryState::Canceled => { + //`The query has been + // canceled already + // Clean up. + let _ = + self.vec_take_query(query_vec.deref_mut(), index); + } + SingleQueryState::Done(_) => { + // Already got a + // result. + return; + } + } + } + } + query_vec.busy -= 1; + if query_vec.busy == 0 { + let mut status = self.status.lock().unwrap(); + + // Clear the activity timer. There is no need to do + // this because state will be set to either IdleTimeout + // or Idle just below. However, it is nicer to keep + // this indenpendent. + status.state = ConnState::Active(None); + + if status.idle_timeout.is_none() { + // Assume that we can just move to IdleTimeout + // state + status.state = ConnState::IdleTimeout; + + // Notify the worker. Then the worker can + // close the tcp connection + self.worker_notify.notify_one(); + } else { + status.state = ConnState::Idle(Instant::now()); + + // Notify the worker. The worker waits for + // the timeout to expire + self.worker_notify.notify_one(); + } + } + } + + fn handle_keepalive(&self, opt_value: TcpKeepalive) { + if let Some(value) = opt_value.timeout() { + let mut status = self.status.lock().unwrap(); + status.idle_timeout = Some(Duration::from_millis( + u64::from(value) * EDNS_TCP_KEEPALIE_TO_MS, + )); + } + } + + fn handle_opts>( + &self, + opts: &OptRecord, + ) { + for option in opts.iter() { + let opt = option.unwrap(); + if let AllOptData::TcpKeepalive(tcpkeepalive) = opt { + self.handle_keepalive(tcpkeepalive); + } + } + } + + // This function is not async cancellation safe + async fn reader(&self, sock: &mut ReadHalf<'_>) -> Result<(), &str> { + loop { + let read_res = sock.read_u16().await; + let len = match read_res { + Ok(len) => len, + Err(error) => { + self.tcp_error(error); + let mut status = self.status.lock().unwrap(); + status.state = ConnState::ReadError; + return Err(ERR_READ_ERROR); + } + } as usize; + + let mut buf = BytesMut::with_capacity(len); + + loop { + let curlen = buf.len(); + if curlen >= len { + if curlen > len { + panic!( + "reader: got too much data {curlen}, expetect {len}"); + } + + // We got what we need + break; + } + + let read_res = sock.read_buf(&mut buf).await; + + match read_res { + Ok(readlen) => { + if readlen == 0 { + let error = io::Error::new( + io::ErrorKind::Other, + "unexpected end of data", + ); + self.tcp_error(error); + let mut status = self.status.lock().unwrap(); + status.state = ConnState::ReadError; + return Err(ERR_READ_ERROR); + } + } + Err(error) => { + self.tcp_error(error); + let mut status = self.status.lock().unwrap(); + status.state = ConnState::ReadError; + return Err(ERR_READ_ERROR); + } + }; + + // Check if we are done at the head of the loop + } + + let reply_message = Message::::from_octets(buf.into()); + + match reply_message { + Ok(answer) => { + // Check for a edns-tcp-keepalive option + let opt_record = answer.opt(); + if let Some(ref opts) = opt_record { + self.handle_opts(opts); + }; + self.insert_answer(answer); + } + Err(_) => { + // The only possible error is short message + let error = + io::Error::new(io::ErrorKind::Other, "short buf"); + self.tcp_error(error); + let mut status = self.status.lock().unwrap(); + status.state = ConnState::ReadError; + return Err(ERR_READ_ERROR); + } + } + } + } + + fn tcp_error(&self, error: std::io::Error) { + // Update all requests that are in progress. Don't wait for + // any reply that may be on its way. + let arc_error = Arc::new(error); + let mut query_vec = self.query_vec.lock().unwrap(); + for query in &mut query_vec.vec { + match query { + None => { + continue; + } + Some(q) => { + match q.state { + SingleQueryState::Busy => { + q.state = SingleQueryState::Done(Err( + arc_error.clone() + )); + q.complete.notify_one(); + } + SingleQueryState::Done(_) + | SingleQueryState::Canceled => + // Nothing to do + { + () + } + } + } + } + } + } + + // This function is not async cancellation safe + async fn writer( + &self, + sock: &mut WriteHalf<'_>, + ) -> Result<(), &'static str> { + loop { + loop { + // Check if we need to shutdown + let do_shutdown = { + // Extra block to satisfy clippy + // await_holding_lock + let status = self.status.lock().unwrap(); + status.do_shutdown + // drop(status); + }; + + if do_shutdown { + // Ignore errors + _ = sock.shutdown().await; + + // Do we need to clear do_shutdown? + break; + } + + let head = { + // Extra block to satisfy clippy + // await_holding_lock + let mut tx_queue = self.tx_queue.lock().unwrap(); + tx_queue.pop_front() + // drop(tx_queue); + }; + match head { + Some(vec) => { + let res = sock.write_all(&vec).await; + if let Err(error) = res { + self.tcp_error(error); + let mut status = self.status.lock().unwrap(); + status.state = ConnState::WriteError; + return Err(ERR_WRITE_ERROR); + } + } + None => break, + } + } + + self.writer_notify.notified().await; + } + } + + // This function is not async cancellation safe because it calls + // reader and writer which are not async cancellation safe + pub async fn worker(&self) -> Option<()> { + let mut stream = { + // Extra block to satisfy clippy + // await_holding_lock + let mut opt_stream = self.stream.lock().unwrap(); + opt_stream.take().unwrap() + // drop(opt_stream); + }; + let (mut read_stream, mut write_stream) = stream.split(); + + let reader_fut = self.reader(&mut read_stream); + tokio::pin!(reader_fut); + let writer_fut = self.writer(&mut write_stream); + tokio::pin!(writer_fut); + + loop { + let opt_timeout: Option = { + // Extra block to satisfy clippy + // await_holding_lock + let mut status = self.status.lock().unwrap(); + match status.state { + ConnState::Active(opt_instant) => { + if let Some(instant) = opt_instant { + let timeout = + Duration::from_secs(RESPONSE_TIMEOUT_S); + let elapsed = instant.elapsed(); + if elapsed > timeout { + let error = io::Error::new( + io::ErrorKind::Other, + "read timeout", + ); + self.tcp_error(error); + status.state = ConnState::ReadTimeout; + break; + } + Some(timeout - elapsed) + } else { + None + } + } + ConnState::Idle(instant) => { + if let Some(timeout) = status.idle_timeout { + let elapsed = instant.elapsed(); + if elapsed >= timeout { + // Move to IdleTimeout and end + // the loop + status.state = ConnState::IdleTimeout; + break; + } + Some(timeout - elapsed) + } else { + panic!("Idle state but no timeout"); + } + } + ConnState::IdleTimeout + | ConnState::ReadError + | ConnState::WriteError => None, // No timers here + ConnState::ReadTimeout => { + panic!("should not be in loop with ReadTimeout") + } + } + // drop(status); + }; + + // For simplicity, make sure we always have a timeout + let timeout = match opt_timeout { + Some(timeout) => timeout, + None => + // Just use the response timeout + { + Duration::from_secs(RESPONSE_TIMEOUT_S) + } + }; + + let sleep_fut = sleep(timeout); + let notify_fut = self.worker_notify.notified(); + + tokio::select! { + res = &mut reader_fut => { + match res { + Ok(_) => + // The reader should not + // terminate without + // error. + panic!("reader terminated"), + Err(_) => + // Reader failed. Break + // out of loop and + // shut down + break + } + } + res = &mut writer_fut => { + match res { + Ok(_) => + // The writer should not + // terminate without + // error. + panic!("reader terminated"), + Err(_) => + // Writer failed. Break + // out of loop and + // shut down + break + } + } + + _ = sleep_fut => { + // Timeout expired, just + // continue with the loop + } + _ = notify_fut => { + // Got notifies, go through the loop + // to see what changed. + } + + } + + // Check if the connection is idle + let status = self.status.lock().unwrap(); + match status.state { + ConnState::Active(_) | ConnState::Idle(_) => { + // Keep going + } + ConnState::IdleTimeout => break, + ConnState::ReadError + | ConnState::ReadTimeout + | ConnState::WriteError => { + panic!("Should not be here"); + } + } + drop(status); + } + + // We can't see a FIN directly because the writer_fut owns + // write_stream. + { + // Extra block to satisfy clippy + // await_holding_lock + let mut status = self.status.lock().unwrap(); + status.do_shutdown = true; + // drop(status); + }; + + // Kick writer + self.writer_notify.notify_one(); + + // Wait for writer to terminate. Ignore the result. We may + // want a timer here + _ = writer_fut.await; + + // Stay around until the last query result is collected + loop { + { + // Extra block to satisfy clippy + // await_holding_lock + let query_vec = self.query_vec.lock().unwrap(); + if query_vec.count == 0 { + // We are done + break; + } + // drop(query_vec); + } + + self.worker_notify.notified().await; + } + None + } + + fn insert_at( + query_vec: &mut Queries, + index: usize, + q: Option, + ) { + query_vec.vec[index] = q; + query_vec.count += 1; + query_vec.busy += 1; + query_vec.curr = index + 1; + } + + // Insert a message in the query vector. Return the index + fn insert(&self) -> Result { + let q = Some(SingleQuery { + state: SingleQueryState::Busy, + complete: Arc::new(Notify::new()), + }); + let mut query_vec = self.query_vec.lock().unwrap(); + + // Fail if there are to many entries already in this vector + // We cannot have more than u16::MAX entries because the + // index needs to fit in an u16. For efficiency we want to + // keep the vector half empty. So we return a failure if + // 2*count > u16::MAX + if 2 * query_vec.count > u16::MAX.into() { + return Err(ERR_TOO_MANY_QUERIES); + } + + let vec_len = query_vec.vec.len(); + + // Append if the amount of empty space in the vector is less + // than half. But limit vec_len to u16::MAX + if vec_len < 2 * (query_vec.count + 1) && vec_len < u16::MAX.into() { + // Just append + query_vec.vec.push(q); + query_vec.count += 1; + query_vec.busy += 1; + let index = query_vec.vec.len() - 1; + return Ok(index); + } + let loc_curr = query_vec.curr; + + for index in loc_curr..vec_len { + match query_vec.vec[index] { + Some(_) => { + // Already in use, just continue + } + None => { + Self::insert_at(&mut query_vec, index, q); + return Ok(index); + } + } + } + + // Nothing until the end of the vector. Try for the entire + // vector + for index in 0..vec_len { + match query_vec.vec[index] { + Some(_) => { + // Already in use, just continue + } + None => { + Self::insert_at(&mut query_vec, index, q); + return Ok(index); + } + } + } + + // Still nothing, that is not good + panic!("insert failed"); + } + + fn queue_query + AsRef<[u8]>>( + &self, + msg: &MessageBuilder>>, + ) { + let vec = msg.as_target().as_target().as_stream_slice(); + + // Store a clone of the request. That makes life easier + // and requests tend to be small + let mut tx_queue = self.tx_queue.lock().unwrap(); + tx_queue.push_back(vec.to_vec()); + } + + pub fn query< + Octs: OctetsBuilder + AsMut<[u8]> + AsRef<[u8]> + Composer + Clone, + >( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result { + // Check the state of the connection, fail if the connection + // is in IdleTimeout. If the connection is Idle, move it + // back to Active. Also check for the need to send a keepalive + let mut status = self.status.lock().unwrap(); + match status.state { + ConnState::Active(timer) => { + // Set timer if we don't have one already + if timer.is_none() { + status.state = ConnState::Active(Some(Instant::now())); + } + } + ConnState::Idle(_) => { + // Go back to active + status.state = ConnState::Active(Some(Instant::now())); + } + ConnState::IdleTimeout => { + // The connection has been closed. Report error + return Err(ERR_IDLE_TIMEOUT); + } + ConnState::ReadError => { + return Err(ERR_READ_ERROR); + } + ConnState::ReadTimeout => { + return Err(ERR_READ_TIMEOUT); + } + ConnState::WriteError => { + return Err(ERR_WRITE_ERROR); + } + } + + // Note that insert may fail if there are too many + // outstanding queires. First call insert before checking + // send_keepalive. + let index = self.insert()?; + + let mut do_keepalive = false; + if status.send_keepalive { + do_keepalive = true; + status.send_keepalive = false; + } + drop(status); + + let ind16: u16 = index.try_into().unwrap(); + + // We set the ID to the array index. Defense in depth + // suggests that a random ID is better because it works + // even if TCP sequence numbers could be predicted. However, + // Section 9.3 of RFC 5452 recommends retrying over TCP + // if many spoofed answers arrive over UDP: "TCP, by the + // nature of its use of sequence numbers, is far more + // resilient against forgery by third parties." + let hdr = query_msg.header_mut(); + hdr.set_id(ind16); + + if do_keepalive { + let mut msgadd = query_msg.clone().additional(); + + // send an empty keepalive option + let res = msgadd.opt(|opt| opt.tcp_keepalive(None)); + match res { + Ok(_) => self.queue_query(&msgadd), + Err(_) => { + // Adding keepalive option + // failed. Send the original + // request and turn the + // send_keepalive flag back on + let mut status = self.status.lock().unwrap(); + status.send_keepalive = true; + drop(status); + self.queue_query(query_msg); + } + } + } else { + self.queue_query(query_msg); + } + // Now kick the writer to transmit the query + self.writer_notify.notify_one(); + + Ok(index) + } + + pub async fn get_result( + &self, + query_msg: &Message, + index: usize, + ) -> Result, Arc> { + // Wait for reply + let local_notify = { + // Extra block to satisfy clippy + // await_holding_lock + let mut query_vec = self.query_vec.lock().unwrap(); + query_vec.vec[index].as_mut().unwrap().complete.clone() + // drop(query_vec); + }; + local_notify.notified().await; + + // take a look + let opt_q = self.take_query(index); + if let Some(q) = opt_q { + if let SingleQueryState::Done(result) = q.state { + if let Ok(answer) = &result { + if !answer.is_answer(query_msg) { + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "wrong answer", + ))); + } + } + return result; + } + panic!("inconsistent state"); + } + + panic!("inconsistent state"); + } + + fn cancel(&self, index: usize) { + let mut query_vec = self.query_vec.lock().unwrap(); + + match &mut query_vec.vec[index] { + None => { + panic!("Cancel called, but nothing to cancel"); + } + Some(query) => { + match &query.state { + SingleQueryState::Busy => { + query.state = SingleQueryState::Canceled; + } + SingleQueryState::Canceled => { + panic!("Already canceled"); + } + SingleQueryState::Done(_) => { + // Remove the result + let _ = + self.vec_take_query(query_vec.deref_mut(), index); + } + } + } + } + } } impl TcpConnection { - pub async fn connect(addr: A) -> - io::Result { - let tcpconnection = InnerTcpConnection::connect(addr).await?; - Ok(Self { inner: Arc::new(tcpconnection) }) - } - pub async fn worker(&self) -> Option<()> { - self.inner.worker().await - } - pub fn query + AsRef<[u8]> + - Composer + Clone> - (&self, query_msg: &mut MessageBuilder>>) - -> Result { - let index = self.inner.query(query_msg)?; - let msg = &query_msg.as_message(); - Ok(Query::new(self, msg, index)) - } + pub async fn connect( + addr: A, + ) -> io::Result { + let tcpconnection = InnerTcpConnection::connect(addr).await?; + Ok(Self { + inner: Arc::new(tcpconnection), + }) + } + pub async fn worker(&self) -> Option<()> { + self.inner.worker().await + } + pub fn query< + OctsBuilder: OctetsBuilder + AsMut<[u8]> + AsRef<[u8]> + Composer + Clone, + >( + &self, + query_msg: &mut MessageBuilder< + StaticCompressor>, + >, + ) -> Result { + let index = self.inner.query(query_msg)?; + let msg = &query_msg.as_message(); + Ok(Query::new(self, msg, index)) + } } - impl Query { - fn new(transport: &TcpConnection, - query_msg: &Message, - index: usize) -> Query { - let msg_ref: &[u8] = query_msg.as_ref(); - let vec = msg_ref.to_vec(); - let msg = Message::from_octets(vec).unwrap(); - Self { - transport: transport.inner.clone(), - query_msg: msg, - state: QueryState::Busy(index) - } - } - pub async fn get_result(&mut self) -> - Result, Arc> { - // Just the result of get_result on tranport. We should record - // that we got an answer to avoid asking again - match self.state { - QueryState::Busy(index) => { - let result = self.transport.get_result( - &self.query_msg, index).await; - self.state = QueryState::Done; - result - } - QueryState::Done => { - panic!("Already done"); - } - } - } + fn new( + transport: &TcpConnection, + query_msg: &Message, + index: usize, + ) -> Query { + let msg_ref: &[u8] = query_msg.as_ref(); + let vec = msg_ref.to_vec(); + let msg = Message::from_octets(vec).unwrap(); + Self { + transport: transport.inner.clone(), + query_msg: msg, + state: QueryState::Busy(index), + } + } + pub async fn get_result( + &mut self, + ) -> Result, Arc> { + // Just the result of get_result on tranport. We should record + // that we got an answer to avoid asking again + match self.state { + QueryState::Busy(index) => { + let result = + self.transport.get_result(&self.query_msg, index).await; + self.state = QueryState::Done; + result + } + QueryState::Done => { + panic!("Already done"); + } + } + } } impl Drop for Query { - fn drop(&mut self) { - match self.state { - QueryState::Busy(index) => { - self.transport.cancel(index); - } - QueryState::Done => { - // Done, nothing to cancel - } - } - } + fn drop(&mut self) { + match self.state { + QueryState::Busy(index) => { + self.transport.cancel(index); + } + QueryState::Done => { + // Done, nothing to cancel + } + } + } } From 83dc4aa38a601682ac9115a460f8f69e5f9a461f Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 17 Apr 2023 16:59:09 +0200 Subject: [PATCH 016/124] keep making changes until both fmt and clippy are happy. --- src/net/client/tcp_mutex.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/net/client/tcp_mutex.rs b/src/net/client/tcp_mutex.rs index 226e4ae58..8162705b2 100644 --- a/src/net/client/tcp_mutex.rs +++ b/src/net/client/tcp_mutex.rs @@ -373,10 +373,8 @@ impl InnerTcpConnection { } SingleQueryState::Done(_) | SingleQueryState::Canceled => - // Nothing to do - { - () - } + // Nothing to do + {} } } } From 4e8d0a3e9d45c4163e947b502b2aef848f5ae7e5 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Tue, 18 Apr 2023 16:31:20 +0200 Subject: [PATCH 017/124] Make tcp_channel.rs a bit more general. Now octet_stream.rs. --- src/net/client/mod.rs | 3 +- src/net/client/octet_stream.rs | 887 +++++++++++++++++++++++++++++++++ 2 files changed, 889 insertions(+), 1 deletion(-) create mode 100644 src/net/client/octet_stream.rs diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index f24db07dc..e3b2b2f10 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -2,5 +2,6 @@ #![cfg(feature = "net")] #![cfg_attr(docsrs, doc(cfg(feature = "net")))] -pub mod tcp_channel; +pub mod octet_stream; pub mod tcp_mutex; +pub mod tcp_channel; diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs new file mode 100644 index 000000000..2a8b40fa7 --- /dev/null +++ b/src/net/client/octet_stream.rs @@ -0,0 +1,887 @@ +//! A DNS over octet stream transport + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +// RFC 7766 describes DNS over TCP +// RFC 7828 describes the edns-tcp-keepalive option + +// TODO: +// - errors +// - connect errors? Retry after connection refused? +// - server errors +// - ID out of range +// - ID not in use +// - reply for wrong query +// - timeouts +// - request timeout +// - create new connection after end/failure of previous one + +use bytes; +use bytes::{Bytes, BytesMut}; +use core::convert::From; +use futures::lock::Mutex as Futures_mutex; +use std::fmt::Debug; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use std::vec::Vec; + +use crate::base::wire::Composer; +use crate::base::{ + opt::{AllOptData, OptRecord, TcpKeepalive}, + Message, MessageBuilder, StaticCompressor, StreamTarget, +}; +use octseq::{Octets, OctetsBuilder}; + +use tokio::io; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::{mpsc, oneshot}; +use tokio::time::sleep; + +/// Error returned when too many queries are currently active. +const ERR_TOO_MANY_QUERIES: &str = "too many outstanding queries"; + +/// Constant from RFC 7828. How to convert the value on the +/// edns-tcp-keepalive option to milliseconds. +// This should go somewhere with the option parsing +const EDNS_TCP_KEEPALIVE_TO_MS: u64 = 100; + +/// Time to wait on a non-idle connection for the other side to send +/// a response on any outstanding query. +// Implement a simple response timer to see if the connection and the server +// are alive. Set the timer when the connection goes from idle to busy. +// Reset the timer each time a reply arrives. Cancel the timer when the +// connection goes back to idle. When the time expires, mark all outstanding +// queries as timed out and shutdown the connection. +// +// Note: nsd has 120 seconds, unbound has 3 seconds. +const RESPONSE_TIMEOUT: Duration = Duration::from_secs(19); + +/// Capacity of the channel that transports [ChanReq]. +const DEF_CHAN_CAP: usize = 8; + +/// Capacity of a private channel between [InnerConnection::reader] and +/// [InnerConnection::run]. +const READ_REPLY_CHAN_CAP: usize = 8; + +/// Error reported when the connection is closed and +/// [InnerConnection::run] terminated. +const ERR_CONN_CLOSED: &str = "connection closed"; + +/// This is the type of sender in [ChanReq]. +type ReplySender = oneshot::Sender; + +#[derive(Debug)] +/// A request from [Query] to [Connection::run] to start a DNS request. +struct ChanReq { + /// DNS request message + msg: MessageBuilder>>, + + /// Sender to send result back to [Query] + sender: ReplySender, +} + +#[derive(Debug)] +/// a response to a [ChanReq]. +struct Response { + /// The 2 octet id that went into the outgoing DNS request. + /// + /// This id is needed to match the response with the query. + id: u16, + + /// The DNS reply message. + reply: Message, +} +/// Response to the DNS request sent by [InnerConnection::run] to [Query]. +type ChanResp = Result>; + +/// The actual implementation of [Connection]. +struct InnerConnection { + /// [InnerConnection::sender] and [InnerConnection::receiver] are + /// part of a single channel. + /// + /// Used by [Query] to send requests to [InnerConnection::run]. + sender: mpsc::Sender>, + + /// receiver part of the channel. + /// + /// Protected by a mutex to allow read/write access by + /// [InnerConnection::run]. + /// The Option is to allow [InnerConnection::run] to signal that the + /// connection is closed. + receiver: Futures_mutex>>>, +} + +/// Internal datastructure of [InnerConnection::run] to keep track of +/// outstanding DNS requests. +struct Queries { + /// The number of elements in [Queries::vec] that are not None. + count: usize, + + /// Index in the [Queries::vec] where to look for a space for a new query. + curr: usize, + + /// Vector of senders to forward a DNS reply message (or error) to. + vec: Vec>, +} + +#[derive(Clone)] +/// A single DNS over octect stream connection. +pub struct Connection { + /// Reference counted [InnerConnection]. + inner: Arc>, +} + +/// Status of a query. Used in [Query]. +enum QueryState { + /// A request is in progress. + /// + /// The receiver for receiving the response is part of this state. + Busy(oneshot::Receiver), + + /// The response has been received and the query is done. + Done, +} + +/// This struct represent an active DNS query. +pub struct Query { + /// Request message. + /// + /// The reply message is compared with the request message to see if + /// it matches the query. + query_msg: Message>, + + /// Current state of the query. + state: QueryState, +} + +/// Internal datastructure of [InnerConnection::run] to keep track of +/// the status of the connection. +// The types Status and ConnState are only used in InnerConnection +struct Status { + /// State of the connection. + state: ConnState, + + /// Boolean if we need to include an edns-tcp-keepalive option in an + /// outogoing request. + /// + /// Typically send_keepalive is true at the start of the connection. + /// it gets cleared when we successfully managed to include the option + /// in a request. + send_keepalive: bool, + + /// Time we are allow to keep the connection open when idle. + /// + /// Initially we assume that the idle timeout is zero. A received + /// edns-tcp-keepalive option may change that. + idle_timeout: Option, +} +/// Status of the connection. Used in [Status]. +enum ConnState { + /// The connection is in this state from the start and when at least + /// one active DNS request is present. + /// + /// The instant contains the time of the first request or the + /// most recent response that was received. + Active(Option), + + /// This state represent a connection that went idle and has an + /// idle timeout. + /// + /// The instant contains the time the connection went idle. + Idle(Instant), + + /// This state represent an idle connection where either there was no + /// idle timeout or the idle timer expired. + IdleTimeout, + + /// A read error occurred. + ReadError, + + /// It took too long to receive a (or another) response. + ReadTimeout, + + /// A write error occurred. + WriteError, +} + +/// A DNS message received to [InnerConnection::reader] and sent to +/// [InnerConnection::run]. +// This type could be local to InnerConnection, but I don't know how +type ReaderChanReply = Message; + +impl + Clone + Composer + Debug + OctetsBuilder> + InnerConnection +{ + /// Constructor for [InnerConnection]. + /// + /// This is the implementation of [Connection::connect]. + pub fn new() -> io::Result> { + let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); + Ok(Self { + sender: tx, + receiver: Futures_mutex::new(Some(rx)), + }) + } + + /// Main execution function for [InnerConnection]. + /// + /// This function Gets called by [Connection::run]. + /// This function is not async cancellation safe + pub async fn run< + ReadStream: AsyncReadExt + Unpin, + WriteStream: AsyncWriteExt + Unpin, + >( + &self, + mut read_stream: ReadStream, + mut write_stream: WriteStream, + ) -> Option<()> { + let (reply_sender, mut reply_receiver) = + mpsc::channel::(READ_REPLY_CHAN_CAP); + + let reader_fut = Self::reader(&mut read_stream, reply_sender); + tokio::pin!(reader_fut); + + let mut receiver = { + let mut locked_opt_receiver = self.receiver.lock().await; + let opt_receiver = locked_opt_receiver.take(); + opt_receiver.expect("no receiver present?") + }; + + let mut status = Status { + state: ConnState::Active(None), + idle_timeout: None, + send_keepalive: true, + }; + let mut query_vec = Queries { + count: 0, + curr: 0, + vec: Vec::new(), + }; + + let mut reqmsg: Option> = None; + + loop { + let opt_timeout = match status.state { + ConnState::Active(opt_instant) => { + if let Some(instant) = opt_instant { + let elapsed = instant.elapsed(); + if elapsed > RESPONSE_TIMEOUT { + let error = io::Error::new( + io::ErrorKind::Other, + "read timeout", + ); + Self::error(error, &mut query_vec); + status.state = ConnState::ReadTimeout; + break; + } + Some(RESPONSE_TIMEOUT - elapsed) + } else { + None + } + } + ConnState::Idle(instant) => { + if let Some(timeout) = &status.idle_timeout { + let elapsed = instant.elapsed(); + if elapsed >= *timeout { + // Move to IdleTimeout and end + // the loop + status.state = ConnState::IdleTimeout; + break; + } + Some(*timeout - elapsed) + } else { + panic!("Idle state but no timeout"); + } + } + ConnState::IdleTimeout + | ConnState::ReadError + | ConnState::WriteError => None, // No timers here + ConnState::ReadTimeout => { + panic!("should not be in loop with ReadTimeout"); + } + }; + + // For simplicity, make sure we always have a timeout + let timeout = match opt_timeout { + Some(timeout) => timeout, + None => + // Just use the response timeout + { + RESPONSE_TIMEOUT + } + }; + + let sleep_fut = sleep(timeout); + let recv_fut = receiver.recv(); + + let (do_write, msg) = match &reqmsg { + None => { + let msg: &[u8] = &[]; + (false, msg) + } + Some(msg) => { + let msg: &[u8] = msg; + (true, msg) + } + }; + + tokio::select! { + biased; + res = &mut reader_fut => { + match res { + Ok(_) => + // The reader should not + // terminate without + // error. + panic!("reader terminated"), + Err(error) => { + Self::error(error, + &mut query_vec); + status.state = + ConnState::ReadError; + // Reader failed. Break + // out of loop and + // shut down + break + } + } + } + opt_answer = reply_receiver.recv() => { + let answer = opt_answer.expect("reader died?"); + // Check for a edns-tcp-keepalive option + let opt_record = answer.opt(); + if let Some(ref opts) = opt_record { + Self::handle_opts(opts, + &mut status); + }; + drop(opt_record); + Self::demux_reply(answer, + &mut status, &mut query_vec); + } + res = write_stream.write_all(msg), + if do_write => { + if let Err(error) = res { + Self::error(error, + &mut query_vec); + status.state = + ConnState::WriteError; + break; + } + else { + reqmsg = None; + } + } + res = recv_fut, if !do_write => { + match res { + Some(req) => + self.insert_req(req, &mut status, + &mut reqmsg, &mut query_vec), + None => panic!("recv failed"), + } + } + _ = sleep_fut => { + // Timeout expired, just + // continue with the loop + } + + } + + // Check if the connection is idle + match status.state { + ConnState::Active(_) | ConnState::Idle(_) => { + // Keep going + } + ConnState::IdleTimeout => break, + ConnState::ReadError + | ConnState::ReadTimeout + | ConnState::WriteError => { + panic!("Should not be here"); + } + } + } + + // Send FIN + _ = write_stream.shutdown().await; + + None + } + + /// This function sends a DNS request to [InnerConnection::run]. + pub async fn query( + &self, + sender: oneshot::Sender, + query_msg: &mut MessageBuilder>>, + ) -> Result<(), &'static str> { + // We should figure out how to get query_msg. + let msg_clone = query_msg.clone(); + + let req = ChanReq { + sender, + msg: msg_clone, + }; + match self.sender.send(req).await { + Err(_) => + // Send error. The receiver is gone, this means that the + // connection is closed. + { + Err(ERR_CONN_CLOSED) + } + Ok(_) => Ok(()), + } + } + + /// This function reads a DNS message from the connection and sends + /// it to [InnerConnection::run]. + /// + /// Reading has to be done in two steps: first read a two octet value + /// the specifies the length of the message, and then read in a loop the + /// body of the message. + /// + /// This function is not async cancellation safe. + async fn reader( + //sock: &mut ReadStream, + mut sock: ReadStream, + sender: mpsc::Sender, + ) -> Result<(), std::io::Error> { + loop { + let read_res = sock.read_u16().await; + let len = match read_res { + Ok(len) => len, + Err(error) => { + return Err(error); + } + } as usize; + + let mut buf = BytesMut::with_capacity(len); + + loop { + let curlen = buf.len(); + if curlen >= len { + if curlen > len { + panic!( + "reader: got too much data {curlen}, expetect {len}"); + } + + // We got what we need + break; + } + + let read_res = sock.read_buf(&mut buf).await; + + match read_res { + Ok(readlen) => { + if readlen == 0 { + let error = io::Error::new( + io::ErrorKind::Other, + "unexpected end of data", + ); + return Err(error); + } + } + Err(error) => { + return Err(error); + } + }; + + // Check if we are done at the head of the loop + } + + let reply_message = Message::::from_octets(buf.into()); + match reply_message { + Ok(answer) => { + sender + .send(answer) + .await + .expect("can't send reply to run"); + } + Err(_) => { + // The only possible error is short message + let error = + io::Error::new(io::ErrorKind::Other, "short buf"); + return Err(error); + } + } + } + } + + /// An error occured, report the error to all outstanding [Query] objects. + fn error(error: std::io::Error, query_vec: &mut Queries) { + // Update all requests that are in progress. Don't wait for + // any reply that may be on its way. + let arc_error = Arc::new(error); + for index in 0..query_vec.vec.len() { + if query_vec.vec[index].is_some() { + let sender = Self::take_query(query_vec, index) + .expect("we tested is_none before"); + _ = sender.send(Err(arc_error.clone())); + } + } + } + + /// Handle received EDNS options, in particular the edns-tcp-keepalive + /// option. + fn handle_opts>( + opts: &OptRecord, + status: &mut Status, + ) { + for option in opts.iter().flatten() { + if let AllOptData::TcpKeepalive(tcpkeepalive) = option { + Self::handle_keepalive(tcpkeepalive, status); + } + } + } + + /// Demultiplex a DNS reply and send it to the right [Query] object. + /// + /// In addition, the status is updated to IdleTimeout or Idle if there + /// are no remaining pending requests. + fn demux_reply( + answer: Message, + status: &mut Status, + query_vec: &mut Queries, + ) { + // We got an answer, reset the timer + status.state = ConnState::Active(Some(Instant::now())); + + let ind16 = answer.header().id(); + let index: usize = ind16.into(); + + let vec_len = query_vec.vec.len(); + if index >= vec_len { + // Index is out of bouds. We should mark + // the connection as broken + return; + } + + // Do we have a query with this ID? + match &mut query_vec.vec[index] { + None => { + // No query with this ID. We should + // mark the connection as broken + return; + } + Some(_) => { + let sender = Self::take_query(query_vec, index).unwrap(); + let ind16: u16 = index.try_into().unwrap(); + let reply = Response { + reply: answer, + id: ind16, + }; + _ = sender.send(Ok(reply)); + } + } + if query_vec.count == 0 { + // Clear the activity timer. There is no need to do + // this because state will be set to either IdleTimeout + // or Idle just below. However, it is nicer to keep + // this independent. + status.state = ConnState::Active(None); + + status.state = if status.idle_timeout.is_none() { + // Assume that we can just move to IdleTimeout + // state + ConnState::IdleTimeout + } else { + ConnState::Idle(Instant::now()) + } + } + } + + /// Insert a request in query_vec and return the request to be sent + /// in *reqmsg. + /// + /// First the status is checked, an error is returned if not Active or + /// idle. Addend a edns-tcp-keepalive option if needed. + // Note: maybe reqmsg should be a return value. + fn insert_req( + &self, + mut req: ChanReq, + status: &mut Status, + reqmsg: &mut Option>, + query_vec: &mut Queries, + ) { + match status.state { + ConnState::Active(timer) => { + // Set timer if we don't have one already + if timer.is_none() { + status.state = ConnState::Active(Some(Instant::now())); + } + } + ConnState::Idle(_) => { + // Go back to active + status.state = ConnState::Active(Some(Instant::now())); + } + ConnState::IdleTimeout => { + // The connection has been closed. Report error + _ = req.sender.send(Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "idle timeout", + )))); + return; + } + ConnState::ReadError => { + _ = req.sender.send(Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "read error", + )))); + return; + } + ConnState::ReadTimeout => { + _ = req.sender.send(Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "read timeout", + )))); + return; + } + ConnState::WriteError => { + _ = req.sender.send(Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "write error", + )))); + return; + } + } + + // Note that insert may fail if there are too many + // outstanding queires. First call insert before checking + // send_keepalive. + // XXX + let index = Self::insert(req.sender, query_vec).unwrap(); + + let ind16: u16 = index.try_into().unwrap(); + + // We set the ID to the array index. Defense in depth + // suggests that a random ID is better because it works + // even if TCP sequence numbers could be predicted. However, + // Section 9.3 of RFC 5452 recommends retrying over TCP + // if many spoofed answers arrive over UDP: "TCP, by the + // nature of its use of sequence numbers, is far more + // resilient against forgery by third parties." + let hdr = req.msg.header_mut(); + hdr.set_id(ind16); + + if status.send_keepalive { + let mut msgadd = req.msg.clone().additional(); + + // send an empty keepalive option + let res = msgadd.opt(|opt| opt.tcp_keepalive(None)); + match res { + Ok(_) => { + Self::convert_query(&msgadd, reqmsg); + status.send_keepalive = false; + } + Err(_) => { + // Adding keepalive option + // failed. Send the original + // request. + Self::convert_query(&req.msg, reqmsg); + } + } + } else { + Self::convert_query(&req.msg, reqmsg); + } + } + + /// Take an element out of query_vec. + fn take_query( + query_vec: &mut Queries, + index: usize, + ) -> Option { + let query = query_vec.vec[index].take(); + query_vec.count -= 1; + query + } + + /// Handle a received edns-tcp-keepalive option. + fn handle_keepalive(opt_value: TcpKeepalive, status: &mut Status) { + if let Some(value) = opt_value.timeout() { + status.idle_timeout = Some(Duration::from_millis( + u64::from(value) * EDNS_TCP_KEEPALIVE_TO_MS, + )); + } + } + + /// Convert the query message to a vector. + // This function should return the vector instead of storing it + // through a reference. + fn convert_query + AsRef<[u8]>>( + msg: &MessageBuilder>>, + reqmsg: &mut Option>, + ) { + let vec = msg.as_target().as_target().as_stream_slice(); + + // Store a clone of the request. That makes life easier + // and requests tend to be small + *reqmsg = Some(vec.to_vec()); + } + + /// Insert a sender (for the reply) in the query_vec and return the index. + fn insert( + sender: oneshot::Sender, + query_vec: &mut Queries, + ) -> Result { + let q = Some(sender); + + // Fail if there are to many entries already in this vector + // We cannot have more than u16::MAX entries because the + // index needs to fit in an u16. For efficiency we want to + // keep the vector half empty. So we return a failure if + // 2*count > u16::MAX + if 2 * query_vec.count > u16::MAX.into() { + return Err(ERR_TOO_MANY_QUERIES); + } + + let vec_len = query_vec.vec.len(); + + // Append if the amount of empty space in the vector is less + // than half. But limit vec_len to u16::MAX + if vec_len < 2 * (query_vec.count + 1) && vec_len < u16::MAX.into() { + // Just append + query_vec.vec.push(q); + query_vec.count += 1; + let index = query_vec.vec.len() - 1; + return Ok(index); + } + let loc_curr = query_vec.curr; + + for index in loc_curr..vec_len { + if query_vec.vec[index].is_none() { + Self::insert_at(query_vec, index, q); + return Ok(index); + } + } + + // Nothing until the end of the vector. Try for the entire + // vector + for index in 0..vec_len { + if query_vec.vec[index].is_none() { + Self::insert_at(query_vec, index, q); + return Ok(index); + } + } + + // Still nothing, that is not good + panic!("insert failed"); + } + + /// Insert a sender at a specific position in query_vec and update + /// the statistics. + fn insert_at( + query_vec: &mut Queries, + index: usize, + q: Option, + ) { + query_vec.vec[index] = q; + query_vec.count += 1; + query_vec.curr = index + 1; + } +} + +impl< + Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Composer + Debug + OctetsBuilder, + > Connection +{ + /// Constructor for [Connection]. + /// + /// Returns a [Connection] wrapped in a [Result](io::Result). + pub fn new() -> io::Result> { + let connection = InnerConnection::new()?; + Ok(Self { + inner: Arc::new(connection), + }) + } + + /// Main execution function for [Connection]. + /// + /// This function has to run in the background or together with + /// any calls to [query](Self::query) or [Query::get_result]. + pub async fn run< + ReadStream: AsyncReadExt + Unpin, + WriteStream: AsyncWriteExt + Unpin, + >( + &self, + read_stream: ReadStream, + write_stream: WriteStream, + ) -> Option<()> { + self.inner.run(read_stream, write_stream).await + } + + /// Start a DNS request. + /// + /// This function takes a precomposed message as a parameter and + /// returns a [Query] object wrapped in a [Result]. + pub async fn query( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result { + let (tx, rx) = oneshot::channel(); + self.inner.query(tx, query_msg).await?; + let msg = &query_msg.as_message(); + Ok(Query::new(msg, rx)) + } +} + +impl Query { + /// Constructor for [Query], takes a DNS query and a receiver for the + /// reply. + fn new( + query_msg: &Message, + receiver: oneshot::Receiver, + ) -> Query { + let msg_ref: &[u8] = query_msg.as_ref(); + let vec = msg_ref.to_vec(); + let msg = Message::from_octets(vec).unwrap(); + Self { + query_msg: msg, + state: QueryState::Busy(receiver), + } + } + + /// Get the result of a DNS query. + /// + /// This function returns the reply to a DNS query wrapped in a + /// [Result]. + pub async fn get_result( + &mut self, + ) -> Result, Arc> { + match self.state { + QueryState::Busy(ref mut receiver) => { + let res = receiver.await; + self.state = QueryState::Done; + if res.is_err() { + // Assume receive error + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "receive error", + ))); + } + let res = res.unwrap(); + + // clippy seems to be wrong here. Replacing + // the following with 'res?;' doesn't work + #[allow(clippy::question_mark)] + if let Err(err) = res { + return Err(err); + } + + let resp = res.unwrap(); + let msg = resp.reply; + + let hdr = self.query_msg.header_mut(); + hdr.set_id(resp.id); + + if !msg.is_answer(&self.query_msg) { + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "wrong answer", + ))); + } + Ok(msg) + } + QueryState::Done => { + panic!("Already done"); + } + } + } +} From 63b203053bfb62a66efabc42d2e82bdc5ca2f49f Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Tue, 18 Apr 2023 16:51:42 +0200 Subject: [PATCH 018/124] Fix fmt problem. --- src/net/client/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index e3b2b2f10..a1c722a18 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -3,5 +3,5 @@ #![cfg_attr(docsrs, doc(cfg(feature = "net")))] pub mod octet_stream; -pub mod tcp_mutex; pub mod tcp_channel; +pub mod tcp_mutex; From 8d7fe2b1aa2a81ff35f11e817f98bbfcdfa112d3 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Wed, 19 Apr 2023 10:46:12 +0200 Subject: [PATCH 019/124] Added example. --- src/net/client/octet_stream.rs | 40 ++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 2a8b40fa7..23bd2a080 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -1,4 +1,44 @@ //! A DNS over octet stream transport +//! # Example +//! ``` +//! use domain::base::Dname; +//! use domain::base::Rtype::Aaaa; +//! use domain::base::{MessageBuilder, StaticCompressor, StreamTarget}; +//! use domain::net::client::octet_stream::Connection; +//! use std::net::{IpAddr, SocketAddr}; +//! use std::str::FromStr; +//! use tokio::net::TcpStream; +//! +//! #[tokio::main] +//! async fn main() { +//! // Create DNS request message +//! // Create a message builder wrapping a compressor wrapping a stream +//! // target. +//! let mut msg = +//! MessageBuilder::from_target(StaticCompressor::new(StreamTarget::new_vec())).unwrap(); +//! msg.header_mut().set_rd(true); +//! let mut msg = msg.question(); +//! msg.push((Dname::>::vec_from_str("example.com").unwrap(), Aaaa)) +//! .unwrap(); +//! let mut msg = msg.as_builder_mut().clone(); +//! +//! let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); +//! +//! let tcp = TcpStream::connect(server_addr).await.unwrap(); +//! let (reader, writer) = tcp.into_split(); +//! +//! let conn = Connection::new().unwrap(); +//! let conn_run = conn.clone(); +//! +//! tokio::spawn(async move { +//! conn_run.run(reader, writer).await; +//! }); +//! +//! let mut query = conn.query(&mut msg).await.unwrap(); +//! let reply = query.get_result().await; +//! println!("reply: {:?}", reply); +//! } +//! ``` #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] From fef193b896b5e69f30911a5288c4e1133650930e Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Wed, 19 Apr 2023 15:09:01 +0200 Subject: [PATCH 020/124] Run now takes a single 'io' object. Separate TCP and TLS examples. --- Cargo.toml | 3 ++ examples/tcp-client.rs | 44 +++++++++++++++++++++++ examples/tls-client.rs | 64 ++++++++++++++++++++++++++++++++++ src/net/client/octet_stream.rs | 64 +++++++--------------------------- 4 files changed, 124 insertions(+), 51 deletions(-) create mode 100644 examples/tcp-client.rs create mode 100644 examples/tls-client.rs diff --git a/Cargo.toml b/Cargo.toml index 120da3aa8..eea62fc16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,9 +57,12 @@ zonefile = ["bytes", "std"] ci-test = ["resolv", "resolv-sync", "sign", "std", "serde", "tsig", "validate", "zonefile"] [dev-dependencies] +rustls = { version = "0" } serde_test = "1.0.130" serde_yaml = "0.9" tokio = { version = "1", features = ["rt-multi-thread", "io-util", "net"] } +tokio-rustls = { version = "0" } +webpki-roots = { version = "0" } [package.metadata.docs.rs] all-features = true diff --git a/examples/tcp-client.rs b/examples/tcp-client.rs new file mode 100644 index 000000000..62e7bba56 --- /dev/null +++ b/examples/tcp-client.rs @@ -0,0 +1,44 @@ +use domain::base::Dname; +use domain::base::Rtype::Aaaa; +use domain::base::{MessageBuilder, StaticCompressor, StreamTarget}; +use domain::net::client::octet_stream::Connection; +use std::net::{IpAddr, SocketAddr}; +use std::str::FromStr; +use tokio::net::TcpStream; + +#[tokio::main] +async fn main() { + // Create DNS request message + // Create a message builder wrapping a compressor wrapping a stream + // target. + let mut msg = MessageBuilder::from_target(StaticCompressor::new( + StreamTarget::new_vec(), + )) + .unwrap(); + msg.header_mut().set_rd(true); + let mut msg = msg.question(); + msg.push((Dname::>::vec_from_str("example.com").unwrap(), Aaaa)) + .unwrap(); + let mut msg = msg.as_builder_mut().clone(); + + let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); + + let tcp = match TcpStream::connect(server_addr).await { + Err(err) => { + println!("TCP connection failed with {}", err); + return; + } + Ok(tcp) => tcp, + }; + + let conn = Connection::new().unwrap(); + let conn_run = conn.clone(); + + tokio::spawn(async move { + conn_run.run(tcp).await; + }); + + let mut query = conn.query(&mut msg).await.unwrap(); + let reply = query.get_result().await; + println!("reply: {:?}", reply); +} diff --git a/examples/tls-client.rs b/examples/tls-client.rs new file mode 100644 index 000000000..9cf80dc47 --- /dev/null +++ b/examples/tls-client.rs @@ -0,0 +1,64 @@ +use domain::base::Dname; +use domain::base::Rtype::Aaaa; +use domain::base::{MessageBuilder, StaticCompressor, StreamTarget}; +use domain::net::client::octet_stream::Connection; +use rustls::{ClientConfig, ServerName}; +use std::net::{IpAddr, SocketAddr}; +use std::str::FromStr; +use std::sync::Arc; +use tokio::net::TcpStream; +use tokio_rustls::TlsConnector; + +#[tokio::main] +async fn main() { + // Create DNS request message + // Create a message builder wrapping a compressor wrapping a stream + // target. + let mut msg = MessageBuilder::from_target(StaticCompressor::new( + StreamTarget::new_vec(), + )) + .unwrap(); + msg.header_mut().set_rd(true); + let mut msg = msg.question(); + msg.push((Dname::>::vec_from_str("example.com").unwrap(), Aaaa)) + .unwrap(); + let mut msg = msg.as_builder_mut().clone(); + + let server_addr = + SocketAddr::new(IpAddr::from_str("8.8.8.8").unwrap(), 853); + + let mut root_store = rustls::RootCertStore::empty(); + root_store.add_server_trust_anchors( + webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }), + ); + let client_config = Arc::new( + ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(), + ); + + let tls_connection = TlsConnector::from(client_config); + + let tcp = TcpStream::connect(server_addr).await.unwrap(); + + let server_name = ServerName::try_from("dns.google").unwrap(); + let tls = tls_connection.connect(server_name, tcp).await.unwrap(); + + let conn = Connection::new().unwrap(); + let conn_run = conn.clone(); + + tokio::spawn(async move { + conn_run.run(tls).await; + }); + + let mut query = conn.query(&mut msg).await.unwrap(); + let reply = query.get_result().await; + println!("reply: {:?}", reply); +} diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 23bd2a080..108022c0a 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -1,43 +1,11 @@ //! A DNS over octet stream transport -//! # Example +//! # Example with TCP connection to port 53 //! ``` -//! use domain::base::Dname; -//! use domain::base::Rtype::Aaaa; -//! use domain::base::{MessageBuilder, StaticCompressor, StreamTarget}; -//! use domain::net::client::octet_stream::Connection; -//! use std::net::{IpAddr, SocketAddr}; -//! use std::str::FromStr; -//! use tokio::net::TcpStream; -//! -//! #[tokio::main] -//! async fn main() { -//! // Create DNS request message -//! // Create a message builder wrapping a compressor wrapping a stream -//! // target. -//! let mut msg = -//! MessageBuilder::from_target(StaticCompressor::new(StreamTarget::new_vec())).unwrap(); -//! msg.header_mut().set_rd(true); -//! let mut msg = msg.question(); -//! msg.push((Dname::>::vec_from_str("example.com").unwrap(), Aaaa)) -//! .unwrap(); -//! let mut msg = msg.as_builder_mut().clone(); -//! -//! let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); -//! -//! let tcp = TcpStream::connect(server_addr).await.unwrap(); -//! let (reader, writer) = tcp.into_split(); -//! -//! let conn = Connection::new().unwrap(); -//! let conn_run = conn.clone(); -//! -//! tokio::spawn(async move { -//! conn_run.run(reader, writer).await; -//! }); -//! -//! let mut query = conn.query(&mut msg).await.unwrap(); -//! let reply = query.get_result().await; -//! println!("reply: {:?}", reply); -//! } +#![doc = include_str!("../../../examples/tcp-client.rs")] +//! ``` +//! # Example with TLS connection to port 853 +//! ``` +#![doc = include_str!("../../../examples/tls-client.rs")] //! ``` #![warn(missing_docs)] @@ -268,17 +236,15 @@ impl + Clone + Composer + Debug + OctetsBuilder> /// /// This function Gets called by [Connection::run]. /// This function is not async cancellation safe - pub async fn run< - ReadStream: AsyncReadExt + Unpin, - WriteStream: AsyncWriteExt + Unpin, - >( + pub async fn run( &self, - mut read_stream: ReadStream, - mut write_stream: WriteStream, + io: IO, ) -> Option<()> { let (reply_sender, mut reply_receiver) = mpsc::channel::(READ_REPLY_CHAN_CAP); + let (mut read_stream, mut write_stream) = tokio::io::split(io); + let reader_fut = Self::reader(&mut read_stream, reply_sender); tokio::pin!(reader_fut); @@ -836,15 +802,11 @@ impl< /// /// This function has to run in the background or together with /// any calls to [query](Self::query) or [Query::get_result]. - pub async fn run< - ReadStream: AsyncReadExt + Unpin, - WriteStream: AsyncWriteExt + Unpin, - >( + pub async fn run( &self, - read_stream: ReadStream, - write_stream: WriteStream, + io: IO, ) -> Option<()> { - self.inner.run(read_stream, write_stream).await + self.inner.run(io).await } /// Start a DNS request. From 7f455e46277c0677a0d6a053db58190591aa001e Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Wed, 19 Apr 2023 15:18:08 +0200 Subject: [PATCH 021/124] Require 'net' feature for tcp-client and tls-client examples. --- Cargo.toml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index eea62fc16..4c7210124 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,3 +88,11 @@ required-features = ["resolv-sync"] name = "client" required-features = ["std", "rand"] +[[example]] +name = "tcp-client" +required-features = ["net"] + +[[example]] +name = "tls-client" +required-features = ["net"] + From c2b1cdd365ecb21025492ebc922d6feaa3c7dd64 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Tue, 25 Apr 2023 13:43:06 +0200 Subject: [PATCH 022/124] Separate function add_tcp_keepalive to add the edns-tcp-keepalive option --- src/net/client/octet_stream.rs | 117 ++++++++++++++++++++++++++++++--- 1 file changed, 108 insertions(+), 9 deletions(-) diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 108022c0a..65e85dc41 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -36,9 +36,11 @@ use std::vec::Vec; use crate::base::wire::Composer; use crate::base::{ - opt::{AllOptData, OptRecord, TcpKeepalive}, - Message, MessageBuilder, StaticCompressor, StreamTarget, + opt::{AllOptData, Opt, OptRecord, TcpKeepalive}, + Message, MessageBuilder, ParsedDname, Rtype, StaticCompressor, + StreamTarget, }; +use crate::rdata::AllRecordData; use octseq::{Octets, OctetsBuilder}; use tokio::io; @@ -223,7 +225,7 @@ impl + Clone + Composer + Debug + OctetsBuilder> { /// Constructor for [InnerConnection]. /// - /// This is the implementation of [Connection::connect]. + /// This is the implementation of [Connection::new]. pub fn new() -> io::Result> { let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); Ok(Self { @@ -668,13 +670,11 @@ impl + Clone + Composer + Debug + OctetsBuilder> hdr.set_id(ind16); if status.send_keepalive { - let mut msgadd = req.msg.clone().additional(); + let res_msg = add_tcp_keepalive(&req.msg); - // send an empty keepalive option - let res = msgadd.opt(|opt| opt.tcp_keepalive(None)); - match res { - Ok(_) => { - Self::convert_query(&msgadd, reqmsg); + match res_msg { + Ok(msg) => { + Self::convert_query(&msg, reqmsg); status.send_keepalive = false; } Err(_) => { @@ -887,3 +887,102 @@ impl Query { } } } + +/// Add an edns-tcp-keepalive option to a MessageBuilder. +/// +/// This is surprisingly difficult. We need to copy the original message to +/// a new MessageBuilder because MessageBuilder has no support for changing the +/// opt record. +fn add_tcp_keepalive( + msg: &MessageBuilder>>, +) -> Result< + MessageBuilder>>, + crate::base::message_builder::PushError, +> { + // We can't just insert a new option in an existing + // opt record. So we have to create new message and copy records + // from the old one. And insert our option while copying the opt + // record. + let src_clone = msg.clone(); + let source = Message::from_octets( + src_clone.as_target().as_target().as_dgram_slice(), + ) + .unwrap(); + let target = msg.clone(); + + let source = source.question(); + // Go to additional and back to builder to delete all sections + // except for the header + let mut target = target.additional().builder().question(); + for rr in source { + let rr = rr.unwrap(); + target.push(rr)?; + } + let mut source = source.answer().unwrap(); + let mut target = target.answer(); + for rr in &mut source { + let rr = rr.unwrap(); + let rr = rr + .into_record::>>() + .unwrap() + .unwrap(); + target.push(rr)?; + } + + let mut source = source.next_section().unwrap().unwrap(); + let mut target = target.authority(); + for rr in &mut source { + let rr = rr.unwrap(); + let rr = rr + .into_record::>>() + .unwrap() + .unwrap(); + target.push(rr)?; + } + + let source = source.next_section().unwrap().unwrap(); + let mut target = target.additional(); + let mut found_opt_rr = false; + for rr in source { + let rr = rr.unwrap(); + if rr.rtype() == Rtype::Opt { + found_opt_rr = true; + let rr = rr.into_record::>().unwrap().unwrap(); + let opt_record = OptRecord::from_record(rr); + target + .opt(|newopt| { + newopt + .set_udp_payload_size(opt_record.udp_payload_size()); + newopt.set_version(opt_record.version()); + newopt.set_dnssec_ok(opt_record.dnssec_ok()); + for option in opt_record.iter::>() { + let option = option.unwrap(); + if let AllOptData::TcpKeepalive(_) = option { + panic!("handle keepalive"); + } else { + newopt.push(&option).unwrap(); + } + } + // send an empty keepalive option + newopt.tcp_keepalive(None).unwrap(); + Ok(()) + }) + .unwrap(); + } else { + let rr = rr + .into_record::>>() + .unwrap() + .unwrap(); + target.push(rr)?; + } + } + if !found_opt_rr { + // send an empty keepalive option + target.opt(|opt| opt.tcp_keepalive(None))?; + } + + // It would be nice to use .builder() here. But that one deletes all + // section. We have to resort to .as_builder() which gives a + // reference and then .clone() + Ok(target.as_builder().clone()) +} From 2809889a0477ac4c2ca831e29ff73caf543d1117 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 8 May 2023 10:34:39 +0200 Subject: [PATCH 023/124] Adjust to new opt interface --- src/net/client/octet_stream.rs | 15 +++++---------- src/net/client/tcp_channel.rs | 12 +++--------- src/net/client/tcp_mutex.rs | 10 +++------- 3 files changed, 11 insertions(+), 26 deletions(-) diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 65e85dc41..52e43269d 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -51,11 +51,6 @@ use tokio::time::sleep; /// Error returned when too many queries are currently active. const ERR_TOO_MANY_QUERIES: &str = "too many outstanding queries"; -/// Constant from RFC 7828. How to convert the value on the -/// edns-tcp-keepalive option to milliseconds. -// This should go somewhere with the option parsing -const EDNS_TCP_KEEPALIVE_TO_MS: u64 = 100; - /// Time to wait on a non-idle connection for the other side to send /// a response on any outstanding query. // Implement a simple response timer to see if the connection and the server @@ -533,7 +528,7 @@ impl + Clone + Composer + Debug + OctetsBuilder> opts: &OptRecord, status: &mut Status, ) { - for option in opts.iter().flatten() { + for option in opts.opt().iter().flatten() { if let AllOptData::TcpKeepalive(tcpkeepalive) = option { Self::handle_keepalive(tcpkeepalive, status); } @@ -702,9 +697,8 @@ impl + Clone + Composer + Debug + OctetsBuilder> /// Handle a received edns-tcp-keepalive option. fn handle_keepalive(opt_value: TcpKeepalive, status: &mut Status) { if let Some(value) = opt_value.timeout() { - status.idle_timeout = Some(Duration::from_millis( - u64::from(value) * EDNS_TCP_KEEPALIVE_TO_MS, - )); + let value_dur = Duration::from(value); + status.idle_timeout = Some(value_dur); } } @@ -955,7 +949,8 @@ fn add_tcp_keepalive( .set_udp_payload_size(opt_record.udp_payload_size()); newopt.set_version(opt_record.version()); newopt.set_dnssec_ok(opt_record.dnssec_ok()); - for option in opt_record.iter::>() { + for option in opt_record.opt().iter::>() + { let option = option.unwrap(); if let AllOptData::TcpKeepalive(_) = option { panic!("handle keepalive"); diff --git a/src/net/client/tcp_channel.rs b/src/net/client/tcp_channel.rs index a51bc3139..41e659842 100644 --- a/src/net/client/tcp_channel.rs +++ b/src/net/client/tcp_channel.rs @@ -43,11 +43,6 @@ use tokio::time::sleep; /// Error returned when too many queries are currently active. const ERR_TOO_MANY_QUERIES: &str = "too many outstanding queries"; -/// Constant from RFC 7828. How to convert the value on the -/// edns-tcp-keepalive option to milliseconds. -// This should go somewhere with the option parsing -const EDNS_TCP_KEEPALIE_TO_MS: u64 = 100; - /// Time to wait on a non-idle TCP connection for the other side to send /// a response on any outstanding query. // Implement a simple response timer to see if the connection and the server @@ -530,7 +525,7 @@ impl + Clone + Composer + Debug + OctetsBuilder> opts: &OptRecord, status: &mut Status, ) { - for option in opts.iter().flatten() { + for option in opts.opt().iter().flatten() { if let AllOptData::TcpKeepalive(tcpkeepalive) = option { Self::handle_keepalive(tcpkeepalive, status); } @@ -701,9 +696,8 @@ impl + Clone + Composer + Debug + OctetsBuilder> /// Handle a received edns-tcp-keepalive option. fn handle_keepalive(opt_value: TcpKeepalive, status: &mut Status) { if let Some(value) = opt_value.timeout() { - status.idle_timeout = Some(Duration::from_millis( - u64::from(value) * EDNS_TCP_KEEPALIE_TO_MS, - )); + let value_dur = Duration::from(value); + status.idle_timeout = Some(value_dur); } } diff --git a/src/net/client/tcp_mutex.rs b/src/net/client/tcp_mutex.rs index 8162705b2..0706d662e 100644 --- a/src/net/client/tcp_mutex.rs +++ b/src/net/client/tcp_mutex.rs @@ -43,9 +43,6 @@ const ERR_READ_TIMEOUT: &str = "read timeout"; const ERR_WRITE_ERROR: &str = "write error"; const ERR_TOO_MANY_QUERIES: &str = "too many outstanding queries"; -// From RFC 7828. This should go somewhere with the option parsing -const EDNS_TCP_KEEPALIE_TO_MS: u64 = 100; - // Implement a simple response timer to see if the connection and the server // are alive. Set the timer when the connection goes from idle to busy. // Reset the timer each time a reply arrives. Cancel the timer when the @@ -257,9 +254,8 @@ impl InnerTcpConnection { fn handle_keepalive(&self, opt_value: TcpKeepalive) { if let Some(value) = opt_value.timeout() { let mut status = self.status.lock().unwrap(); - status.idle_timeout = Some(Duration::from_millis( - u64::from(value) * EDNS_TCP_KEEPALIE_TO_MS, - )); + let value_dur = Duration::from(value); + status.idle_timeout = Some(value_dur); } } @@ -267,7 +263,7 @@ impl InnerTcpConnection { &self, opts: &OptRecord, ) { - for option in opts.iter() { + for option in opts.opt().iter() { let opt = option.unwrap(); if let AllOptData::TcpKeepalive(tcpkeepalive) = opt { self.handle_keepalive(tcpkeepalive); From 331d6748561fcc76d65806ce67863e22f0460c7b Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 30 Jun 2023 15:53:14 +0200 Subject: [PATCH 024/124] query_no_check (don't check if answer match query), is_answer_ignore_id, ERR_TOO_MANY_QUERIES (too many outstanding queries) --- src/net/client/octet_stream.rs | 127 ++++++++++++++++++++++++++++----- 1 file changed, 108 insertions(+), 19 deletions(-) diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 52e43269d..2baf11f92 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -89,11 +89,6 @@ struct ChanReq { #[derive(Debug)] /// a response to a [ChanReq]. struct Response { - /// The 2 octet id that went into the outgoing DNS request. - /// - /// This id is needed to match the response with the query. - id: u16, - /// The DNS reply message. reply: Message, } @@ -101,6 +96,7 @@ struct Response { type ChanResp = Result>; /// The actual implementation of [Connection]. +#[derive(Debug)] struct InnerConnection { /// [InnerConnection::sender] and [InnerConnection::receiver] are /// part of a single channel. @@ -130,7 +126,7 @@ struct Queries { vec: Vec>, } -#[derive(Clone)] +#[derive(Clone, Debug)] /// A single DNS over octect stream connection. pub struct Connection { /// Reference counted [InnerConnection]. @@ -160,6 +156,14 @@ pub struct Query { state: QueryState, } +/// This represents that state of an active DNS query if there is no need +/// to check that the reply matches the request. The assumption is that the +/// caller will do this check. +pub struct QueryNoCheck { + /// Current state of the query. + state: QueryState, +} + /// Internal datastructure of [InnerConnection::run] to keep track of /// the status of the connection. // The types Status and ConnState are only used in InnerConnection @@ -566,11 +570,7 @@ impl + Clone + Composer + Debug + OctetsBuilder> } Some(_) => { let sender = Self::take_query(query_vec, index).unwrap(); - let ind16: u16 = index.try_into().unwrap(); - let reply = Response { - reply: answer, - id: ind16, - }; + let reply = Response { reply: answer }; _ = sender.send(Ok(reply)); } } @@ -649,8 +649,17 @@ impl + Clone + Composer + Debug + OctetsBuilder> // Note that insert may fail if there are too many // outstanding queires. First call insert before checking // send_keepalive. - // XXX - let index = Self::insert(req.sender, query_vec).unwrap(); + let index = { + let res = Self::insert(req.sender, query_vec); + match res { + Err(_) => { + // insert sends an error reply, so we can just + // return here + return; + } + Ok(index) => index, + } + }; let ind16: u16 = index.try_into().unwrap(); @@ -721,17 +730,22 @@ impl + Clone + Composer + Debug + OctetsBuilder> sender: oneshot::Sender, query_vec: &mut Queries, ) -> Result { - let q = Some(sender); - // Fail if there are to many entries already in this vector // We cannot have more than u16::MAX entries because the // index needs to fit in an u16. For efficiency we want to // keep the vector half empty. So we return a failure if // 2*count > u16::MAX if 2 * query_vec.count > u16::MAX.into() { + // We own sender. So we need to send the error reply here + _ = sender.send(Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + ERR_TOO_MANY_QUERIES, + )))); return Err(ERR_TOO_MANY_QUERIES); } + let q = Some(sender); + let vec_len = query_vec.vec.len(); // Append if the amount of empty space in the vector is less @@ -816,6 +830,19 @@ impl< let msg = &query_msg.as_message(); Ok(Query::new(msg, rx)) } + + /// Start a DNS request but do not check if the reply matches the request. + /// + /// This function is similar to [Self::query]. Not checking if the reply + /// match the request avoids having to keep the request around. + pub async fn query_no_check( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result { + let (tx, rx) = oneshot::channel(); + self.inner.query(tx, query_msg).await?; + Ok(QueryNoCheck::new(rx)) + } } impl Query { @@ -864,10 +891,7 @@ impl Query { let resp = res.unwrap(); let msg = resp.reply; - let hdr = self.query_msg.header_mut(); - hdr.set_id(resp.id); - - if !msg.is_answer(&self.query_msg) { + if !is_answer_ignore_id(&msg, &self.query_msg) { return Err(Arc::new(io::Error::new( io::ErrorKind::Other, "wrong answer", @@ -882,6 +906,54 @@ impl Query { } } +impl QueryNoCheck { + /// Constructor for [Query], takes a DNS query and a receiver for the + /// reply. + fn new(receiver: oneshot::Receiver) -> QueryNoCheck { + Self { + state: QueryState::Busy(receiver), + } + } + + /// Get the result of a DNS query. + /// + /// This function returns the reply to a DNS query wrapped in a + /// [Result]. + pub async fn get_result( + &mut self, + ) -> Result, Arc> { + match self.state { + QueryState::Busy(ref mut receiver) => { + let res = receiver.await; + self.state = QueryState::Done; + if res.is_err() { + // Assume receive error + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "receive error", + ))); + } + let res = res.unwrap(); + + // clippy seems to be wrong here. Replacing + // the following with 'res?;' doesn't work + #[allow(clippy::question_mark)] + if let Err(err) = res { + return Err(err); + } + + let resp = res.unwrap(); + let msg = resp.reply; + + Ok(msg) + } + QueryState::Done => { + panic!("Already done"); + } + } + } +} + /// Add an edns-tcp-keepalive option to a MessageBuilder. /// /// This is surprisingly difficult. We need to copy the original message to @@ -981,3 +1053,20 @@ fn add_tcp_keepalive( // reference and then .clone() Ok(target.as_builder().clone()) } + +/// Check if a DNS reply match the query. Ignore whether id fields match. +fn is_answer_ignore_id< + Octs1: Octets + AsRef<[u8]>, + Octs2: Octets + AsRef<[u8]>, +>( + reply: &Message, + query: &Message, +) -> bool { + if !reply.header().qr() + || reply.header_counts().qdcount() != query.header_counts().qdcount() + { + false + } else { + reply.question() == query.question() + } +} From 99fcb0df93c3c8a3fc8e3d086b989bedd798833b Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 30 Jun 2023 15:54:49 +0200 Subject: [PATCH 025/124] Trait for connection factories --- src/net/client/factory.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 src/net/client/factory.rs diff --git a/src/net/client/factory.rs b/src/net/client/factory.rs new file mode 100644 index 000000000..ee19d4fc3 --- /dev/null +++ b/src/net/client/factory.rs @@ -0,0 +1,22 @@ +//! Trait for connection factories + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +use std::boxed::Box; +use std::future::Future; +use std::pin::Pin; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// This trait is for creating new network connections. +/// +/// The IO type is the type of the resulting connection object. +pub trait ConnFactory { + /// The next method is an asynchronous function that returns a + /// new connection. + /// + /// This method is equivalent to async fn next(&self) -> Result; + fn next( + &self, + ) -> Pin> + Send + '_>>; +} From 74147919606c8df79f8c592a6dc403cb722d8cba Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 30 Jun 2023 15:56:21 +0200 Subject: [PATCH 026/124] TCP connection factory --- src/net/client/tcp_factory.rs | 64 +++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 src/net/client/tcp_factory.rs diff --git a/src/net/client/tcp_factory.rs b/src/net/client/tcp_factory.rs new file mode 100644 index 000000000..3586db96d --- /dev/null +++ b/src/net/client/tcp_factory.rs @@ -0,0 +1,64 @@ +//! A factory for TCP connections + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +use core::ops::DerefMut; +use futures::Future; +use std::boxed::Box; +use std::fmt::Debug; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; +use tokio::net::{TcpStream, ToSocketAddrs}; + +use crate::net::client::factory::ConnFactory; + +/// This a connection factory that produces TCP connections. +pub struct TcpConnFactory { + /// Remote address to connect to. + addr: A, +} + +/// This is an internal structure that provides the future for a new +/// connection. +pub struct Next { + /// Future for creating a new TCP connection. + future: Pin< + Box> + Send>, + >, +} + +impl TcpConnFactory { + /// Create a new factory. + /// + /// addr is the destination address to connect to. + pub fn new(addr: A) -> TcpConnFactory { + Self { addr } + } +} + +impl Future for Next { + type Output = Result; + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let me = self.deref_mut(); + let io = ready!(me.future.as_mut().poll(cx))?; + Poll::Ready(Ok(io)) + } +} + +impl ConnFactory + for TcpConnFactory +{ + fn next( + &self, + ) -> Pin> + Send>> + { + Box::pin(Next { + future: Box::pin(TcpStream::connect(self.addr.clone())), + }) + } +} From 7774f0c13e44c5e2c70dbf731adfddaac6d6ac1d Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 30 Jun 2023 15:58:11 +0200 Subject: [PATCH 027/124] TLS connection factury --- src/net/client/tls_factory.rs | 96 +++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 src/net/client/tls_factory.rs diff --git a/src/net/client/tls_factory.rs b/src/net/client/tls_factory.rs new file mode 100644 index 000000000..dd346aef1 --- /dev/null +++ b/src/net/client/tls_factory.rs @@ -0,0 +1,96 @@ +//! A factory for TLS connections + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +use core::ops::DerefMut; +use futures::Future; +use std::boxed::Box; +use std::pin::Pin; +use std::string::String; +use std::sync::Arc; +use std::task::{ready, Context, Poll}; +use tokio::net::{TcpStream, ToSocketAddrs}; +use tokio_rustls::client::TlsStream; +use tokio_rustls::rustls::{ClientConfig, ServerName}; +use tokio_rustls::TlsConnector; + +use crate::net::client::factory::ConnFactory; + +/// Factory object for TLS connections +pub struct TlsConnFactory { + /// Configuration for setting up a TLS connection. + client_config: Arc, + + /// Server name for certificate verification. + server_name: String, + + /// Remote address to connect to. + addr: A, +} + +/// Internal structure that contains the future for creating a new +/// TLS connection. +pub struct Next { + /// Future for creating a new TLS connection. + future: Pin< + Box< + dyn Future, std::io::Error>> + + Send, + >, + >, +} + +impl TlsConnFactory { + /// Function to create a new TLS connection factory + pub fn new( + client_config: Arc, + server_name: &str, + addr: A, + ) -> TlsConnFactory { + Self { + client_config, + server_name: String::from(server_name), + addr, + } + } +} + +impl Future for Next { + type Output = Result, std::io::Error>; + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, std::io::Error>> { + let me = self.deref_mut(); + let io = ready!(me.future.as_mut().poll(cx))?; + Poll::Ready(Ok(io)) + } +} + +impl + ConnFactory> for TlsConnFactory +{ + fn next( + &self, + ) -> Pin< + Box< + dyn Future, std::io::Error>> + + Send + + '_, + >, + > { + let tls_connection = TlsConnector::from(self.client_config.clone()); + let server_name = + ServerName::try_from(self.server_name.as_str()).unwrap(); + let addr = self.addr.clone(); + Box::pin(Next { + future: Box::pin(async { + let box_connection = Box::new(tls_connection); + let tcp = TcpStream::connect(addr).await?; + box_connection.connect(server_name, tcp).await + }), + }) + } +} From 29cf98fa6b3ceb019467c67b63836aa20631404f Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 30 Jun 2023 15:59:22 +0200 Subject: [PATCH 028/124] A DNS over multiple octet streams transport --- src/net/client/multi_stream.rs | 697 +++++++++++++++++++++++++++++++++ 1 file changed, 697 insertions(+) create mode 100644 src/net/client/multi_stream.rs diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs new file mode 100644 index 000000000..5acaab7cf --- /dev/null +++ b/src/net/client/multi_stream.rs @@ -0,0 +1,697 @@ +//! A DNS over multiple octet streams transport + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +use bytes::Bytes; + +use futures::lock::Mutex as Futures_mutex; +use futures::stream::FuturesUnordered; +use futures::StreamExt; + +use octseq::{Octets, OctetsBuilder}; + +use rand::random; + +use std::boxed::Box; +use std::fmt::Debug; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; + +use tokio::io; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::{mpsc, oneshot}; +use tokio::time::{sleep_until, Instant}; + +use crate::base::wire::Composer; +use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::net::client::factory::ConnFactory; +use crate::net::client::octet_stream::Connection as SingleConnection; +use crate::net::client::octet_stream::QueryNoCheck as SingleQuery; + +/// Capacity of the channel that transports [ChanReq]. +const DEF_CHAN_CAP: usize = 8; + +/// Error reported when the connection is closed and +/// [InnerConnection::run] terminated. +const ERR_CONN_CLOSED: &str = "connection closed"; + +/// Response to the DNS request sent by [InnerConnection::run] to [Query]. +#[derive(Debug)] +struct ChanRespOk { + /// id of this connection. + id: u64, + + /// New octet_stream transport. + conn: SingleConnection, +} + +/// The reply to a NewConn request. +type ChanResp = Result, Arc>; + +/// This is the type of sender in [ReqCmd]. +type ReplySender = oneshot::Sender>; + +#[derive(Debug)] +/// Commands that can be requested. +enum ReqCmd { + /// Request for a (new) connection. + /// + /// The id of the previous connection (if any) is passed as well as a + /// channel to send the reply. + NewConn(Option, ReplySender), + + /// Shutdown command. + Shutdown, +} + +#[derive(Debug)] +/// A request to [Connection::run] either for a new octet_stream or to +/// shutdown. +struct ChanReq { + /// A requests consists of a command. + cmd: ReqCmd, +} + +/// The actual implementation of [Connection]. +struct InnerConnection { + /// [InnerConnection::sender] and [InnerConnection::receiver] are + /// part of a single channel. + /// + /// Used by [Query] to send requests to [InnerConnection::run]. + sender: mpsc::Sender>, + + /// receiver part of the channel. + /// + /// Protected by a mutex to allow read/write access by + /// [InnerConnection::run]. + /// The Option is to allow [InnerConnection::run] to signal that the + /// connection is closed. + receiver: Futures_mutex>>>, +} + +#[derive(Clone)] +/// A DNS over octect streams transport. +pub struct Connection { + /// Reference counted [InnerConnection]. + inner: Arc>, +} + +/// Status of a query. Used in [Query]. +enum QueryState { + /// Get a octet_stream transport. + GetConn(oneshot::Receiver>), + + /// Start a query using the transport. + StartQuery(SingleConnection), + + /// Get the result of the query. + GetResult(SingleQuery), + + /// Wait until trying again. + /// + /// The instant represents when the error occured, the duration how + /// long to wait. + Delay(Instant, Duration), + + /// The response has been received and the query is done. + Done, +} + +/// State associated with a failed attempt to create a new octet_stream +/// transport. +#[derive(Clone)] +struct ErrorState { + /// The error we got from the most recent attempt. + error: Arc, + + /// How many times we tried so far. + retries: u64, + + /// When we got an error. + timer: Instant, + + /// Time to wait before trying to create a new connection. + timeout: Duration, +} + +/// State of the current underlying octet_stream transport. +enum SingleConnState { + /// No current octet_stream transport. + None, + + /// Current octet_stream transport. + Some(SingleConnection), + + /// State that deals with an error getting a new octet stream from + /// a factory. + Err(ErrorState), +} + +/// Internal datastructure of [InnerConnection::run] to keep track of +/// the status of the connection. +// The types Status and ConnState are only used in InnerConnection +struct State<'a, F, IO, Octs: OctetsBuilder> { + /// Underlying octet_stream connection. + conn_state: SingleConnState, + + /// Current connection id. + conn_id: u64, + + /// Factory for new octet streams. + factory: F, + + /// Collection of futures for the async run function of the underlying + /// octet_stream. + runners: FuturesUnordered< + Pin> + Send + 'a>>, + >, + + /// Phantom data for type IO + phantom: PhantomData<&'a IO>, +} + +/// This struct represent an active DNS query. +pub struct Query { + /// Request message. + /// + /// The reply message is compared with the request message to see if + /// it matches the query. + // query_msg: Message>, + query_msg: MessageBuilder>>, + + /// Current state of the query. + state: QueryState, + + /// A multi_octet connection object is needed to request new underlying + /// octet_stream transport connections. + conn: Connection, + + /// id of most recent connection. + conn_id: u64, + + /// Number of retries without delay. + imm_retry_count: u16, + + /// Number of retries with delay. + delayed_retry_count: u64, +} + +impl< + Octs: 'static + + AsMut<[u8]> + + Clone + + Composer + + Debug + + OctetsBuilder + + Send, + > InnerConnection +{ + /// Constructor for [InnerConnection]. + /// + /// This is the implementation of [Connection::new]. + pub fn new() -> io::Result> { + let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); + Ok(Self { + sender: tx, + receiver: Futures_mutex::new(Some(rx)), + }) + } + + /// Main execution function for [InnerConnection]. + /// + /// This function Gets called by [Connection::run]. + /// This function is not async cancellation safe + pub async fn run< + 'a, + F: ConnFactory + Send, + IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Unpin, + >( + &self, + factory: F, + ) -> Option<()> { + let mut receiver = { + let mut locked_opt_receiver = self.receiver.lock().await; + let opt_receiver = locked_opt_receiver.take(); + opt_receiver.expect("no receiver present?") + }; + let mut curr_cmd: Option> = None; + + let mut state = State::<'a, F, IO, Octs> { + conn_state: SingleConnState::None, + conn_id: 0, + factory, + runners: FuturesUnordered::< + Pin> + Send>>, + >::new(), + phantom: PhantomData, + }; + + let mut do_stream = false; + let mut stream_fut: Pin< + Box> + Send>, + > = Box::pin(factory_nop()); + let mut opt_chan = None; + + loop { + if let Some(req) = curr_cmd { + assert!(!do_stream); + curr_cmd = None; + match req { + ReqCmd::NewConn(opt_id, chan) => { + if let SingleConnState::Err(error_state) = + &state.conn_state + { + if error_state.timer.elapsed() + < error_state.timeout + { + let resp = + ChanResp::Err(error_state.error.clone()); + + // Ignore errors. We don't care if the receiver + // is gone + _ = chan.send(resp); + continue; + } + + // Try to set up a new connection + } + + // Check if the command has an id greather than the + // current id. + if let Some(id) = opt_id { + if id >= state.conn_id { + // We need a new connection. Remove the + // current one. This is the best place to + // increment conn_id. + state.conn_id += 1; + state.conn_state = SingleConnState::None; + } + } + // If we still have a connection then we can reply + // immediately. + if let SingleConnState::Some(conn) = &state.conn_state + { + let resp = ChanResp::Ok(ChanRespOk { + id: state.conn_id, + conn: conn.clone(), + }); + // Ignore errors. We don't care if the receiver + // is gone + _ = chan.send(resp); + } else { + opt_chan = Some(chan); + stream_fut = Box::pin(state.factory.next()); + do_stream = true; + } + } + ReqCmd::Shutdown => break, + } + } + + if do_stream { + let runners_empty = state.runners.is_empty(); + + loop { + tokio::select! { + res_conn = stream_fut.as_mut() => { + do_stream = false; + stream_fut = Box::pin(factory_nop()); + + if let Err(error) = res_conn { + let error = Arc::new(error); + match state.conn_state { + SingleConnState::None => + state.conn_state = + SingleConnState::Err(ErrorState { + error: error.clone(), + retries: 0, + timer: Instant::now(), + timeout: retry_time(0), + }), + SingleConnState::Some(_) => + panic!("Illegal Some state"), + SingleConnState::Err(error_state) => { + state.conn_state = + SingleConnState::Err(ErrorState { + error: error_state.error.clone(), + retries: error_state.retries+1, + timer: Instant::now(), + timeout: retry_time( + error_state.retries+1), + + }); + } + } + + let resp = ChanResp::Err(error); + let loc_opt_chan = opt_chan.take(); + + // Ignore errors. We don't care if the receiver + // is gone + _ = loc_opt_chan.unwrap().send(resp); + break; + } + + let stream = res_conn.unwrap(); + let conn = SingleConnection::new().unwrap(); + let conn_run = conn.clone(); + + let clo = || async move { conn_run.run(stream).await }; + let fut = clo(); + state.runners.push(Box::pin(fut)); + + let resp = ChanResp::Ok(ChanRespOk { id: state.conn_id, conn: conn.clone(), }); + state.conn_state = SingleConnState::Some(conn); + + let loc_opt_chan = opt_chan.take(); + + // Ignore errors. We don't care if the receiver + // is gone + _ = loc_opt_chan.unwrap().send(resp); + break; + } + _ = state.runners.next(), if !runners_empty => { + } + } + } + continue; + } + + assert!(curr_cmd.is_none()); + let recv_fut = receiver.recv(); + let runners_empty = state.runners.is_empty(); + tokio::select! { + msg = recv_fut => { + if msg.is_none() { + panic!("recv failed"); + } + curr_cmd = Some(msg.unwrap().cmd); + } + _ = state.runners.next(), if !runners_empty => { + } + } + } + + // Avoid new queries + drop(receiver); + + // Wait for existing octet_stream runners to terminate + while !state.runners.is_empty() { + state.runners.next().await; + } + + // Done + Some(()) + } + + /// Request a new connection. + async fn new_conn( + &self, + opt_id: Option, + sender: oneshot::Sender>, + ) -> Result<(), &'static str> { + let req = ChanReq { + cmd: ReqCmd::NewConn(opt_id, sender), + }; + match self.sender.send(req).await { + Err(_) => + // Send error. The receiver is gone, this means that the + // connection is closed. + { + Err(ERR_CONN_CLOSED) + } + Ok(_) => Ok(()), + } + } + + /// Request a shutdown. + async fn shutdown(&self) -> Result<(), &'static str> { + let req = ChanReq { + cmd: ReqCmd::Shutdown, + }; + match self.sender.send(req).await { + Err(_) => + // Send error. The receiver is gone, this means that the + // connection is closed. + { + Err(ERR_CONN_CLOSED) + } + Ok(_) => Ok(()), + } + } +} + +impl< + Octs: 'static + + AsMut<[u8]> + + AsRef<[u8]> + + Clone + + Composer + + Debug + + OctetsBuilder + + Send, + > Connection +{ + /// Constructor for [Connection]. + /// + /// Returns a [Connection] wrapped in a [Result](io::Result). + pub fn new() -> io::Result> { + let connection = InnerConnection::new()?; + Ok(Self { + inner: Arc::new(connection), + }) + } + + /// Main execution function for [Connection]. + /// + /// This function has to run in the background or together with + /// any calls to [query](Self::query) or [Query::get_result]. + pub async fn run< + F: ConnFactory + Send, + IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Unpin, + >( + &self, + factory: F, + ) -> Option<()> { + self.inner.run(factory).await + } + + /// Start a DNS request. + /// + /// This function takes a precomposed message as a parameter and + /// returns a [Query] object wrapped in a [Result]. + pub async fn query( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result, &'static str> { + let (tx, rx) = oneshot::channel(); + self.inner.new_conn(None, tx).await?; + Ok(Query::new(self.clone(), query_msg, rx)) + } + + /// Shutdown this transport. + pub async fn shutdown(&self) -> Result<(), &'static str> { + self.inner.shutdown().await + } + + /// Request a new connection. + async fn new_conn( + &self, + id: u64, + tx: oneshot::Sender>, + ) -> Result<(), &'static str> { + self.inner.new_conn(Some(id), tx).await + } +} + +impl< + Octs: AsRef<[u8]> + + AsMut<[u8]> + + Composer + + OctetsBuilder + + Clone + + Debug + + Send + + 'static, + > Query +{ + /// Constructor for [Query], takes a DNS query and a receiver for the + /// reply. + fn new( + conn: Connection, + query_msg: &mut MessageBuilder>>, + receiver: oneshot::Receiver>, + ) -> Query { + Self { + conn, + query_msg: query_msg.clone(), + state: QueryState::GetConn(receiver), + conn_id: 0, + imm_retry_count: 0, + delayed_retry_count: 0, + } + } + + /// Get the result of a DNS query. + /// + /// This function returns the reply to a DNS query wrapped in a + /// [Result]. + pub async fn get_result( + &mut self, + ) -> Result, Arc> { + loop { + match self.state { + QueryState::GetConn(ref mut receiver) => { + let res = receiver.await; + if res.is_err() { + // Assume receive error + self.state = QueryState::Done; + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "receive error", + ))); + } + let res = res.unwrap(); + + // Another Result. This time from executing the request + match res { + Err(_) => { + self.delayed_retry_count += 1; + let retry_time = + retry_time(self.delayed_retry_count); + self.state = + QueryState::Delay(Instant::now(), retry_time); + continue; + } + Ok(ok_res) => { + let id = ok_res.id; + let conn = ok_res.conn; + + self.conn_id = id; + self.state = QueryState::StartQuery(conn); + continue; + } + } + } + QueryState::StartQuery(ref mut conn) => { + let mut msg = self.query_msg.clone(); + let query_res = conn.query_no_check(&mut msg).await; + match query_res { + Err(err) => { + if err == ERR_CONN_CLOSED { + let (tx, rx) = oneshot::channel(); + let res = self + .conn + .new_conn(self.conn_id, tx) + .await; + if let Err(err) = res { + self.state = QueryState::Done; + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + err, + ))); + } + self.state = QueryState::GetConn(rx); + continue; + } + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + err, + ))); + } + Ok(query) => { + self.state = QueryState::GetResult(query); + continue; + } + } + } + QueryState::GetResult(ref mut query) => { + let reply = query.get_result().await; + + if reply.is_err() { + self.delayed_retry_count += 1; + let retry_time = retry_time(self.delayed_retry_count); + self.state = + QueryState::Delay(Instant::now(), retry_time); + continue; + } + + let msg = reply.unwrap(); + let query_msg_ref: &[u8] = self.query_msg.as_ref(); + let query_msg_vec = query_msg_ref.to_vec(); + let mut query_msg = + Message::from_octets(query_msg_vec).unwrap(); + let hdr = query_msg.header_mut(); + hdr.set_id(msg.header().id()); + + if !is_answer_ignore_id(&msg, &query_msg) { + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "wrong answer", + ))); + } + return Ok(msg); + } + QueryState::Delay(instant, duration) => { + sleep_until(instant + duration).await; + let (tx, rx) = oneshot::channel(); + let res = self.conn.new_conn(self.conn_id, tx).await; + if let Err(err) = res { + self.state = QueryState::Done; + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + err, + ))); + } + self.state = QueryState::GetConn(rx); + continue; + } + QueryState::Done => { + panic!("Already done"); + } + } + } + } +} + +/// Compute the retry timeout based on the number of retries so far. +/// +/// The computation is a random value (in microseconds) between zero and +/// two to the power of the number of retries. +fn retry_time(retries: u64) -> Duration { + let to_secs = if retries > 6 { 60 } else { 1 << retries }; + let to_usecs = to_secs * 1000000; + let rnd: f64 = random(); + let to_usecs = to_usecs as f64 * rnd; + Duration::from_micros(to_usecs as u64) +} + +/// Check if a message is the reply to a query. +/// +/// Avoid checking the id field because the id has been changed in the +/// query that was actually issued. +fn is_answer_ignore_id< + Octs1: Octets + AsRef<[u8]>, + Octs2: Octets + AsRef<[u8]>, +>( + reply: &Message, + query: &Message, +) -> bool { + if !reply.header().qr() + || reply.header_counts().qdcount() != query.header_counts().qdcount() + { + false + } else { + reply.question() == query.question() + } +} + +/// Helper function to create an empty future that is compatible with the +/// future return by a factory. +async fn factory_nop() -> Result { + Err(io::Error::new(io::ErrorKind::Other, "nop")) +} From e49b2c393ef2117adf10a93f03bc5fddc30f4c42 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 30 Jun 2023 15:59:58 +0200 Subject: [PATCH 029/124] Add factory, multi_stream, tcp_factory, tls_factory. --- src/net/client/mod.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index a1c722a18..8a429ed80 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -2,6 +2,10 @@ #![cfg(feature = "net")] #![cfg_attr(docsrs, doc(cfg(feature = "net")))] +pub mod factory; +pub mod multi_stream; pub mod octet_stream; pub mod tcp_channel; +pub mod tcp_factory; pub mod tcp_mutex; +pub mod tls_factory; From c26843731b29a654ff5a336fcc82e028d4dbc81f Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 30 Jun 2023 16:06:33 +0200 Subject: [PATCH 030/124] Add tokio-rustls and add rt-multi-thread to tokio --- Cargo.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4c7210124..1ef0f3752 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,8 @@ ring = { version = "0.17", optional = true } serde = { version = "1.0.130", optional = true, features = ["derive"] } siphasher = { version = "1", optional = true } smallvec = { version = "1", optional = true } -tokio = { version = "1.0", optional = true, features = ["io-util", "macros", "net", "time", "sync" ] } +tokio = { version = "1.0", optional = true, features = ["io-util", "macros", "net", "time", "sync", "rt-multi-thread" ] } +tokio-rustls = { version = "0", optional = true, features = [] } [target.'cfg(macos)'.dependencies] # specifying this overrides minimum-version mio's 0.2.69 libc dependency, which allows the build to work @@ -47,7 +48,7 @@ serde = ["dep:serde", "octseq/serde"] sign = ["std"] smallvec = ["dep:smallvec", "octseq/smallvec"] std = [] -net = ["bytes", "futures", "tokio"] +net = ["bytes", "futures", "tokio", "tokio-rustls"] tsig = ["bytes", "ring", "smallvec"] validate = ["std", "ring"] zonefile = ["bytes", "std"] From 7ca9893ab86f2c6214441bf57076292129f81141 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 30 Jun 2023 16:16:58 +0200 Subject: [PATCH 031/124] Disable imm_retry_count for the moment --- src/net/client/multi_stream.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 5acaab7cf..2cf07648f 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -193,9 +193,8 @@ pub struct Query { /// id of most recent connection. conn_id: u64, - /// Number of retries without delay. - imm_retry_count: u16, - + // /// Number of retries without delay. + // imm_retry_count: u16, /// Number of retries with delay. delayed_retry_count: u64, } @@ -531,7 +530,7 @@ impl< query_msg: query_msg.clone(), state: QueryState::GetConn(receiver), conn_id: 0, - imm_retry_count: 0, + //imm_retry_count: 0, delayed_retry_count: 0, } } From 3f98ce7817941e481c3b30e5fffbc6c89ff90890 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 3 Jul 2023 15:32:39 +0200 Subject: [PATCH 032/124] replace unwrap with expect and layout --- src/net/client/multi_stream.rs | 137 ++++++++++++++++++--------------- 1 file changed, 73 insertions(+), 64 deletions(-) diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 2cf07648f..014991e74 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -224,6 +224,8 @@ impl< /// /// This function Gets called by [Connection::run]. /// This function is not async cancellation safe + #[rustfmt::skip] + pub async fn run< 'a, F: ConnFactory + Send, @@ -316,66 +318,75 @@ impl< loop { tokio::select! { - res_conn = stream_fut.as_mut() => { - do_stream = false; - stream_fut = Box::pin(factory_nop()); - - if let Err(error) = res_conn { - let error = Arc::new(error); - match state.conn_state { - SingleConnState::None => - state.conn_state = - SingleConnState::Err(ErrorState { - error: error.clone(), - retries: 0, - timer: Instant::now(), - timeout: retry_time(0), - }), - SingleConnState::Some(_) => - panic!("Illegal Some state"), - SingleConnState::Err(error_state) => { - state.conn_state = - SingleConnState::Err(ErrorState { - error: error_state.error.clone(), - retries: error_state.retries+1, - timer: Instant::now(), - timeout: retry_time( - error_state.retries+1), - - }); + res_conn = stream_fut.as_mut() => { + do_stream = false; + stream_fut = Box::pin(factory_nop()); + + if let Err(error) = res_conn { + let error = Arc::new(error); + match state.conn_state { + SingleConnState::None => + state.conn_state = + SingleConnState::Err(ErrorState { + error: error.clone(), + retries: 0, + timer: Instant::now(), + timeout: retry_time(0), + }), + SingleConnState::Some(_) => + panic!("Illegal Some state"), + SingleConnState::Err(error_state) => { + state.conn_state = + SingleConnState::Err(ErrorState { + error: error_state.error.clone(), + retries: error_state.retries+1, + timer: Instant::now(), + timeout: retry_time( + error_state.retries+1), + }); + } } - } - let resp = ChanResp::Err(error); - let loc_opt_chan = opt_chan.take(); + let resp = ChanResp::Err(error); + let loc_opt_chan = opt_chan.take(); - // Ignore errors. We don't care if the receiver - // is gone - _ = loc_opt_chan.unwrap().send(resp); - break; - } + // Ignore errors. We don't care if the receiver + // is gone + _ = loc_opt_chan.expect("weird, no channel?") + .send(resp); + break; + } - let stream = res_conn.unwrap(); - let conn = SingleConnection::new().unwrap(); - let conn_run = conn.clone(); + let stream = res_conn + .expect("error case is checked before"); + let conn = SingleConnection::new() + .expect( + "the connect implementation cannot fail"); + let conn_run = conn.clone(); - let clo = || async move { conn_run.run(stream).await }; - let fut = clo(); - state.runners.push(Box::pin(fut)); + let clo = || async move { + conn_run.run(stream).await + }; + let fut = clo(); + state.runners.push(Box::pin(fut)); - let resp = ChanResp::Ok(ChanRespOk { id: state.conn_id, conn: conn.clone(), }); - state.conn_state = SingleConnState::Some(conn); + let resp = ChanResp::Ok(ChanRespOk { + id: state.conn_id, + conn: conn.clone(), + }); + state.conn_state = SingleConnState::Some(conn); - let loc_opt_chan = opt_chan.take(); + let loc_opt_chan = opt_chan.take(); - // Ignore errors. We don't care if the receiver - // is gone - _ = loc_opt_chan.unwrap().send(resp); - break; - } - _ = state.runners.next(), if !runners_empty => { - } + // Ignore errors. We don't care if the receiver + // is gone + _ = loc_opt_chan.expect("weird, no channel?") + .send(resp); + break; } + _ = state.runners.next(), if !runners_empty => { + } + } } continue; } @@ -384,15 +395,15 @@ impl< let recv_fut = receiver.recv(); let runners_empty = state.runners.is_empty(); tokio::select! { - msg = recv_fut => { - if msg.is_none() { - panic!("recv failed"); + msg = recv_fut => { + if msg.is_none() { + panic!("recv failed"); + } + curr_cmd = Some(msg.expect("None is checked before").cmd); } - curr_cmd = Some(msg.unwrap().cmd); + _ = state.runners.next(), if !runners_empty => { + } } - _ = state.runners.next(), if !runners_empty => { - } - } } // Avoid new queries @@ -554,7 +565,7 @@ impl< "receive error", ))); } - let res = res.unwrap(); + let res = res.expect("error is checked before"); // Another Result. This time from executing the request match res { @@ -619,13 +630,11 @@ impl< continue; } - let msg = reply.unwrap(); + let msg = reply.expect("error is checked before"); let query_msg_ref: &[u8] = self.query_msg.as_ref(); let query_msg_vec = query_msg_ref.to_vec(); - let mut query_msg = - Message::from_octets(query_msg_vec).unwrap(); - let hdr = query_msg.header_mut(); - hdr.set_id(msg.header().id()); + let query_msg = Message::from_octets(query_msg_vec) + .expect("how to go from MessageBuild to Message?"); if !is_answer_ignore_id(&msg, &query_msg) { return Err(Arc::new(io::Error::new( From eec1f85023b16d9851841f0b51e6802de1af1928 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 20 Jul 2023 15:29:49 +0200 Subject: [PATCH 033/124] Ignore existing TcpKeepalive --- src/net/client/octet_stream.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 2baf11f92..74d25bb7c 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -1025,7 +1025,7 @@ fn add_tcp_keepalive( { let option = option.unwrap(); if let AllOptData::TcpKeepalive(_) = option { - panic!("handle keepalive"); + // Ignore existing TcpKeepalive } else { newopt.push(&option).unwrap(); } From fc80f80d2057d4957c2f34cce6a98e483750679d Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 7 Aug 2023 10:29:58 +0200 Subject: [PATCH 034/124] UDP and UDP+TCP client transports --- src/net/client/mod.rs | 2 + src/net/client/udp.rs | 290 ++++++++++++++++++++++++++++++++++++++ src/net/client/udp_tcp.rs | 211 +++++++++++++++++++++++++++ 3 files changed, 503 insertions(+) create mode 100644 src/net/client/udp.rs create mode 100644 src/net/client/udp_tcp.rs diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 8a429ed80..61cf20ebb 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -9,3 +9,5 @@ pub mod tcp_channel; pub mod tcp_factory; pub mod tcp_mutex; pub mod tls_factory; +pub mod udp; +pub mod udp_tcp; diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs new file mode 100644 index 000000000..5cb94e81f --- /dev/null +++ b/src/net/client/udp.rs @@ -0,0 +1,290 @@ +//! A DNS over UDP transport + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +// To do: +// - cookies +// - random port + +use bytes::Bytes; +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::UdpSocket; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio::time::{timeout, Duration, Instant}; + +use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; + +/// How many times do we try a new random port if we get ‘address in use.’ +const RETRY_RANDOM_PORT: usize = 10; + +/// Maximum number of parallel DNS query over a single UDP transport +/// connection. +const MAX_PARALLEL: usize = 100; + +/// Maximum amount of time to wait for a reply. +const READ_TIMEOUT: Duration = Duration::from_secs(5); + +/// Maximum number of retries after timeouts. +const MAX_RETRIES: u8 = 5; + +/// A UDP transport connection. +#[derive(Clone)] +pub struct Connection { + /// Reference to the actual connection object. + inner: Arc, +} + +impl Connection { + /// Create a new UDP transport connection. + pub fn new(remote_addr: SocketAddr) -> io::Result { + let connection = InnerConnection::new(remote_addr)?; + Ok(Self { + inner: Arc::new(connection), + }) + } + + /// Start a new DNS query. + pub async fn query + Clone>( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result, &'static str> { + self.inner.query(query_msg, self.clone()).await + } + + /// Get a permit from the semaphore to start using a socket. + async fn get_permit(&self) -> OwnedSemaphorePermit { + self.inner.get_permit().await + } +} + +/// State of the DNS query. +enum QueryState { + /// Get a semaphore permit. + GetPermit(Connection), + + /// Get a UDP socket. + GetSocket, + + /// Connect the socket. + Connect, + + /// Send the request. + Send, + + /// Receive the reply. + Receive(Instant), +} + +/// The state of a DNS query. +pub struct Query { + /// Address of remote server to connect to. + remote_addr: SocketAddr, + + /// DNS request message. + query_msg: MessageBuilder>>, + + /// Semaphore permit that allow use of socket. + _permit: Option, + + /// UDP socket for communication. + sock: Option, + + /// Current number of retries. + retries: u8, + + /// State of query. + state: QueryState, +} + +impl + Clone> Query { + /// Create new Query object. + fn new( + query_msg: &mut MessageBuilder>>, + remote_addr: SocketAddr, + conn: Connection, + ) -> Query { + Query { + query_msg: query_msg.clone(), + remote_addr, + _permit: None, + sock: None, + retries: 0, + state: QueryState::GetPermit(conn), + } + } + + /// Get the result of a DNS Query. + /// + /// This function is cancel safe. + pub async fn get_result( + &mut self, + ) -> Result, Arc> { + let recv_size = 2000; // Should be configurable. + + loop { + match &self.state { + QueryState::GetPermit(conn) => { + // We need to get past the semaphore that limits the + // number of concurrent sockets we can use. + let permit = conn.get_permit().await; + self._permit = Some(permit); + self.state = QueryState::GetSocket; + continue; + } + QueryState::GetSocket => { + self.sock = Some( + Self::udp_bind(self.remote_addr.is_ipv4()).await?, + ); + self.state = QueryState::Connect; + continue; + } + QueryState::Connect => { + self.sock + .as_ref() + .expect("socket should be present") + .connect(self.remote_addr) + .await?; + self.state = QueryState::Send; + continue; + } + QueryState::Send => { + let sent = self + .sock + .as_ref() + .expect("socket should be present") + .send( + self.query_msg + .as_target() + .as_target() + .as_dgram_slice(), + ) + .await?; + if sent + != self + .query_msg + .as_target() + .as_target() + .as_dgram_slice() + .len() + { + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "short UDP send", + ))); + } + self.state = QueryState::Receive(Instant::now()); + continue; + } + QueryState::Receive(start) => { + let elapsed = start.elapsed(); + if elapsed > READ_TIMEOUT { + todo!(); + } + let remain = READ_TIMEOUT - elapsed; + + let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here. + let timeout_res = timeout( + remain, + self.sock + .as_ref() + .expect("socket should be present") + .recv(&mut buf), + ) + .await; + if timeout_res.is_err() { + self.retries += 1; + if self.retries < MAX_RETRIES { + self.sock = None; + self.state = QueryState::GetSocket; + continue; + } + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "no response", + ))); + } + let len = + timeout_res.expect("errror case is checked above")?; + buf.truncate(len); + + // We ignore garbage since there is a timer on this whole thing. + let answer = match Message::from_octets(buf.into()) { + Ok(answer) => answer, + Err(_) => continue, + }; + if !answer.is_answer(&self.query_msg.as_message()) { + continue; + } + self.sock = None; + self._permit = None; + return Ok(answer); + } + } + } + } + + /// Bind to a local UDP port. + /// + /// This should explicitly pick a random number in a suitable range of + /// ports. + async fn udp_bind(v4: bool) -> Result { + let mut i = 0; + loop { + let local: SocketAddr = if v4 { + ([0u8; 4], 0).into() + } else { + ([0u16; 8], 0).into() + }; + match UdpSocket::bind(&local).await { + Ok(sock) => return Ok(sock), + Err(err) => { + if i == RETRY_RANDOM_PORT { + return Err(err); + } else { + i += 1 + } + } + } + } + } +} + +/// Actual implementation of the UDP transport connection. +struct InnerConnection { + /// Address of the remote server. + remote_addr: SocketAddr, + + /// Semaphore to limit access to UDP sockets. + semaphore: Arc, +} + +impl InnerConnection { + /// Create new InnerConnection object. + fn new(remote_addr: SocketAddr) -> io::Result { + Ok(Self { + remote_addr, + semaphore: Arc::new(Semaphore::new(MAX_PARALLEL)), + }) + } + + /// Return a Query object that contains the query state. + async fn query + Clone>( + &self, + query_msg: &mut MessageBuilder>>, + conn: Connection, + ) -> Result, &'static str> { + Ok(Query::new(query_msg, self.remote_addr, conn)) + } + + /// Return a permit for a our semaphore. + async fn get_permit(&self) -> OwnedSemaphorePermit { + self.semaphore + .clone() + .acquire_owned() + .await + .expect("the semaphore has not been closed") + } +} diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs new file mode 100644 index 000000000..69fa63c87 --- /dev/null +++ b/src/net/client/udp_tcp.rs @@ -0,0 +1,211 @@ +//! A UDP transport that falls back to TCP if the reply is truncated + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +// To do: +// - handle shutdown + +use bytes::Bytes; +use octseq::OctetsBuilder; +use std::fmt::Debug; +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; + +use crate::base::wire::Composer; +use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::net::client::multi_stream; +use crate::net::client::tcp_factory::TcpConnFactory; +use crate::net::client::udp; + +/// DNS transport connection that first issue a query over a UDP transport and +/// falls back to TCP if the reply is truncated. +#[derive(Clone)] +pub struct Connection { + /// Reference to the real object that provides the connection. + inner: Arc>, +} + +impl + Connection +{ + /// Create a new connection. + pub fn new(remote_addr: SocketAddr) -> io::Result> { + let connection = InnerConnection::new(remote_addr)?; + Ok(Self { + inner: Arc::new(connection), + }) + } + + /// Worker function for a connection object. + pub async fn run(&self) -> Option<()> { + self.inner.run().await + } + + /// Start a query. + pub async fn query( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result, &'static str> { + self.inner.query(query_msg).await + } +} + +/// Object that contains the current state of a query. +pub struct Query { + /// Reqeust message. + query_msg: MessageBuilder>>, + + /// UDP transport to be used. + udp_conn: udp::Connection, + + /// TCP transport to be used. + tcp_conn: multi_stream::Connection, + + /// Current state of the query. + state: QueryState, +} + +/// Status of the query. +enum QueryState { + /// Start a query over the UDP transport. + StartUdpQuery, + + /// Get the result from the UDP transport. + GetUdpResult(udp::Query), + + /// Start a query over the TCP transport. + StartTcpQuery, + + /// Get the result from the TCP transport. + GetTcpResult(multi_stream::Query), +} + +impl< + Octs: AsMut<[u8]> + + AsRef<[u8]> + + Clone + + Composer + + Debug + + OctetsBuilder + + Send + + 'static, + > Query +{ + /// Create a new Query object. + /// + /// The initial state is to start with a UDP transport. + fn new( + query_msg: &mut MessageBuilder>>, + udp_conn: udp::Connection, + tcp_conn: multi_stream::Connection, + ) -> Query { + Query { + query_msg: query_msg.clone(), + udp_conn, + tcp_conn, + state: QueryState::StartUdpQuery, + } + } + + /// Get the result of a DNS query. + /// + /// This function is cancel safe. + pub async fn get_result( + &mut self, + ) -> Result, Arc> { + loop { + match &mut self.state { + QueryState::StartUdpQuery => { + let query = self + .udp_conn + .query(&mut self.query_msg.clone()) + .await + .map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + })?; + self.state = QueryState::GetUdpResult(query); + continue; + } + QueryState::GetUdpResult(ref mut query) => { + let reply = query.get_result().await?; + if reply.header().tc() { + self.state = QueryState::StartTcpQuery; + continue; + } + return Ok(reply); + } + QueryState::StartTcpQuery => { + let query = self + .tcp_conn + .query(&mut self.query_msg.clone()) + .await + .map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + })?; + self.state = QueryState::GetTcpResult(query); + continue; + } + QueryState::GetTcpResult(ref mut query) => { + let reply = query.get_result().await?; + return Ok(reply); + } + } + } + } +} + +/// The actual connection object. +struct InnerConnection { + /// The remote address to connect to. + remote_addr: SocketAddr, + + /// The UDP transport connection. + udp_conn: udp::Connection, + + /// The TCP transport connection. + tcp_conn: multi_stream::Connection, +} + +impl + InnerConnection +{ + /// Create a new InnerConnection object. + /// + /// Create the UDP and TCP connections. Store the remote address because + /// run needs it later. + fn new(remote_addr: SocketAddr) -> io::Result> { + let udp_conn = udp::Connection::new(remote_addr)?; + let tcp_conn = multi_stream::Connection::new()?; + + Ok(Self { + remote_addr, + udp_conn, + tcp_conn, + }) + } + + /// Implementation of the worker function. + /// + /// Create a TCP connection factory and pass that to worker function + /// of the multi_stream object. + pub async fn run(&self) -> Option<()> { + let tcp_factory = TcpConnFactory::new(self.remote_addr); + self.tcp_conn.run(tcp_factory).await + } + + /// Implementation of the query function. + /// + /// Just create a Query object with the state it needs. + async fn query( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result, &'static str> { + Ok(Query::new( + query_msg, + self.udp_conn.clone(), + self.tcp_conn.clone(), + )) + } +} From bc60041bbd05da43aa12e57a5385d52d68e24228 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Wed, 9 Aug 2023 17:03:19 +0200 Subject: [PATCH 035/124] Traits for queries and an error type. --- src/net/client/error.rs | 126 ++++++++++++++++++++++++++ src/net/client/mod.rs | 2 + src/net/client/multi_stream.rs | 78 +++++++++++------ src/net/client/octet_stream.rs | 156 +++++++++++++++------------------ src/net/client/query.rs | 36 ++++++++ src/net/client/udp.rs | 62 ++++++++----- src/net/client/udp_tcp.rs | 57 +++++++++--- 7 files changed, 369 insertions(+), 148 deletions(-) create mode 100644 src/net/client/error.rs create mode 100644 src/net/client/query.rs diff --git a/src/net/client/error.rs b/src/net/client/error.rs new file mode 100644 index 000000000..bcd751eed --- /dev/null +++ b/src/net/client/error.rs @@ -0,0 +1,126 @@ +//! Error type for client transports. + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +use std::error; +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +/// Error type for client transports. +#[derive(Clone, Debug)] +pub enum Error { + /// Connection was already closed. + ConnectionClosed, + + /// Octet sequence too short to be a valid DNS message. + ShortMessage, + + /// Stream transport closed because it was idle (for too long). + StreamIdleTimeout, + + /// Error receiving a reply. + StreamReceiveError, + + /// Reading from stream gave an error. + StreamReadError(Arc), + + /// Reading from stream took too long. + StreamReadTimeout, + + /// Too many outstand queries on a single stream transport. + StreamTooManyOutstandingQueries, + + /// Writing to a stream gave an error. + StreamWriteError(Arc), + + /// Reading for a stream ended unexpectedly. + StreamUnexpectedEndOfData, + + /// Binding a UDP socket gave an error. + UdpBind(Arc), + + /// Connecting a UDP socket gave an error. + UdpConnect(Arc), + + /// Receiving from a UDP socket gave an error. + UdpReceive(Arc), + + /// Sending over a UDP socket gaven an error. + UdpSend(Arc), + + /// Sending over a UDP socket gave a partial result. + UdpShortSend, + + /// Timeout receiving a response over a UDP socket. + UdpTimeoutNoResponse, + + /// Reply does not match the query. + WrongReplyForQuery, +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + match self { + Error::ConnectionClosed => write!(f, "connection closed"), + Error::ShortMessage => { + write!(f, "octet sequence to short to be a valid message") + } + Error::StreamIdleTimeout => { + write!(f, "stream was idle for too long") + } + Error::StreamReceiveError => write!(f, "error receiving a reply"), + Error::StreamReadError(_) => { + write!(f, "error reading from stream") + } + Error::StreamReadTimeout => { + write!(f, "timeout reading from stream") + } + Error::StreamTooManyOutstandingQueries => { + write!(f, "too many outstanding queries on stream") + } + Error::StreamWriteError(_) => { + write!(f, "error writing to stream") + } + Error::StreamUnexpectedEndOfData => { + write!(f, "unexpected end of data") + } + Error::UdpBind(_) => write!(f, "error binding UDP socket"), + Error::UdpConnect(_) => write!(f, "error connecting UDP socket"), + Error::UdpReceive(_) => { + write!(f, "error receiving from UDP socket") + } + Error::UdpSend(_) => write!(f, "error sending to UDP socket"), + Error::UdpShortSend => write!(f, "partial sent to UDP socket"), + Error::UdpTimeoutNoResponse => { + write!(f, "timeout waiting for response") + } + Error::WrongReplyForQuery => { + write!(f, "reply does not match query") + } + } + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match self { + Error::ConnectionClosed => None, + Error::ShortMessage => None, + Error::StreamIdleTimeout => None, + Error::StreamReceiveError => None, + Error::StreamReadError(e) => Some(e), + Error::StreamReadTimeout => None, + Error::StreamTooManyOutstandingQueries => None, + Error::StreamWriteError(e) => Some(e), + Error::StreamUnexpectedEndOfData => None, + Error::UdpBind(e) => Some(e), + Error::UdpConnect(e) => Some(e), + Error::UdpReceive(e) => Some(e), + Error::UdpSend(e) => Some(e), + Error::UdpShortSend => None, + Error::UdpTimeoutNoResponse => None, + Error::WrongReplyForQuery => None, + } + } +} diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 61cf20ebb..3c058cbad 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -2,9 +2,11 @@ #![cfg(feature = "net")] #![cfg_attr(docsrs, doc(cfg(feature = "net")))] +pub mod error; pub mod factory; pub mod multi_stream; pub mod octet_stream; +pub mod query; pub mod tcp_channel; pub mod tcp_factory; pub mod tcp_mutex; diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 014991e74..c4a2a6582 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -3,6 +3,9 @@ #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] +// To do: +// - too many connection errors + use bytes::Bytes; use futures::lock::Mutex as Futures_mutex; @@ -28,9 +31,11 @@ use tokio::time::{sleep_until, Instant}; use crate::base::wire::Composer; use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::net::client::error::Error; use crate::net::client::factory::ConnFactory; use crate::net::client::octet_stream::Connection as SingleConnection; use crate::net::client::octet_stream::QueryNoCheck as SingleQuery; +use crate::net::client::query::{GetResult, QueryMessage}; /// Capacity of the channel that transports [ChanReq]. const DEF_CHAN_CAP: usize = 8; @@ -423,7 +428,7 @@ impl< &self, opt_id: Option, sender: oneshot::Sender>, - ) -> Result<(), &'static str> { + ) -> Result<(), Error> { let req = ChanReq { cmd: ReqCmd::NewConn(opt_id, sender), }; @@ -432,7 +437,7 @@ impl< // Send error. The receiver is gone, this means that the // connection is closed. { - Err(ERR_CONN_CLOSED) + Err(Error::ConnectionClosed) } Ok(_) => Ok(()), } @@ -494,10 +499,10 @@ impl< /// /// This function takes a precomposed message as a parameter and /// returns a [Query] object wrapped in a [Result]. - pub async fn query( + pub async fn query_impl( &self, query_msg: &mut MessageBuilder>>, - ) -> Result, &'static str> { + ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.inner.new_conn(None, tx).await?; Ok(Query::new(self.clone(), query_msg, rx)) @@ -513,11 +518,24 @@ impl< &self, id: u64, tx: oneshot::Sender>, - ) -> Result<(), &'static str> { + ) -> Result<(), Error> { self.inner.new_conn(Some(id), tx).await } } +impl + QueryMessage, Octs> for Connection +{ + fn query<'a>( + &'a self, + query_msg: &'a mut MessageBuilder< + StaticCompressor>, + >, + ) -> Pin, Error>> + '_>> { + return Box::pin(self.query_impl(query_msg)); + } +} + impl< Octs: AsRef<[u8]> + AsMut<[u8]> @@ -550,9 +568,7 @@ impl< /// /// This function returns the reply to a DNS query wrapped in a /// [Result]. - pub async fn get_result( - &mut self, - ) -> Result, Arc> { + pub async fn get_result_impl(&mut self) -> Result, Error> { loop { match self.state { QueryState::GetConn(ref mut receiver) => { @@ -560,10 +576,7 @@ impl< if res.is_err() { // Assume receive error self.state = QueryState::Done; - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "receive error", - ))); + return Err(Error::StreamReceiveError); } let res = res.expect("error is checked before"); @@ -592,7 +605,7 @@ impl< let query_res = conn.query_no_check(&mut msg).await; match query_res { Err(err) => { - if err == ERR_CONN_CLOSED { + if let Error::ConnectionClosed = err { let (tx, rx) = oneshot::channel(); let res = self .conn @@ -600,18 +613,12 @@ impl< .await; if let Err(err) = res { self.state = QueryState::Done; - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - err, - ))); + return Err(err); } self.state = QueryState::GetConn(rx); continue; } - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - err, - ))); + return Err(err); } Ok(query) => { self.state = QueryState::GetResult(query); @@ -637,10 +644,7 @@ impl< .expect("how to go from MessageBuild to Message?"); if !is_answer_ignore_id(&msg, &query_msg) { - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "wrong answer", - ))); + return Err(Error::WrongReplyForQuery); } return Ok(msg); } @@ -650,10 +654,7 @@ impl< let res = self.conn.new_conn(self.conn_id, tx).await; if let Err(err) = res { self.state = QueryState::Done; - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - err, - ))); + return Err(err); } self.state = QueryState::GetConn(rx); continue; @@ -666,6 +667,25 @@ impl< } } +impl< + Octs: AsMut<[u8]> + + AsRef<[u8]> + + Clone + + Composer + + Debug + + OctetsBuilder + + Send + + 'static, + > GetResult for Query +{ + fn get_result( + &mut self, + ) -> Pin, Error>> + '_>> + { + Box::pin(self.get_result_impl()) + } +} + /// Compute the retry timeout based on the number of retries so far. /// /// The computation is a random value (in microseconds) between zero and diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 74d25bb7c..fa7a6d853 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -29,7 +29,10 @@ use bytes; use bytes::{Bytes, BytesMut}; use core::convert::From; use futures::lock::Mutex as Futures_mutex; +use std::boxed::Box; use std::fmt::Debug; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use std::time::{Duration, Instant}; use std::vec::Vec; @@ -40,6 +43,8 @@ use crate::base::{ Message, MessageBuilder, ParsedDname, Rtype, StaticCompressor, StreamTarget, }; +use crate::net::client::error::Error; +use crate::net::client::query::{GetResult, QueryMessage}; use crate::rdata::AllRecordData; use octseq::{Octets, OctetsBuilder}; @@ -48,9 +53,6 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot}; use tokio::time::sleep; -/// Error returned when too many queries are currently active. -const ERR_TOO_MANY_QUERIES: &str = "too many outstanding queries"; - /// Time to wait on a non-idle connection for the other side to send /// a response on any outstanding query. // Implement a simple response timer to see if the connection and the server @@ -69,10 +71,6 @@ const DEF_CHAN_CAP: usize = 8; /// [InnerConnection::run]. const READ_REPLY_CHAN_CAP: usize = 8; -/// Error reported when the connection is closed and -/// [InnerConnection::run] terminated. -const ERR_CONN_CLOSED: &str = "connection closed"; - /// This is the type of sender in [ChanReq]. type ReplySender = oneshot::Sender; @@ -93,7 +91,7 @@ struct Response { reply: Message, } /// Response to the DNS request sent by [InnerConnection::run] to [Query]. -type ChanResp = Result>; +type ChanResp = Result; /// The actual implementation of [Connection]. #[derive(Debug)] @@ -205,13 +203,13 @@ enum ConnState { IdleTimeout, /// A read error occurred. - ReadError, + ReadError(Error), /// It took too long to receive a (or another) response. ReadTimeout, /// A write error occurred. - WriteError, + WriteError(Error), } /// A DNS message received to [InnerConnection::reader] and sent to @@ -274,11 +272,10 @@ impl + Clone + Composer + Debug + OctetsBuilder> if let Some(instant) = opt_instant { let elapsed = instant.elapsed(); if elapsed > RESPONSE_TIMEOUT { - let error = io::Error::new( - io::ErrorKind::Other, - "read timeout", + Self::error( + Error::StreamReadTimeout, + &mut query_vec, ); - Self::error(error, &mut query_vec); status.state = ConnState::ReadTimeout; break; } @@ -302,8 +299,8 @@ impl + Clone + Composer + Debug + OctetsBuilder> } } ConnState::IdleTimeout - | ConnState::ReadError - | ConnState::WriteError => None, // No timers here + | ConnState::ReadError(_) + | ConnState::WriteError(_) => None, // No timers here ConnState::ReadTimeout => { panic!("should not be in loop with ReadTimeout"); } @@ -343,10 +340,10 @@ impl + Clone + Composer + Debug + OctetsBuilder> // error. panic!("reader terminated"), Err(error) => { - Self::error(error, + Self::error(error.clone(), &mut query_vec); status.state = - ConnState::ReadError; + ConnState::ReadError(error); // Reader failed. Break // out of loop and // shut down @@ -369,10 +366,10 @@ impl + Clone + Composer + Debug + OctetsBuilder> res = write_stream.write_all(msg), if do_write => { if let Err(error) = res { - Self::error(error, - &mut query_vec); + let error = Error::StreamWriteError(Arc::new(error)); + Self::error(error.clone(), &mut query_vec); status.state = - ConnState::WriteError; + ConnState::WriteError(error); break; } else { @@ -400,9 +397,9 @@ impl + Clone + Composer + Debug + OctetsBuilder> // Keep going } ConnState::IdleTimeout => break, - ConnState::ReadError + ConnState::ReadError(_) | ConnState::ReadTimeout - | ConnState::WriteError => { + | ConnState::WriteError(_) => { panic!("Should not be here"); } } @@ -419,7 +416,7 @@ impl + Clone + Composer + Debug + OctetsBuilder> &self, sender: oneshot::Sender, query_msg: &mut MessageBuilder>>, - ) -> Result<(), &'static str> { + ) -> Result<(), Error> { // We should figure out how to get query_msg. let msg_clone = query_msg.clone(); @@ -432,7 +429,7 @@ impl + Clone + Composer + Debug + OctetsBuilder> // Send error. The receiver is gone, this means that the // connection is closed. { - Err(ERR_CONN_CLOSED) + Err(Error::ConnectionClosed) } Ok(_) => Ok(()), } @@ -450,13 +447,13 @@ impl + Clone + Composer + Debug + OctetsBuilder> //sock: &mut ReadStream, mut sock: ReadStream, sender: mpsc::Sender, - ) -> Result<(), std::io::Error> { + ) -> Result<(), Error> { loop { let read_res = sock.read_u16().await; let len = match read_res { Ok(len) => len, Err(error) => { - return Err(error); + return Err(Error::StreamReadError(Arc::new(error))); } } as usize; @@ -479,15 +476,11 @@ impl + Clone + Composer + Debug + OctetsBuilder> match read_res { Ok(readlen) => { if readlen == 0 { - let error = io::Error::new( - io::ErrorKind::Other, - "unexpected end of data", - ); - return Err(error); + return Err(Error::StreamUnexpectedEndOfData); } } Err(error) => { - return Err(error); + return Err(Error::StreamReadError(Arc::new(error))); } }; @@ -504,24 +497,21 @@ impl + Clone + Composer + Debug + OctetsBuilder> } Err(_) => { // The only possible error is short message - let error = - io::Error::new(io::ErrorKind::Other, "short buf"); - return Err(error); + return Err(Error::ShortMessage); } } } } /// An error occured, report the error to all outstanding [Query] objects. - fn error(error: std::io::Error, query_vec: &mut Queries) { + fn error(error: Error, query_vec: &mut Queries) { // Update all requests that are in progress. Don't wait for // any reply that may be on its way. - let arc_error = Arc::new(error); for index in 0..query_vec.vec.len() { if query_vec.vec[index].is_some() { let sender = Self::take_query(query_vec, index) .expect("we tested is_none before"); - _ = sender.send(Err(arc_error.clone())); + _ = sender.send(Err(error.clone())); } } } @@ -604,7 +594,7 @@ impl + Clone + Composer + Debug + OctetsBuilder> reqmsg: &mut Option>, query_vec: &mut Queries, ) { - match status.state { + match &status.state { ConnState::Active(timer) => { // Set timer if we don't have one already if timer.is_none() { @@ -617,31 +607,19 @@ impl + Clone + Composer + Debug + OctetsBuilder> } ConnState::IdleTimeout => { // The connection has been closed. Report error - _ = req.sender.send(Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "idle timeout", - )))); + _ = req.sender.send(Err(Error::StreamIdleTimeout)); return; } - ConnState::ReadError => { - _ = req.sender.send(Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "read error", - )))); + ConnState::ReadError(error) => { + _ = req.sender.send(Err(error.clone())); return; } ConnState::ReadTimeout => { - _ = req.sender.send(Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "read timeout", - )))); + _ = req.sender.send(Err(Error::StreamReadTimeout)); return; } - ConnState::WriteError => { - _ = req.sender.send(Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "write error", - )))); + ConnState::WriteError(error) => { + _ = req.sender.send(Err(error.clone())); return; } } @@ -729,7 +707,7 @@ impl + Clone + Composer + Debug + OctetsBuilder> fn insert( sender: oneshot::Sender, query_vec: &mut Queries, - ) -> Result { + ) -> Result { // Fail if there are to many entries already in this vector // We cannot have more than u16::MAX entries because the // index needs to fit in an u16. For efficiency we want to @@ -737,11 +715,9 @@ impl + Clone + Composer + Debug + OctetsBuilder> // 2*count > u16::MAX if 2 * query_vec.count > u16::MAX.into() { // We own sender. So we need to send the error reply here - _ = sender.send(Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - ERR_TOO_MANY_QUERIES, - )))); - return Err(ERR_TOO_MANY_QUERIES); + let error = Error::StreamTooManyOutstandingQueries; + _ = sender.send(Err(error.clone())); + return Err(error); } let q = Some(sender); @@ -821,10 +797,10 @@ impl< /// /// This function takes a precomposed message as a parameter and /// returns a [Query] object wrapped in a [Result]. - pub async fn query( + async fn query_impl( &self, query_msg: &mut MessageBuilder>>, - ) -> Result { + ) -> Result { let (tx, rx) = oneshot::channel(); self.inner.query(tx, query_msg).await?; let msg = &query_msg.as_message(); @@ -838,13 +814,27 @@ impl< pub async fn query_no_check( &self, query_msg: &mut MessageBuilder>>, - ) -> Result { + ) -> Result { let (tx, rx) = oneshot::channel(); self.inner.query(tx, query_msg).await?; Ok(QueryNoCheck::new(rx)) } } +impl< + Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Composer + Debug + OctetsBuilder, + > QueryMessage for Connection +{ + fn query<'a>( + &'a self, + query_msg: &'a mut MessageBuilder< + StaticCompressor>, + >, + ) -> Pin> + '_>> { + return Box::pin(self.query_impl(query_msg)); + } +} + impl Query { /// Constructor for [Query], takes a DNS query and a receiver for the /// reply. @@ -865,19 +855,14 @@ impl Query { /// /// This function returns the reply to a DNS query wrapped in a /// [Result]. - pub async fn get_result( - &mut self, - ) -> Result, Arc> { + pub async fn get_result_impl(&mut self) -> Result, Error> { match self.state { QueryState::Busy(ref mut receiver) => { let res = receiver.await; self.state = QueryState::Done; if res.is_err() { // Assume receive error - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "receive error", - ))); + return Err(Error::StreamReceiveError); } let res = res.unwrap(); @@ -892,10 +877,7 @@ impl Query { let msg = resp.reply; if !is_answer_ignore_id(&msg, &self.query_msg) { - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "wrong answer", - ))); + return Err(Error::WrongReplyForQuery); } Ok(msg) } @@ -906,6 +888,15 @@ impl Query { } } +impl GetResult for Query { + fn get_result( + &mut self, + ) -> Pin, Error>> + '_>> + { + Box::pin(self.get_result_impl()) + } +} + impl QueryNoCheck { /// Constructor for [Query], takes a DNS query and a receiver for the /// reply. @@ -919,19 +910,14 @@ impl QueryNoCheck { /// /// This function returns the reply to a DNS query wrapped in a /// [Result]. - pub async fn get_result( - &mut self, - ) -> Result, Arc> { + pub async fn get_result(&mut self) -> Result, Error> { match self.state { QueryState::Busy(ref mut receiver) => { let res = receiver.await; self.state = QueryState::Done; if res.is_err() { // Assume receive error - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "receive error", - ))); + return Err(Error::StreamReceiveError); } let res = res.unwrap(); diff --git a/src/net/client/query.rs b/src/net/client/query.rs new file mode 100644 index 000000000..5d9d68e56 --- /dev/null +++ b/src/net/client/query.rs @@ -0,0 +1,36 @@ +//! Traits for query transports + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +use bytes::Bytes; +use std::boxed::Box; +use std::future::Future; +use std::pin::Pin; +// use std::sync::Arc; + +use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::net::client::error::Error; + +/// Trait for starting a DNS query based on a message. +pub trait QueryMessage { + /// Query function that takes a message builder type. + /// + /// This function is intended to be cancel safe. + fn query<'a>( + &'a self, + query_msg: &'a mut MessageBuilder< + StaticCompressor>, + >, + ) -> Pin> + '_>>; +} + +/// Trait for getting the result of a DNS query. +pub trait GetResult { + /// Get the result of a DNS query. + /// + /// This function is intended to be cancel safe. + fn get_result( + &mut self, + ) -> Pin, Error>> + '_>>; +} diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 5cb94e81f..f912d7262 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -8,14 +8,19 @@ // - random port use bytes::Bytes; +use std::boxed::Box; +use std::future::Future; use std::io; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::Arc; use tokio::net::UdpSocket; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio::time::{timeout, Duration, Instant}; use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::net::client::error::Error; +use crate::net::client::query::{GetResult, QueryMessage}; /// How many times do we try a new random port if we get ‘address in use.’ const RETRY_RANDOM_PORT: usize = 10; @@ -47,10 +52,10 @@ impl Connection { } /// Start a new DNS query. - pub async fn query + Clone>( + async fn query_impl + Clone>( &self, query_msg: &mut MessageBuilder>>, - ) -> Result, &'static str> { + ) -> Result, Error> { self.inner.query(query_msg, self.clone()).await } @@ -60,6 +65,19 @@ impl Connection { } } +impl + Clone> QueryMessage, Octs> + for Connection +{ + fn query<'a>( + &'a self, + query_msg: &'a mut MessageBuilder< + StaticCompressor>, + >, + ) -> Pin, Error>> + '_>> { + return Box::pin(self.query_impl(query_msg)); + } +} + /// State of the DNS query. enum QueryState { /// Get a semaphore permit. @@ -119,9 +137,7 @@ impl + Clone> Query { /// Get the result of a DNS Query. /// /// This function is cancel safe. - pub async fn get_result( - &mut self, - ) -> Result, Arc> { + async fn get_result_impl(&mut self) -> Result, Error> { let recv_size = 2000; // Should be configurable. loop { @@ -146,7 +162,8 @@ impl + Clone> Query { .as_ref() .expect("socket should be present") .connect(self.remote_addr) - .await?; + .await + .map_err(|e| Error::UdpConnect(Arc::new(e)))?; self.state = QueryState::Send; continue; } @@ -161,7 +178,8 @@ impl + Clone> Query { .as_target() .as_dgram_slice(), ) - .await?; + .await + .map_err(|e| Error::UdpSend(Arc::new(e)))?; if sent != self .query_msg @@ -170,10 +188,7 @@ impl + Clone> Query { .as_dgram_slice() .len() { - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "short UDP send", - ))); + return Err(Error::UdpShortSend); } self.state = QueryState::Receive(Instant::now()); continue; @@ -201,13 +216,11 @@ impl + Clone> Query { self.state = QueryState::GetSocket; continue; } - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "no response", - ))); + return Err(Error::UdpTimeoutNoResponse); } - let len = - timeout_res.expect("errror case is checked above")?; + let len = timeout_res + .expect("errror case is checked above") + .map_err(|e| Error::UdpReceive(Arc::new(e)))?; buf.truncate(len); // We ignore garbage since there is a timer on this whole thing. @@ -230,7 +243,7 @@ impl + Clone> Query { /// /// This should explicitly pick a random number in a suitable range of /// ports. - async fn udp_bind(v4: bool) -> Result { + async fn udp_bind(v4: bool) -> Result { let mut i = 0; loop { let local: SocketAddr = if v4 { @@ -242,7 +255,7 @@ impl + Clone> Query { Ok(sock) => return Ok(sock), Err(err) => { if i == RETRY_RANDOM_PORT { - return Err(err); + return Err(Error::UdpBind(Arc::new(err))); } else { i += 1 } @@ -252,6 +265,15 @@ impl + Clone> Query { } } +impl + Clone> GetResult for Query { + fn get_result( + &mut self, + ) -> Pin, Error>> + '_>> + { + Box::pin(self.get_result_impl()) + } +} + /// Actual implementation of the UDP transport connection. struct InnerConnection { /// Address of the remote server. @@ -275,7 +297,7 @@ impl InnerConnection { &self, query_msg: &mut MessageBuilder>>, conn: Connection, - ) -> Result, &'static str> { + ) -> Result, Error> { Ok(Query::new(query_msg, self.remote_addr, conn)) } diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index 69fa63c87..166ce3563 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -8,14 +8,19 @@ use bytes::Bytes; use octseq::OctetsBuilder; +use std::boxed::Box; use std::fmt::Debug; +use std::future::Future; use std::io; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::Arc; use crate::base::wire::Composer; use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::net::client::error::Error; use crate::net::client::multi_stream; +use crate::net::client::query::{GetResult, QueryMessage}; use crate::net::client::tcp_factory::TcpConnFactory; use crate::net::client::udp; @@ -44,14 +49,27 @@ impl } /// Start a query. - pub async fn query( + pub async fn query_impl( &self, query_msg: &mut MessageBuilder>>, - ) -> Result, &'static str> { + ) -> Result, Error> { self.inner.query(query_msg).await } } +impl + QueryMessage, Octs> for Connection +{ + fn query<'a>( + &'a self, + query_msg: &'a mut MessageBuilder< + StaticCompressor>, + >, + ) -> Pin, Error>> + '_>> { + return Box::pin(self.query_impl(query_msg)); + } +} + /// Object that contains the current state of a query. pub struct Query { /// Reqeust message. @@ -112,19 +130,14 @@ impl< /// Get the result of a DNS query. /// /// This function is cancel safe. - pub async fn get_result( - &mut self, - ) -> Result, Arc> { + async fn get_result_impl(&mut self) -> Result, Error> { loop { match &mut self.state { QueryState::StartUdpQuery => { let query = self .udp_conn .query(&mut self.query_msg.clone()) - .await - .map_err(|e| { - io::Error::new(io::ErrorKind::Other, e) - })?; + .await?; self.state = QueryState::GetUdpResult(query); continue; } @@ -140,10 +153,7 @@ impl< let query = self .tcp_conn .query(&mut self.query_msg.clone()) - .await - .map_err(|e| { - io::Error::new(io::ErrorKind::Other, e) - })?; + .await?; self.state = QueryState::GetTcpResult(query); continue; } @@ -156,6 +166,25 @@ impl< } } +impl< + Octs: AsMut<[u8]> + + AsRef<[u8]> + + Clone + + Composer + + Debug + + OctetsBuilder + + Send + + 'static, + > GetResult for Query +{ + fn get_result( + &mut self, + ) -> Pin, Error>> + '_>> + { + Box::pin(self.get_result_impl()) + } +} + /// The actual connection object. struct InnerConnection { /// The remote address to connect to. @@ -201,7 +230,7 @@ impl async fn query( &self, query_msg: &mut MessageBuilder>>, - ) -> Result, &'static str> { + ) -> Result, Error> { Ok(Query::new( query_msg, self.udp_conn.clone(), From e65218a28ed696882bd907c3e30f8f235042c4f7 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 10 Aug 2023 10:44:36 +0200 Subject: [PATCH 036/124] Add Send --- src/net/client/multi_stream.rs | 4 ++-- src/net/client/octet_stream.rs | 8 ++++---- src/net/client/query.rs | 4 ++-- src/net/client/udp.rs | 22 +++++++++------------- src/net/client/udp_tcp.rs | 10 ++++++---- 5 files changed, 23 insertions(+), 25 deletions(-) diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index c4a2a6582..ee8321076 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -531,7 +531,7 @@ impl query_msg: &'a mut MessageBuilder< StaticCompressor>, >, - ) -> Pin, Error>> + '_>> { + ) -> Pin, Error>> + Send + '_>> { return Box::pin(self.query_impl(query_msg)); } } @@ -680,7 +680,7 @@ impl< { fn get_result( &mut self, - ) -> Pin, Error>> + '_>> + ) -> Pin, Error>> + Send + '_>> { Box::pin(self.get_result_impl()) } diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index fa7a6d853..05f7a6026 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -769,7 +769,7 @@ impl + Clone + Composer + Debug + OctetsBuilder> } impl< - Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Composer + Debug + OctetsBuilder, + Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Composer + Debug + OctetsBuilder + Send, > Connection { /// Constructor for [Connection]. @@ -822,7 +822,7 @@ impl< } impl< - Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Composer + Debug + OctetsBuilder, + Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Composer + Debug + OctetsBuilder + Send, > QueryMessage for Connection { fn query<'a>( @@ -830,7 +830,7 @@ impl< query_msg: &'a mut MessageBuilder< StaticCompressor>, >, - ) -> Pin> + '_>> { + ) -> Pin> + Send + '_>> { return Box::pin(self.query_impl(query_msg)); } } @@ -891,7 +891,7 @@ impl Query { impl GetResult for Query { fn get_result( &mut self, - ) -> Pin, Error>> + '_>> + ) -> Pin, Error>> + Send + '_>> { Box::pin(self.get_result_impl()) } diff --git a/src/net/client/query.rs b/src/net/client/query.rs index 5d9d68e56..56881a07b 100644 --- a/src/net/client/query.rs +++ b/src/net/client/query.rs @@ -22,7 +22,7 @@ pub trait QueryMessage { query_msg: &'a mut MessageBuilder< StaticCompressor>, >, - ) -> Pin> + '_>>; + ) -> Pin> + Send + '_>>; } /// Trait for getting the result of a DNS query. @@ -32,5 +32,5 @@ pub trait GetResult { /// This function is intended to be cancel safe. fn get_result( &mut self, - ) -> Pin, Error>> + '_>>; + ) -> Pin, Error>> + Send + '_>>; } diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index f912d7262..34e42e9fd 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -52,7 +52,7 @@ impl Connection { } /// Start a new DNS query. - async fn query_impl + Clone>( + async fn query_impl + Clone + Send>( &self, query_msg: &mut MessageBuilder>>, ) -> Result, Error> { @@ -65,7 +65,7 @@ impl Connection { } } -impl + Clone> QueryMessage, Octs> +impl + Clone + Send> QueryMessage, Octs> for Connection { fn query<'a>( @@ -73,7 +73,7 @@ impl + Clone> QueryMessage, Octs> query_msg: &'a mut MessageBuilder< StaticCompressor>, >, - ) -> Pin, Error>> + '_>> { + ) -> Pin, Error>> + Send + '_>> { return Box::pin(self.query_impl(query_msg)); } } @@ -117,7 +117,7 @@ pub struct Query { state: QueryState, } -impl + Clone> Query { +impl + Clone + Send> Query { /// Create new Query object. fn new( query_msg: &mut MessageBuilder>>, @@ -168,16 +168,12 @@ impl + Clone> Query { continue; } QueryState::Send => { + let dgram = self.query_msg .as_target() .as_target() .as_dgram_slice(); let sent = self .sock .as_ref() .expect("socket should be present") - .send( - self.query_msg - .as_target() - .as_target() - .as_dgram_slice(), - ) + .send( dgram) .await .map_err(|e| Error::UdpSend(Arc::new(e)))?; if sent @@ -265,10 +261,10 @@ impl + Clone> Query { } } -impl + Clone> GetResult for Query { +impl + Clone + Send> GetResult for Query { fn get_result( &mut self, - ) -> Pin, Error>> + '_>> + ) -> Pin, Error>> + Send + '_>> { Box::pin(self.get_result_impl()) } @@ -293,7 +289,7 @@ impl InnerConnection { } /// Return a Query object that contains the query state. - async fn query + Clone>( + async fn query + Clone + Send>( &self, query_msg: &mut MessageBuilder>>, conn: Connection, diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index 166ce3563..76491b889 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -65,7 +65,7 @@ impl query_msg: &'a mut MessageBuilder< StaticCompressor>, >, - ) -> Pin, Error>> + '_>> { + ) -> Pin, Error>> + Send + '_>> { return Box::pin(self.query_impl(query_msg)); } } @@ -134,9 +134,10 @@ impl< loop { match &mut self.state { QueryState::StartUdpQuery => { + let mut msg = self.query_msg.clone(); let query = self .udp_conn - .query(&mut self.query_msg.clone()) + .query(&mut msg) .await?; self.state = QueryState::GetUdpResult(query); continue; @@ -150,9 +151,10 @@ impl< return Ok(reply); } QueryState::StartTcpQuery => { + let mut msg = self.query_msg.clone(); let query = self .tcp_conn - .query(&mut self.query_msg.clone()) + .query(&mut msg) .await?; self.state = QueryState::GetTcpResult(query); continue; @@ -179,7 +181,7 @@ impl< { fn get_result( &mut self, - ) -> Pin, Error>> + '_>> + ) -> Pin, Error>> + Send + '_>> { Box::pin(self.get_result_impl()) } From 4979a6251064e3018945342c25ee3cf79e810b23 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 10 Aug 2023 10:56:57 +0200 Subject: [PATCH 037/124] Fmt --- src/net/client/multi_stream.rs | 8 +++++--- src/net/client/octet_stream.rs | 21 +++++++++++++++++---- src/net/client/query.rs | 4 +++- src/net/client/udp.rs | 16 +++++++++++----- src/net/client/udp_tcp.rs | 22 +++++++++------------- 5 files changed, 45 insertions(+), 26 deletions(-) diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index ee8321076..6a26c9e95 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -531,7 +531,8 @@ impl query_msg: &'a mut MessageBuilder< StaticCompressor>, >, - ) -> Pin, Error>> + Send + '_>> { + ) -> Pin, Error>> + Send + '_>> + { return Box::pin(self.query_impl(query_msg)); } } @@ -680,8 +681,9 @@ impl< { fn get_result( &mut self, - ) -> Pin, Error>> + Send + '_>> - { + ) -> Pin< + Box, Error>> + Send + '_>, + > { Box::pin(self.get_result_impl()) } } diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 05f7a6026..b0f810c9f 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -769,7 +769,13 @@ impl + Clone + Composer + Debug + OctetsBuilder> } impl< - Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Composer + Debug + OctetsBuilder + Send, + Octs: AsMut<[u8]> + + AsRef<[u8]> + + Clone + + Composer + + Debug + + OctetsBuilder + + Send, > Connection { /// Constructor for [Connection]. @@ -822,7 +828,13 @@ impl< } impl< - Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Composer + Debug + OctetsBuilder + Send, + Octs: AsMut<[u8]> + + AsRef<[u8]> + + Clone + + Composer + + Debug + + OctetsBuilder + + Send, > QueryMessage for Connection { fn query<'a>( @@ -891,8 +903,9 @@ impl Query { impl GetResult for Query { fn get_result( &mut self, - ) -> Pin, Error>> + Send + '_>> - { + ) -> Pin< + Box, Error>> + Send + '_>, + > { Box::pin(self.get_result_impl()) } } diff --git a/src/net/client/query.rs b/src/net/client/query.rs index 56881a07b..248648358 100644 --- a/src/net/client/query.rs +++ b/src/net/client/query.rs @@ -32,5 +32,7 @@ pub trait GetResult { /// This function is intended to be cancel safe. fn get_result( &mut self, - ) -> Pin, Error>> + Send + '_>>; + ) -> Pin< + Box, Error>> + Send + '_>, + >; } diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 34e42e9fd..d737e816b 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -73,7 +73,8 @@ impl + Clone + Send> QueryMessage, Octs> query_msg: &'a mut MessageBuilder< StaticCompressor>, >, - ) -> Pin, Error>> + Send + '_>> { + ) -> Pin, Error>> + Send + '_>> + { return Box::pin(self.query_impl(query_msg)); } } @@ -168,12 +169,16 @@ impl + Clone + Send> Query { continue; } QueryState::Send => { - let dgram = self.query_msg .as_target() .as_target() .as_dgram_slice(); + let dgram = self + .query_msg + .as_target() + .as_target() + .as_dgram_slice(); let sent = self .sock .as_ref() .expect("socket should be present") - .send( dgram) + .send(dgram) .await .map_err(|e| Error::UdpSend(Arc::new(e)))?; if sent @@ -264,8 +269,9 @@ impl + Clone + Send> Query { impl + Clone + Send> GetResult for Query { fn get_result( &mut self, - ) -> Pin, Error>> + Send + '_>> - { + ) -> Pin< + Box, Error>> + Send + '_>, + > { Box::pin(self.get_result_impl()) } } diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index 76491b889..b48edf5b6 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -65,7 +65,8 @@ impl query_msg: &'a mut MessageBuilder< StaticCompressor>, >, - ) -> Pin, Error>> + Send + '_>> { + ) -> Pin, Error>> + Send + '_>> + { return Box::pin(self.query_impl(query_msg)); } } @@ -134,11 +135,8 @@ impl< loop { match &mut self.state { QueryState::StartUdpQuery => { - let mut msg = self.query_msg.clone(); - let query = self - .udp_conn - .query(&mut msg) - .await?; + let mut msg = self.query_msg.clone(); + let query = self.udp_conn.query(&mut msg).await?; self.state = QueryState::GetUdpResult(query); continue; } @@ -151,11 +149,8 @@ impl< return Ok(reply); } QueryState::StartTcpQuery => { - let mut msg = self.query_msg.clone(); - let query = self - .tcp_conn - .query(&mut msg) - .await?; + let mut msg = self.query_msg.clone(); + let query = self.tcp_conn.query(&mut msg).await?; self.state = QueryState::GetTcpResult(query); continue; } @@ -181,8 +176,9 @@ impl< { fn get_result( &mut self, - ) -> Pin, Error>> + Send + '_>> - { + ) -> Pin< + Box, Error>> + Send + '_>, + > { Box::pin(self.get_result_impl()) } } From fd92f07ebb045fd9dc1b3ae8038921614b11cb62 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 21 Sep 2023 14:42:57 +0200 Subject: [PATCH 038/124] QueryMessage2 and Debug --- src/net/client/multi_stream.rs | 40 ++++++++++++++++++++++++++++-- src/net/client/octet_stream.rs | 3 +++ src/net/client/query.rs | 20 ++++++++++++++- src/net/client/udp.rs | 42 ++++++++++++++++++++++++++++--- src/net/client/udp_tcp.rs | 45 +++++++++++++++++++++++++++++++--- 5 files changed, 140 insertions(+), 10 deletions(-) diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 6a26c9e95..471e8e11e 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -35,7 +35,7 @@ use crate::net::client::error::Error; use crate::net::client::factory::ConnFactory; use crate::net::client::octet_stream::Connection as SingleConnection; use crate::net::client::octet_stream::QueryNoCheck as SingleQuery; -use crate::net::client::query::{GetResult, QueryMessage}; +use crate::net::client::query::{GetResult, QueryMessage, QueryMessage2}; /// Capacity of the channel that transports [ChanReq]. const DEF_CHAN_CAP: usize = 8; @@ -82,6 +82,7 @@ struct ChanReq { } /// The actual implementation of [Connection]. +#[derive(Debug)] struct InnerConnection { /// [InnerConnection::sender] and [InnerConnection::receiver] are /// part of a single channel. @@ -98,7 +99,7 @@ struct InnerConnection { receiver: Futures_mutex>>>, } -#[derive(Clone)] +#[derive(Clone, Debug)] /// A DNS over octect streams transport. pub struct Connection { /// Reference counted [InnerConnection]. @@ -106,6 +107,7 @@ pub struct Connection { } /// Status of a query. Used in [Query]. +#[derive(Debug)] enum QueryState { /// Get a octet_stream transport. GetConn(oneshot::Receiver>), @@ -180,6 +182,7 @@ struct State<'a, F, IO, Octs: OctetsBuilder> { } /// This struct represent an active DNS query. +#[derive(Debug)] pub struct Query { /// Request message. /// @@ -508,6 +511,20 @@ impl< Ok(Query::new(self.clone(), query_msg, rx)) } + /// Start a DNS request. + /// + /// This function takes a precomposed message as a parameter and + /// returns a [Query] object wrapped in a [Result]. + pub async fn query_impl2( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result, Error> { + let (tx, rx) = oneshot::channel(); + self.inner.new_conn(None, tx).await?; + let gr = Query::new(self.clone(), query_msg, rx); + Ok(Box::new(gr)) + } + /// Shutdown this transport. pub async fn shutdown(&self) -> Result<(), &'static str> { self.inner.shutdown().await @@ -537,6 +554,25 @@ impl } } +impl + QueryMessage2 for Connection +{ + fn query<'a>( + &'a self, + query_msg: &'a mut MessageBuilder< + StaticCompressor>, + >, + ) -> Pin< + Box< + dyn Future, Error>> + + Send + + '_, + >, + > { + return Box::pin(self.query_impl2(query_msg)); + } +} + impl< Octs: AsRef<[u8]> + AsMut<[u8]> diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index b0f810c9f..6985eefa9 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -132,6 +132,7 @@ pub struct Connection { } /// Status of a query. Used in [Query]. +#[derive(Debug)] enum QueryState { /// A request is in progress. /// @@ -143,6 +144,7 @@ enum QueryState { } /// This struct represent an active DNS query. +#[derive(Debug)] pub struct Query { /// Request message. /// @@ -157,6 +159,7 @@ pub struct Query { /// This represents that state of an active DNS query if there is no need /// to check that the reply matches the request. The assumption is that the /// caller will do this check. +#[derive(Debug)] pub struct QueryNoCheck { /// Current state of the query. state: QueryState, diff --git a/src/net/client/query.rs b/src/net/client/query.rs index 248648358..6aca0a620 100644 --- a/src/net/client/query.rs +++ b/src/net/client/query.rs @@ -5,6 +5,7 @@ use bytes::Bytes; use std::boxed::Box; +use std::fmt::Debug; use std::future::Future; use std::pin::Pin; // use std::sync::Arc; @@ -25,8 +26,25 @@ pub trait QueryMessage { ) -> Pin> + Send + '_>>; } +/// Trait for starting a DNS query based on a message. +pub trait QueryMessage2 { + /// Query function that takes a message builder type. + /// + /// This function is intended to be cancel safe. + fn query<'a>( + &'a self, + query_msg: &'a mut MessageBuilder< + StaticCompressor>, + >, + ) -> Pin + Send + '_>>; +} + +/// This type is the actual result type of the future returned by the +/// query function in the QueryMessage2 trait. +type QueryResultOutput = Result, Error>; + /// Trait for getting the result of a DNS query. -pub trait GetResult { +pub trait GetResult: Debug { /// Get the result of a DNS query. /// /// This function is intended to be cancel safe. diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index d737e816b..ae9bddd0b 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -9,6 +9,7 @@ use bytes::Bytes; use std::boxed::Box; +use std::fmt::Debug; use std::future::Future; use std::io; use std::net::SocketAddr; @@ -20,7 +21,7 @@ use tokio::time::{timeout, Duration, Instant}; use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; use crate::net::client::error::Error; -use crate::net::client::query::{GetResult, QueryMessage}; +use crate::net::client::query::{GetResult, QueryMessage, QueryMessage2}; /// How many times do we try a new random port if we get ‘address in use.’ const RETRY_RANDOM_PORT: usize = 10; @@ -36,7 +37,7 @@ const READ_TIMEOUT: Duration = Duration::from_secs(5); const MAX_RETRIES: u8 = 5; /// A UDP transport connection. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Connection { /// Reference to the actual connection object. inner: Arc, @@ -59,13 +60,24 @@ impl Connection { self.inner.query(query_msg, self.clone()).await } + /// Start a new DNS query. + async fn query_impl2< + Octs: AsRef<[u8]> + Clone + Debug + Send + 'static, + >( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result, Error> { + let gr = self.inner.query(query_msg, self.clone()).await?; + Ok(Box::new(gr)) + } + /// Get a permit from the semaphore to start using a socket. async fn get_permit(&self) -> OwnedSemaphorePermit { self.inner.get_permit().await } } -impl + Clone + Send> QueryMessage, Octs> +impl + Clone + Debug + Send> QueryMessage, Octs> for Connection { fn query<'a>( @@ -79,7 +91,27 @@ impl + Clone + Send> QueryMessage, Octs> } } +impl + Clone + Debug + Send + 'static> QueryMessage2 + for Connection +{ + fn query<'a>( + &'a self, + query_msg: &'a mut MessageBuilder< + StaticCompressor>, + >, + ) -> Pin< + Box< + dyn Future, Error>> + + Send + + '_, + >, + > { + return Box::pin(self.query_impl2(query_msg)); + } +} + /// State of the DNS query. +#[derive(Debug)] enum QueryState { /// Get a semaphore permit. GetPermit(Connection), @@ -98,6 +130,7 @@ enum QueryState { } /// The state of a DNS query. +#[derive(Debug)] pub struct Query { /// Address of remote server to connect to. remote_addr: SocketAddr, @@ -266,7 +299,7 @@ impl + Clone + Send> Query { } } -impl + Clone + Send> GetResult for Query { +impl + Clone + Debug + Send> GetResult for Query { fn get_result( &mut self, ) -> Pin< @@ -277,6 +310,7 @@ impl + Clone + Send> GetResult for Query { } /// Actual implementation of the UDP transport connection. +#[derive(Debug)] struct InnerConnection { /// Address of the remote server. remote_addr: SocketAddr, diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index b48edf5b6..b97159ad3 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -20,7 +20,7 @@ use crate::base::wire::Composer; use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; use crate::net::client::error::Error; use crate::net::client::multi_stream; -use crate::net::client::query::{GetResult, QueryMessage}; +use crate::net::client::query::{GetResult, QueryMessage, QueryMessage2}; use crate::net::client::tcp_factory::TcpConnFactory; use crate::net::client::udp; @@ -55,6 +55,15 @@ impl ) -> Result, Error> { self.inner.query(query_msg).await } + + /// Start a query for the QueryMessage2 trait. + async fn query_impl2( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result, Error> { + let gr = self.inner.query(query_msg).await?; + Ok(Box::new(gr)) + } } impl @@ -71,7 +80,34 @@ impl } } +impl< + Octs: AsRef<[u8]> + + Clone + + Composer + + Debug + + OctetsBuilder + + Send + + 'static, + > QueryMessage2 for Connection +{ + fn query<'a>( + &'a self, + query_msg: &'a mut MessageBuilder< + StaticCompressor>, + >, + ) -> Pin< + Box< + dyn Future, Error>> + + Send + + '_, + >, + > { + return Box::pin(self.query_impl2(query_msg)); + } +} + /// Object that contains the current state of a query. +#[derive(Debug)] pub struct Query { /// Reqeust message. query_msg: MessageBuilder>>, @@ -87,6 +123,7 @@ pub struct Query { } /// Status of the query. +#[derive(Debug)] enum QueryState { /// Start a query over the UDP transport. StartUdpQuery, @@ -136,7 +173,8 @@ impl< match &mut self.state { QueryState::StartUdpQuery => { let mut msg = self.query_msg.clone(); - let query = self.udp_conn.query(&mut msg).await?; + let query = + QueryMessage::query(&self.udp_conn, &mut msg).await?; self.state = QueryState::GetUdpResult(query); continue; } @@ -150,7 +188,8 @@ impl< } QueryState::StartTcpQuery => { let mut msg = self.query_msg.clone(); - let query = self.tcp_conn.query(&mut msg).await?; + let query = + QueryMessage::query(&self.tcp_conn, &mut msg).await?; self.state = QueryState::GetTcpResult(query); continue; } From 0eff1090344389ddf68607e06ddec87d9a171c8a Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 21 Sep 2023 14:43:41 +0200 Subject: [PATCH 039/124] A transport that multiplexes requests over multiple redundant transports --- src/net/client/error.rs | 7 + src/net/client/mod.rs | 1 + src/net/client/redundant.rs | 619 ++++++++++++++++++++++++++++++++++++ 3 files changed, 627 insertions(+) create mode 100644 src/net/client/redundant.rs diff --git a/src/net/client/error.rs b/src/net/client/error.rs index bcd751eed..cda583dd3 100644 --- a/src/net/client/error.rs +++ b/src/net/client/error.rs @@ -57,6 +57,9 @@ pub enum Error { /// Reply does not match the query. WrongReplyForQuery, + + /// No transport available to transmit request. + NoTransportAvailable, } impl Display for Error { @@ -98,6 +101,9 @@ impl Display for Error { Error::WrongReplyForQuery => { write!(f, "reply does not match query") } + Error::NoTransportAvailable => { + write!(f, "no transport available") + } } } } @@ -121,6 +127,7 @@ impl error::Error for Error { Error::UdpShortSend => None, Error::UdpTimeoutNoResponse => None, Error::WrongReplyForQuery => None, + Error::NoTransportAvailable => None, } } } diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 3c058cbad..2da1fb696 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -7,6 +7,7 @@ pub mod factory; pub mod multi_stream; pub mod octet_stream; pub mod query; +pub mod redundant; pub mod tcp_channel; pub mod tcp_factory; pub mod tcp_mutex; diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs new file mode 100644 index 000000000..29e9cf69a --- /dev/null +++ b/src/net/client/redundant.rs @@ -0,0 +1,619 @@ +//! A transport that multiplexes requests over multiple redundant transports. + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +use bytes::Bytes; + +use futures::stream::FuturesUnordered; +use futures::StreamExt; + +use octseq::OctetsBuilder; + +use rand::random; + +use std::boxed::Box; +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::vec::Vec; + +use tokio::sync::{mpsc, oneshot, Mutex}; +use tokio::time::{sleep_until, Duration, Instant}; + +use crate::base::wire::Composer; +use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::net::client::error::Error; +use crate::net::client::query::{GetResult, QueryMessage2}; + +/* +Basic algorithm: +- keep track of expected response time for every upstream +- start with the upstream with the lowest expected response time +- set a timer to the expect response time. +- if the timer expires before reply arrives, send the query to the next lowest + and set a timer +- when a reply arrives update the expected response time for the relevant + upstream and for the ones that failed. + +Based on a random number generator: +- pick a different upstream rather then the best but set the timer to the + expected response time of the best. +*/ + +/// Capacity of the channel that transports [ChanReq]. +const DEF_CHAN_CAP: usize = 8; + +/// Time in milliseconds for the initial response time estimate. +const DEFAULT_RT_MS: u64 = 300; + +/// The initial response time estimate for unused connections. +const DEFAULT_RT: Duration = Duration::from_millis(DEFAULT_RT_MS); + +/// Maintain a moving average for the measured response time and the +/// square of that. The window is SMOOTH_N. +const SMOOTH_N: f64 = 8.; + +/// Chance to probe a worse connection. +const PROBE_P: f64 = 0.05; + +/// Avoid sending two requests at the same time. +/// +/// When a worse connection is probed, give it a slight head start. +const PROBE_RT: Duration = Duration::from_millis(1); + +/// This type represents a transport connection. +#[derive(Clone)] +pub struct Connection { + /// Reference to the actual implementation of the connection. + inner: Arc>, +} + +impl<'a, Octs: Clone + Composer + Debug + Send + Sync + 'static> + Connection +{ + /// Create a new connection. + pub fn new() -> io::Result> { + let connection = InnerConnection::new()?; + //test_send(connection); + Ok(Self { + inner: Arc::new(connection), + }) + } + + /// Runner function for a connection. + pub async fn run(&self) { + self.inner.run().await + } + + /// Add a transport connection. + pub async fn add( + &self, + conn: Box + Send + Sync>, + ) { + self.inner.add(conn).await + } + + /// Implementation of the query function. + async fn query_impl( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result, Error> { + Ok(Box::new(self.inner.query(query_msg.clone()).await.unwrap())) + } +} + +impl< + Octs: Clone + Composer + Debug + OctetsBuilder + Send + Sync + 'static, + > QueryMessage2 for Connection +{ + fn query<'a>( + &'a self, + query_msg: &'a mut MessageBuilder< + StaticCompressor>, + >, + ) -> Pin< + Box< + dyn Future, Error>> + + Send + + '_, + >, + > { + return Box::pin(self.query_impl(query_msg)); + } +} + +/// This type represents an active query request. +#[derive(Debug)] +pub struct Query { + /// The state of the query + state: QueryState, + + /// The query message + query_msg: MessageBuilder>>, + + /// List of connections identifiers and estimated response times. + conn_rt: Vec, + + /// Channel to send requests to the run function. + sender: mpsc::Sender>, + + /// List of futures for outstanding requests. + fut_list: + FuturesUnordered + Send>>>, + + /// The result from one of the connectons. + result: Option, Error>>, + + /// Index of the connection that returned a result. + res_index: usize, +} + +/// Result of the futures in fut_list. +type FutListOutput = (usize, Result, Error>); + +/// The various states a query can be in. +#[derive(Debug)] +enum QueryState { + /// The initial state + Init, + + /// Start a request on a specific connection. + Probe(usize), + + /// Report the response time for a specific index in the list. + Report(usize), + + /// Wait for one of the requests to finish. + Wait, +} + +impl Query { + /// Create a new query object. + fn new( + query_msg: MessageBuilder>>, + // conns: Vec<&'a dyn QueryMessage>, + mut conn_rt: Vec, + sender: mpsc::Sender>, + ) -> Query { + let conn_rt_len = conn_rt.len(); + println!("before sort:"); + for (i, item) in conn_rt.iter().enumerate().take(conn_rt_len) { + println!("{}: id {} ert {:?}", i, item.id, item.est_rt); + } + conn_rt.sort_unstable_by(conn_rt_cmp); + println!("after sort:"); + for (i, item) in conn_rt.iter().enumerate().take(conn_rt_len) { + println!("{}: id {} ert {:?}", i, item.id, item.est_rt); + } + + // Do we want to probe a less performant upstream? + if conn_rt_len > 1 && random::() < PROBE_P { + let index: usize = 1 + random::() % (conn_rt_len - 1); + conn_rt[index].est_rt = PROBE_RT; + + // Sort again + conn_rt.sort_unstable_by(conn_rt_cmp); + println!("sort for probe :"); + for (i, item) in conn_rt.iter().enumerate().take(conn_rt_len) { + println!("{}: id {} ert {:?}", i, item.id, item.est_rt); + } + } + + Query { + query_msg, + //conns, + conn_rt, + sender, + state: QueryState::Init, + fut_list: FuturesUnordered::new(), + result: None, + res_index: 0, + } + } + + /// Implementation of get_result. + async fn get_result_impl(&mut self) -> Result, Error> { + loop { + match self.state { + QueryState::Init => { + if self.conn_rt.is_empty() { + return Err(Error::NoTransportAvailable); + } + self.state = QueryState::Probe(0); + continue; + } + QueryState::Probe(ind) => { + self.conn_rt[ind].start = Some(Instant::now()); + let fut = start_request( + ind, + self.conn_rt[ind].id, + self.sender.clone(), + self.query_msg.clone(), + ); + self.fut_list.push(Box::pin(fut)); + println!("timeout {:?}", self.conn_rt[ind].est_rt); + let timeout = Instant::now() + self.conn_rt[ind].est_rt; + tokio::select! { + res = self.fut_list.next() => { + println!("got res {:?}", res); + let res = res.unwrap(); + self.result = Some(res.1); + self.res_index= res.0; + + self.state = QueryState::Report(0); + continue; + } + _ = sleep_until(timeout) => { + // Move to the next Probe state if there + // are more upstreams to try, otherwise + // move to the Wait state. + self.state = + if ind+1 < self.conn_rt.len() { + QueryState::Probe(ind+1) + } + else + { + QueryState::Wait + }; + continue; + } + } + } + QueryState::Report(ind) => { + if ind >= self.conn_rt.len() + || self.conn_rt[ind].start.is_none() + { + // Nothing more to report. Return result. + let res = self.result.take().unwrap(); + return res; + } + + let start = self.conn_rt[ind].start.unwrap(); + let elapsed = start.elapsed(); + println!( + "expected rt was {:?}", + self.conn_rt[ind].est_rt + ); + println!("reporting duration {:?}", elapsed); + let time_report = TimeReport { + id: self.conn_rt[ind].id, + elapsed, + }; + let report = if ind == self.res_index { + // Succesfull entry + ChanReq::Report(time_report) + } else { + // Failed entry + ChanReq::Failure(time_report) + }; + self.sender.send(report).await.unwrap(); + self.state = QueryState::Report(ind + 1); + continue; + } + QueryState::Wait => { + let res = self.fut_list.next().await; + println!("got res {:?}", res); + let res = res.unwrap(); + self.result = Some(res.1); + self.res_index = res.0; + self.state = QueryState::Report(0); + continue; + } + } + } + } +} + +impl< + Octs: AsMut<[u8]> + + AsRef<[u8]> + + Clone + + Composer + + Debug + + OctetsBuilder + + Send + + Sync + + 'static, + > GetResult for Query +{ + fn get_result( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + > { + Box::pin(self.get_result_impl()) + } +} + +/// Async function to send a request and wait for the reply. +/// +/// This gives a single future that we can put in a list. +async fn start_request( + index: usize, + id: u64, + sender: mpsc::Sender>, + query_msg: MessageBuilder>>, +) -> (usize, Result, Error>) { + let (tx, rx) = oneshot::channel(); + sender + .send(ChanReq::Query(QueryReq { + id, + query_msg: query_msg.clone(), + tx, + })) + .await + .unwrap(); + let mut query = rx.await.unwrap().unwrap(); + let reply = query.get_result().await; + + (index, reply) +} + +/// The commands that can be sent to the run function. +enum ChanReq { + /// Add a connection + Add(AddReq), + + /// Get the list of estimated response times for all connections + GetRT(RTReq), + + /// Start a query + Query(QueryReq), + + /// Report how long it took to get a response + Report(TimeReport), + + /// Report that a connection failed to provide a timely response + Failure(TimeReport), +} + +impl Debug for ChanReq { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + f.debug_struct("ChanReq").finish() + } +} + +/// Request to add a new connection +struct AddReq { + /// New connection to add + conn: Box + Send + Sync>, + + /// Channel to send the reply to + tx: oneshot::Sender, +} + +/// Reply to an Add request +type AddReply = Result<(), Error>; + +/// Request to give the estimated response times for all connections +struct RTReq /**/ { + /// Channel to send the reply to + tx: oneshot::Sender, +} + +/// Reply to a RT request +type RTReply = Result, Error>; + +/// Request to start a query +struct QueryReq { + /// Identifier of connection + id: u64, + + /// Request message + query_msg: MessageBuilder>>, + + /// Channel to send the reply to + tx: oneshot::Sender, +} + +impl Debug for QueryReq { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + f.debug_struct("QueryReq") + .field("id", &self.id) + .field("query_msg", &self.query_msg) + .finish() + } +} + +/// Reply to a query request. +type QueryReply = Result, Error>; + +/// Report the amount of time until success or failure. +#[derive(Debug)] +struct TimeReport { + /// Identifier of the transport connection. + id: u64, + + /// Time spend waiting for a reply. + elapsed: Duration, +} + +/// Connection statistics to compute the estimated response time. +struct ConnStats { + /// Aproximation of the windowed average of response times. + mean: f64, + + /// Aproximation of the windowed average of the square of response times. + mean_sq: f64, +} + +/// Data required to schedule requests and report timing results. +#[derive(Clone, Debug)] +struct ConnRT { + /// Estimated response time. + est_rt: Duration, + + /// Identifier of the connection. + id: u64, + + /// Start of a request using this connection. + start: Option, +} + +/// Compare ConnRT elements based on estimated response time. +fn conn_rt_cmp(e1: &ConnRT, e2: &ConnRT) -> Ordering { + e1.est_rt.cmp(&e2.est_rt) +} + +/// Type that actually implements the connection. +struct InnerConnection { + /// Receive side of the channel used by the runner. + receiver: Mutex>>>, + + /// To send a request to the runner. + sender: mpsc::Sender>, +} + +impl<'a, Octs: Clone + Debug + Send + Sync + 'static> InnerConnection { + /// Implementation of the new method. + fn new() -> io::Result> { + let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); + Ok(Self { + receiver: Mutex::new(Some(rx)), + sender: tx, + }) + } + + /// Implementation of the run method. + async fn run(&self) { + let mut next_id: u64 = 10; + let mut conn_stats: Vec = Vec::new(); + let mut conn_rt: Vec = Vec::new(); + let mut conns: Vec + Send + Sync>> = + Vec::new(); + + let mut receiver = self.receiver.lock().await; + let opt_receiver = receiver.take(); + drop(receiver); + let mut receiver = opt_receiver.unwrap(); + loop { + let req = receiver.recv().await.unwrap(); + match req { + ChanReq::Add(add_req) => { + let id = next_id; + next_id += 1; + conn_stats.push(ConnStats { + mean: (DEFAULT_RT_MS as f64) / 1000., + mean_sq: 0., + }); + conn_rt.push(ConnRT { + id, + est_rt: DEFAULT_RT, + start: None, + }); + conns.push(add_req.conn); + add_req.tx.send(Ok(())).unwrap(); + } + ChanReq::GetRT(rt_req) => { + rt_req.tx.send(Ok(conn_rt.clone())).unwrap(); + } + ChanReq::Query(mut query_req) => { + println!("QueryReq for id {}", query_req.id); + let opt_ind = + conn_rt.iter().position(|e| e.id == query_req.id); + let ind = opt_ind.unwrap(); + println!("QueryReq for ind {}", ind); + let query = + conns[ind].query(&mut query_req.query_msg).await; + query_req.tx.send(query).unwrap(); + } + ChanReq::Report(time_report) => { + println!( + "for {} time {:?}", + time_report.id, time_report.elapsed + ); + let opt_ind = + conn_rt.iter().position(|e| e.id == time_report.id); + let ind = opt_ind.unwrap(); + println!("Report for ind {}", ind); + let elapsed = time_report.elapsed.as_secs_f64(); + conn_stats[ind].mean += + (elapsed - conn_stats[ind].mean) / SMOOTH_N; + let elapsed_sq = elapsed * elapsed; + conn_stats[ind].mean_sq += + (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N; + println!( + "new mean {} mean_sq {}", + conn_stats[ind].mean, conn_stats[ind].mean_sq + ); + let mean = conn_stats[ind].mean; + let var = conn_stats[ind].mean_sq - mean * mean; + let std_dev = if var < 0. { 0. } else { f64::sqrt(var) }; + println!("std dev {}", std_dev); + let est_rt = mean + 3. * std_dev; + conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt); + println!("new est_rt {:?}", conn_rt[ind].est_rt); + } + ChanReq::Failure(time_report) => { + println!( + "failure for {} time {:?}", + time_report.id, time_report.elapsed + ); + let opt_ind = + conn_rt.iter().position(|e| e.id == time_report.id); + let ind = opt_ind.unwrap(); + println!("Failure Report for ind {}", ind); + let elapsed = time_report.elapsed.as_secs_f64(); + if elapsed < conn_stats[ind].mean { + // Do not update the mean if a + // failure took less time than the + // current mean. + println!("ignoring better time"); + continue; + } + conn_stats[ind].mean += + (elapsed - conn_stats[ind].mean) / SMOOTH_N; + let elapsed_sq = elapsed * elapsed; + conn_stats[ind].mean_sq += + (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N; + println!( + "new mean {} mean_sq {}", + conn_stats[ind].mean, conn_stats[ind].mean_sq + ); + let mean = conn_stats[ind].mean; + let var = conn_stats[ind].mean_sq - mean * mean; + let std_dev = if var < 0. { 0. } else { f64::sqrt(var) }; + println!("std dev {}", std_dev); + let est_rt = mean + 3. * std_dev; + conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt); + println!("new est_rt {:?}", conn_rt[ind].est_rt); + } + } + } + } + + /// Implementation of the add method. + async fn add(&self, conn: Box + Send + Sync>) { + let (tx, rx) = oneshot::channel(); + self.sender + .send(ChanReq::Add(AddReq { conn, tx })) + .await + .unwrap(); + rx.await.unwrap().unwrap(); + } + + /// Implementation of the query method. + async fn query( + &'a self, + query_msg: MessageBuilder>>, + ) -> Result, Error> { + let (tx, rx) = oneshot::channel(); + self.sender + .send(ChanReq::GetRT(RTReq { tx })) + .await + .unwrap(); + let conn_rt = rx.await.unwrap().unwrap(); + Ok(Query::new( + query_msg, + //conns, + conn_rt, + self.sender.clone(), + )) + } +} + +//fn test_send(t: T) -> T { t } From 6e2a907252ed7a81bda1a2a9f2dd4ab5bb3ba886 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 21 Sep 2023 16:36:57 +0200 Subject: [PATCH 040/124] Redo examples. --- Cargo.toml | 7 +- examples/client-transports.rs | 130 +++++++++++++++++++++++++++++++++ examples/tcp-client.rs | 44 ----------- examples/tls-client.rs | 64 ---------------- src/net/client/mod.rs | 5 ++ src/net/client/octet_stream.rs | 8 -- 6 files changed, 136 insertions(+), 122 deletions(-) create mode 100644 examples/client-transports.rs delete mode 100644 examples/tcp-client.rs delete mode 100644 examples/tls-client.rs diff --git a/Cargo.toml b/Cargo.toml index 1ef0f3752..29a835730 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -90,10 +90,5 @@ name = "client" required-features = ["std", "rand"] [[example]] -name = "tcp-client" +name = "client-transports" required-features = ["net"] - -[[example]] -name = "tls-client" -required-features = ["net"] - diff --git a/examples/client-transports.rs b/examples/client-transports.rs new file mode 100644 index 000000000..33bfedf19 --- /dev/null +++ b/examples/client-transports.rs @@ -0,0 +1,130 @@ +use domain::base::Dname; +use domain::base::Rtype::Aaaa; +use domain::base::{MessageBuilder, StaticCompressor, StreamTarget}; +use domain::net::client::multi_stream; +use domain::net::client::query::QueryMessage2; +use domain::net::client::redundant; +use domain::net::client::tcp_factory::TcpConnFactory; +use domain::net::client::tls_factory::TlsConnFactory; +use domain::net::client::udp_tcp; +use std::net::{IpAddr, SocketAddr}; +use std::str::FromStr; +use std::sync::Arc; +use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore}; + +#[tokio::main] +async fn main() { + // Create DNS request message + // Create a message builder wrapping a compressor wrapping a stream + // target. + let mut msg = + MessageBuilder::from_target(StaticCompressor::new(StreamTarget::new_vec())).unwrap(); + msg.header_mut().set_rd(true); + let mut msg = msg.question(); + msg.push((Dname::>::vec_from_str("example.com").unwrap(), Aaaa)) + .unwrap(); + let mut msg = msg.as_builder_mut().clone(); + + // Destination for UDP and TCP + let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); + + // Create a new UDP+TCP transport connection. Pass the destination address + // and port as parameter. + let udptcp_conn = udp_tcp::Connection::new(server_addr).unwrap(); + + // Create a clone for the run function. Start the run function on a + // separate task. + let conn_run = udptcp_conn.clone(); + tokio::spawn(async move { + conn_run.run().await; + }); + + // Send a query message. + let mut query = udptcp_conn.query(&mut msg).await.unwrap(); + + // Get the reply + let reply = query.get_result().await; + println!("UDP+TCP reply: {:?}", reply); + + // Create a factory of TCP connections. Pass the destination address and + // port as parameter. + let tcp_factory = TcpConnFactory::new(server_addr); + + // A muli_stream transport connection sets up new TCP connections when + // needed. + let tcp_conn = multi_stream::Connection::>::new().unwrap(); + + // Start the run function as a separate task. The run function receives + // the factory as a parameter. + let conn_run = tcp_conn.clone(); + tokio::spawn(async move { + conn_run.run(tcp_factory).await; + }); + + // Send a query message. + let mut query = tcp_conn.query(&mut msg).await.unwrap(); + + // Get the reply + let reply = query.get_result().await; + println!("TCP reply: {:?}", reply); + + // Some TLS boiler plate for the root certificates. + let mut root_store = RootCertStore::empty(); + root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + + // TLS config + let client_config = Arc::new( + ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(), + ); + + // Currently the only support TLS connections are the ones that have a + // valid certificate. Use a well known public resolver. + let server_addr = SocketAddr::new(IpAddr::from_str("8.8.8.8").unwrap(), 853); + + // Create a new TLS connection factory. We pass the TLS config, the name of + // the remote server and the destination address and port. + let tls_factory = TlsConnFactory::new(client_config, "dns.google", server_addr); + + // Again create a multi_stream transport connection. + let tls_conn = multi_stream::Connection::new().unwrap(); + + // Can start the run function. + let conn_run = tls_conn.clone(); + tokio::spawn(async move { + conn_run.run(tls_factory).await; + }); + + let mut query = tls_conn.query(&mut msg).await.unwrap(); + let reply = query.get_result().await; + println!("TLS reply: {:?}", reply); + + // Create a transport connection for redundant connections. + let redun = redundant::Connection::new().unwrap(); + + // Start the run function on a separate task. + let redun_run = redun.clone(); + tokio::spawn(async move { + redun_run.run().await; + }); + + // Add the previously created transports. + redun.add(Box::new(udptcp_conn)).await; + redun.add(Box::new(tcp_conn)).await; + redun.add(Box::new(tls_conn)).await; + + // Start a few queries. + for _i in 1..10 { + let mut query = redun.query(&mut msg).await.unwrap(); + let reply = query.get_result().await; + println!("redundant connection reply: {:?}", reply); + } +} diff --git a/examples/tcp-client.rs b/examples/tcp-client.rs deleted file mode 100644 index 62e7bba56..000000000 --- a/examples/tcp-client.rs +++ /dev/null @@ -1,44 +0,0 @@ -use domain::base::Dname; -use domain::base::Rtype::Aaaa; -use domain::base::{MessageBuilder, StaticCompressor, StreamTarget}; -use domain::net::client::octet_stream::Connection; -use std::net::{IpAddr, SocketAddr}; -use std::str::FromStr; -use tokio::net::TcpStream; - -#[tokio::main] -async fn main() { - // Create DNS request message - // Create a message builder wrapping a compressor wrapping a stream - // target. - let mut msg = MessageBuilder::from_target(StaticCompressor::new( - StreamTarget::new_vec(), - )) - .unwrap(); - msg.header_mut().set_rd(true); - let mut msg = msg.question(); - msg.push((Dname::>::vec_from_str("example.com").unwrap(), Aaaa)) - .unwrap(); - let mut msg = msg.as_builder_mut().clone(); - - let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); - - let tcp = match TcpStream::connect(server_addr).await { - Err(err) => { - println!("TCP connection failed with {}", err); - return; - } - Ok(tcp) => tcp, - }; - - let conn = Connection::new().unwrap(); - let conn_run = conn.clone(); - - tokio::spawn(async move { - conn_run.run(tcp).await; - }); - - let mut query = conn.query(&mut msg).await.unwrap(); - let reply = query.get_result().await; - println!("reply: {:?}", reply); -} diff --git a/examples/tls-client.rs b/examples/tls-client.rs deleted file mode 100644 index 9cf80dc47..000000000 --- a/examples/tls-client.rs +++ /dev/null @@ -1,64 +0,0 @@ -use domain::base::Dname; -use domain::base::Rtype::Aaaa; -use domain::base::{MessageBuilder, StaticCompressor, StreamTarget}; -use domain::net::client::octet_stream::Connection; -use rustls::{ClientConfig, ServerName}; -use std::net::{IpAddr, SocketAddr}; -use std::str::FromStr; -use std::sync::Arc; -use tokio::net::TcpStream; -use tokio_rustls::TlsConnector; - -#[tokio::main] -async fn main() { - // Create DNS request message - // Create a message builder wrapping a compressor wrapping a stream - // target. - let mut msg = MessageBuilder::from_target(StaticCompressor::new( - StreamTarget::new_vec(), - )) - .unwrap(); - msg.header_mut().set_rd(true); - let mut msg = msg.question(); - msg.push((Dname::>::vec_from_str("example.com").unwrap(), Aaaa)) - .unwrap(); - let mut msg = msg.as_builder_mut().clone(); - - let server_addr = - SocketAddr::new(IpAddr::from_str("8.8.8.8").unwrap(), 853); - - let mut root_store = rustls::RootCertStore::empty(); - root_store.add_server_trust_anchors( - webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { - rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }), - ); - let client_config = Arc::new( - ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_no_client_auth(), - ); - - let tls_connection = TlsConnector::from(client_config); - - let tcp = TcpStream::connect(server_addr).await.unwrap(); - - let server_name = ServerName::try_from("dns.google").unwrap(); - let tls = tls_connection.connect(server_name, tcp).await.unwrap(); - - let conn = Connection::new().unwrap(); - let conn_run = conn.clone(); - - tokio::spawn(async move { - conn_run.run(tls).await; - }); - - let mut query = conn.query(&mut msg).await.unwrap(); - let reply = query.get_result().await; - println!("reply: {:?}", reply); -} diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 2da1fb696..a08a26543 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -2,6 +2,11 @@ #![cfg(feature = "net")] #![cfg_attr(docsrs, doc(cfg(feature = "net")))] +//! # Example with various transport connections +//! ``` +#![doc = include_str!("../../../examples/client-transports.rs")] +//! ``` + pub mod error; pub mod factory; pub mod multi_stream; diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 6985eefa9..8b43ef70a 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -1,12 +1,4 @@ //! A DNS over octet stream transport -//! # Example with TCP connection to port 53 -//! ``` -#![doc = include_str!("../../../examples/tcp-client.rs")] -//! ``` -//! # Example with TLS connection to port 853 -//! ``` -#![doc = include_str!("../../../examples/tls-client.rs")] -//! ``` #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] From 3302a3c2b61361b41c3e73e4bad436f14d3d5d5d Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 21 Sep 2023 16:40:09 +0200 Subject: [PATCH 041/124] Cargo fmt. --- examples/client-transports.rs | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 33bfedf19..83dfd8699 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -17,8 +17,10 @@ async fn main() { // Create DNS request message // Create a message builder wrapping a compressor wrapping a stream // target. - let mut msg = - MessageBuilder::from_target(StaticCompressor::new(StreamTarget::new_vec())).unwrap(); + let mut msg = MessageBuilder::from_target(StaticCompressor::new( + StreamTarget::new_vec(), + )) + .unwrap(); msg.header_mut().set_rd(true); let mut msg = msg.question(); msg.push((Dname::>::vec_from_str("example.com").unwrap(), Aaaa)) @@ -70,13 +72,15 @@ async fn main() { // Some TLS boiler plate for the root certificates. let mut root_store = RootCertStore::empty(); - root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map( + |ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }, + )); // TLS config let client_config = Arc::new( @@ -88,11 +92,13 @@ async fn main() { // Currently the only support TLS connections are the ones that have a // valid certificate. Use a well known public resolver. - let server_addr = SocketAddr::new(IpAddr::from_str("8.8.8.8").unwrap(), 853); + let server_addr = + SocketAddr::new(IpAddr::from_str("8.8.8.8").unwrap(), 853); // Create a new TLS connection factory. We pass the TLS config, the name of // the remote server and the destination address and port. - let tls_factory = TlsConnFactory::new(client_config, "dns.google", server_addr); + let tls_factory = + TlsConnFactory::new(client_config, "dns.google", server_addr); // Again create a multi_stream transport connection. let tls_conn = multi_stream::Connection::new().unwrap(); From 35170b7d668b0b33aac80bacecd39874ba709f37 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Tue, 3 Oct 2023 17:21:45 +0200 Subject: [PATCH 042/124] Set ID in the header to new random value each time before sending a request. Changed from MessagedBuilder to using Message. --- src/net/client/udp.rs | 57 ++++++++++++++++++++------------------- src/net/client/udp_tcp.rs | 2 +- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index ae9bddd0b..897a0209e 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -7,7 +7,7 @@ // - cookies // - random port -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use std::boxed::Box; use std::fmt::Debug; use std::future::Future; @@ -56,7 +56,7 @@ impl Connection { async fn query_impl + Clone + Send>( &self, query_msg: &mut MessageBuilder>>, - ) -> Result, Error> { + ) -> Result { self.inner.query(query_msg, self.clone()).await } @@ -77,7 +77,7 @@ impl Connection { } } -impl + Clone + Debug + Send> QueryMessage, Octs> +impl + Clone + Debug + Send> QueryMessage for Connection { fn query<'a>( @@ -85,8 +85,7 @@ impl + Clone + Debug + Send> QueryMessage, Octs> query_msg: &'a mut MessageBuilder< StaticCompressor>, >, - ) -> Pin, Error>> + Send + '_>> - { + ) -> Pin> + Send + '_>> { return Box::pin(self.query_impl(query_msg)); } } @@ -131,12 +130,12 @@ enum QueryState { /// The state of a DNS query. #[derive(Debug)] -pub struct Query { +pub struct Query { /// Address of remote server to connect to. remote_addr: SocketAddr, /// DNS request message. - query_msg: MessageBuilder>>, + query_msg: Message, /// Semaphore permit that allow use of socket. _permit: Option, @@ -151,15 +150,15 @@ pub struct Query { state: QueryState, } -impl + Clone + Send> Query { +impl Query { /// Create new Query object. fn new( - query_msg: &mut MessageBuilder>>, + query_msg: Message, remote_addr: SocketAddr, conn: Connection, - ) -> Query { + ) -> Query { Query { - query_msg: query_msg.clone(), + query_msg, remote_addr, _permit: None, sock: None, @@ -202,11 +201,11 @@ impl + Clone + Send> Query { continue; } QueryState::Send => { - let dgram = self - .query_msg - .as_target() - .as_target() - .as_dgram_slice(); + // Set random ID in header + let header = self.query_msg.header_mut(); + header.set_random_id(); + let dgram = self.query_msg.as_slice(); + let sent = self .sock .as_ref() @@ -214,14 +213,7 @@ impl + Clone + Send> Query { .send(dgram) .await .map_err(|e| Error::UdpSend(Arc::new(e)))?; - if sent - != self - .query_msg - .as_target() - .as_target() - .as_dgram_slice() - .len() - { + if sent != self.query_msg.as_slice().len() { return Err(Error::UdpShortSend); } self.state = QueryState::Receive(Instant::now()); @@ -262,7 +254,14 @@ impl + Clone + Send> Query { Ok(answer) => answer, Err(_) => continue, }; - if !answer.is_answer(&self.query_msg.as_message()) { + + // Unfortunately we cannot pass query_msg to is_answer + // because is_answer requires Octets, which is not + // implemented by BytesMut. Make a copy. + let query_msg = + Message::from_octets(self.query_msg.as_slice()) + .unwrap(); + if !answer.is_answer(&query_msg) { continue; } self.sock = None; @@ -299,7 +298,7 @@ impl + Clone + Send> Query { } } -impl + Clone + Debug + Send> GetResult for Query { +impl GetResult for Query { fn get_result( &mut self, ) -> Pin< @@ -333,7 +332,11 @@ impl InnerConnection { &self, query_msg: &mut MessageBuilder>>, conn: Connection, - ) -> Result, Error> { + ) -> Result { + let slice = query_msg.as_target().as_target().as_dgram_slice(); + let mut bytes = BytesMut::with_capacity(slice.len()); + bytes.extend_from_slice(slice); + let query_msg = Message::from_octets(bytes).unwrap(); Ok(Query::new(query_msg, self.remote_addr, conn)) } diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index b97159ad3..a927276c3 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -129,7 +129,7 @@ enum QueryState { StartUdpQuery, /// Get the result from the UDP transport. - GetUdpResult(udp::Query), + GetUdpResult(udp::Query), /// Start a query over the TCP transport. StartTcpQuery, From 80f8479f3c79085414f134dfb5ec46d396b673e1 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Wed, 4 Oct 2023 10:49:26 +0200 Subject: [PATCH 043/124] Introduction of QueryMessage3, which takes a Message instead of a MessageBuilder as in the previous traits. --- src/net/client/multi_stream.rs | 134 +++++++++++-------------------- src/net/client/octet_stream.rs | 141 ++++++++++++++++++--------------- src/net/client/query.rs | 11 +++ src/net/client/redundant.rs | 53 ++++++------- src/net/client/udp.rs | 45 ++++++++++- src/net/client/udp_tcp.rs | 91 ++++++++------------- 6 files changed, 236 insertions(+), 239 deletions(-) diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 471e8e11e..5f6193c96 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -12,7 +12,7 @@ use futures::lock::Mutex as Futures_mutex; use futures::stream::FuturesUnordered; use futures::StreamExt; -use octseq::{Octets, OctetsBuilder}; +use octseq::Octets; use rand::random; @@ -29,13 +29,12 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::{mpsc, oneshot}; use tokio::time::{sleep_until, Instant}; -use crate::base::wire::Composer; use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; use crate::net::client::error::Error; use crate::net::client::factory::ConnFactory; use crate::net::client::octet_stream::Connection as SingleConnection; use crate::net::client::octet_stream::QueryNoCheck as SingleQuery; -use crate::net::client::query::{GetResult, QueryMessage, QueryMessage2}; +use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; /// Capacity of the channel that transports [ChanReq]. const DEF_CHAN_CAP: usize = 8; @@ -46,7 +45,7 @@ const ERR_CONN_CLOSED: &str = "connection closed"; /// Response to the DNS request sent by [InnerConnection::run] to [Query]. #[derive(Debug)] -struct ChanRespOk { +struct ChanRespOk> { /// id of this connection. id: u64, @@ -62,7 +61,7 @@ type ReplySender = oneshot::Sender>; #[derive(Debug)] /// Commands that can be requested. -enum ReqCmd { +enum ReqCmd> { /// Request for a (new) connection. /// /// The id of the previous connection (if any) is passed as well as a @@ -76,14 +75,14 @@ enum ReqCmd { #[derive(Debug)] /// A request to [Connection::run] either for a new octet_stream or to /// shutdown. -struct ChanReq { +struct ChanReq> { /// A requests consists of a command. cmd: ReqCmd, } /// The actual implementation of [Connection]. #[derive(Debug)] -struct InnerConnection { +struct InnerConnection> { /// [InnerConnection::sender] and [InnerConnection::receiver] are /// part of a single channel. /// @@ -101,14 +100,14 @@ struct InnerConnection { #[derive(Clone, Debug)] /// A DNS over octect streams transport. -pub struct Connection { +pub struct Connection> { /// Reference counted [InnerConnection]. inner: Arc>, } /// Status of a query. Used in [Query]. #[derive(Debug)] -enum QueryState { +enum QueryState> { /// Get a octet_stream transport. GetConn(oneshot::Receiver>), @@ -146,7 +145,7 @@ struct ErrorState { } /// State of the current underlying octet_stream transport. -enum SingleConnState { +enum SingleConnState3> { /// No current octet_stream transport. None, @@ -161,9 +160,9 @@ enum SingleConnState { /// Internal datastructure of [InnerConnection::run] to keep track of /// the status of the connection. // The types Status and ConnState are only used in InnerConnection -struct State<'a, F, IO, Octs: OctetsBuilder> { +struct State3<'a, F, IO, Octs: AsRef<[u8]>> { /// Underlying octet_stream connection. - conn_state: SingleConnState, + conn_state: SingleConnState3, /// Current connection id. conn_id: u64, @@ -183,13 +182,13 @@ struct State<'a, F, IO, Octs: OctetsBuilder> { /// This struct represent an active DNS query. #[derive(Debug)] -pub struct Query { +pub struct Query> { /// Request message. /// /// The reply message is compared with the request message to see if /// it matches the query. // query_msg: Message>, - query_msg: MessageBuilder>>, + query_msg: Message, /// Current state of the query. state: QueryState, @@ -207,15 +206,8 @@ pub struct Query { delayed_retry_count: u64, } -impl< - Octs: 'static - + AsMut<[u8]> - + Clone - + Composer - + Debug - + OctetsBuilder - + Send, - > InnerConnection +impl + Clone + Octets + Send + 'static> + InnerConnection { /// Constructor for [InnerConnection]. /// @@ -249,8 +241,8 @@ impl< }; let mut curr_cmd: Option> = None; - let mut state = State::<'a, F, IO, Octs> { - conn_state: SingleConnState::None, + let mut state = State3::<'a, F, IO, Octs> { + conn_state: SingleConnState3::None, conn_id: 0, factory, runners: FuturesUnordered::< @@ -271,7 +263,7 @@ impl< curr_cmd = None; match req { ReqCmd::NewConn(opt_id, chan) => { - if let SingleConnState::Err(error_state) = + if let SingleConnState3::Err(error_state) = &state.conn_state { if error_state.timer.elapsed() @@ -297,12 +289,12 @@ impl< // current one. This is the best place to // increment conn_id. state.conn_id += 1; - state.conn_state = SingleConnState::None; + state.conn_state = SingleConnState3::None; } } // If we still have a connection then we can reply // immediately. - if let SingleConnState::Some(conn) = &state.conn_state + if let SingleConnState3::Some(conn) = &state.conn_state { let resp = ChanResp::Ok(ChanRespOk { id: state.conn_id, @@ -333,19 +325,19 @@ impl< if let Err(error) = res_conn { let error = Arc::new(error); match state.conn_state { - SingleConnState::None => + SingleConnState3::None => state.conn_state = - SingleConnState::Err(ErrorState { + SingleConnState3::Err(ErrorState { error: error.clone(), retries: 0, timer: Instant::now(), timeout: retry_time(0), }), - SingleConnState::Some(_) => + SingleConnState3::Some(_) => panic!("Illegal Some state"), - SingleConnState::Err(error_state) => { + SingleConnState3::Err(error_state) => { state.conn_state = - SingleConnState::Err(ErrorState { + SingleConnState3::Err(ErrorState { error: error_state.error.clone(), retries: error_state.retries+1, timer: Instant::now(), @@ -382,7 +374,7 @@ impl< id: state.conn_id, conn: conn.clone(), }); - state.conn_state = SingleConnState::Some(conn); + state.conn_state = SingleConnState3::Some(conn); let loc_opt_chan = opt_chan.take(); @@ -463,16 +455,8 @@ impl< } } -impl< - Octs: 'static - + AsMut<[u8]> - + AsRef<[u8]> - + Clone - + Composer - + Debug - + OctetsBuilder - + Send, - > Connection +impl + Clone + Debug + Octets + Send + Sync + 'static> + Connection { /// Constructor for [Connection]. /// @@ -502,22 +486,9 @@ impl< /// /// This function takes a precomposed message as a parameter and /// returns a [Query] object wrapped in a [Result]. - pub async fn query_impl( + pub async fn query_impl3( &self, - query_msg: &mut MessageBuilder>>, - ) -> Result, Error> { - let (tx, rx) = oneshot::channel(); - self.inner.new_conn(None, tx).await?; - Ok(Query::new(self.clone(), query_msg, rx)) - } - - /// Start a DNS request. - /// - /// This function takes a precomposed message as a parameter and - /// returns a [Query] object wrapped in a [Result]. - pub async fn query_impl2( - &self, - query_msg: &mut MessageBuilder>>, + query_msg: &Message, ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.inner.new_conn(None, tx).await?; @@ -540,28 +511,29 @@ impl< } } -impl +impl QueryMessage, Octs> for Connection { fn query<'a>( &'a self, - query_msg: &'a mut MessageBuilder< + _query_msg: &'a mut MessageBuilder< StaticCompressor>, >, ) -> Pin, Error>> + Send + '_>> { - return Box::pin(self.query_impl(query_msg)); + todo!(); + /* + return Box::pin(self.query_impl3(query_msg)); + */ } } -impl - QueryMessage2 for Connection +impl + Clone + Debug + Octets + Send + Sync + 'static> + QueryMessage3 for Connection { fn query<'a>( &'a self, - query_msg: &'a mut MessageBuilder< - StaticCompressor>, - >, + query_msg: &'a Message, ) -> Pin< Box< dyn Future, Error>> @@ -569,26 +541,18 @@ impl + '_, >, > { - return Box::pin(self.query_impl2(query_msg)); + return Box::pin(self.query_impl3(query_msg)); } } -impl< - Octs: AsRef<[u8]> - + AsMut<[u8]> - + Composer - + OctetsBuilder - + Clone - + Debug - + Send - + 'static, - > Query +impl + Clone + Debug + Octets + Send + Sync + 'static> + Query { /// Constructor for [Query], takes a DNS query and a receiver for the /// reply. fn new( conn: Connection, - query_msg: &mut MessageBuilder>>, + query_msg: &Message, receiver: oneshot::Receiver>, ) -> Query { Self { @@ -704,16 +668,8 @@ impl< } } -impl< - Octs: AsMut<[u8]> - + AsRef<[u8]> - + Clone - + Composer - + Debug - + OctetsBuilder - + Send - + 'static, - > GetResult for Query +impl + Clone + Debug + Octets + Send + Sync + 'static> + GetResult for Query { fn get_result( &mut self, diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 8b43ef70a..a2ffad556 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -29,16 +29,15 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use std::vec::Vec; -use crate::base::wire::Composer; use crate::base::{ opt::{AllOptData, Opt, OptRecord, TcpKeepalive}, Message, MessageBuilder, ParsedDname, Rtype, StaticCompressor, StreamTarget, }; use crate::net::client::error::Error; -use crate::net::client::query::{GetResult, QueryMessage}; +use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; use crate::rdata::AllRecordData; -use octseq::{Octets, OctetsBuilder}; +use octseq::Octets; use tokio::io; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -68,9 +67,9 @@ type ReplySender = oneshot::Sender; #[derive(Debug)] /// A request from [Query] to [Connection::run] to start a DNS request. -struct ChanReq { +struct ChanReq> { /// DNS request message - msg: MessageBuilder>>, + msg: Message, /// Sender to send result back to [Query] sender: ReplySender, @@ -87,7 +86,7 @@ type ChanResp = Result; /// The actual implementation of [Connection]. #[derive(Debug)] -struct InnerConnection { +struct InnerConnection> { /// [InnerConnection::sender] and [InnerConnection::receiver] are /// part of a single channel. /// @@ -118,7 +117,7 @@ struct Queries { #[derive(Clone, Debug)] /// A single DNS over octect stream connection. -pub struct Connection { +pub struct Connection> { /// Reference counted [InnerConnection]. inner: Arc>, } @@ -212,9 +211,7 @@ enum ConnState { // This type could be local to InnerConnection, but I don't know how type ReaderChanReply = Message; -impl + Clone + Composer + Debug + OctetsBuilder> - InnerConnection -{ +impl + Clone> InnerConnection { /// Constructor for [InnerConnection]. /// /// This is the implementation of [Connection::new]. @@ -410,7 +407,7 @@ impl + Clone + Composer + Debug + OctetsBuilder> pub async fn query( &self, sender: oneshot::Sender, - query_msg: &mut MessageBuilder>>, + query_msg: &Message, ) -> Result<(), Error> { // We should figure out how to get query_msg. let msg_clone = query_msg.clone(); @@ -584,7 +581,7 @@ impl + Clone + Composer + Debug + OctetsBuilder> // Note: maybe reqmsg should be a return value. fn insert_req( &self, - mut req: ChanReq, + req: ChanReq, status: &mut Status, reqmsg: &mut Option>, query_vec: &mut Queries, @@ -643,11 +640,14 @@ impl + Clone + Composer + Debug + OctetsBuilder> // if many spoofed answers arrive over UDP: "TCP, by the // nature of its use of sequence numbers, is far more // resilient against forgery by third parties." - let hdr = req.msg.header_mut(); + + let mut mut_msg = + Message::from_octets(req.msg.as_slice().to_vec()).unwrap(); + let hdr = mut_msg.header_mut(); hdr.set_id(ind16); if status.send_keepalive { - let res_msg = add_tcp_keepalive(&req.msg); + let res_msg = add_tcp_keepalive(&mut_msg); match res_msg { Ok(msg) => { @@ -687,15 +687,26 @@ impl + Clone + Composer + Debug + OctetsBuilder> /// Convert the query message to a vector. // This function should return the vector instead of storing it // through a reference. - fn convert_query + AsRef<[u8]>>( - msg: &MessageBuilder>>, + fn convert_query>( + msg: &Message, reqmsg: &mut Option>, ) { - let vec = msg.as_target().as_target().as_stream_slice(); + // Ideally there should be a write_all_vectored. Until there is one, + // copy to a new Vec and prepend the length octets. + + let slice = msg.as_slice(); + let len = slice.len(); + + println!("convert_query: slice len {}, slice {:?}", len, slice); + + let mut vec = Vec::with_capacity(2 + len); + let len16 = len as u16; + vec.extend_from_slice(&len16.to_be_bytes()); + vec.extend_from_slice(slice); - // Store a clone of the request. That makes life easier - // and requests tend to be small - *reqmsg = Some(vec.to_vec()); + println!("convert_query: vec {:?}", vec); + + *reqmsg = Some(vec); } /// Insert a sender (for the reply) in the query_vec and return the index. @@ -763,16 +774,23 @@ impl + Clone + Composer + Debug + OctetsBuilder> } } -impl< - Octs: AsMut<[u8]> - + AsRef<[u8]> - + Clone - + Composer - + Debug - + OctetsBuilder - + Send, - > Connection +impl + AsRef<[u8]> + Clone + Debug + Octets + Send> + QueryMessage for Connection { + fn query<'a>( + &'a self, + _query_msg: &'a mut MessageBuilder< + StaticCompressor>, + >, + ) -> Pin> + Send + '_>> { + todo!(); + /* + return Box::pin(self.query_impl(query_msg)); + */ + } +} + +impl Connection { /// Constructor for [Connection]. /// /// Returns a [Connection] wrapped in a [Result](io::Result). @@ -800,12 +818,12 @@ impl< /// returns a [Query] object wrapped in a [Result]. async fn query_impl( &self, - query_msg: &mut MessageBuilder>>, - ) -> Result { + query_msg: &Message, + ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.inner.query(tx, query_msg).await?; - let msg = &query_msg.as_message(); - Ok(Query::new(msg, rx)) + let msg = query_msg; + Ok(Box::new(Query::new(msg, rx))) } /// Start a DNS request but do not check if the reply matches the request. @@ -814,7 +832,7 @@ impl< /// match the request avoids having to keep the request around. pub async fn query_no_check( &self, - query_msg: &mut MessageBuilder>>, + query_msg: &mut Message, ) -> Result { let (tx, rx) = oneshot::channel(); self.inner.query(tx, query_msg).await?; @@ -822,22 +840,19 @@ impl< } } -impl< - Octs: AsMut<[u8]> - + AsRef<[u8]> - + Clone - + Composer - + Debug - + OctetsBuilder - + Send, - > QueryMessage for Connection +impl + Clone + Octets + Send + Sync> QueryMessage3 + for Connection { fn query<'a>( &'a self, - query_msg: &'a mut MessageBuilder< - StaticCompressor>, + query_msg: &'a Message, + ) -> Pin< + Box< + dyn Future, Error>> + + Send + + '_, >, - ) -> Pin> + Send + '_>> { + > { return Box::pin(self.query_impl(query_msg)); } } @@ -953,27 +968,27 @@ impl QueryNoCheck { /// This is surprisingly difficult. We need to copy the original message to /// a new MessageBuilder because MessageBuilder has no support for changing the /// opt record. -fn add_tcp_keepalive( - msg: &MessageBuilder>>, -) -> Result< - MessageBuilder>>, - crate::base::message_builder::PushError, -> { +fn add_tcp_keepalive( + msg: &Message, +) -> Result>, crate::base::message_builder::PushError> { // We can't just insert a new option in an existing // opt record. So we have to create new message and copy records // from the old one. And insert our option while copying the opt // record. - let src_clone = msg.clone(); - let source = Message::from_octets( - src_clone.as_target().as_target().as_dgram_slice(), - ) - .unwrap(); - let target = msg.clone(); + let source = msg; + + let mut target = + MessageBuilder::from_target(StaticCompressor::new(Vec::new())) + .unwrap(); + let source_hdr = source.header(); + let target_hdr = target.header_mut(); + target_hdr.set_flags(source_hdr.flags()); + target_hdr.set_opcode(source_hdr.opcode()); + target_hdr.set_rcode(source_hdr.rcode()); + target_hdr.set_id(source_hdr.id()); let source = source.question(); - // Go to additional and back to builder to delete all sections - // except for the header - let mut target = target.additional().builder().question(); + let mut target = target.question(); for rr in source { let rr = rr.unwrap(); target.push(rr)?; @@ -1045,7 +1060,9 @@ fn add_tcp_keepalive( // It would be nice to use .builder() here. But that one deletes all // section. We have to resort to .as_builder() which gives a // reference and then .clone() - Ok(target.as_builder().clone()) + let result = target.as_builder().clone(); + let msg = Message::from_octets(result.finish().into_target()).unwrap(); + Ok(msg) } /// Check if a DNS reply match the query. Ignore whether id fields match. diff --git a/src/net/client/query.rs b/src/net/client/query.rs index 6aca0a620..4b5878bef 100644 --- a/src/net/client/query.rs +++ b/src/net/client/query.rs @@ -39,6 +39,17 @@ pub trait QueryMessage2 { ) -> Pin + Send + '_>>; } +/// Trait for starting a DNS query based on a message. +pub trait QueryMessage3 { + /// Query function that takes a message type. + /// + /// This function is intended to be cancel safe. + fn query<'a>( + &'a self, + query_msg: &'a Message, + ) -> Pin + Send + '_>>; +} + /// This type is the actual result type of the future returned by the /// query function in the QueryMessage2 trait. type QueryResultOutput = Result, Error>; diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index 29e9cf69a..8cffa7ee4 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -25,9 +25,9 @@ use tokio::sync::{mpsc, oneshot, Mutex}; use tokio::time::{sleep_until, Duration, Instant}; use crate::base::wire::Composer; -use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::base::Message; use crate::net::client::error::Error; -use crate::net::client::query::{GetResult, QueryMessage2}; +use crate::net::client::query::{GetResult, QueryMessage3}; /* Basic algorithm: @@ -92,7 +92,7 @@ impl<'a, Octs: Clone + Composer + Debug + Send + Sync + 'static> /// Add a transport connection. pub async fn add( &self, - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, ) { self.inner.add(conn).await } @@ -100,7 +100,7 @@ impl<'a, Octs: Clone + Composer + Debug + Send + Sync + 'static> /// Implementation of the query function. async fn query_impl( &self, - query_msg: &mut MessageBuilder>>, + query_msg: &Message, ) -> Result, Error> { Ok(Box::new(self.inner.query(query_msg.clone()).await.unwrap())) } @@ -108,13 +108,11 @@ impl<'a, Octs: Clone + Composer + Debug + Send + Sync + 'static> impl< Octs: Clone + Composer + Debug + OctetsBuilder + Send + Sync + 'static, - > QueryMessage2 for Connection + > QueryMessage3 for Connection { fn query<'a>( &'a self, - query_msg: &'a mut MessageBuilder< - StaticCompressor>, - >, + query_msg: &'a Message, ) -> Pin< Box< dyn Future, Error>> @@ -128,12 +126,12 @@ impl< /// This type represents an active query request. #[derive(Debug)] -pub struct Query { +pub struct Query + Send> { /// The state of the query state: QueryState, /// The query message - query_msg: MessageBuilder>>, + query_msg: Message, /// List of connections identifiers and estimated response times. conn_rt: Vec, @@ -171,11 +169,10 @@ enum QueryState { Wait, } -impl Query { +impl + Clone + Debug + Send + Sync + 'static> Query { /// Create a new query object. fn new( - query_msg: MessageBuilder>>, - // conns: Vec<&'a dyn QueryMessage>, + query_msg: Message, mut conn_rt: Vec, sender: mpsc::Sender>, ) -> Query { @@ -336,7 +333,7 @@ async fn start_request( index: usize, id: u64, sender: mpsc::Sender>, - query_msg: MessageBuilder>>, + query_msg: Message, ) -> (usize, Result, Error>) { let (tx, rx) = oneshot::channel(); sender @@ -380,7 +377,7 @@ impl Debug for ChanReq { /// Request to add a new connection struct AddReq { /// New connection to add - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, /// Channel to send the reply to tx: oneshot::Sender, @@ -404,13 +401,13 @@ struct QueryReq { id: u64, /// Request message - query_msg: MessageBuilder>>, + query_msg: Message, /// Channel to send the reply to tx: oneshot::Sender, } -impl Debug for QueryReq { +impl + Debug + Send> Debug for QueryReq { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { f.debug_struct("QueryReq") .field("id", &self.id) @@ -468,7 +465,9 @@ struct InnerConnection { sender: mpsc::Sender>, } -impl<'a, Octs: Clone + Debug + Send + Sync + 'static> InnerConnection { +impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> + InnerConnection +{ /// Implementation of the new method. fn new() -> io::Result> { let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); @@ -483,7 +482,7 @@ impl<'a, Octs: Clone + Debug + Send + Sync + 'static> InnerConnection { let mut next_id: u64 = 10; let mut conn_stats: Vec = Vec::new(); let mut conn_rt: Vec = Vec::new(); - let mut conns: Vec + Send + Sync>> = + let mut conns: Vec + Send + Sync>> = Vec::new(); let mut receiver = self.receiver.lock().await; @@ -511,14 +510,13 @@ impl<'a, Octs: Clone + Debug + Send + Sync + 'static> InnerConnection { ChanReq::GetRT(rt_req) => { rt_req.tx.send(Ok(conn_rt.clone())).unwrap(); } - ChanReq::Query(mut query_req) => { + ChanReq::Query(query_req) => { println!("QueryReq for id {}", query_req.id); let opt_ind = conn_rt.iter().position(|e| e.id == query_req.id); let ind = opt_ind.unwrap(); println!("QueryReq for ind {}", ind); - let query = - conns[ind].query(&mut query_req.query_msg).await; + let query = conns[ind].query(&query_req.query_msg).await; query_req.tx.send(query).unwrap(); } ChanReq::Report(time_report) => { @@ -587,7 +585,7 @@ impl<'a, Octs: Clone + Debug + Send + Sync + 'static> InnerConnection { } /// Implementation of the add method. - async fn add(&self, conn: Box + Send + Sync>) { + async fn add(&self, conn: Box + Send + Sync>) { let (tx, rx) = oneshot::channel(); self.sender .send(ChanReq::Add(AddReq { conn, tx })) @@ -599,7 +597,7 @@ impl<'a, Octs: Clone + Debug + Send + Sync + 'static> InnerConnection { /// Implementation of the query method. async fn query( &'a self, - query_msg: MessageBuilder>>, + query_msg: Message, ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.sender @@ -607,12 +605,7 @@ impl<'a, Octs: Clone + Debug + Send + Sync + 'static> InnerConnection { .await .unwrap(); let conn_rt = rx.await.unwrap().unwrap(); - Ok(Query::new( - query_msg, - //conns, - conn_rt, - self.sender.clone(), - )) + Ok(Query::new(query_msg, conn_rt, self.sender.clone())) } } diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 897a0209e..8913adb8a 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -21,7 +21,9 @@ use tokio::time::{timeout, Duration, Instant}; use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; use crate::net::client::error::Error; -use crate::net::client::query::{GetResult, QueryMessage, QueryMessage2}; +use crate::net::client::query::{ + GetResult, QueryMessage, QueryMessage2, QueryMessage3, +}; /// How many times do we try a new random port if we get ‘address in use.’ const RETRY_RANDOM_PORT: usize = 10; @@ -71,6 +73,17 @@ impl Connection { Ok(Box::new(gr)) } + /// Start a new DNS query. + async fn query_impl3< + Octs: AsRef<[u8]> + Clone + Debug + Send + 'static, + >( + &self, + query_msg: &Message, + ) -> Result, Error> { + let gr = self.inner.query3(query_msg, self.clone()).await?; + Ok(Box::new(gr)) + } + /// Get a permit from the semaphore to start using a socket. async fn get_permit(&self) -> OwnedSemaphorePermit { self.inner.get_permit().await @@ -109,6 +122,23 @@ impl + Clone + Debug + Send + 'static> QueryMessage2 } } +impl + Clone + Debug + Send + Sync + 'static> + QueryMessage3 for Connection +{ + fn query<'a>( + &'a self, + query_msg: &'a Message, + ) -> Pin< + Box< + dyn Future, Error>> + + Send + + '_, + >, + > { + return Box::pin(self.query_impl3(query_msg)); + } +} + /// State of the DNS query. #[derive(Debug)] enum QueryState { @@ -340,6 +370,19 @@ impl InnerConnection { Ok(Query::new(query_msg, self.remote_addr, conn)) } + /// Return a Query object that contains the query state. + async fn query3 + Clone>( + &self, + query_msg: &Message, + conn: Connection, + ) -> Result { + let slice = query_msg.as_slice(); + let mut bytes = BytesMut::with_capacity(slice.len()); + bytes.extend_from_slice(slice); + let query_msg = Message::from_octets(bytes).unwrap(); + Ok(Query::new(query_msg, self.remote_addr, conn)) + } + /// Return a permit for a our semaphore. async fn get_permit(&self) -> OwnedSemaphorePermit { self.semaphore diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index a927276c3..5aa76e88a 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -7,7 +7,7 @@ // - handle shutdown use bytes::Bytes; -use octseq::OctetsBuilder; +use octseq::Octets; use std::boxed::Box; use std::fmt::Debug; use std::future::Future; @@ -16,23 +16,22 @@ use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; -use crate::base::wire::Composer; use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; use crate::net::client::error::Error; use crate::net::client::multi_stream; -use crate::net::client::query::{GetResult, QueryMessage, QueryMessage2}; +use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; use crate::net::client::tcp_factory::TcpConnFactory; use crate::net::client::udp; /// DNS transport connection that first issue a query over a UDP transport and /// falls back to TCP if the reply is truncated. #[derive(Clone)] -pub struct Connection { +pub struct Connection + Debug> { /// Reference to the real object that provides the connection. inner: Arc>, } -impl +impl + Clone + Debug + Octets + Send + Sync + 'static> Connection { /// Create a new connection. @@ -51,22 +50,25 @@ impl /// Start a query. pub async fn query_impl( &self, - query_msg: &mut MessageBuilder>>, + _query_msg: &mut MessageBuilder>>, ) -> Result, Error> { - self.inner.query(query_msg).await + todo!(); + /* + self.inner.query(query_msg).await + */ } /// Start a query for the QueryMessage2 trait. - async fn query_impl2( + async fn query_impl3( &self, - query_msg: &mut MessageBuilder>>, + query_msg: &Message, ) -> Result, Error> { let gr = self.inner.query(query_msg).await?; Ok(Box::new(gr)) } } -impl +impl + Clone + Debug + Octets + Send + Sync + 'static> QueryMessage, Octs> for Connection { fn query<'a>( @@ -80,21 +82,12 @@ impl } } -impl< - Octs: AsRef<[u8]> - + Clone - + Composer - + Debug - + OctetsBuilder - + Send - + 'static, - > QueryMessage2 for Connection +impl + Clone + Debug + Octets + Send + Sync + 'static> + QueryMessage3 for Connection { fn query<'a>( &'a self, - query_msg: &'a mut MessageBuilder< - StaticCompressor>, - >, + query_msg: &'a Message, ) -> Pin< Box< dyn Future, Error>> @@ -102,15 +95,15 @@ impl< + '_, >, > { - return Box::pin(self.query_impl2(query_msg)); + return Box::pin(self.query_impl3(query_msg)); } } /// Object that contains the current state of a query. #[derive(Debug)] -pub struct Query { +pub struct Query + Debug> { /// Reqeust message. - query_msg: MessageBuilder>>, + query_msg: Message, /// UDP transport to be used. udp_conn: udp::Connection, @@ -119,41 +112,33 @@ pub struct Query { tcp_conn: multi_stream::Connection, /// Current state of the query. - state: QueryState, + state: QueryState, } /// Status of the query. #[derive(Debug)] -enum QueryState { +enum QueryState { /// Start a query over the UDP transport. StartUdpQuery, /// Get the result from the UDP transport. - GetUdpResult(udp::Query), + GetUdpResult(Box), /// Start a query over the TCP transport. StartTcpQuery, /// Get the result from the TCP transport. - GetTcpResult(multi_stream::Query), + GetTcpResult(Box), } -impl< - Octs: AsMut<[u8]> - + AsRef<[u8]> - + Clone - + Composer - + Debug - + OctetsBuilder - + Send - + 'static, - > Query +impl + Clone + Debug + Octets + Send + Sync + 'static> + Query { /// Create a new Query object. /// /// The initial state is to start with a UDP transport. fn new( - query_msg: &mut MessageBuilder>>, + query_msg: &Message, udp_conn: udp::Connection, tcp_conn: multi_stream::Connection, ) -> Query { @@ -172,9 +157,9 @@ impl< loop { match &mut self.state { QueryState::StartUdpQuery => { - let mut msg = self.query_msg.clone(); + let msg = self.query_msg.clone(); let query = - QueryMessage::query(&self.udp_conn, &mut msg).await?; + QueryMessage3::query(&self.udp_conn, &msg).await?; self.state = QueryState::GetUdpResult(query); continue; } @@ -187,9 +172,9 @@ impl< return Ok(reply); } QueryState::StartTcpQuery => { - let mut msg = self.query_msg.clone(); + let msg = self.query_msg.clone(); let query = - QueryMessage::query(&self.tcp_conn, &mut msg).await?; + QueryMessage3::query(&self.tcp_conn, &msg).await?; self.state = QueryState::GetTcpResult(query); continue; } @@ -202,16 +187,8 @@ impl< } } -impl< - Octs: AsMut<[u8]> - + AsRef<[u8]> - + Clone - + Composer - + Debug - + OctetsBuilder - + Send - + 'static, - > GetResult for Query +impl + Clone + Debug + Octets + Send + Sync + 'static> + GetResult for Query { fn get_result( &mut self, @@ -223,7 +200,7 @@ impl< } /// The actual connection object. -struct InnerConnection { +struct InnerConnection + Debug> { /// The remote address to connect to. remote_addr: SocketAddr, @@ -234,7 +211,7 @@ struct InnerConnection { tcp_conn: multi_stream::Connection, } -impl +impl InnerConnection { /// Create a new InnerConnection object. @@ -266,7 +243,7 @@ impl /// Just create a Query object with the state it needs. async fn query( &self, - query_msg: &mut MessageBuilder>>, + query_msg: &Message, ) -> Result, Error> { Ok(Query::new( query_msg, From 3e0f0d8a47b898aa466b642e73991fae7bad1347 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Wed, 4 Oct 2023 11:19:28 +0200 Subject: [PATCH 044/124] Updated client-transports example. --- examples/client-transports.rs | 57 ++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 83dfd8699..e69f99a55 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -1,15 +1,18 @@ use domain::base::Dname; use domain::base::Rtype::Aaaa; -use domain::base::{MessageBuilder, StaticCompressor, StreamTarget}; +use domain::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; use domain::net::client::multi_stream; -use domain::net::client::query::QueryMessage2; +use domain::net::client::octet_stream; +use domain::net::client::query::QueryMessage3; use domain::net::client::redundant; use domain::net::client::tcp_factory::TcpConnFactory; use domain::net::client::tls_factory::TlsConnFactory; +use domain::net::client::udp; use domain::net::client::udp_tcp; use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; use std::sync::Arc; +use tokio::net::TcpStream; use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore}; #[tokio::main] @@ -25,10 +28,17 @@ async fn main() { let mut msg = msg.question(); msg.push((Dname::>::vec_from_str("example.com").unwrap(), Aaaa)) .unwrap(); - let mut msg = msg.as_builder_mut().clone(); + + let msg = Message::from_octets( + msg.as_target().as_target().as_dgram_slice().to_vec(), + ) + .unwrap(); + + println!("request msg: {:?}", msg.as_slice()); // Destination for UDP and TCP - let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); + let server_addr = + SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); // Create a new UDP+TCP transport connection. Pass the destination address // and port as parameter. @@ -42,7 +52,7 @@ async fn main() { }); // Send a query message. - let mut query = udptcp_conn.query(&mut msg).await.unwrap(); + let mut query = udptcp_conn.query(&msg).await.unwrap(); // Get the reply let reply = query.get_result().await; @@ -64,7 +74,7 @@ async fn main() { }); // Send a query message. - let mut query = tcp_conn.query(&mut msg).await.unwrap(); + let mut query = tcp_conn.query(&msg).await.unwrap(); // Get the reply let reply = query.get_result().await; @@ -109,7 +119,7 @@ async fn main() { conn_run.run(tls_factory).await; }); - let mut query = tls_conn.query(&mut msg).await.unwrap(); + let mut query = tls_conn.query(&msg).await.unwrap(); let reply = query.get_result().await; println!("TLS reply: {:?}", reply); @@ -129,8 +139,39 @@ async fn main() { // Start a few queries. for _i in 1..10 { - let mut query = redun.query(&mut msg).await.unwrap(); + let mut query = redun.query(&msg).await.unwrap(); let reply = query.get_result().await; println!("redundant connection reply: {:?}", reply); } + + // Create a new UDP transport connection. Pass the destination address + // and port as parameter. This transport does not retry over TCP if the + // reply is truncated. + let udp_conn = udp::Connection::new(server_addr).unwrap(); + + // Send a query message. + let mut query = udp_conn.query(&msg).await.unwrap(); + + // Get the reply + let reply = query.get_result().await; + println!("UDP reply: {:?}", reply); + + // Create a single TCP transport connection. This is usefull for a + // single request or a small burst of requests. + let tcp_conn = TcpStream::connect(server_addr).await.unwrap(); + + let tcp = octet_stream::Connection::new().unwrap(); + let tcp_worker = tcp.clone(); + + tokio::spawn(async move { + tcp_worker.run(tcp_conn).await; + println!("run terminated"); + }); + + // Send a query message. + let mut query = tcp.query(&msg).await.unwrap(); + + // Get the reply + let reply = query.get_result().await; + println!("TCP reply: {:?}", reply); } From 3e564ba5472a4c0eb029254ce8a1611088fddd32 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Wed, 4 Oct 2023 11:55:11 +0200 Subject: [PATCH 045/124] Get rid of QueryMessage2 --- src/net/client/query.rs | 3 +++ src/net/client/udp.rs | 32 +------------------------------- src/net/client/udp_tcp.rs | 2 +- 3 files changed, 5 insertions(+), 32 deletions(-) diff --git a/src/net/client/query.rs b/src/net/client/query.rs index 4b5878bef..f9b263b1d 100644 --- a/src/net/client/query.rs +++ b/src/net/client/query.rs @@ -26,6 +26,8 @@ pub trait QueryMessage { ) -> Pin> + Send + '_>>; } +/* +// This trait is replaced with QueryMessage3 /// Trait for starting a DNS query based on a message. pub trait QueryMessage2 { /// Query function that takes a message builder type. @@ -38,6 +40,7 @@ pub trait QueryMessage2 { >, ) -> Pin + Send + '_>>; } +*/ /// Trait for starting a DNS query based on a message. pub trait QueryMessage3 { diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 8913adb8a..133116638 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -22,7 +22,7 @@ use tokio::time::{timeout, Duration, Instant}; use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; use crate::net::client::error::Error; use crate::net::client::query::{ - GetResult, QueryMessage, QueryMessage2, QueryMessage3, + GetResult, QueryMessage, QueryMessage3, }; /// How many times do we try a new random port if we get ‘address in use.’ @@ -62,17 +62,6 @@ impl Connection { self.inner.query(query_msg, self.clone()).await } - /// Start a new DNS query. - async fn query_impl2< - Octs: AsRef<[u8]> + Clone + Debug + Send + 'static, - >( - &self, - query_msg: &mut MessageBuilder>>, - ) -> Result, Error> { - let gr = self.inner.query(query_msg, self.clone()).await?; - Ok(Box::new(gr)) - } - /// Start a new DNS query. async fn query_impl3< Octs: AsRef<[u8]> + Clone + Debug + Send + 'static, @@ -103,25 +92,6 @@ impl + Clone + Debug + Send> QueryMessage } } -impl + Clone + Debug + Send + 'static> QueryMessage2 - for Connection -{ - fn query<'a>( - &'a self, - query_msg: &'a mut MessageBuilder< - StaticCompressor>, - >, - ) -> Pin< - Box< - dyn Future, Error>> - + Send - + '_, - >, - > { - return Box::pin(self.query_impl2(query_msg)); - } -} - impl + Clone + Debug + Send + Sync + 'static> QueryMessage3 for Connection { diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index 5aa76e88a..6aa56e080 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -58,7 +58,7 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> */ } - /// Start a query for the QueryMessage2 trait. + /// Start a query for the QueryMessage3 trait. async fn query_impl3( &self, query_msg: &Message, From a68d6239a7d02870859db830e68af3074d8b9d9a Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Wed, 4 Oct 2023 15:25:49 +0200 Subject: [PATCH 046/124] Change the query_msg in QueryMessage from MessageBuilder to Message. --- src/net/client/multi_stream.rs | 27 ++++++++++++++++++--------- src/net/client/octet_stream.rs | 33 +++++++++++++++++++++------------ src/net/client/query.rs | 6 ++---- src/net/client/udp.rs | 33 ++++++++------------------------- src/net/client/udp_tcp.rs | 13 ++++--------- 5 files changed, 53 insertions(+), 59 deletions(-) diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 5f6193c96..3d6c5b756 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -29,7 +29,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::{mpsc, oneshot}; use tokio::time::{sleep_until, Instant}; -use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::base::Message; use crate::net::client::error::Error; use crate::net::client::factory::ConnFactory; use crate::net::client::octet_stream::Connection as SingleConnection; @@ -206,7 +206,7 @@ pub struct Query> { delayed_retry_count: u64, } -impl + Clone + Octets + Send + 'static> +impl + Clone + Octets + Send + Sync + 'static> InnerConnection { /// Constructor for [InnerConnection]. @@ -482,6 +482,20 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> self.inner.run(factory).await } + /// Start a DNS request. + /// + /// This function takes a precomposed message as a parameter and + /// returns a [Query] object wrapped in a [Result]. + pub async fn query_impl( + &self, + query_msg: &Message, + ) -> Result, Error> { + let (tx, rx) = oneshot::channel(); + self.inner.new_conn(None, tx).await?; + let gr = Query::new(self.clone(), query_msg, rx); + Ok(gr) + } + /// Start a DNS request. /// /// This function takes a precomposed message as a parameter and @@ -516,15 +530,10 @@ impl { fn query<'a>( &'a self, - _query_msg: &'a mut MessageBuilder< - StaticCompressor>, - >, + query_msg: &'a Message, ) -> Pin, Error>> + Send + '_>> { - todo!(); - /* - return Box::pin(self.query_impl3(query_msg)); - */ + return Box::pin(self.query_impl(query_msg)); } } diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index a2ffad556..5f04f982f 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -32,7 +32,6 @@ use std::vec::Vec; use crate::base::{ opt::{AllOptData, Opt, OptRecord, TcpKeepalive}, Message, MessageBuilder, ParsedDname, Rtype, StaticCompressor, - StreamTarget, }; use crate::net::client::error::Error; use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; @@ -774,23 +773,19 @@ impl + Clone> InnerConnection { } } -impl + AsRef<[u8]> + Clone + Debug + Octets + Send> - QueryMessage for Connection +impl< + Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Debug + Octets + Send + Sync, + > QueryMessage for Connection { fn query<'a>( &'a self, - _query_msg: &'a mut MessageBuilder< - StaticCompressor>, - >, + query_msg: &'a Message, ) -> Pin> + Send + '_>> { - todo!(); - /* - return Box::pin(self.query_impl(query_msg)); - */ + return Box::pin(self.query_impl(query_msg)); } } -impl Connection { +impl Connection { /// Constructor for [Connection]. /// /// Returns a [Connection] wrapped in a [Result](io::Result). @@ -819,6 +814,20 @@ impl Connection { async fn query_impl( &self, query_msg: &Message, + ) -> Result { + let (tx, rx) = oneshot::channel(); + self.inner.query(tx, query_msg).await?; + let msg = query_msg; + Ok(Query::new(msg, rx)) + } + + /// Start a DNS request. + /// + /// This function takes a precomposed message as a parameter and + /// returns a [Query] object wrapped in a [Result]. + async fn query_impl3( + &self, + query_msg: &Message, ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.inner.query(tx, query_msg).await?; @@ -853,7 +862,7 @@ impl + Clone + Octets + Send + Sync> QueryMessage3 + '_, >, > { - return Box::pin(self.query_impl(query_msg)); + return Box::pin(self.query_impl3(query_msg)); } } diff --git a/src/net/client/query.rs b/src/net/client/query.rs index f9b263b1d..ba141a4f8 100644 --- a/src/net/client/query.rs +++ b/src/net/client/query.rs @@ -10,7 +10,7 @@ use std::future::Future; use std::pin::Pin; // use std::sync::Arc; -use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::base::Message; use crate::net::client::error::Error; /// Trait for starting a DNS query based on a message. @@ -20,9 +20,7 @@ pub trait QueryMessage { /// This function is intended to be cancel safe. fn query<'a>( &'a self, - query_msg: &'a mut MessageBuilder< - StaticCompressor>, - >, + query_msg: &'a Message, ) -> Pin> + Send + '_>>; } diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 133116638..12cafda30 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -19,11 +19,9 @@ use tokio::net::UdpSocket; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio::time::{timeout, Duration, Instant}; -use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::base::Message; use crate::net::client::error::Error; -use crate::net::client::query::{ - GetResult, QueryMessage, QueryMessage3, -}; +use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; /// How many times do we try a new random port if we get ‘address in use.’ const RETRY_RANDOM_PORT: usize = 10; @@ -57,7 +55,7 @@ impl Connection { /// Start a new DNS query. async fn query_impl + Clone + Send>( &self, - query_msg: &mut MessageBuilder>>, + query_msg: &Message, ) -> Result { self.inner.query(query_msg, self.clone()).await } @@ -69,7 +67,7 @@ impl Connection { &self, query_msg: &Message, ) -> Result, Error> { - let gr = self.inner.query3(query_msg, self.clone()).await?; + let gr = self.inner.query(query_msg, self.clone()).await?; Ok(Box::new(gr)) } @@ -79,14 +77,12 @@ impl Connection { } } -impl + Clone + Debug + Send> QueryMessage - for Connection +impl + Clone + Debug + Send + Sync> + QueryMessage for Connection { fn query<'a>( &'a self, - query_msg: &'a mut MessageBuilder< - StaticCompressor>, - >, + query_msg: &'a Message, ) -> Pin> + Send + '_>> { return Box::pin(self.query_impl(query_msg)); } @@ -328,20 +324,7 @@ impl InnerConnection { } /// Return a Query object that contains the query state. - async fn query + Clone + Send>( - &self, - query_msg: &mut MessageBuilder>>, - conn: Connection, - ) -> Result { - let slice = query_msg.as_target().as_target().as_dgram_slice(); - let mut bytes = BytesMut::with_capacity(slice.len()); - bytes.extend_from_slice(slice); - let query_msg = Message::from_octets(bytes).unwrap(); - Ok(Query::new(query_msg, self.remote_addr, conn)) - } - - /// Return a Query object that contains the query state. - async fn query3 + Clone>( + async fn query + Clone>( &self, query_msg: &Message, conn: Connection, diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index 6aa56e080..9ea15c3ec 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -16,7 +16,7 @@ use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; -use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::base::Message; use crate::net::client::error::Error; use crate::net::client::multi_stream; use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; @@ -50,12 +50,9 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> /// Start a query. pub async fn query_impl( &self, - _query_msg: &mut MessageBuilder>>, + query_msg: &Message, ) -> Result, Error> { - todo!(); - /* - self.inner.query(query_msg).await - */ + self.inner.query(query_msg).await } /// Start a query for the QueryMessage3 trait. @@ -73,9 +70,7 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> { fn query<'a>( &'a self, - query_msg: &'a mut MessageBuilder< - StaticCompressor>, - >, + query_msg: &'a Message, ) -> Pin, Error>> + Send + '_>> { return Box::pin(self.query_impl(query_msg)); From 4462d7ecf1cd707d71d7cd0de6fd7cee20bc58cd Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Wed, 4 Oct 2023 15:35:49 +0200 Subject: [PATCH 047/124] tcp_channel has been replaced by octet_stream --- src/net/client/mod.rs | 1 - src/net/client/tcp_channel.rs | 882 ---------------------------------- 2 files changed, 883 deletions(-) delete mode 100644 src/net/client/tcp_channel.rs diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index a08a26543..be97cf9b3 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -13,7 +13,6 @@ pub mod multi_stream; pub mod octet_stream; pub mod query; pub mod redundant; -pub mod tcp_channel; pub mod tcp_factory; pub mod tcp_mutex; pub mod tls_factory; diff --git a/src/net/client/tcp_channel.rs b/src/net/client/tcp_channel.rs deleted file mode 100644 index 41e659842..000000000 --- a/src/net/client/tcp_channel.rs +++ /dev/null @@ -1,882 +0,0 @@ -//! A DNS over TCP transport - -#![warn(missing_docs)] -#![warn(clippy::missing_docs_in_private_items)] - -// RFC 7766 describes DNS over TCP -// RFC 7828 describes the edns-tcp-keepalive option - -// TODO: -// - errors -// - connect errors? Retry after connection refused? -// - server errors -// - ID out of range -// - ID not in use -// - reply for wrong query -// - timeouts -// - request timeout -// - create new TCP connection after end/failure of previous one - -use bytes; -use bytes::{Bytes, BytesMut}; -use core::convert::From; -use futures::lock::Mutex as Futures_mutex; -use std::fmt::Debug; -use std::sync::Arc; -use std::time::{Duration, Instant}; -use std::vec::Vec; - -use crate::base::wire::Composer; -use crate::base::{ - opt::{AllOptData, OptRecord, TcpKeepalive}, - Message, MessageBuilder, StaticCompressor, StreamTarget, -}; -use octseq::{Octets, OctetsBuilder}; - -use tokio::io; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::tcp::ReadHalf; -use tokio::net::{TcpStream, ToSocketAddrs}; -use tokio::sync::{mpsc, oneshot}; -use tokio::time::sleep; - -/// Error returned when too many queries are currently active. -const ERR_TOO_MANY_QUERIES: &str = "too many outstanding queries"; - -/// Time to wait on a non-idle TCP connection for the other side to send -/// a response on any outstanding query. -// Implement a simple response timer to see if the connection and the server -// are alive. Set the timer when the connection goes from idle to busy. -// Reset the timer each time a reply arrives. Cancel the timer when the -// connection goes back to idle. When the time expires, mark all outstanding -// queries as timed out and shutdown the connection. -// -// Note: nsd has 120 seconds, unbound has 3 seconds. -const RESPONSE_TIMEOUT: Duration = Duration::from_secs(19); - -/// Capacity of the channel that transports [ChanReq]. -const DEF_CHAN_CAP: usize = 8; - -/// Capacity of a private channel between [InnerTcpConnection::reader] and -/// [InnerTcpConnection::run]. -const READ_REPLY_CHAN_CAP: usize = 8; - -/// Error reported when the TCP connection is closed and -/// [InnerTcpConnection::run] terminated. -const ERR_CONN_CLOSED: &str = "connection closed"; - -/// This is the type of sender in [ChanReq]. -type ReplySender = oneshot::Sender; - -#[derive(Debug)] -/// A request from [Query] to [TcpConnection::run] to start a DNS request. -struct ChanReq { - /// DNS request message - msg: MessageBuilder>>, - - /// Sender to send result back to [Query] - sender: ReplySender, -} - -#[derive(Debug)] -/// a response to a [ChanReq]. -struct Response { - /// The 2 octet id that went into the outgoing DNS request. - /// - /// This id is needed to match the response with the query. - id: u16, - - /// The DNS reply message. - reply: Message, -} -/// Response to the DNS request sent by [InnerTcpConnection::run] to [Query]. -type ChanResp = Result>; - -/// The actual implementation of [TcpConnection]. -struct InnerTcpConnection { - /// TCP connection protected by a mutex to allow read/write access by - /// [InnerTcpConnection::run]. - stream: Futures_mutex, - - /// [InnerTcpConnection::sender] and [InnerTcpConnection::receiver] are - /// part of a single channel. - /// - /// Used by [Query] to send requests to [InnerTcpConnection::run]. - sender: mpsc::Sender>, - - /// receiver part of the channel. - /// - /// Protected by a mutex to allow read/write access by - /// [InnerTcpConnection::run]. - /// The Option is to allow [InnerTcpConnection::run] to signal that the - /// TCP connection is closed. - receiver: Futures_mutex>>>, -} - -/// Internal datastructure of [InnerTcpConnection::run] to keep track of -/// outstanding DNS requests. -struct Queries { - /// The number of elements in [Queries::vec] that are not None. - count: usize, - - /// Index in the [Queries::vec] where to look for a space for a new query. - curr: usize, - - /// Vector of senders to forward a DNS reply message (or error) to. - vec: Vec>, -} - -#[derive(Clone)] -/// A single DNS over TCP connection. -pub struct TcpConnection { - /// Reference counted [InnerTcpConnection]. - inner: Arc>, -} - -/// Status of a query. Used in [Query]. -enum QueryState { - /// A request is in progress. - /// - /// The receiver for receiving the response is part of this state. - Busy(oneshot::Receiver), - - /// The response has been received and the query is done. - Done, -} - -/// This struct represent an active DNS query. -pub struct Query { - /// Request message. - /// - /// The reply message is compared with the request message to see if - /// it matches the query. - query_msg: Message>, - - /// Current state of the query. - state: QueryState, -} - -/// Internal datastructure of [InnerTcpConnection::run] to keep track of -/// the status of the TCP connection. -// The types Status and ConnState are only used in InnerTcpConnection -struct Status { - /// State of the TCP connection. - state: ConnState, - - /// Boolean if we need to include an edns-tcp-keepalive option in an - /// outogoing request. - /// - /// Typically send_keepalive is true at the start of the connection. - /// it gets cleared when we successfully managed to include the option - /// in a request. - send_keepalive: bool, - - /// Time we are allow to keep the TCP connection open when idle. - /// - /// Initially we assume that the idle timeout is zero. A received - /// edns-tcp-keepalive option may change that. - idle_timeout: Option, -} -/// Status of the TCP connection. Used in [Status]. -enum ConnState { - /// The connection is in this state from the start and when at least - /// one active DNS request is present. - /// - /// The instant contains the time of the first request or the - /// most recent response that was received. - Active(Option), - - /// This state represent a TCP connection that went idle and has an - /// idle timeout. - /// - /// The instant contains the time the connection went idle. - Idle(Instant), - - /// This state represent an idle connection where either there was no - /// idle timeout or the idle timer expired. - IdleTimeout, - - /// A read error occurred. - ReadError, - - /// It took too long to receive a (or another) response. - ReadTimeout, - - /// A write error occurred. - WriteError, -} - -/// A DNS message received to [InnerTcpConnection::reader] and sent to -/// [InnerTcpConnection::run]. -// This type could be local to InnerTcpConnection, but I don't know how -type ReaderChanReply = Message; - -impl + Clone + Composer + Debug + OctetsBuilder> - InnerTcpConnection -{ - /// Constructor for [InnerTcpConnection]. - /// - /// This is the implementation of [TcpConnection::connect]. - pub async fn connect( - addr: A, - ) -> io::Result> { - let tcp = TcpStream::connect(addr).await?; - let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); - Ok(Self { - stream: Futures_mutex::new(tcp), - sender: tx, - receiver: Futures_mutex::new(Some(rx)), - }) - } - - /// Main execution function for [InnerTcpConnection]. - /// - /// This function Gets called by [TcpConnection::run]. - /// This function is not async cancellation safe - pub async fn run(&self) -> Option<()> { - let mut stream = self.stream.lock().await; - let (mut read_stream, mut write_stream) = stream.split(); - - let (reply_sender, mut reply_receiver) = - mpsc::channel::(READ_REPLY_CHAN_CAP); - - let reader_fut = Self::reader(&mut read_stream, reply_sender); - tokio::pin!(reader_fut); - - let mut receiver = { - let mut locked_opt_receiver = self.receiver.lock().await; - let opt_receiver = locked_opt_receiver.take(); - opt_receiver.expect("no receiver present?") - }; - - let mut status = Status { - state: ConnState::Active(None), - idle_timeout: None, - send_keepalive: true, - }; - let mut query_vec = Queries { - count: 0, - curr: 0, - vec: Vec::new(), - }; - - let mut reqmsg: Option> = None; - - loop { - let opt_timeout = match status.state { - ConnState::Active(opt_instant) => { - if let Some(instant) = opt_instant { - let elapsed = instant.elapsed(); - if elapsed > RESPONSE_TIMEOUT { - let error = io::Error::new( - io::ErrorKind::Other, - "read timeout", - ); - Self::tcp_error(error, &mut query_vec); - status.state = ConnState::ReadTimeout; - break; - } - Some(RESPONSE_TIMEOUT - elapsed) - } else { - None - } - } - ConnState::Idle(instant) => { - if let Some(timeout) = &status.idle_timeout { - let elapsed = instant.elapsed(); - if elapsed >= *timeout { - // Move to IdleTimeout and end - // the loop - status.state = ConnState::IdleTimeout; - break; - } - Some(*timeout - elapsed) - } else { - panic!("Idle state but no timeout"); - } - } - ConnState::IdleTimeout - | ConnState::ReadError - | ConnState::WriteError => None, // No timers here - ConnState::ReadTimeout => { - panic!("should not be in loop with ReadTimeout"); - } - }; - - // For simplicity, make sure we always have a timeout - let timeout = match opt_timeout { - Some(timeout) => timeout, - None => - // Just use the response timeout - { - RESPONSE_TIMEOUT - } - }; - - let sleep_fut = sleep(timeout); - let recv_fut = receiver.recv(); - - let (do_write, msg) = match &reqmsg { - None => { - let msg: &[u8] = &[]; - (false, msg) - } - Some(msg) => { - let msg: &[u8] = msg; - (true, msg) - } - }; - - tokio::select! { - biased; - res = &mut reader_fut => { - match res { - Ok(_) => - // The reader should not - // terminate without - // error. - panic!("reader terminated"), - Err(error) => { - Self::tcp_error(error, - &mut query_vec); - status.state = - ConnState::ReadError; - // Reader failed. Break - // out of loop and - // shut down - break - } - } - } - opt_answer = reply_receiver.recv() => { - let answer = opt_answer.expect("reader died?"); - // Check for a edns-tcp-keepalive option - let opt_record = answer.opt(); - if let Some(ref opts) = opt_record { - Self::handle_opts(opts, - &mut status); - }; - drop(opt_record); - Self::demux_reply(answer, - &mut status, &mut query_vec); - } - res = write_stream.write_all(msg), - if do_write => { - if let Err(error) = res { - Self::tcp_error(error, - &mut query_vec); - status.state = - ConnState::WriteError; - break; - } - else { - reqmsg = None; - } - } - res = recv_fut, if !do_write => { - match res { - Some(req) => - self.insert_req(req, &mut status, - &mut reqmsg, &mut query_vec), - None => panic!("recv failed"), - } - } - _ = sleep_fut => { - // Timeout expired, just - // continue with the loop - } - - } - - // Check if the connection is idle - match status.state { - ConnState::Active(_) | ConnState::Idle(_) => { - // Keep going - } - ConnState::IdleTimeout => break, - ConnState::ReadError - | ConnState::ReadTimeout - | ConnState::WriteError => { - panic!("Should not be here"); - } - } - } - - // Send FIN - _ = write_stream.shutdown().await; - - None - } - - /// This function sends a DNS request to [InnerTcpConnection::run]. - pub async fn query( - &self, - sender: oneshot::Sender, - query_msg: &mut MessageBuilder>>, - ) -> Result<(), &'static str> { - // We should figure out how to get query_msg. - let msg_clone = query_msg.clone(); - - let req = ChanReq { - sender, - msg: msg_clone, - }; - match self.sender.send(req).await { - Err(_) => - // Send error. The receiver is gone, this means that the - // connection is closed. - { - Err(ERR_CONN_CLOSED) - } - Ok(_) => Ok(()), - } - } - - /// This function reads a DNS message from the TCP connection and sends - /// it to [InnerTcpConnection::run]. - /// - /// Reading has to be done in two steps: first read a two octet value - /// the specifies the length of the message, and then read in a loop the - /// body of the message. - /// - /// This function is not async cancellation safe. - async fn reader( - sock: &mut ReadHalf<'_>, - sender: mpsc::Sender, - ) -> Result<(), std::io::Error> { - loop { - let read_res = sock.read_u16().await; - let len = match read_res { - Ok(len) => len, - Err(error) => { - return Err(error); - } - } as usize; - - let mut buf = BytesMut::with_capacity(len); - - loop { - let curlen = buf.len(); - if curlen >= len { - if curlen > len { - panic!( - "reader: got too much data {curlen}, expetect {len}"); - } - - // We got what we need - break; - } - - let read_res = sock.read_buf(&mut buf).await; - - match read_res { - Ok(readlen) => { - if readlen == 0 { - let error = io::Error::new( - io::ErrorKind::Other, - "unexpected end of data", - ); - return Err(error); - } - } - Err(error) => { - return Err(error); - } - }; - - // Check if we are done at the head of the loop - } - - let reply_message = Message::::from_octets(buf.into()); - match reply_message { - Ok(answer) => { - sender - .send(answer) - .await - .expect("can't send reply to run"); - } - Err(_) => { - // The only possible error is short message - let error = - io::Error::new(io::ErrorKind::Other, "short buf"); - return Err(error); - } - } - } - } - - /// An error occured, report the error to all outstanding [Query] objects. - fn tcp_error(error: std::io::Error, query_vec: &mut Queries) { - // Update all requests that are in progress. Don't wait for - // any reply that may be on its way. - let arc_error = Arc::new(error); - for index in 0..query_vec.vec.len() { - if query_vec.vec[index].is_some() { - let sender = Self::take_query(query_vec, index) - .expect("we tested is_none before"); - _ = sender.send(Err(arc_error.clone())); - } - } - } - - /// Handle received EDNS options, in particular the edns-tcp-keepalive - /// option. - fn handle_opts>( - opts: &OptRecord, - status: &mut Status, - ) { - for option in opts.opt().iter().flatten() { - if let AllOptData::TcpKeepalive(tcpkeepalive) = option { - Self::handle_keepalive(tcpkeepalive, status); - } - } - } - - /// Demultiplex a DNS reply and send it to the right [Query] object. - /// - /// In addition, the status is updated to IdleTimeout or Idle if there - /// are no remaining pending requests. - fn demux_reply( - answer: Message, - status: &mut Status, - query_vec: &mut Queries, - ) { - // We got an answer, reset the timer - status.state = ConnState::Active(Some(Instant::now())); - - let ind16 = answer.header().id(); - let index: usize = ind16.into(); - - let vec_len = query_vec.vec.len(); - if index >= vec_len { - // Index is out of bouds. We should mark - // the TCP connection as broken - return; - } - - // Do we have a query with this ID? - match &mut query_vec.vec[index] { - None => { - // No query with this ID. We should - // mark the TCP connection as broken - return; - } - Some(_) => { - let sender = Self::take_query(query_vec, index).unwrap(); - let ind16: u16 = index.try_into().unwrap(); - let reply = Response { - reply: answer, - id: ind16, - }; - _ = sender.send(Ok(reply)); - } - } - if query_vec.count == 0 { - // Clear the activity timer. There is no need to do - // this because state will be set to either IdleTimeout - // or Idle just below. However, it is nicer to keep - // this independent. - status.state = ConnState::Active(None); - - status.state = if status.idle_timeout.is_none() { - // Assume that we can just move to IdleTimeout - // state - ConnState::IdleTimeout - } else { - ConnState::Idle(Instant::now()) - } - } - } - - /// Insert a request in query_vec and return the request to be sent - /// in *reqmsg. - /// - /// First the status is checked, an error is returned if not Active or - /// idle. Addend a edns-tcp-keepalive option if needed. - // Note: maybe reqmsg should be a return value. - fn insert_req( - &self, - mut req: ChanReq, - status: &mut Status, - reqmsg: &mut Option>, - query_vec: &mut Queries, - ) { - match status.state { - ConnState::Active(timer) => { - // Set timer if we don't have one already - if timer.is_none() { - status.state = ConnState::Active(Some(Instant::now())); - } - } - ConnState::Idle(_) => { - // Go back to active - status.state = ConnState::Active(Some(Instant::now())); - } - ConnState::IdleTimeout => { - // The connection has been closed. Report error - _ = req.sender.send(Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "idle timeout", - )))); - return; - } - ConnState::ReadError => { - _ = req.sender.send(Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "read error", - )))); - return; - } - ConnState::ReadTimeout => { - _ = req.sender.send(Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "read timeout", - )))); - return; - } - ConnState::WriteError => { - _ = req.sender.send(Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "write error", - )))); - return; - } - } - - // Note that insert may fail if there are too many - // outstanding queires. First call insert before checking - // send_keepalive. - // XXX - let index = Self::insert(req.sender, query_vec).unwrap(); - - let ind16: u16 = index.try_into().unwrap(); - - // We set the ID to the array index. Defense in depth - // suggests that a random ID is better because it works - // even if TCP sequence numbers could be predicted. However, - // Section 9.3 of RFC 5452 recommends retrying over TCP - // if many spoofed answers arrive over UDP: "TCP, by the - // nature of its use of sequence numbers, is far more - // resilient against forgery by third parties." - let hdr = req.msg.header_mut(); - hdr.set_id(ind16); - - if status.send_keepalive { - let mut msgadd = req.msg.clone().additional(); - - // send an empty keepalive option - let res = msgadd.opt(|opt| opt.tcp_keepalive(None)); - match res { - Ok(_) => { - Self::convert_query(&msgadd, reqmsg); - status.send_keepalive = false; - } - Err(_) => { - // Adding keepalive option - // failed. Send the original - // request. - Self::convert_query(&req.msg, reqmsg); - } - } - } else { - Self::convert_query(&req.msg, reqmsg); - } - } - - /// Take an element out of query_vec. - fn take_query( - query_vec: &mut Queries, - index: usize, - ) -> Option { - let query = query_vec.vec[index].take(); - query_vec.count -= 1; - query - } - - /// Handle a received edns-tcp-keepalive option. - fn handle_keepalive(opt_value: TcpKeepalive, status: &mut Status) { - if let Some(value) = opt_value.timeout() { - let value_dur = Duration::from(value); - status.idle_timeout = Some(value_dur); - } - } - - /// Convert the query message to a vector. - // This function should return the vector instead of storing it - // through a reference. - fn convert_query + AsRef<[u8]>>( - msg: &MessageBuilder>>, - reqmsg: &mut Option>, - ) { - let vec = msg.as_target().as_target().as_stream_slice(); - - // Store a clone of the request. That makes life easier - // and requests tend to be small - *reqmsg = Some(vec.to_vec()); - } - - /// Insert a sender (for the reply) in the query_vec and return the index. - fn insert( - sender: oneshot::Sender, - query_vec: &mut Queries, - ) -> Result { - let q = Some(sender); - - // Fail if there are to many entries already in this vector - // We cannot have more than u16::MAX entries because the - // index needs to fit in an u16. For efficiency we want to - // keep the vector half empty. So we return a failure if - // 2*count > u16::MAX - if 2 * query_vec.count > u16::MAX.into() { - return Err(ERR_TOO_MANY_QUERIES); - } - - let vec_len = query_vec.vec.len(); - - // Append if the amount of empty space in the vector is less - // than half. But limit vec_len to u16::MAX - if vec_len < 2 * (query_vec.count + 1) && vec_len < u16::MAX.into() { - // Just append - query_vec.vec.push(q); - query_vec.count += 1; - let index = query_vec.vec.len() - 1; - return Ok(index); - } - let loc_curr = query_vec.curr; - - for index in loc_curr..vec_len { - if query_vec.vec[index].is_none() { - Self::insert_at(query_vec, index, q); - return Ok(index); - } - } - - // Nothing until the end of the vector. Try for the entire - // vector - for index in 0..vec_len { - if query_vec.vec[index].is_none() { - Self::insert_at(query_vec, index, q); - return Ok(index); - } - } - - // Still nothing, that is not good - panic!("insert failed"); - } - - /// Insert a sender at a specific position in query_vec and update - /// the statistics. - fn insert_at( - query_vec: &mut Queries, - index: usize, - q: Option, - ) { - query_vec.vec[index] = q; - query_vec.count += 1; - query_vec.curr = index + 1; - } -} - -impl< - Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Composer + Debug + OctetsBuilder, - > TcpConnection -{ - /// Constructor for [TcpConnection]. - /// - /// Takes an address ([ToSocketAddrs]) and - /// returns a [TcpConnection] wrapped in a [Result](io::Result). - pub async fn connect( - addr: A, - ) -> io::Result> { - let tcpconnection = InnerTcpConnection::connect(addr).await?; - Ok(Self { - inner: Arc::new(tcpconnection), - }) - } - - /// Main execution function for [TcpConnection]. - /// - /// This function has to run in the background or together with - /// any calls to [query](Self::query) or [Query::get_result]. - pub async fn run(&self) -> Option<()> { - self.inner.run().await - } - - /// Start a DNS request. - /// - /// This function takes a precomposed message as a parameter and - /// returns a [Query] object wrapped in a [Result]. - pub async fn query( - &self, - query_msg: &mut MessageBuilder>>, - ) -> Result { - let (tx, rx) = oneshot::channel(); - self.inner.query(tx, query_msg).await?; - let msg = &query_msg.as_message(); - Ok(Query::new(msg, rx)) - } -} - -impl Query { - /// Constructor for [Query], takes a DNS query and a receiver for the - /// reply. - fn new( - query_msg: &Message, - receiver: oneshot::Receiver, - ) -> Query { - let msg_ref: &[u8] = query_msg.as_ref(); - let vec = msg_ref.to_vec(); - let msg = Message::from_octets(vec).unwrap(); - Self { - query_msg: msg, - state: QueryState::Busy(receiver), - } - } - - /// Get the result of a DNS query. - /// - /// This function returns the reply to a DNS query wrapped in a - /// [Result]. - pub async fn get_result( - &mut self, - ) -> Result, Arc> { - match self.state { - QueryState::Busy(ref mut receiver) => { - let res = receiver.await; - self.state = QueryState::Done; - if res.is_err() { - // Assume receive error - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "receive error", - ))); - } - let res = res.unwrap(); - - // clippy seems to be wrong here. Replacing - // the following with 'res?;' doesn't work - #[allow(clippy::question_mark)] - if let Err(err) = res { - return Err(err); - } - - let resp = res.unwrap(); - let msg = resp.reply; - - let hdr = self.query_msg.header_mut(); - hdr.set_id(resp.id); - - if !msg.is_answer(&self.query_msg) { - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "wrong answer", - ))); - } - Ok(msg) - } - QueryState::Done => { - panic!("Already done"); - } - } - } -} From 25fe7a649b29f76e52d3441b336c0d2016e35a8e Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 9 Oct 2023 13:24:03 +0200 Subject: [PATCH 048/124] Get rid of most unwraps. --- src/net/client/error.rs | 20 ++++ src/net/client/octet_stream.rs | 89 ++++++++++------- src/net/client/redundant.rs | 170 ++++++++++++++++++++------------- src/net/client/tls_factory.rs | 16 +++- src/net/client/udp.rs | 12 ++- 5 files changed, 201 insertions(+), 106 deletions(-) diff --git a/src/net/client/error.rs b/src/net/client/error.rs index cda583dd3..5c8ba0953 100644 --- a/src/net/client/error.rs +++ b/src/net/client/error.rs @@ -13,6 +13,15 @@ pub enum Error { /// Connection was already closed. ConnectionClosed, + /// PushError from MessageBuilder. + MessageBuilderPushError, + + /// ParseError from Message. + MessageParseError, + + /// Underlying transport not found in redundant connection + RedundantTransportNotFound, + /// Octet sequence too short to be a valid DNS message. ShortMessage, @@ -66,6 +75,14 @@ impl Display for Error { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { match self { Error::ConnectionClosed => write!(f, "connection closed"), + Error::MessageBuilderPushError => { + write!(f, "PushError from MessageBuilder") + } + Error::MessageParseError => write!(f, "ParseError from Message"), + Error::RedundantTransportNotFound => write!( + f, + "Underlying transport not found in redundant connection" + ), Error::ShortMessage => { write!(f, "octet sequence to short to be a valid message") } @@ -112,6 +129,9 @@ impl error::Error for Error { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { Error::ConnectionClosed => None, + Error::MessageBuilderPushError => None, + Error::MessageParseError => None, + Error::RedundantTransportNotFound => None, Error::ShortMessage => None, Error::StreamIdleTimeout => None, Error::StreamReceiveError => None, diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 5f04f982f..df3d3686f 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -550,7 +550,8 @@ impl + Clone> InnerConnection { return; } Some(_) => { - let sender = Self::take_query(query_vec, index).unwrap(); + let sender = Self::take_query(query_vec, index) + .expect("sender should be there"); let reply = Response { reply: answer }; _ = sender.send(Ok(reply)); } @@ -630,7 +631,9 @@ impl + Clone> InnerConnection { } }; - let ind16: u16 = index.try_into().unwrap(); + let ind16: u16 = index + .try_into() + .expect("insert should return a value that fits in u16"); // We set the ID to the array index. Defense in depth // suggests that a random ID is better because it works @@ -640,8 +643,8 @@ impl + Clone> InnerConnection { // nature of its use of sequence numbers, is far more // resilient against forgery by third parties." - let mut mut_msg = - Message::from_octets(req.msg.as_slice().to_vec()).unwrap(); + let mut mut_msg = Message::from_octets(req.msg.as_slice().to_vec()) + .expect("Message failed to parse contents of another Message"); let hdr = mut_msg.header_mut(); hdr.set_id(ind16); @@ -875,7 +878,8 @@ impl Query { ) -> Query { let msg_ref: &[u8] = query_msg.as_ref(); let vec = msg_ref.to_vec(); - let msg = Message::from_octets(vec).unwrap(); + let msg = Message::from_octets(vec) + .expect("Message failed to parse contents of another Message"); Self { query_msg: msg, state: QueryState::Busy(receiver), @@ -895,7 +899,7 @@ impl Query { // Assume receive error return Err(Error::StreamReceiveError); } - let res = res.unwrap(); + let res = res.expect("already check error case"); // clippy seems to be wrong here. Replacing // the following with 'res?;' doesn't work @@ -904,7 +908,7 @@ impl Query { return Err(err); } - let resp = res.unwrap(); + let resp = res.expect("error case is checked already"); let msg = resp.reply; if !is_answer_ignore_id(&msg, &self.query_msg) { @@ -951,7 +955,7 @@ impl QueryNoCheck { // Assume receive error return Err(Error::StreamReceiveError); } - let res = res.unwrap(); + let res = res.expect("error case is checked already"); // clippy seems to be wrong here. Replacing // the following with 'res?;' doesn't work @@ -960,7 +964,7 @@ impl QueryNoCheck { return Err(err); } - let resp = res.unwrap(); + let resp = res.expect("error case is checked already"); let msg = resp.reply; Ok(msg) @@ -979,7 +983,7 @@ impl QueryNoCheck { /// opt record. fn add_tcp_keepalive( msg: &Message, -) -> Result>, crate::base::message_builder::PushError> { +) -> Result>, Error> { // We can't just insert a new option in an existing // opt record. So we have to create new message and copy records // from the old one. And insert our option while copying the opt @@ -988,7 +992,7 @@ fn add_tcp_keepalive( let mut target = MessageBuilder::from_target(StaticCompressor::new(Vec::new())) - .unwrap(); + .expect("Vec is expected to have enough space"); let source_hdr = source.header(); let target_hdr = target.header_mut(); target_hdr.set_flags(source_hdr.flags()); @@ -999,39 +1003,55 @@ fn add_tcp_keepalive( let source = source.question(); let mut target = target.question(); for rr in source { - let rr = rr.unwrap(); - target.push(rr)?; + let rr = rr.map_err(|_e| Error::MessageParseError)?; + target + .push(rr) + .map_err(|_e| Error::MessageBuilderPushError)?; } - let mut source = source.answer().unwrap(); + let mut source = + source.answer().map_err(|_e| Error::MessageParseError)?; let mut target = target.answer(); for rr in &mut source { - let rr = rr.unwrap(); + let rr = rr.map_err(|_e| Error::MessageParseError)?; let rr = rr .into_record::>>() - .unwrap() - .unwrap(); - target.push(rr)?; + .map_err(|_e| Error::MessageParseError)? + .expect("record expected"); + target + .push(rr) + .map_err(|_e| Error::MessageBuilderPushError)?; } - let mut source = source.next_section().unwrap().unwrap(); + let mut source = source + .next_section() + .map_err(|_e| Error::MessageParseError)? + .expect("section should be present"); let mut target = target.authority(); for rr in &mut source { - let rr = rr.unwrap(); + let rr = rr.map_err(|_e| Error::MessageParseError)?; let rr = rr .into_record::>>() - .unwrap() - .unwrap(); - target.push(rr)?; + .map_err(|_e| Error::MessageParseError)? + .expect("record expected"); + target + .push(rr) + .map_err(|_e| Error::MessageBuilderPushError)?; } - let source = source.next_section().unwrap().unwrap(); + let source = source + .next_section() + .map_err(|_e| Error::MessageParseError)? + .expect("section should be present"); let mut target = target.additional(); let mut found_opt_rr = false; for rr in source { - let rr = rr.unwrap(); + let rr = rr.map_err(|_e| Error::MessageParseError)?; if rr.rtype() == Rtype::Opt { found_opt_rr = true; - let rr = rr.into_record::>().unwrap().unwrap(); + let rr = rr + .into_record::>() + .map_err(|_e| Error::MessageParseError)? + .expect("record expected"); let opt_record = OptRecord::from_record(rr); target .opt(|newopt| { @@ -1052,25 +1072,30 @@ fn add_tcp_keepalive( newopt.tcp_keepalive(None).unwrap(); Ok(()) }) - .unwrap(); + .map_err(|_e| Error::MessageBuilderPushError)?; } else { let rr = rr .into_record::>>() - .unwrap() - .unwrap(); - target.push(rr)?; + .map_err(|_e| Error::MessageParseError)? + .expect("record expected"); + target + .push(rr) + .map_err(|_e| Error::MessageBuilderPushError)?; } } if !found_opt_rr { // send an empty keepalive option - target.opt(|opt| opt.tcp_keepalive(None))?; + target + .opt(|opt| opt.tcp_keepalive(None)) + .map_err(|_e| Error::MessageBuilderPushError)?; } // It would be nice to use .builder() here. But that one deletes all // section. We have to resort to .as_builder() which gives a // reference and then .clone() let result = target.as_builder().clone(); - let msg = Message::from_octets(result.finish().into_target()).unwrap(); + let msg = Message::from_octets(result.finish().into_target()) + .expect("Message should be able to parse output from MessageBuilder"); Ok(msg) } diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index 8cffa7ee4..aa7a3cf38 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -93,7 +93,7 @@ impl<'a, Octs: Clone + Composer + Debug + Send + Sync + 'static> pub async fn add( &self, conn: Box + Send + Sync>, - ) { + ) -> Result<(), Error> { self.inner.add(conn).await } @@ -102,7 +102,8 @@ impl<'a, Octs: Clone + Composer + Debug + Send + Sync + 'static> &self, query_msg: &Message, ) -> Result, Error> { - Ok(Box::new(self.inner.query(query_msg.clone()).await.unwrap())) + let query = self.inner.query(query_msg.clone()).await?; + Ok(Box::new(query)) } } @@ -151,7 +152,7 @@ pub struct Query + Send> { } /// Result of the futures in fut_list. -type FutListOutput = (usize, Result, Error>); +type FutListOutput = Result<(usize, Result, Error>), Error>; /// The various states a query can be in. #[derive(Debug)] @@ -237,7 +238,7 @@ impl + Clone + Debug + Send + Sync + 'static> Query { tokio::select! { res = self.fut_list.next() => { println!("got res {:?}", res); - let res = res.unwrap(); + let res = res.expect("res should not be empty")?; self.result = Some(res.1); self.res_index= res.0; @@ -265,11 +266,16 @@ impl + Clone + Debug + Send + Sync + 'static> Query { || self.conn_rt[ind].start.is_none() { // Nothing more to report. Return result. - let res = self.result.take().unwrap(); + let res = self + .result + .take() + .expect("result should not be empty"); return res; } - let start = self.conn_rt[ind].start.unwrap(); + let start = self.conn_rt[ind] + .start + .expect("start time should not be empty"); let elapsed = start.elapsed(); println!( "expected rt was {:?}", @@ -287,14 +293,17 @@ impl + Clone + Debug + Send + Sync + 'static> Query { // Failed entry ChanReq::Failure(time_report) }; - self.sender.send(report).await.unwrap(); + + // Send could fail but we don't care. + let _ = self.sender.send(report).await; + self.state = QueryState::Report(ind + 1); continue; } QueryState::Wait => { let res = self.fut_list.next().await; println!("got res {:?}", res); - let res = res.unwrap(); + let res = res.expect("res should not be empty")?; self.result = Some(res.1); self.res_index = res.0; self.state = QueryState::Report(0); @@ -334,7 +343,7 @@ async fn start_request( id: u64, sender: mpsc::Sender>, query_msg: Message, -) -> (usize, Result, Error>) { +) -> Result<(usize, Result, Error>), Error> { let (tx, rx) = oneshot::channel(); sender .send(ChanReq::Query(QueryReq { @@ -343,11 +352,11 @@ async fn start_request( tx, })) .await - .unwrap(); - let mut query = rx.await.unwrap().unwrap(); + .expect("send is expected to work"); + let mut query = rx.await.expect("receive is expected to work")?; let reply = query.get_result().await; - (index, reply) + Ok((index, reply)) } /// The commands that can be sent to the run function. @@ -488,9 +497,11 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> let mut receiver = self.receiver.lock().await; let opt_receiver = receiver.take(); drop(receiver); - let mut receiver = opt_receiver.unwrap(); + let mut receiver = + opt_receiver.expect("receiver should not be empty"); loop { - let req = receiver.recv().await.unwrap(); + let req = + receiver.recv().await.expect("receiver should not fail"); match req { ChanReq::Add(add_req) => { let id = next_id; @@ -505,19 +516,33 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> start: None, }); conns.push(add_req.conn); - add_req.tx.send(Ok(())).unwrap(); + + // Don't care if send fails + let _ = add_req.tx.send(Ok(())); } ChanReq::GetRT(rt_req) => { - rt_req.tx.send(Ok(conn_rt.clone())).unwrap(); + // Don't care if send fails + let _ = rt_req.tx.send(Ok(conn_rt.clone())); } ChanReq::Query(query_req) => { println!("QueryReq for id {}", query_req.id); let opt_ind = conn_rt.iter().position(|e| e.id == query_req.id); - let ind = opt_ind.unwrap(); - println!("QueryReq for ind {}", ind); - let query = conns[ind].query(&query_req.query_msg).await; - query_req.tx.send(query).unwrap(); + match opt_ind { + Some(ind) => { + println!("QueryReq for ind {}", ind); + let query = + conns[ind].query(&query_req.query_msg).await; + // Don't care if send fails + let _ = query_req.tx.send(query); + } + None => { + // Don't care if send fails + let _ = query_req + .tx + .send(Err(Error::RedundantTransportNotFound)); + } + } } ChanReq::Report(time_report) => { println!( @@ -526,25 +551,27 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> ); let opt_ind = conn_rt.iter().position(|e| e.id == time_report.id); - let ind = opt_ind.unwrap(); - println!("Report for ind {}", ind); - let elapsed = time_report.elapsed.as_secs_f64(); - conn_stats[ind].mean += - (elapsed - conn_stats[ind].mean) / SMOOTH_N; - let elapsed_sq = elapsed * elapsed; - conn_stats[ind].mean_sq += - (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N; - println!( - "new mean {} mean_sq {}", - conn_stats[ind].mean, conn_stats[ind].mean_sq - ); - let mean = conn_stats[ind].mean; - let var = conn_stats[ind].mean_sq - mean * mean; - let std_dev = if var < 0. { 0. } else { f64::sqrt(var) }; - println!("std dev {}", std_dev); - let est_rt = mean + 3. * std_dev; - conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt); - println!("new est_rt {:?}", conn_rt[ind].est_rt); + if let Some(ind) = opt_ind { + println!("Report for ind {}", ind); + let elapsed = time_report.elapsed.as_secs_f64(); + conn_stats[ind].mean += + (elapsed - conn_stats[ind].mean) / SMOOTH_N; + let elapsed_sq = elapsed * elapsed; + conn_stats[ind].mean_sq += + (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N; + println!( + "new mean {} mean_sq {}", + conn_stats[ind].mean, conn_stats[ind].mean_sq + ); + let mean = conn_stats[ind].mean; + let var = conn_stats[ind].mean_sq - mean * mean; + let std_dev = + if var < 0. { 0. } else { f64::sqrt(var) }; + println!("std dev {}", std_dev); + let est_rt = mean + 3. * std_dev; + conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt); + println!("new est_rt {:?}", conn_rt[ind].est_rt); + } } ChanReq::Failure(time_report) => { println!( @@ -553,45 +580,50 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> ); let opt_ind = conn_rt.iter().position(|e| e.id == time_report.id); - let ind = opt_ind.unwrap(); - println!("Failure Report for ind {}", ind); - let elapsed = time_report.elapsed.as_secs_f64(); - if elapsed < conn_stats[ind].mean { - // Do not update the mean if a - // failure took less time than the - // current mean. - println!("ignoring better time"); - continue; + if let Some(ind) = opt_ind { + println!("Failure Report for ind {}", ind); + let elapsed = time_report.elapsed.as_secs_f64(); + if elapsed < conn_stats[ind].mean { + // Do not update the mean if a + // failure took less time than the + // current mean. + println!("ignoring better time"); + continue; + } + conn_stats[ind].mean += + (elapsed - conn_stats[ind].mean) / SMOOTH_N; + let elapsed_sq = elapsed * elapsed; + conn_stats[ind].mean_sq += + (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N; + println!( + "new mean {} mean_sq {}", + conn_stats[ind].mean, conn_stats[ind].mean_sq + ); + let mean = conn_stats[ind].mean; + let var = conn_stats[ind].mean_sq - mean * mean; + let std_dev = + if var < 0. { 0. } else { f64::sqrt(var) }; + println!("std dev {}", std_dev); + let est_rt = mean + 3. * std_dev; + conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt); + println!("new est_rt {:?}", conn_rt[ind].est_rt); } - conn_stats[ind].mean += - (elapsed - conn_stats[ind].mean) / SMOOTH_N; - let elapsed_sq = elapsed * elapsed; - conn_stats[ind].mean_sq += - (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N; - println!( - "new mean {} mean_sq {}", - conn_stats[ind].mean, conn_stats[ind].mean_sq - ); - let mean = conn_stats[ind].mean; - let var = conn_stats[ind].mean_sq - mean * mean; - let std_dev = if var < 0. { 0. } else { f64::sqrt(var) }; - println!("std dev {}", std_dev); - let est_rt = mean + 3. * std_dev; - conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt); - println!("new est_rt {:?}", conn_rt[ind].est_rt); } } } } /// Implementation of the add method. - async fn add(&self, conn: Box + Send + Sync>) { + async fn add( + &self, + conn: Box + Send + Sync>, + ) -> Result<(), Error> { let (tx, rx) = oneshot::channel(); self.sender .send(ChanReq::Add(AddReq { conn, tx })) .await - .unwrap(); - rx.await.unwrap().unwrap(); + .expect("send should not fail"); + rx.await.expect("receive should not fail") } /// Implementation of the query method. @@ -603,8 +635,8 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> self.sender .send(ChanReq::GetRT(RTReq { tx })) .await - .unwrap(); - let conn_rt = rx.await.unwrap().unwrap(); + .expect("send should not fail"); + let conn_rt = rx.await.expect("receive should not fail")?; Ok(Query::new(query_msg, conn_rt, self.sender.clone())) } } diff --git a/src/net/client/tls_factory.rs b/src/net/client/tls_factory.rs index dd346aef1..fbf0158af 100644 --- a/src/net/client/tls_factory.rs +++ b/src/net/client/tls_factory.rs @@ -83,7 +83,15 @@ impl > { let tls_connection = TlsConnector::from(self.client_config.clone()); let server_name = - ServerName::try_from(self.server_name.as_str()).unwrap(); + match ServerName::try_from(self.server_name.as_str()) { + Err(_) => { + return Box::pin(error_helper(std::io::Error::new( + std::io::ErrorKind::Other, + "invalid DNS name", + ))); + } + Ok(res) => res, + }; let addr = self.addr.clone(); Box::pin(Next { future: Box::pin(async { @@ -94,3 +102,9 @@ impl }) } } + +async fn error_helper( + err: std::io::Error, +) -> Result, std::io::Error> { + return Err(err); +} diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 12cafda30..f33ec339b 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -254,9 +254,12 @@ impl Query { // Unfortunately we cannot pass query_msg to is_answer // because is_answer requires Octets, which is not // implemented by BytesMut. Make a copy. - let query_msg = - Message::from_octets(self.query_msg.as_slice()) - .unwrap(); + let query_msg = Message::from_octets( + self.query_msg.as_slice(), + ) + .expect( + "Message failed to parse contents of another Message", + ); if !answer.is_answer(&query_msg) { continue; } @@ -332,7 +335,8 @@ impl InnerConnection { let slice = query_msg.as_slice(); let mut bytes = BytesMut::with_capacity(slice.len()); bytes.extend_from_slice(slice); - let query_msg = Message::from_octets(bytes).unwrap(); + let query_msg = Message::from_octets(bytes) + .expect("Message failed to parse contents of another Message"); Ok(Query::new(query_msg, self.remote_addr, conn)) } From 086cd282b76fa4aacdc3fc63c616c4e84ad94564 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 9 Oct 2023 13:46:02 +0200 Subject: [PATCH 049/124] Clippy --- src/net/client/tls_factory.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/net/client/tls_factory.rs b/src/net/client/tls_factory.rs index fbf0158af..9ad5eeb3d 100644 --- a/src/net/client/tls_factory.rs +++ b/src/net/client/tls_factory.rs @@ -103,8 +103,9 @@ impl } } +/// Helper to return an error as an async function. async fn error_helper( err: std::io::Error, ) -> Result, std::io::Error> { - return Err(err); + Err(err) } From 431a7ba4fd4b40b0e99c6485afb4b06f7091ef64 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 9 Oct 2023 13:58:25 +0200 Subject: [PATCH 050/124] Check result of add. --- examples/client-transports.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index e69f99a55..5b982cebd 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -37,8 +37,7 @@ async fn main() { println!("request msg: {:?}", msg.as_slice()); // Destination for UDP and TCP - let server_addr = - SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); + let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); // Create a new UDP+TCP transport connection. Pass the destination address // and port as parameter. @@ -133,9 +132,9 @@ async fn main() { }); // Add the previously created transports. - redun.add(Box::new(udptcp_conn)).await; - redun.add(Box::new(tcp_conn)).await; - redun.add(Box::new(tls_conn)).await; + redun.add(Box::new(udptcp_conn)).await.unwrap(); + redun.add(Box::new(tcp_conn)).await.unwrap(); + redun.add(Box::new(tls_conn)).await.unwrap(); // Start a few queries. for _i in 1..10 { From 72c36de35572b337e77a40d6db1663c6d7623d6c Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 9 Oct 2023 14:01:38 +0200 Subject: [PATCH 051/124] rt-multi-thread is not needed. --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 29a835730..3d4a3f190 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ ring = { version = "0.17", optional = true } serde = { version = "1.0.130", optional = true, features = ["derive"] } siphasher = { version = "1", optional = true } smallvec = { version = "1", optional = true } -tokio = { version = "1.0", optional = true, features = ["io-util", "macros", "net", "time", "sync", "rt-multi-thread" ] } +tokio = { version = "1.0", optional = true, features = ["io-util", "macros", "net", "time", "sync" ] } tokio-rustls = { version = "0", optional = true, features = [] } [target.'cfg(macos)'.dependencies] From 1d72b130c48cd47fe310d3b4784fd419a78f3b31 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 9 Oct 2023 14:11:59 +0200 Subject: [PATCH 052/124] Added minor version where major version is 0. --- Cargo.toml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3d4a3f190..c8bc8c572 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ serde = { version = "1.0.130", optional = true, features = ["derive"] } siphasher = { version = "1", optional = true } smallvec = { version = "1", optional = true } tokio = { version = "1.0", optional = true, features = ["io-util", "macros", "net", "time", "sync" ] } -tokio-rustls = { version = "0", optional = true, features = [] } +tokio-rustls = { version = "0.24", optional = true, features = [] } [target.'cfg(macos)'.dependencies] # specifying this overrides minimum-version mio's 0.2.69 libc dependency, which allows the build to work @@ -58,12 +58,11 @@ zonefile = ["bytes", "std"] ci-test = ["resolv", "resolv-sync", "sign", "std", "serde", "tsig", "validate", "zonefile"] [dev-dependencies] -rustls = { version = "0" } +rustls = { version = "0.21" } serde_test = "1.0.130" serde_yaml = "0.9" tokio = { version = "1", features = ["rt-multi-thread", "io-util", "net"] } -tokio-rustls = { version = "0" } -webpki-roots = { version = "0" } +webpki-roots = { version = "0.25" } [package.metadata.docs.rs] all-features = true From 9d6fa9201263b718e6d8578e04c5e67eab3b01f8 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 9 Oct 2023 14:35:03 +0200 Subject: [PATCH 053/124] Get rid of Pin { /// new connection. /// /// This method is equivalent to async fn next(&self) -> Result; - fn next( - &self, - ) -> Pin> + Send + '_>>; + + /// Associated type for the return type of next. + type F: Future> + Send; + + /// Get the next IO connection. + fn next(&self) -> Self::F; } diff --git a/src/net/client/tcp_factory.rs b/src/net/client/tcp_factory.rs index 3586db96d..97096ae19 100644 --- a/src/net/client/tcp_factory.rs +++ b/src/net/client/tcp_factory.rs @@ -53,10 +53,11 @@ impl Future for Next { impl ConnFactory for TcpConnFactory { - fn next( - &self, - ) -> Pin> + Send>> - { + type F = Pin< + Box> + Send>, + >; + + fn next(&self) -> Self::F { Box::pin(Next { future: Box::pin(TcpStream::connect(self.addr.clone())), }) diff --git a/src/net/client/tls_factory.rs b/src/net/client/tls_factory.rs index 9ad5eeb3d..748233e58 100644 --- a/src/net/client/tls_factory.rs +++ b/src/net/client/tls_factory.rs @@ -72,15 +72,14 @@ impl Future for Next { impl ConnFactory> for TlsConnFactory { - fn next( - &self, - ) -> Pin< + type F = Pin< Box< dyn Future, std::io::Error>> - + Send - + '_, + + Send, >, - > { + >; + + fn next(&self) -> Self::F { let tls_connection = TlsConnector::from(self.client_config.clone()); let server_name = match ServerName::try_from(self.server_name.as_str()) { From 1ab08b6c27f578c1cd32513972fbb3bb0b5a6ca0 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 9 Oct 2023 14:40:09 +0200 Subject: [PATCH 054/124] Feature net depends on std. --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index c8bc8c572..4b93c0c76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,7 +48,7 @@ serde = ["dep:serde", "octseq/serde"] sign = ["std"] smallvec = ["dep:smallvec", "octseq/smallvec"] std = [] -net = ["bytes", "futures", "tokio", "tokio-rustls"] +net = ["bytes", "futures", "std", "tokio", "tokio-rustls"] tsig = ["bytes", "ring", "smallvec"] validate = ["std", "ring"] zonefile = ["bytes", "std"] From 9d7be4c9122aa7d6d0980ddb0e9e7d9e45a48bb5 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Tue, 10 Oct 2023 15:13:31 +0200 Subject: [PATCH 055/124] Reorder and section markers --- src/net/client/multi_stream.rs | 646 +++++++++++++++++---------------- src/net/client/octet_stream.rs | 534 ++++++++++++++------------- src/net/client/redundant.rs | 272 +++++++------- src/net/client/tcp_factory.rs | 48 +-- src/net/client/tls_factory.rs | 56 +-- src/net/client/udp_tcp.rs | 8 +- 6 files changed, 803 insertions(+), 761 deletions(-) diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 3d6c5b756..8995089b3 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -43,66 +43,140 @@ const DEF_CHAN_CAP: usize = 8; /// [InnerConnection::run] terminated. const ERR_CONN_CLOSED: &str = "connection closed"; -/// Response to the DNS request sent by [InnerConnection::run] to [Query]. -#[derive(Debug)] -struct ChanRespOk> { - /// id of this connection. - id: u64, +//------------ Connection ----------------------------------------------------- - /// New octet_stream transport. - conn: SingleConnection, +#[derive(Clone, Debug)] +/// A DNS over octect streams transport. +pub struct Connection> { + /// Reference counted [InnerConnection]. + inner: Arc>, } -/// The reply to a NewConn request. -type ChanResp = Result, Arc>; +impl + Clone + Debug + Octets + Send + Sync + 'static> + Connection +{ + /// Constructor for [Connection]. + /// + /// Returns a [Connection] wrapped in a [Result](io::Result). + pub fn new() -> io::Result> { + let connection = InnerConnection::new()?; + Ok(Self { + inner: Arc::new(connection), + }) + } -/// This is the type of sender in [ReqCmd]. -type ReplySender = oneshot::Sender>; + /// Main execution function for [Connection]. + /// + /// This function has to run in the background or together with + /// any calls to [query](Self::query) or [Query::get_result]. + pub async fn run< + F: ConnFactory + Send, + IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Unpin, + >( + &self, + factory: F, + ) -> Option<()> { + self.inner.run(factory).await + } -#[derive(Debug)] -/// Commands that can be requested. -enum ReqCmd> { - /// Request for a (new) connection. + /// Start a DNS request. /// - /// The id of the previous connection (if any) is passed as well as a - /// channel to send the reply. - NewConn(Option, ReplySender), + /// This function takes a precomposed message as a parameter and + /// returns a [Query] object wrapped in a [Result]. + pub async fn query_impl( + &self, + query_msg: &Message, + ) -> Result, Error> { + let (tx, rx) = oneshot::channel(); + self.inner.new_conn(None, tx).await?; + let gr = Query::new(self.clone(), query_msg, rx); + Ok(gr) + } - /// Shutdown command. - Shutdown, + /// Start a DNS request. + /// + /// This function takes a precomposed message as a parameter and + /// returns a [Query] object wrapped in a [Result]. + pub async fn query_impl3( + &self, + query_msg: &Message, + ) -> Result, Error> { + let (tx, rx) = oneshot::channel(); + self.inner.new_conn(None, tx).await?; + let gr = Query::new(self.clone(), query_msg, rx); + Ok(Box::new(gr)) + } + + /// Shutdown this transport. + pub async fn shutdown(&self) -> Result<(), &'static str> { + self.inner.shutdown().await + } + + /// Request a new connection. + async fn new_conn( + &self, + id: u64, + tx: oneshot::Sender>, + ) -> Result<(), Error> { + self.inner.new_conn(Some(id), tx).await + } } -#[derive(Debug)] -/// A request to [Connection::run] either for a new octet_stream or to -/// shutdown. -struct ChanReq> { - /// A requests consists of a command. - cmd: ReqCmd, +impl + QueryMessage, Octs> for Connection +{ + fn query<'a>( + &'a self, + query_msg: &'a Message, + ) -> Pin, Error>> + Send + '_>> + { + return Box::pin(self.query_impl(query_msg)); + } } -/// The actual implementation of [Connection]. +impl + Clone + Debug + Octets + Send + Sync + 'static> + QueryMessage3 for Connection +{ + fn query<'a>( + &'a self, + query_msg: &'a Message, + ) -> Pin< + Box< + dyn Future, Error>> + + Send + + '_, + >, + > { + return Box::pin(self.query_impl3(query_msg)); + } +} + +//------------ Query ---------------------------------------------------------- + +/// This struct represent an active DNS query. #[derive(Debug)] -struct InnerConnection> { - /// [InnerConnection::sender] and [InnerConnection::receiver] are - /// part of a single channel. +pub struct Query> { + /// Request message. /// - /// Used by [Query] to send requests to [InnerConnection::run]. - sender: mpsc::Sender>, + /// The reply message is compared with the request message to see if + /// it matches the query. + // query_msg: Message>, + query_msg: Message, - /// receiver part of the channel. - /// - /// Protected by a mutex to allow read/write access by - /// [InnerConnection::run]. - /// The Option is to allow [InnerConnection::run] to signal that the - /// connection is closed. - receiver: Futures_mutex>>>, -} + /// Current state of the query. + state: QueryState, -#[derive(Clone, Debug)] -/// A DNS over octect streams transport. -pub struct Connection> { - /// Reference counted [InnerConnection]. - inner: Arc>, + /// A multi_octet connection object is needed to request new underlying + /// octet_stream transport connections. + conn: Connection, + + /// id of most recent connection. + conn_id: u64, + + // /// Number of retries without delay. + // imm_retry_count: u16, + /// Number of retries with delay. + delayed_retry_count: u64, } /// Status of a query. Used in [Query]. @@ -127,36 +201,198 @@ enum QueryState> { Done, } -/// State associated with a failed attempt to create a new octet_stream -/// transport. -#[derive(Clone)] -struct ErrorState { - /// The error we got from the most recent attempt. - error: Arc, +/// The reply to a NewConn request. +type ChanResp = Result, Arc>; - /// How many times we tried so far. - retries: u64, +/// Response to the DNS request sent by [InnerConnection::run] to [Query]. +#[derive(Debug)] +struct ChanRespOk> { + /// id of this connection. + id: u64, - /// When we got an error. - timer: Instant, + /// New octet_stream transport. + conn: SingleConnection, +} - /// Time to wait before trying to create a new connection. - timeout: Duration, +impl + Clone + Debug + Octets + Send + Sync + 'static> + Query +{ + /// Constructor for [Query], takes a DNS query and a receiver for the + /// reply. + fn new( + conn: Connection, + query_msg: &Message, + receiver: oneshot::Receiver>, + ) -> Query { + Self { + conn, + query_msg: query_msg.clone(), + state: QueryState::GetConn(receiver), + conn_id: 0, + //imm_retry_count: 0, + delayed_retry_count: 0, + } + } + + /// Get the result of a DNS query. + /// + /// This function returns the reply to a DNS query wrapped in a + /// [Result]. + pub async fn get_result_impl(&mut self) -> Result, Error> { + loop { + match self.state { + QueryState::GetConn(ref mut receiver) => { + let res = receiver.await; + if res.is_err() { + // Assume receive error + self.state = QueryState::Done; + return Err(Error::StreamReceiveError); + } + let res = res.expect("error is checked before"); + + // Another Result. This time from executing the request + match res { + Err(_) => { + self.delayed_retry_count += 1; + let retry_time = + retry_time(self.delayed_retry_count); + self.state = + QueryState::Delay(Instant::now(), retry_time); + continue; + } + Ok(ok_res) => { + let id = ok_res.id; + let conn = ok_res.conn; + + self.conn_id = id; + self.state = QueryState::StartQuery(conn); + continue; + } + } + } + QueryState::StartQuery(ref mut conn) => { + let mut msg = self.query_msg.clone(); + let query_res = conn.query_no_check(&mut msg).await; + match query_res { + Err(err) => { + if let Error::ConnectionClosed = err { + let (tx, rx) = oneshot::channel(); + let res = self + .conn + .new_conn(self.conn_id, tx) + .await; + if let Err(err) = res { + self.state = QueryState::Done; + return Err(err); + } + self.state = QueryState::GetConn(rx); + continue; + } + return Err(err); + } + Ok(query) => { + self.state = QueryState::GetResult(query); + continue; + } + } + } + QueryState::GetResult(ref mut query) => { + let reply = query.get_result().await; + + if reply.is_err() { + self.delayed_retry_count += 1; + let retry_time = retry_time(self.delayed_retry_count); + self.state = + QueryState::Delay(Instant::now(), retry_time); + continue; + } + + let msg = reply.expect("error is checked before"); + let query_msg_ref: &[u8] = self.query_msg.as_ref(); + let query_msg_vec = query_msg_ref.to_vec(); + let query_msg = Message::from_octets(query_msg_vec) + .expect("how to go from MessageBuild to Message?"); + + if !is_answer_ignore_id(&msg, &query_msg) { + return Err(Error::WrongReplyForQuery); + } + return Ok(msg); + } + QueryState::Delay(instant, duration) => { + sleep_until(instant + duration).await; + let (tx, rx) = oneshot::channel(); + let res = self.conn.new_conn(self.conn_id, tx).await; + if let Err(err) = res { + self.state = QueryState::Done; + return Err(err); + } + self.state = QueryState::GetConn(rx); + continue; + } + QueryState::Done => { + panic!("Already done"); + } + } + } + } } -/// State of the current underlying octet_stream transport. -enum SingleConnState3> { - /// No current octet_stream transport. - None, +impl + Clone + Debug + Octets + Send + Sync + 'static> + GetResult for Query +{ + fn get_result( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + > { + Box::pin(self.get_result_impl()) + } +} - /// Current octet_stream transport. - Some(SingleConnection), +//------------ InnerConnection ------------------------------------------------ - /// State that deals with an error getting a new octet stream from - /// a factory. - Err(ErrorState), +/// The actual implementation of [Connection]. +#[derive(Debug)] +struct InnerConnection> { + /// [InnerConnection::sender] and [InnerConnection::receiver] are + /// part of a single channel. + /// + /// Used by [Query] to send requests to [InnerConnection::run]. + sender: mpsc::Sender>, + + /// receiver part of the channel. + /// + /// Protected by a mutex to allow read/write access by + /// [InnerConnection::run]. + /// The Option is to allow [InnerConnection::run] to signal that the + /// connection is closed. + receiver: Futures_mutex>>>, +} + +#[derive(Debug)] +/// A request to [Connection::run] either for a new octet_stream or to +/// shutdown. +struct ChanReq> { + /// A requests consists of a command. + cmd: ReqCmd, +} + +#[derive(Debug)] +/// Commands that can be requested. +enum ReqCmd> { + /// Request for a (new) connection. + /// + /// The id of the previous connection (if any) is passed as well as a + /// channel to send the reply. + NewConn(Option, ReplySender), + + /// Shutdown command. + Shutdown, } +/// This is the type of sender in [ReqCmd]. +type ReplySender = oneshot::Sender>; + /// Internal datastructure of [InnerConnection::run] to keep track of /// the status of the connection. // The types Status and ConnState are only used in InnerConnection @@ -180,30 +416,34 @@ struct State3<'a, F, IO, Octs: AsRef<[u8]>> { phantom: PhantomData<&'a IO>, } -/// This struct represent an active DNS query. -#[derive(Debug)] -pub struct Query> { - /// Request message. - /// - /// The reply message is compared with the request message to see if - /// it matches the query. - // query_msg: Message>, - query_msg: Message, +/// State of the current underlying octet_stream transport. +enum SingleConnState3> { + /// No current octet_stream transport. + None, + + /// Current octet_stream transport. + Some(SingleConnection), + + /// State that deals with an error getting a new octet stream from + /// a factory. + Err(ErrorState), +} - /// Current state of the query. - state: QueryState, +/// State associated with a failed attempt to create a new octet_stream +/// transport. +#[derive(Clone)] +struct ErrorState { + /// The error we got from the most recent attempt. + error: Arc, - /// A multi_octet connection object is needed to request new underlying - /// octet_stream transport connections. - conn: Connection, + /// How many times we tried so far. + retries: u64, - /// id of most recent connection. - conn_id: u64, + /// When we got an error. + timer: Instant, - // /// Number of retries without delay. - // imm_retry_count: u16, - /// Number of retries with delay. - delayed_retry_count: u64, + /// Time to wait before trying to create a new connection. + timeout: Duration, } impl + Clone + Octets + Send + Sync + 'static> @@ -455,239 +695,7 @@ impl + Clone + Octets + Send + Sync + 'static> } } -impl + Clone + Debug + Octets + Send + Sync + 'static> - Connection -{ - /// Constructor for [Connection]. - /// - /// Returns a [Connection] wrapped in a [Result](io::Result). - pub fn new() -> io::Result> { - let connection = InnerConnection::new()?; - Ok(Self { - inner: Arc::new(connection), - }) - } - - /// Main execution function for [Connection]. - /// - /// This function has to run in the background or together with - /// any calls to [query](Self::query) or [Query::get_result]. - pub async fn run< - F: ConnFactory + Send, - IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Unpin, - >( - &self, - factory: F, - ) -> Option<()> { - self.inner.run(factory).await - } - - /// Start a DNS request. - /// - /// This function takes a precomposed message as a parameter and - /// returns a [Query] object wrapped in a [Result]. - pub async fn query_impl( - &self, - query_msg: &Message, - ) -> Result, Error> { - let (tx, rx) = oneshot::channel(); - self.inner.new_conn(None, tx).await?; - let gr = Query::new(self.clone(), query_msg, rx); - Ok(gr) - } - - /// Start a DNS request. - /// - /// This function takes a precomposed message as a parameter and - /// returns a [Query] object wrapped in a [Result]. - pub async fn query_impl3( - &self, - query_msg: &Message, - ) -> Result, Error> { - let (tx, rx) = oneshot::channel(); - self.inner.new_conn(None, tx).await?; - let gr = Query::new(self.clone(), query_msg, rx); - Ok(Box::new(gr)) - } - - /// Shutdown this transport. - pub async fn shutdown(&self) -> Result<(), &'static str> { - self.inner.shutdown().await - } - - /// Request a new connection. - async fn new_conn( - &self, - id: u64, - tx: oneshot::Sender>, - ) -> Result<(), Error> { - self.inner.new_conn(Some(id), tx).await - } -} - -impl - QueryMessage, Octs> for Connection -{ - fn query<'a>( - &'a self, - query_msg: &'a Message, - ) -> Pin, Error>> + Send + '_>> - { - return Box::pin(self.query_impl(query_msg)); - } -} - -impl + Clone + Debug + Octets + Send + Sync + 'static> - QueryMessage3 for Connection -{ - fn query<'a>( - &'a self, - query_msg: &'a Message, - ) -> Pin< - Box< - dyn Future, Error>> - + Send - + '_, - >, - > { - return Box::pin(self.query_impl3(query_msg)); - } -} - -impl + Clone + Debug + Octets + Send + Sync + 'static> - Query -{ - /// Constructor for [Query], takes a DNS query and a receiver for the - /// reply. - fn new( - conn: Connection, - query_msg: &Message, - receiver: oneshot::Receiver>, - ) -> Query { - Self { - conn, - query_msg: query_msg.clone(), - state: QueryState::GetConn(receiver), - conn_id: 0, - //imm_retry_count: 0, - delayed_retry_count: 0, - } - } - - /// Get the result of a DNS query. - /// - /// This function returns the reply to a DNS query wrapped in a - /// [Result]. - pub async fn get_result_impl(&mut self) -> Result, Error> { - loop { - match self.state { - QueryState::GetConn(ref mut receiver) => { - let res = receiver.await; - if res.is_err() { - // Assume receive error - self.state = QueryState::Done; - return Err(Error::StreamReceiveError); - } - let res = res.expect("error is checked before"); - - // Another Result. This time from executing the request - match res { - Err(_) => { - self.delayed_retry_count += 1; - let retry_time = - retry_time(self.delayed_retry_count); - self.state = - QueryState::Delay(Instant::now(), retry_time); - continue; - } - Ok(ok_res) => { - let id = ok_res.id; - let conn = ok_res.conn; - - self.conn_id = id; - self.state = QueryState::StartQuery(conn); - continue; - } - } - } - QueryState::StartQuery(ref mut conn) => { - let mut msg = self.query_msg.clone(); - let query_res = conn.query_no_check(&mut msg).await; - match query_res { - Err(err) => { - if let Error::ConnectionClosed = err { - let (tx, rx) = oneshot::channel(); - let res = self - .conn - .new_conn(self.conn_id, tx) - .await; - if let Err(err) = res { - self.state = QueryState::Done; - return Err(err); - } - self.state = QueryState::GetConn(rx); - continue; - } - return Err(err); - } - Ok(query) => { - self.state = QueryState::GetResult(query); - continue; - } - } - } - QueryState::GetResult(ref mut query) => { - let reply = query.get_result().await; - - if reply.is_err() { - self.delayed_retry_count += 1; - let retry_time = retry_time(self.delayed_retry_count); - self.state = - QueryState::Delay(Instant::now(), retry_time); - continue; - } - - let msg = reply.expect("error is checked before"); - let query_msg_ref: &[u8] = self.query_msg.as_ref(); - let query_msg_vec = query_msg_ref.to_vec(); - let query_msg = Message::from_octets(query_msg_vec) - .expect("how to go from MessageBuild to Message?"); - - if !is_answer_ignore_id(&msg, &query_msg) { - return Err(Error::WrongReplyForQuery); - } - return Ok(msg); - } - QueryState::Delay(instant, duration) => { - sleep_until(instant + duration).await; - let (tx, rx) = oneshot::channel(); - let res = self.conn.new_conn(self.conn_id, tx).await; - if let Err(err) = res { - self.state = QueryState::Done; - return Err(err); - } - self.state = QueryState::GetConn(rx); - continue; - } - QueryState::Done => { - panic!("Already done"); - } - } - } - } -} - -impl + Clone + Debug + Octets + Send + Sync + 'static> - GetResult for Query -{ - fn get_result( - &mut self, - ) -> Pin< - Box, Error>> + Send + '_>, - > { - Box::pin(self.get_result_impl()) - } -} +//------------ Utility -------------------------------------------------------- /// Compute the retry timeout based on the number of retries so far. /// diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index df3d3686f..d7c785d17 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -61,27 +61,254 @@ const DEF_CHAN_CAP: usize = 8; /// [InnerConnection::run]. const READ_REPLY_CHAN_CAP: usize = 8; -/// This is the type of sender in [ChanReq]. -type ReplySender = oneshot::Sender; +//------------ Connection ----------------------------------------------------- + +#[derive(Clone, Debug)] +/// A single DNS over octect stream connection. +pub struct Connection> { + /// Reference counted [InnerConnection]. + inner: Arc>, +} +impl Connection { + /// Constructor for [Connection]. + /// + /// Returns a [Connection] wrapped in a [Result](io::Result). + pub fn new() -> io::Result> { + let connection = InnerConnection::new()?; + Ok(Self { + inner: Arc::new(connection), + }) + } + + /// Main execution function for [Connection]. + /// + /// This function has to run in the background or together with + /// any calls to [query](Self::query) or [Query::get_result]. + pub async fn run( + &self, + io: IO, + ) -> Option<()> { + self.inner.run(io).await + } + + /// Start a DNS request. + /// + /// This function takes a precomposed message as a parameter and + /// returns a [Query] object wrapped in a [Result]. + async fn query_impl( + &self, + query_msg: &Message, + ) -> Result { + let (tx, rx) = oneshot::channel(); + self.inner.query(tx, query_msg).await?; + let msg = query_msg; + Ok(Query::new(msg, rx)) + } + + /// Start a DNS request. + /// + /// This function takes a precomposed message as a parameter and + /// returns a [Query] object wrapped in a [Result]. + async fn query_impl3( + &self, + query_msg: &Message, + ) -> Result, Error> { + let (tx, rx) = oneshot::channel(); + self.inner.query(tx, query_msg).await?; + let msg = query_msg; + Ok(Box::new(Query::new(msg, rx))) + } + + /// Start a DNS request but do not check if the reply matches the request. + /// + /// This function is similar to [Self::query]. Not checking if the reply + /// match the request avoids having to keep the request around. + pub async fn query_no_check( + &self, + query_msg: &mut Message, + ) -> Result { + let (tx, rx) = oneshot::channel(); + self.inner.query(tx, query_msg).await?; + Ok(QueryNoCheck::new(rx)) + } +} + +impl< + Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Debug + Octets + Send + Sync, + > QueryMessage for Connection +{ + fn query<'a>( + &'a self, + query_msg: &'a Message, + ) -> Pin> + Send + '_>> { + return Box::pin(self.query_impl(query_msg)); + } +} + +impl + Clone + Octets + Send + Sync> QueryMessage3 + for Connection +{ + fn query<'a>( + &'a self, + query_msg: &'a Message, + ) -> Pin< + Box< + dyn Future, Error>> + + Send + + '_, + >, + > { + return Box::pin(self.query_impl3(query_msg)); + } +} + +//------------ Query ---------------------------------------------------------- + +/// This struct represent an active DNS query. #[derive(Debug)] -/// A request from [Query] to [Connection::run] to start a DNS request. -struct ChanReq> { - /// DNS request message - msg: Message, +pub struct Query { + /// Request message. + /// + /// The reply message is compared with the request message to see if + /// it matches the query. + query_msg: Message>, - /// Sender to send result back to [Query] - sender: ReplySender, + /// Current state of the query. + state: QueryState, } +/// Status of a query. Used in [Query]. #[derive(Debug)] -/// a response to a [ChanReq]. -struct Response { - /// The DNS reply message. - reply: Message, +enum QueryState { + /// A request is in progress. + /// + /// The receiver for receiving the response is part of this state. + Busy(oneshot::Receiver), + + /// The response has been received and the query is done. + Done, } -/// Response to the DNS request sent by [InnerConnection::run] to [Query]. -type ChanResp = Result; + +impl Query { + /// Constructor for [Query], takes a DNS query and a receiver for the + /// reply. + fn new( + query_msg: &Message, + receiver: oneshot::Receiver, + ) -> Query { + let msg_ref: &[u8] = query_msg.as_ref(); + let vec = msg_ref.to_vec(); + let msg = Message::from_octets(vec) + .expect("Message failed to parse contents of another Message"); + Self { + query_msg: msg, + state: QueryState::Busy(receiver), + } + } + + /// Get the result of a DNS query. + /// + /// This function returns the reply to a DNS query wrapped in a + /// [Result]. + pub async fn get_result_impl(&mut self) -> Result, Error> { + match self.state { + QueryState::Busy(ref mut receiver) => { + let res = receiver.await; + self.state = QueryState::Done; + if res.is_err() { + // Assume receive error + return Err(Error::StreamReceiveError); + } + let res = res.expect("already check error case"); + + // clippy seems to be wrong here. Replacing + // the following with 'res?;' doesn't work + #[allow(clippy::question_mark)] + if let Err(err) = res { + return Err(err); + } + + let resp = res.expect("error case is checked already"); + let msg = resp.reply; + + if !is_answer_ignore_id(&msg, &self.query_msg) { + return Err(Error::WrongReplyForQuery); + } + Ok(msg) + } + QueryState::Done => { + panic!("Already done"); + } + } + } +} + +impl GetResult for Query { + fn get_result( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + > { + Box::pin(self.get_result_impl()) + } +} + +//------------ QueryNoCheck --------------------------------------------------- + +/// This represents that state of an active DNS query if there is no need +/// to check that the reply matches the request. The assumption is that the +/// caller will do this check. +#[derive(Debug)] +pub struct QueryNoCheck { + /// Current state of the query. + state: QueryState, +} + +impl QueryNoCheck { + /// Constructor for [Query], takes a DNS query and a receiver for the + /// reply. + fn new(receiver: oneshot::Receiver) -> QueryNoCheck { + Self { + state: QueryState::Busy(receiver), + } + } + + /// Get the result of a DNS query. + /// + /// This function returns the reply to a DNS query wrapped in a + /// [Result]. + pub async fn get_result(&mut self) -> Result, Error> { + match self.state { + QueryState::Busy(ref mut receiver) => { + let res = receiver.await; + self.state = QueryState::Done; + if res.is_err() { + // Assume receive error + return Err(Error::StreamReceiveError); + } + let res = res.expect("error case is checked already"); + + // clippy seems to be wrong here. Replacing + // the following with 'res?;' doesn't work + #[allow(clippy::question_mark)] + if let Err(err) = res { + return Err(err); + } + + let resp = res.expect("error case is checked already"); + let msg = resp.reply; + + Ok(msg) + } + QueryState::Done => { + panic!("Already done"); + } + } + } +} + +//------------ InnerConnection ------------------------------------------------ /// The actual implementation of [Connection]. #[derive(Debug)] @@ -92,13 +319,36 @@ struct InnerConnection> { /// Used by [Query] to send requests to [InnerConnection::run]. sender: mpsc::Sender>, - /// receiver part of the channel. - /// - /// Protected by a mutex to allow read/write access by - /// [InnerConnection::run]. - /// The Option is to allow [InnerConnection::run] to signal that the - /// connection is closed. - receiver: Futures_mutex>>>, + /// receiver part of the channel. + /// + /// Protected by a mutex to allow read/write access by + /// [InnerConnection::run]. + /// The Option is to allow [InnerConnection::run] to signal that the + /// connection is closed. + receiver: Futures_mutex>>>, +} + +#[derive(Debug)] +/// A request from [Query] to [Connection::run] to start a DNS request. +struct ChanReq> { + /// DNS request message + msg: Message, + + /// Sender to send result back to [Query] + sender: ReplySender, +} + +/// This is the type of sender in [ChanReq]. +type ReplySender = oneshot::Sender; + +/// Response to the DNS request sent by [InnerConnection::run] to [Query]. +type ChanResp = Result; + +#[derive(Debug)] +/// a response to a [ChanReq]. +struct Response { + /// The DNS reply message. + reply: Message, } /// Internal datastructure of [InnerConnection::run] to keep track of @@ -114,47 +364,6 @@ struct Queries { vec: Vec>, } -#[derive(Clone, Debug)] -/// A single DNS over octect stream connection. -pub struct Connection> { - /// Reference counted [InnerConnection]. - inner: Arc>, -} - -/// Status of a query. Used in [Query]. -#[derive(Debug)] -enum QueryState { - /// A request is in progress. - /// - /// The receiver for receiving the response is part of this state. - Busy(oneshot::Receiver), - - /// The response has been received and the query is done. - Done, -} - -/// This struct represent an active DNS query. -#[derive(Debug)] -pub struct Query { - /// Request message. - /// - /// The reply message is compared with the request message to see if - /// it matches the query. - query_msg: Message>, - - /// Current state of the query. - state: QueryState, -} - -/// This represents that state of an active DNS query if there is no need -/// to check that the reply matches the request. The assumption is that the -/// caller will do this check. -#[derive(Debug)] -pub struct QueryNoCheck { - /// Current state of the query. - state: QueryState, -} - /// Internal datastructure of [InnerConnection::run] to keep track of /// the status of the connection. // The types Status and ConnState are only used in InnerConnection @@ -176,6 +385,7 @@ struct Status { /// edns-tcp-keepalive option may change that. idle_timeout: Option, } + /// Status of the connection. Used in [Status]. enum ConnState { /// The connection is in this state from the start and when at least @@ -776,205 +986,7 @@ impl + Clone> InnerConnection { } } -impl< - Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Debug + Octets + Send + Sync, - > QueryMessage for Connection -{ - fn query<'a>( - &'a self, - query_msg: &'a Message, - ) -> Pin> + Send + '_>> { - return Box::pin(self.query_impl(query_msg)); - } -} - -impl Connection { - /// Constructor for [Connection]. - /// - /// Returns a [Connection] wrapped in a [Result](io::Result). - pub fn new() -> io::Result> { - let connection = InnerConnection::new()?; - Ok(Self { - inner: Arc::new(connection), - }) - } - - /// Main execution function for [Connection]. - /// - /// This function has to run in the background or together with - /// any calls to [query](Self::query) or [Query::get_result]. - pub async fn run( - &self, - io: IO, - ) -> Option<()> { - self.inner.run(io).await - } - - /// Start a DNS request. - /// - /// This function takes a precomposed message as a parameter and - /// returns a [Query] object wrapped in a [Result]. - async fn query_impl( - &self, - query_msg: &Message, - ) -> Result { - let (tx, rx) = oneshot::channel(); - self.inner.query(tx, query_msg).await?; - let msg = query_msg; - Ok(Query::new(msg, rx)) - } - - /// Start a DNS request. - /// - /// This function takes a precomposed message as a parameter and - /// returns a [Query] object wrapped in a [Result]. - async fn query_impl3( - &self, - query_msg: &Message, - ) -> Result, Error> { - let (tx, rx) = oneshot::channel(); - self.inner.query(tx, query_msg).await?; - let msg = query_msg; - Ok(Box::new(Query::new(msg, rx))) - } - - /// Start a DNS request but do not check if the reply matches the request. - /// - /// This function is similar to [Self::query]. Not checking if the reply - /// match the request avoids having to keep the request around. - pub async fn query_no_check( - &self, - query_msg: &mut Message, - ) -> Result { - let (tx, rx) = oneshot::channel(); - self.inner.query(tx, query_msg).await?; - Ok(QueryNoCheck::new(rx)) - } -} - -impl + Clone + Octets + Send + Sync> QueryMessage3 - for Connection -{ - fn query<'a>( - &'a self, - query_msg: &'a Message, - ) -> Pin< - Box< - dyn Future, Error>> - + Send - + '_, - >, - > { - return Box::pin(self.query_impl3(query_msg)); - } -} - -impl Query { - /// Constructor for [Query], takes a DNS query and a receiver for the - /// reply. - fn new( - query_msg: &Message, - receiver: oneshot::Receiver, - ) -> Query { - let msg_ref: &[u8] = query_msg.as_ref(); - let vec = msg_ref.to_vec(); - let msg = Message::from_octets(vec) - .expect("Message failed to parse contents of another Message"); - Self { - query_msg: msg, - state: QueryState::Busy(receiver), - } - } - - /// Get the result of a DNS query. - /// - /// This function returns the reply to a DNS query wrapped in a - /// [Result]. - pub async fn get_result_impl(&mut self) -> Result, Error> { - match self.state { - QueryState::Busy(ref mut receiver) => { - let res = receiver.await; - self.state = QueryState::Done; - if res.is_err() { - // Assume receive error - return Err(Error::StreamReceiveError); - } - let res = res.expect("already check error case"); - - // clippy seems to be wrong here. Replacing - // the following with 'res?;' doesn't work - #[allow(clippy::question_mark)] - if let Err(err) = res { - return Err(err); - } - - let resp = res.expect("error case is checked already"); - let msg = resp.reply; - - if !is_answer_ignore_id(&msg, &self.query_msg) { - return Err(Error::WrongReplyForQuery); - } - Ok(msg) - } - QueryState::Done => { - panic!("Already done"); - } - } - } -} - -impl GetResult for Query { - fn get_result( - &mut self, - ) -> Pin< - Box, Error>> + Send + '_>, - > { - Box::pin(self.get_result_impl()) - } -} - -impl QueryNoCheck { - /// Constructor for [Query], takes a DNS query and a receiver for the - /// reply. - fn new(receiver: oneshot::Receiver) -> QueryNoCheck { - Self { - state: QueryState::Busy(receiver), - } - } - - /// Get the result of a DNS query. - /// - /// This function returns the reply to a DNS query wrapped in a - /// [Result]. - pub async fn get_result(&mut self) -> Result, Error> { - match self.state { - QueryState::Busy(ref mut receiver) => { - let res = receiver.await; - self.state = QueryState::Done; - if res.is_err() { - // Assume receive error - return Err(Error::StreamReceiveError); - } - let res = res.expect("error case is checked already"); - - // clippy seems to be wrong here. Replacing - // the following with 'res?;' doesn't work - #[allow(clippy::question_mark)] - if let Err(err) = res { - return Err(err); - } - - let resp = res.expect("error case is checked already"); - let msg = resp.reply; - - Ok(msg) - } - QueryState::Done => { - panic!("Already done"); - } - } - } -} +//------------ Utility -------------------------------------------------------- /// Add an edns-tcp-keepalive option to a MessageBuilder. /// diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index aa7a3cf38..f58a017ec 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -65,6 +65,8 @@ const PROBE_P: f64 = 0.05; /// When a worse connection is probed, give it a slight head start. const PROBE_RT: Duration = Duration::from_millis(1); +//------------ Connection ----------------------------------------------------- + /// This type represents a transport connection. #[derive(Clone)] pub struct Connection { @@ -125,6 +127,8 @@ impl< } } +//------------ Query ---------------------------------------------------------- + /// This type represents an active query request. #[derive(Debug)] pub struct Query + Send> { @@ -151,9 +155,6 @@ pub struct Query + Send> { res_index: usize, } -/// Result of the futures in fut_list. -type FutListOutput = Result<(usize, Result, Error>), Error>; - /// The various states a query can be in. #[derive(Debug)] enum QueryState { @@ -170,6 +171,110 @@ enum QueryState { Wait, } +/// The commands that can be sent to the run function. +enum ChanReq { + /// Add a connection + Add(AddReq), + + /// Get the list of estimated response times for all connections + GetRT(RTReq), + + /// Start a query + Query(QueryReq), + + /// Report how long it took to get a response + Report(TimeReport), + + /// Report that a connection failed to provide a timely response + Failure(TimeReport), +} + +impl Debug for ChanReq { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + f.debug_struct("ChanReq").finish() + } +} + +/// Request to add a new connection +struct AddReq { + /// New connection to add + conn: Box + Send + Sync>, + + /// Channel to send the reply to + tx: oneshot::Sender, +} + +/// Reply to an Add request +type AddReply = Result<(), Error>; + +/// Request to give the estimated response times for all connections +struct RTReq /**/ { + /// Channel to send the reply to + tx: oneshot::Sender, +} + +/// Reply to a RT request +type RTReply = Result, Error>; + +/// Request to start a query +struct QueryReq { + /// Identifier of connection + id: u64, + + /// Request message + query_msg: Message, + + /// Channel to send the reply to + tx: oneshot::Sender, +} + +impl + Debug + Send> Debug for QueryReq { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + f.debug_struct("QueryReq") + .field("id", &self.id) + .field("query_msg", &self.query_msg) + .finish() + } +} + +/// Reply to a query request. +type QueryReply = Result, Error>; + +/// Report the amount of time until success or failure. +#[derive(Debug)] +struct TimeReport { + /// Identifier of the transport connection. + id: u64, + + /// Time spend waiting for a reply. + elapsed: Duration, +} + +/// Connection statistics to compute the estimated response time. +struct ConnStats { + /// Aproximation of the windowed average of response times. + mean: f64, + + /// Aproximation of the windowed average of the square of response times. + mean_sq: f64, +} + +/// Data required to schedule requests and report timing results. +#[derive(Clone, Debug)] +struct ConnRT { + /// Estimated response time. + est_rt: Duration, + + /// Identifier of the connection. + id: u64, + + /// Start of a request using this connection. + start: Option, +} + +/// Result of the futures in fut_list. +type FutListOutput = Result<(usize, Result, Error>), Error>; + impl + Clone + Debug + Send + Sync + 'static> Query { /// Create a new query object. fn new( @@ -335,135 +440,7 @@ impl< } } -/// Async function to send a request and wait for the reply. -/// -/// This gives a single future that we can put in a list. -async fn start_request( - index: usize, - id: u64, - sender: mpsc::Sender>, - query_msg: Message, -) -> Result<(usize, Result, Error>), Error> { - let (tx, rx) = oneshot::channel(); - sender - .send(ChanReq::Query(QueryReq { - id, - query_msg: query_msg.clone(), - tx, - })) - .await - .expect("send is expected to work"); - let mut query = rx.await.expect("receive is expected to work")?; - let reply = query.get_result().await; - - Ok((index, reply)) -} - -/// The commands that can be sent to the run function. -enum ChanReq { - /// Add a connection - Add(AddReq), - - /// Get the list of estimated response times for all connections - GetRT(RTReq), - - /// Start a query - Query(QueryReq), - - /// Report how long it took to get a response - Report(TimeReport), - - /// Report that a connection failed to provide a timely response - Failure(TimeReport), -} - -impl Debug for ChanReq { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - f.debug_struct("ChanReq").finish() - } -} - -/// Request to add a new connection -struct AddReq { - /// New connection to add - conn: Box + Send + Sync>, - - /// Channel to send the reply to - tx: oneshot::Sender, -} - -/// Reply to an Add request -type AddReply = Result<(), Error>; - -/// Request to give the estimated response times for all connections -struct RTReq /**/ { - /// Channel to send the reply to - tx: oneshot::Sender, -} - -/// Reply to a RT request -type RTReply = Result, Error>; - -/// Request to start a query -struct QueryReq { - /// Identifier of connection - id: u64, - - /// Request message - query_msg: Message, - - /// Channel to send the reply to - tx: oneshot::Sender, -} - -impl + Debug + Send> Debug for QueryReq { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - f.debug_struct("QueryReq") - .field("id", &self.id) - .field("query_msg", &self.query_msg) - .finish() - } -} - -/// Reply to a query request. -type QueryReply = Result, Error>; - -/// Report the amount of time until success or failure. -#[derive(Debug)] -struct TimeReport { - /// Identifier of the transport connection. - id: u64, - - /// Time spend waiting for a reply. - elapsed: Duration, -} - -/// Connection statistics to compute the estimated response time. -struct ConnStats { - /// Aproximation of the windowed average of response times. - mean: f64, - - /// Aproximation of the windowed average of the square of response times. - mean_sq: f64, -} - -/// Data required to schedule requests and report timing results. -#[derive(Clone, Debug)] -struct ConnRT { - /// Estimated response time. - est_rt: Duration, - - /// Identifier of the connection. - id: u64, - - /// Start of a request using this connection. - start: Option, -} - -/// Compare ConnRT elements based on estimated response time. -fn conn_rt_cmp(e1: &ConnRT, e2: &ConnRT) -> Ordering { - e1.est_rt.cmp(&e2.est_rt) -} +//------------ InnerConnection ------------------------------------------------ /// Type that actually implements the connection. struct InnerConnection { @@ -641,4 +618,33 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> } } -//fn test_send(t: T) -> T { t } +//------------ Utility -------------------------------------------------------- + +/// Async function to send a request and wait for the reply. +/// +/// This gives a single future that we can put in a list. +async fn start_request( + index: usize, + id: u64, + sender: mpsc::Sender>, + query_msg: Message, +) -> Result<(usize, Result, Error>), Error> { + let (tx, rx) = oneshot::channel(); + sender + .send(ChanReq::Query(QueryReq { + id, + query_msg: query_msg.clone(), + tx, + })) + .await + .expect("send is expected to work"); + let mut query = rx.await.expect("receive is expected to work")?; + let reply = query.get_result().await; + + Ok((index, reply)) +} + +/// Compare ConnRT elements based on estimated response time. +fn conn_rt_cmp(e1: &ConnRT, e2: &ConnRT) -> Ordering { + e1.est_rt.cmp(&e2.est_rt) +} diff --git a/src/net/client/tcp_factory.rs b/src/net/client/tcp_factory.rs index 97096ae19..6abecdfae 100644 --- a/src/net/client/tcp_factory.rs +++ b/src/net/client/tcp_factory.rs @@ -13,21 +13,14 @@ use tokio::net::{TcpStream, ToSocketAddrs}; use crate::net::client::factory::ConnFactory; +//------------ TcpConnFactory ------------------------------------------------- + /// This a connection factory that produces TCP connections. pub struct TcpConnFactory { /// Remote address to connect to. addr: A, } -/// This is an internal structure that provides the future for a new -/// connection. -pub struct Next { - /// Future for creating a new TCP connection. - future: Pin< - Box> + Send>, - >, -} - impl TcpConnFactory { /// Create a new factory. /// @@ -37,19 +30,6 @@ impl TcpConnFactory { } } -impl Future for Next { - type Output = Result; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let me = self.deref_mut(); - let io = ready!(me.future.as_mut().poll(cx))?; - Poll::Ready(Ok(io)) - } -} - impl ConnFactory for TcpConnFactory { @@ -63,3 +43,27 @@ impl ConnFactory }) } } + +//------------ Next ----------------------------------------------------------- + +/// This is an internal structure that provides the future for a new +/// connection. +pub struct Next { + /// Future for creating a new TCP connection. + future: Pin< + Box> + Send>, + >, +} + +impl Future for Next { + type Output = Result; + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let me = self.deref_mut(); + let io = ready!(me.future.as_mut().poll(cx))?; + Poll::Ready(Ok(io)) + } +} diff --git a/src/net/client/tls_factory.rs b/src/net/client/tls_factory.rs index 748233e58..64b680bd3 100644 --- a/src/net/client/tls_factory.rs +++ b/src/net/client/tls_factory.rs @@ -17,6 +17,8 @@ use tokio_rustls::TlsConnector; use crate::net::client::factory::ConnFactory; +//------------ TlsConnFactory ------------------------------------------------- + /// Factory object for TLS connections pub struct TlsConnFactory { /// Configuration for setting up a TLS connection. @@ -29,18 +31,6 @@ pub struct TlsConnFactory { addr: A, } -/// Internal structure that contains the future for creating a new -/// TLS connection. -pub struct Next { - /// Future for creating a new TLS connection. - future: Pin< - Box< - dyn Future, std::io::Error>> - + Send, - >, - >, -} - impl TlsConnFactory { /// Function to create a new TLS connection factory pub fn new( @@ -56,19 +46,6 @@ impl TlsConnFactory { } } -impl Future for Next { - type Output = Result, std::io::Error>; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, std::io::Error>> { - let me = self.deref_mut(); - let io = ready!(me.future.as_mut().poll(cx))?; - Poll::Ready(Ok(io)) - } -} - impl ConnFactory> for TlsConnFactory { @@ -102,6 +79,35 @@ impl } } +//------------ Next ----------------------------------------------------------- + +/// Internal structure that contains the future for creating a new +/// TLS connection. +pub struct Next { + /// Future for creating a new TLS connection. + future: Pin< + Box< + dyn Future, std::io::Error>> + + Send, + >, + >, +} + +impl Future for Next { + type Output = Result, std::io::Error>; + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, std::io::Error>> { + let me = self.deref_mut(); + let io = ready!(me.future.as_mut().poll(cx))?; + Poll::Ready(Ok(io)) + } +} + +//------------ Utility -------------------------------------------------------- + /// Helper to return an error as an async function. async fn error_helper( err: std::io::Error, diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index 9ea15c3ec..a2f4a2e29 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -23,7 +23,9 @@ use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; use crate::net::client::tcp_factory::TcpConnFactory; use crate::net::client::udp; -/// DNS transport connection that first issue a query over a UDP transport and +//------------ Connection ----------------------------------------------------- + +/// DNS transport connection that first issues a query over a UDP transport and /// falls back to TCP if the reply is truncated. #[derive(Clone)] pub struct Connection + Debug> { @@ -94,6 +96,8 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> } } +//------------ Query ---------------------------------------------------------- + /// Object that contains the current state of a query. #[derive(Debug)] pub struct Query + Debug> { @@ -194,6 +198,8 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> } } +//------------ InnerConnection ------------------------------------------------ + /// The actual connection object. struct InnerConnection + Debug> { /// The remote address to connect to. From 9ff795bc6fda405cba8e148aaf99ed153b1ecd89 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Tue, 10 Oct 2023 17:15:46 +0200 Subject: [PATCH 056/124] New GetResult using a nested future to be cancel-safe --- src/net/client/udp.rs | 185 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 179 insertions(+), 6 deletions(-) diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index f33ec339b..88665763a 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -9,7 +9,7 @@ use bytes::{Bytes, BytesMut}; use std::boxed::Box; -use std::fmt::Debug; +use std::fmt::{Debug, Formatter}; use std::future::Future; use std::io; use std::net::SocketAddr; @@ -36,6 +36,8 @@ const READ_TIMEOUT: Duration = Duration::from_secs(5); /// Maximum number of retries after timeouts. const MAX_RETRIES: u8 = 5; +//------------ Connection ----------------------------------------------------- + /// A UDP transport connection. #[derive(Clone, Debug)] pub struct Connection { @@ -56,7 +58,7 @@ impl Connection { async fn query_impl + Clone + Send>( &self, query_msg: &Message, - ) -> Result { + ) -> Result { self.inner.query(query_msg, self.clone()).await } @@ -78,12 +80,13 @@ impl Connection { } impl + Clone + Debug + Send + Sync> - QueryMessage for Connection + QueryMessage for Connection { fn query<'a>( &'a self, query_msg: &'a Message, - ) -> Pin> + Send + '_>> { + ) -> Pin> + Send + '_>> + { return Box::pin(self.query_impl(query_msg)); } } @@ -105,6 +108,10 @@ impl + Clone + Debug + Send + Sync + 'static> } } +//------------ Query ---------------------------------------------------------- + +/* + /// State of the DNS query. #[derive(Debug)] enum QueryState { @@ -123,6 +130,9 @@ enum QueryState { /// Receive the reply. Receive(Instant), } +*/ + +/* /// The state of a DNS query. #[derive(Debug)] @@ -307,6 +317,169 @@ impl GetResult for Query { } } +*/ + +//------------ Query2 --------------------------------------------------------- + +/// The state of a DNS query. +pub struct Query2 { + /// Future that does the actual work of GetResult. + get_result_fut: + Pin, Error>> + Send>>, +} + +impl Query2 { + /// Create new Query object. + fn new( + query_msg: Message, + remote_addr: SocketAddr, + conn: Connection, + ) -> Query2 { + Query2 { + get_result_fut: Box::pin(Self::get_result_impl2( + query_msg, + remote_addr, + conn, + )), + } + } + + /// Async function that waits for the future stored in Query to complete. + async fn get_result_impl(&mut self) -> Result, Error> { + (&mut self.get_result_fut).await + } + + /// Get the result of a DNS Query. + /// + /// This function is not cancel safe. + async fn get_result_impl2( + mut query_msg: Message, + remote_addr: SocketAddr, + conn: Connection, + ) -> Result, Error> { + let recv_size = 2000; // Should be configurable. + + let mut retries: u8 = 0; + + // We need to get past the semaphore that limits the + // number of concurrent sockets we can use. + let _permit = conn.get_permit().await; + + loop { + let sock = Some(Self::udp_bind(remote_addr.is_ipv4()).await?); + + sock.as_ref() + .expect("socket should be present") + .connect(remote_addr) + .await + .map_err(|e| Error::UdpConnect(Arc::new(e)))?; + + // Set random ID in header + let header = query_msg.header_mut(); + header.set_random_id(); + let dgram = query_msg.as_slice(); + + let sent = sock + .as_ref() + .expect("socket should be present") + .send(dgram) + .await + .map_err(|e| Error::UdpSend(Arc::new(e)))?; + if sent != query_msg.as_slice().len() { + return Err(Error::UdpShortSend); + } + + let start = Instant::now(); + let elapsed = start.elapsed(); + if elapsed > READ_TIMEOUT { + todo!(); + } + let remain = READ_TIMEOUT - elapsed; + + let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here. + let timeout_res = timeout( + remain, + sock.as_ref() + .expect("socket should be present") + .recv(&mut buf), + ) + .await; + if timeout_res.is_err() { + retries += 1; + if retries < MAX_RETRIES { + continue; + } + return Err(Error::UdpTimeoutNoResponse); + } + let len = timeout_res + .expect("errror case is checked above") + .map_err(|e| Error::UdpReceive(Arc::new(e)))?; + buf.truncate(len); + + // We ignore garbage since there is a timer on this whole thing. + let answer = match Message::from_octets(buf.into()) { + Ok(answer) => answer, + Err(_) => continue, + }; + + // Unfortunately we cannot pass query_msg to is_answer + // because is_answer requires Octets, which is not + // implemented by BytesMut. Make a copy. + let query_msg = Message::from_octets(query_msg.as_slice()) + .expect( + "Message failed to parse contents of another Message", + ); + if !answer.is_answer(&query_msg) { + continue; + } + return Ok(answer); + } + } + + /// Bind to a local UDP port. + /// + /// This should explicitly pick a random number in a suitable range of + /// ports. + async fn udp_bind(v4: bool) -> Result { + let mut i = 0; + loop { + let local: SocketAddr = if v4 { + ([0u8; 4], 0).into() + } else { + ([0u16; 8], 0).into() + }; + match UdpSocket::bind(&local).await { + Ok(sock) => return Ok(sock), + Err(err) => { + if i == RETRY_RANDOM_PORT { + return Err(Error::UdpBind(Arc::new(err))); + } else { + i += 1 + } + } + } + } + } +} + +impl Debug for Query2 { + fn fmt(&self, _: &mut Formatter<'_>) -> Result<(), core::fmt::Error> { + todo!() + } +} + +impl GetResult for Query2 { + fn get_result( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + > { + Box::pin(self.get_result_impl()) + } +} + +//------------ InnerConnection ------------------------------------------------ + /// Actual implementation of the UDP transport connection. #[derive(Debug)] struct InnerConnection { @@ -331,13 +504,13 @@ impl InnerConnection { &self, query_msg: &Message, conn: Connection, - ) -> Result { + ) -> Result { let slice = query_msg.as_slice(); let mut bytes = BytesMut::with_capacity(slice.len()); bytes.extend_from_slice(slice); let query_msg = Message::from_octets(bytes) .expect("Message failed to parse contents of another Message"); - Ok(Query::new(query_msg, self.remote_addr, conn)) + Ok(Query2::new(query_msg, self.remote_addr, conn)) } /// Return a permit for a our semaphore. From cc72dc08fa6ca0d745c0addc03f5d3255be6d65d Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Tue, 10 Oct 2023 17:16:13 +0200 Subject: [PATCH 057/124] Remove tcp_mutex. This was an implementation using shared state (instead of communication channels) --- src/net/client/mod.rs | 1 - src/net/client/tcp_mutex.rs | 903 ------------------------------------ 2 files changed, 904 deletions(-) delete mode 100644 src/net/client/tcp_mutex.rs diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index be97cf9b3..d1f7beebf 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -14,7 +14,6 @@ pub mod octet_stream; pub mod query; pub mod redundant; pub mod tcp_factory; -pub mod tcp_mutex; pub mod tls_factory; pub mod udp; pub mod udp_tcp; diff --git a/src/net/client/tcp_mutex.rs b/src/net/client/tcp_mutex.rs deleted file mode 100644 index 0706d662e..000000000 --- a/src/net/client/tcp_mutex.rs +++ /dev/null @@ -1,903 +0,0 @@ -//! A DNS over TCP transport - -// RFC 7766 describes DNS over TCP -// RFC 7828 describes the edns-tcp-keepalive option - -// TODO: -// - errors -// - connect errors? Retry after connection refused? -// - server errors -// - ID out of range -// - ID not in use -// - reply for wrong query -// - timeouts -// - request timeout -// - limit number of outstanding queries to 32K -// - create new TCP connection after end/failure of previous one - -use bytes::{Bytes, BytesMut}; -use std::collections::VecDeque; -use std::ops::DerefMut; -use std::sync::Arc; -use std::sync::Mutex as Std_mutex; -use std::time::{Duration, Instant}; -use std::vec::Vec; - -use crate::base::wire::Composer; -use crate::base::{ - opt::{AllOptData, OptRecord, TcpKeepalive}, - Message, MessageBuilder, StaticCompressor, StreamTarget, -}; -use octseq::{Octets, OctetsBuilder}; - -use tokio::io; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::tcp::{ReadHalf, WriteHalf}; -use tokio::net::{TcpStream, ToSocketAddrs}; -use tokio::sync::Notify; -use tokio::time::sleep; - -const ERR_IDLE_TIMEOUT: &str = "idle connection was closed"; -const ERR_READ_ERROR: &str = "read error"; -const ERR_READ_TIMEOUT: &str = "read timeout"; -const ERR_WRITE_ERROR: &str = "write error"; -const ERR_TOO_MANY_QUERIES: &str = "too many outstanding queries"; - -// Implement a simple response timer to see if the connection and the server -// are alive. Set the timer when the connection goes from idle to busy. -// Reset the timer each time a reply arrives. Cancel the timer when the -// connection goes back to idle. When the time expires, mark all outstanding -// queries as timed out and shutdown the connection. -// -// Note: nsd has 120 seconds, unbound has 3 seconds. -const RESPONSE_TIMEOUT_S: u64 = 19; - -enum SingleQueryState { - Busy, - Done(Result, Arc>), - Canceled, -} - -struct SingleQuery { - state: SingleQueryState, - complete: Arc, -} - -struct Queries { - // Number of queries in the vector. The count of element that are - // not None - count: usize, - - // Number of queries that are still waiting for an answer - busy: usize, - - // Index in the vector where to look for a space for a new query - curr: usize, - - vec: Vec>, -} - -enum ConnState { - Active(Option), - Idle(Instant), - IdleTimeout, - ReadError, - ReadTimeout, - WriteError, -} - -struct Status { - state: ConnState, - - // For edns-tcp-keepalive, we have a boolean the specifies if we - // need to send one (typically at the start of the connection). - // Initially we assume that the idle timeout is zero. A received - // edns-tcp-keepalive option may change that. What the best way to - // specify time in Rust? Currently we specify it in milliseconds. - send_keepalive: bool, - idle_timeout: Option, - do_shutdown: bool, -} - -struct InnerTcpConnection { - stream: Std_mutex>, - - /* status */ - status: Std_mutex, - - /* Vector with outstanding queries */ - query_vec: Std_mutex, - - /* Vector with outstanding requests that need to be transmitted */ - tx_queue: Std_mutex>>, - - worker_notify: Notify, - writer_notify: Notify, -} - -pub struct TcpConnection { - inner: Arc, -} - -enum QueryState { - Busy(usize), // index - Done, -} - -pub struct Query { - transport: Arc, - query_msg: Message>, - state: QueryState, -} - -impl InnerTcpConnection { - pub async fn connect( - addr: A, - ) -> io::Result { - let tcp = TcpStream::connect(addr).await?; - Ok(Self { - stream: Std_mutex::new(Some(tcp)), - status: Std_mutex::new(Status { - state: ConnState::Active(None), - send_keepalive: true, - idle_timeout: None, - do_shutdown: false, - }), - query_vec: Std_mutex::new(Queries { - count: 0, - busy: 0, - curr: 0, - vec: Vec::new(), - }), - tx_queue: Std_mutex::new(VecDeque::new()), - worker_notify: Notify::new(), - writer_notify: Notify::new(), - }) - } - - // Take a query out of query_vec and decrement the query count - fn take_query(&self, index: usize) -> Option { - let mut query_vec = self.query_vec.lock().unwrap(); - self.vec_take_query(query_vec.deref_mut(), index) - } - - // Very similar to take_query, but sometime the caller already has - // a lock on the mutex - fn vec_take_query( - &self, - query_vec: &mut Queries, - index: usize, - ) -> Option { - let query = query_vec.vec[index].take(); - query_vec.count -= 1; - if query_vec.count == 0 { - // The worker may be waiting for this - self.worker_notify.notify_one(); - } - query - } - - fn insert_answer(&self, answer: Message) { - // We got an answer, reset the timer - let mut status = self.status.lock().unwrap(); - status.state = ConnState::Active(Some(Instant::now())); - drop(status); - - let ind16 = answer.header().id(); - let index: usize = ind16.into(); - - let mut query_vec = self.query_vec.lock().unwrap(); - - let vec_len = query_vec.vec.len(); - if index >= vec_len { - // Index is out of bouds. We should mark - // the TCP connection as broken - return; - } - - // Do we have a query with this ID? - match &mut query_vec.vec[index] { - None => { - // No query with this ID. We should - // mark the TCP connection as broken - return; - } - Some(query) => { - match &query.state { - SingleQueryState::Busy => { - query.state = SingleQueryState::Done(Ok(answer)); - query.complete.notify_one(); - } - SingleQueryState::Canceled => { - //`The query has been - // canceled already - // Clean up. - let _ = - self.vec_take_query(query_vec.deref_mut(), index); - } - SingleQueryState::Done(_) => { - // Already got a - // result. - return; - } - } - } - } - query_vec.busy -= 1; - if query_vec.busy == 0 { - let mut status = self.status.lock().unwrap(); - - // Clear the activity timer. There is no need to do - // this because state will be set to either IdleTimeout - // or Idle just below. However, it is nicer to keep - // this indenpendent. - status.state = ConnState::Active(None); - - if status.idle_timeout.is_none() { - // Assume that we can just move to IdleTimeout - // state - status.state = ConnState::IdleTimeout; - - // Notify the worker. Then the worker can - // close the tcp connection - self.worker_notify.notify_one(); - } else { - status.state = ConnState::Idle(Instant::now()); - - // Notify the worker. The worker waits for - // the timeout to expire - self.worker_notify.notify_one(); - } - } - } - - fn handle_keepalive(&self, opt_value: TcpKeepalive) { - if let Some(value) = opt_value.timeout() { - let mut status = self.status.lock().unwrap(); - let value_dur = Duration::from(value); - status.idle_timeout = Some(value_dur); - } - } - - fn handle_opts>( - &self, - opts: &OptRecord, - ) { - for option in opts.opt().iter() { - let opt = option.unwrap(); - if let AllOptData::TcpKeepalive(tcpkeepalive) = opt { - self.handle_keepalive(tcpkeepalive); - } - } - } - - // This function is not async cancellation safe - async fn reader(&self, sock: &mut ReadHalf<'_>) -> Result<(), &str> { - loop { - let read_res = sock.read_u16().await; - let len = match read_res { - Ok(len) => len, - Err(error) => { - self.tcp_error(error); - let mut status = self.status.lock().unwrap(); - status.state = ConnState::ReadError; - return Err(ERR_READ_ERROR); - } - } as usize; - - let mut buf = BytesMut::with_capacity(len); - - loop { - let curlen = buf.len(); - if curlen >= len { - if curlen > len { - panic!( - "reader: got too much data {curlen}, expetect {len}"); - } - - // We got what we need - break; - } - - let read_res = sock.read_buf(&mut buf).await; - - match read_res { - Ok(readlen) => { - if readlen == 0 { - let error = io::Error::new( - io::ErrorKind::Other, - "unexpected end of data", - ); - self.tcp_error(error); - let mut status = self.status.lock().unwrap(); - status.state = ConnState::ReadError; - return Err(ERR_READ_ERROR); - } - } - Err(error) => { - self.tcp_error(error); - let mut status = self.status.lock().unwrap(); - status.state = ConnState::ReadError; - return Err(ERR_READ_ERROR); - } - }; - - // Check if we are done at the head of the loop - } - - let reply_message = Message::::from_octets(buf.into()); - - match reply_message { - Ok(answer) => { - // Check for a edns-tcp-keepalive option - let opt_record = answer.opt(); - if let Some(ref opts) = opt_record { - self.handle_opts(opts); - }; - self.insert_answer(answer); - } - Err(_) => { - // The only possible error is short message - let error = - io::Error::new(io::ErrorKind::Other, "short buf"); - self.tcp_error(error); - let mut status = self.status.lock().unwrap(); - status.state = ConnState::ReadError; - return Err(ERR_READ_ERROR); - } - } - } - } - - fn tcp_error(&self, error: std::io::Error) { - // Update all requests that are in progress. Don't wait for - // any reply that may be on its way. - let arc_error = Arc::new(error); - let mut query_vec = self.query_vec.lock().unwrap(); - for query in &mut query_vec.vec { - match query { - None => { - continue; - } - Some(q) => { - match q.state { - SingleQueryState::Busy => { - q.state = SingleQueryState::Done(Err( - arc_error.clone() - )); - q.complete.notify_one(); - } - SingleQueryState::Done(_) - | SingleQueryState::Canceled => - // Nothing to do - {} - } - } - } - } - } - - // This function is not async cancellation safe - async fn writer( - &self, - sock: &mut WriteHalf<'_>, - ) -> Result<(), &'static str> { - loop { - loop { - // Check if we need to shutdown - let do_shutdown = { - // Extra block to satisfy clippy - // await_holding_lock - let status = self.status.lock().unwrap(); - status.do_shutdown - // drop(status); - }; - - if do_shutdown { - // Ignore errors - _ = sock.shutdown().await; - - // Do we need to clear do_shutdown? - break; - } - - let head = { - // Extra block to satisfy clippy - // await_holding_lock - let mut tx_queue = self.tx_queue.lock().unwrap(); - tx_queue.pop_front() - // drop(tx_queue); - }; - match head { - Some(vec) => { - let res = sock.write_all(&vec).await; - if let Err(error) = res { - self.tcp_error(error); - let mut status = self.status.lock().unwrap(); - status.state = ConnState::WriteError; - return Err(ERR_WRITE_ERROR); - } - } - None => break, - } - } - - self.writer_notify.notified().await; - } - } - - // This function is not async cancellation safe because it calls - // reader and writer which are not async cancellation safe - pub async fn worker(&self) -> Option<()> { - let mut stream = { - // Extra block to satisfy clippy - // await_holding_lock - let mut opt_stream = self.stream.lock().unwrap(); - opt_stream.take().unwrap() - // drop(opt_stream); - }; - let (mut read_stream, mut write_stream) = stream.split(); - - let reader_fut = self.reader(&mut read_stream); - tokio::pin!(reader_fut); - let writer_fut = self.writer(&mut write_stream); - tokio::pin!(writer_fut); - - loop { - let opt_timeout: Option = { - // Extra block to satisfy clippy - // await_holding_lock - let mut status = self.status.lock().unwrap(); - match status.state { - ConnState::Active(opt_instant) => { - if let Some(instant) = opt_instant { - let timeout = - Duration::from_secs(RESPONSE_TIMEOUT_S); - let elapsed = instant.elapsed(); - if elapsed > timeout { - let error = io::Error::new( - io::ErrorKind::Other, - "read timeout", - ); - self.tcp_error(error); - status.state = ConnState::ReadTimeout; - break; - } - Some(timeout - elapsed) - } else { - None - } - } - ConnState::Idle(instant) => { - if let Some(timeout) = status.idle_timeout { - let elapsed = instant.elapsed(); - if elapsed >= timeout { - // Move to IdleTimeout and end - // the loop - status.state = ConnState::IdleTimeout; - break; - } - Some(timeout - elapsed) - } else { - panic!("Idle state but no timeout"); - } - } - ConnState::IdleTimeout - | ConnState::ReadError - | ConnState::WriteError => None, // No timers here - ConnState::ReadTimeout => { - panic!("should not be in loop with ReadTimeout") - } - } - // drop(status); - }; - - // For simplicity, make sure we always have a timeout - let timeout = match opt_timeout { - Some(timeout) => timeout, - None => - // Just use the response timeout - { - Duration::from_secs(RESPONSE_TIMEOUT_S) - } - }; - - let sleep_fut = sleep(timeout); - let notify_fut = self.worker_notify.notified(); - - tokio::select! { - res = &mut reader_fut => { - match res { - Ok(_) => - // The reader should not - // terminate without - // error. - panic!("reader terminated"), - Err(_) => - // Reader failed. Break - // out of loop and - // shut down - break - } - } - res = &mut writer_fut => { - match res { - Ok(_) => - // The writer should not - // terminate without - // error. - panic!("reader terminated"), - Err(_) => - // Writer failed. Break - // out of loop and - // shut down - break - } - } - - _ = sleep_fut => { - // Timeout expired, just - // continue with the loop - } - _ = notify_fut => { - // Got notifies, go through the loop - // to see what changed. - } - - } - - // Check if the connection is idle - let status = self.status.lock().unwrap(); - match status.state { - ConnState::Active(_) | ConnState::Idle(_) => { - // Keep going - } - ConnState::IdleTimeout => break, - ConnState::ReadError - | ConnState::ReadTimeout - | ConnState::WriteError => { - panic!("Should not be here"); - } - } - drop(status); - } - - // We can't see a FIN directly because the writer_fut owns - // write_stream. - { - // Extra block to satisfy clippy - // await_holding_lock - let mut status = self.status.lock().unwrap(); - status.do_shutdown = true; - // drop(status); - }; - - // Kick writer - self.writer_notify.notify_one(); - - // Wait for writer to terminate. Ignore the result. We may - // want a timer here - _ = writer_fut.await; - - // Stay around until the last query result is collected - loop { - { - // Extra block to satisfy clippy - // await_holding_lock - let query_vec = self.query_vec.lock().unwrap(); - if query_vec.count == 0 { - // We are done - break; - } - // drop(query_vec); - } - - self.worker_notify.notified().await; - } - None - } - - fn insert_at( - query_vec: &mut Queries, - index: usize, - q: Option, - ) { - query_vec.vec[index] = q; - query_vec.count += 1; - query_vec.busy += 1; - query_vec.curr = index + 1; - } - - // Insert a message in the query vector. Return the index - fn insert(&self) -> Result { - let q = Some(SingleQuery { - state: SingleQueryState::Busy, - complete: Arc::new(Notify::new()), - }); - let mut query_vec = self.query_vec.lock().unwrap(); - - // Fail if there are to many entries already in this vector - // We cannot have more than u16::MAX entries because the - // index needs to fit in an u16. For efficiency we want to - // keep the vector half empty. So we return a failure if - // 2*count > u16::MAX - if 2 * query_vec.count > u16::MAX.into() { - return Err(ERR_TOO_MANY_QUERIES); - } - - let vec_len = query_vec.vec.len(); - - // Append if the amount of empty space in the vector is less - // than half. But limit vec_len to u16::MAX - if vec_len < 2 * (query_vec.count + 1) && vec_len < u16::MAX.into() { - // Just append - query_vec.vec.push(q); - query_vec.count += 1; - query_vec.busy += 1; - let index = query_vec.vec.len() - 1; - return Ok(index); - } - let loc_curr = query_vec.curr; - - for index in loc_curr..vec_len { - match query_vec.vec[index] { - Some(_) => { - // Already in use, just continue - } - None => { - Self::insert_at(&mut query_vec, index, q); - return Ok(index); - } - } - } - - // Nothing until the end of the vector. Try for the entire - // vector - for index in 0..vec_len { - match query_vec.vec[index] { - Some(_) => { - // Already in use, just continue - } - None => { - Self::insert_at(&mut query_vec, index, q); - return Ok(index); - } - } - } - - // Still nothing, that is not good - panic!("insert failed"); - } - - fn queue_query + AsRef<[u8]>>( - &self, - msg: &MessageBuilder>>, - ) { - let vec = msg.as_target().as_target().as_stream_slice(); - - // Store a clone of the request. That makes life easier - // and requests tend to be small - let mut tx_queue = self.tx_queue.lock().unwrap(); - tx_queue.push_back(vec.to_vec()); - } - - pub fn query< - Octs: OctetsBuilder + AsMut<[u8]> + AsRef<[u8]> + Composer + Clone, - >( - &self, - query_msg: &mut MessageBuilder>>, - ) -> Result { - // Check the state of the connection, fail if the connection - // is in IdleTimeout. If the connection is Idle, move it - // back to Active. Also check for the need to send a keepalive - let mut status = self.status.lock().unwrap(); - match status.state { - ConnState::Active(timer) => { - // Set timer if we don't have one already - if timer.is_none() { - status.state = ConnState::Active(Some(Instant::now())); - } - } - ConnState::Idle(_) => { - // Go back to active - status.state = ConnState::Active(Some(Instant::now())); - } - ConnState::IdleTimeout => { - // The connection has been closed. Report error - return Err(ERR_IDLE_TIMEOUT); - } - ConnState::ReadError => { - return Err(ERR_READ_ERROR); - } - ConnState::ReadTimeout => { - return Err(ERR_READ_TIMEOUT); - } - ConnState::WriteError => { - return Err(ERR_WRITE_ERROR); - } - } - - // Note that insert may fail if there are too many - // outstanding queires. First call insert before checking - // send_keepalive. - let index = self.insert()?; - - let mut do_keepalive = false; - if status.send_keepalive { - do_keepalive = true; - status.send_keepalive = false; - } - drop(status); - - let ind16: u16 = index.try_into().unwrap(); - - // We set the ID to the array index. Defense in depth - // suggests that a random ID is better because it works - // even if TCP sequence numbers could be predicted. However, - // Section 9.3 of RFC 5452 recommends retrying over TCP - // if many spoofed answers arrive over UDP: "TCP, by the - // nature of its use of sequence numbers, is far more - // resilient against forgery by third parties." - let hdr = query_msg.header_mut(); - hdr.set_id(ind16); - - if do_keepalive { - let mut msgadd = query_msg.clone().additional(); - - // send an empty keepalive option - let res = msgadd.opt(|opt| opt.tcp_keepalive(None)); - match res { - Ok(_) => self.queue_query(&msgadd), - Err(_) => { - // Adding keepalive option - // failed. Send the original - // request and turn the - // send_keepalive flag back on - let mut status = self.status.lock().unwrap(); - status.send_keepalive = true; - drop(status); - self.queue_query(query_msg); - } - } - } else { - self.queue_query(query_msg); - } - // Now kick the writer to transmit the query - self.writer_notify.notify_one(); - - Ok(index) - } - - pub async fn get_result( - &self, - query_msg: &Message, - index: usize, - ) -> Result, Arc> { - // Wait for reply - let local_notify = { - // Extra block to satisfy clippy - // await_holding_lock - let mut query_vec = self.query_vec.lock().unwrap(); - query_vec.vec[index].as_mut().unwrap().complete.clone() - // drop(query_vec); - }; - local_notify.notified().await; - - // take a look - let opt_q = self.take_query(index); - if let Some(q) = opt_q { - if let SingleQueryState::Done(result) = q.state { - if let Ok(answer) = &result { - if !answer.is_answer(query_msg) { - return Err(Arc::new(io::Error::new( - io::ErrorKind::Other, - "wrong answer", - ))); - } - } - return result; - } - panic!("inconsistent state"); - } - - panic!("inconsistent state"); - } - - fn cancel(&self, index: usize) { - let mut query_vec = self.query_vec.lock().unwrap(); - - match &mut query_vec.vec[index] { - None => { - panic!("Cancel called, but nothing to cancel"); - } - Some(query) => { - match &query.state { - SingleQueryState::Busy => { - query.state = SingleQueryState::Canceled; - } - SingleQueryState::Canceled => { - panic!("Already canceled"); - } - SingleQueryState::Done(_) => { - // Remove the result - let _ = - self.vec_take_query(query_vec.deref_mut(), index); - } - } - } - } - } -} - -impl TcpConnection { - pub async fn connect( - addr: A, - ) -> io::Result { - let tcpconnection = InnerTcpConnection::connect(addr).await?; - Ok(Self { - inner: Arc::new(tcpconnection), - }) - } - pub async fn worker(&self) -> Option<()> { - self.inner.worker().await - } - pub fn query< - OctsBuilder: OctetsBuilder + AsMut<[u8]> + AsRef<[u8]> + Composer + Clone, - >( - &self, - query_msg: &mut MessageBuilder< - StaticCompressor>, - >, - ) -> Result { - let index = self.inner.query(query_msg)?; - let msg = &query_msg.as_message(); - Ok(Query::new(self, msg, index)) - } -} - -impl Query { - fn new( - transport: &TcpConnection, - query_msg: &Message, - index: usize, - ) -> Query { - let msg_ref: &[u8] = query_msg.as_ref(); - let vec = msg_ref.to_vec(); - let msg = Message::from_octets(vec).unwrap(); - Self { - transport: transport.inner.clone(), - query_msg: msg, - state: QueryState::Busy(index), - } - } - pub async fn get_result( - &mut self, - ) -> Result, Arc> { - // Just the result of get_result on tranport. We should record - // that we got an answer to avoid asking again - match self.state { - QueryState::Busy(index) => { - let result = - self.transport.get_result(&self.query_msg, index).await; - self.state = QueryState::Done; - result - } - QueryState::Done => { - panic!("Already done"); - } - } - } -} - -impl Drop for Query { - fn drop(&mut self) { - match self.state { - QueryState::Busy(index) => { - self.transport.cancel(index); - } - QueryState::Done => { - // Done, nothing to cancel - } - } - } -} From fc4c11dd1cb85f3c081032fbb469f37152f23712 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 16 Oct 2023 10:08:56 +0200 Subject: [PATCH 058/124] Add configuration options. --- src/net/client/error.rs | 10 +++ src/net/client/multi_stream.rs | 53 +++++++++---- src/net/client/octet_stream.rs | 67 +++++++++++++--- src/net/client/udp.rs | 135 +++++++++++++++++++++++++++++---- src/net/client/udp_tcp.rs | 46 +++++++++-- 5 files changed, 262 insertions(+), 49 deletions(-) diff --git a/src/net/client/error.rs b/src/net/client/error.rs index 5c8ba0953..2aeb5e68c 100644 --- a/src/net/client/error.rs +++ b/src/net/client/error.rs @@ -19,6 +19,9 @@ pub enum Error { /// ParseError from Message. MessageParseError, + /// octet_stream configuration error. + OctetStreamConfigError(Arc), + /// Underlying transport not found in redundant connection RedundantTransportNotFound, @@ -49,6 +52,9 @@ pub enum Error { /// Binding a UDP socket gave an error. UdpBind(Arc), + /// UDP configuration error. + UdpConfigError(Arc), + /// Connecting a UDP socket gave an error. UdpConnect(Arc), @@ -79,6 +85,7 @@ impl Display for Error { write!(f, "PushError from MessageBuilder") } Error::MessageParseError => write!(f, "ParseError from Message"), + Error::OctetStreamConfigError(_) => write!(f, "bad config value"), Error::RedundantTransportNotFound => write!( f, "Underlying transport not found in redundant connection" @@ -106,6 +113,7 @@ impl Display for Error { write!(f, "unexpected end of data") } Error::UdpBind(_) => write!(f, "error binding UDP socket"), + Error::UdpConfigError(_) => write!(f, "bad config value"), Error::UdpConnect(_) => write!(f, "error connecting UDP socket"), Error::UdpReceive(_) => { write!(f, "error receiving from UDP socket") @@ -131,6 +139,7 @@ impl error::Error for Error { Error::ConnectionClosed => None, Error::MessageBuilderPushError => None, Error::MessageParseError => None, + Error::OctetStreamConfigError(e) => Some(e), Error::RedundantTransportNotFound => None, Error::ShortMessage => None, Error::StreamIdleTimeout => None, @@ -141,6 +150,7 @@ impl error::Error for Error { Error::StreamWriteError(e) => Some(e), Error::StreamUnexpectedEndOfData => None, Error::UdpBind(e) => Some(e), + Error::UdpConfigError(e) => Some(e), Error::UdpConnect(e) => Some(e), Error::UdpReceive(e) => Some(e), Error::UdpSend(e) => Some(e), diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 8995089b3..962fcdbeb 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -32,8 +32,7 @@ use tokio::time::{sleep_until, Instant}; use crate::base::Message; use crate::net::client::error::Error; use crate::net::client::factory::ConnFactory; -use crate::net::client::octet_stream::Connection as SingleConnection; -use crate::net::client::octet_stream::QueryNoCheck as SingleQuery; +use crate::net::client::octet_stream; use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; /// Capacity of the channel that transports [ChanReq]. @@ -43,6 +42,15 @@ const DEF_CHAN_CAP: usize = 8; /// [InnerConnection::run] terminated. const ERR_CONN_CLOSED: &str = "connection closed"; +//------------ Config --------------------------------------------------------- + +/// Configuration for an octet_stream transport connection. +#[derive(Clone, Debug, Default)] +pub struct Config { + /// Response timeout. + pub octet_stream: Option, +} + //------------ Connection ----------------------------------------------------- #[derive(Clone, Debug)] @@ -58,8 +66,15 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> /// Constructor for [Connection]. /// /// Returns a [Connection] wrapped in a [Result](io::Result). - pub fn new() -> io::Result> { - let connection = InnerConnection::new()?; + pub fn new(config: Option) -> Result, Error> { + let config = match config { + Some(config) => { + check_config(&config)?; + config + } + None => Default::default(), + }; + let connection = InnerConnection::new(config)?; Ok(Self { inner: Arc::new(connection), }) @@ -75,7 +90,7 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> >( &self, factory: F, - ) -> Option<()> { + ) -> Result<(), Error> { self.inner.run(factory).await } @@ -186,10 +201,10 @@ enum QueryState> { GetConn(oneshot::Receiver>), /// Start a query using the transport. - StartQuery(SingleConnection), + StartQuery(octet_stream::Connection), /// Get the result of the query. - GetResult(SingleQuery), + GetResult(octet_stream::QueryNoCheck), /// Wait until trying again. /// @@ -211,7 +226,7 @@ struct ChanRespOk> { id: u64, /// New octet_stream transport. - conn: SingleConnection, + conn: octet_stream::Connection, } impl + Clone + Debug + Octets + Send + Sync + 'static> @@ -354,6 +369,9 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> /// The actual implementation of [Connection]. #[derive(Debug)] struct InnerConnection> { + /// User configuration values. + config: Config, + /// [InnerConnection::sender] and [InnerConnection::receiver] are /// part of a single channel. /// @@ -422,7 +440,7 @@ enum SingleConnState3> { None, /// Current octet_stream transport. - Some(SingleConnection), + Some(octet_stream::Connection), /// State that deals with an error getting a new octet stream from /// a factory. @@ -452,9 +470,10 @@ impl + Clone + Octets + Send + Sync + 'static> /// Constructor for [InnerConnection]. /// /// This is the implementation of [Connection::new]. - pub fn new() -> io::Result> { + pub fn new(config: Config) -> Result, Error> { let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); Ok(Self { + config, sender: tx, receiver: Futures_mutex::new(Some(rx)), }) @@ -473,7 +492,7 @@ impl + Clone + Octets + Send + Sync + 'static> >( &self, factory: F, - ) -> Option<()> { + ) -> Result<(), Error> { let mut receiver = { let mut locked_opt_receiver = self.receiver.lock().await; let opt_receiver = locked_opt_receiver.take(); @@ -599,9 +618,7 @@ impl + Clone + Octets + Send + Sync + 'static> let stream = res_conn .expect("error case is checked before"); - let conn = SingleConnection::new() - .expect( - "the connect implementation cannot fail"); + let conn = octet_stream::Connection::new(self.config.octet_stream.clone())?; let conn_run = conn.clone(); let clo = || async move { @@ -655,7 +672,7 @@ impl + Clone + Octets + Send + Sync + 'static> } // Done - Some(()) + Ok(()) } /// Request a new connection. @@ -734,3 +751,9 @@ fn is_answer_ignore_id< async fn factory_nop() -> Result { Err(io::Error::new(io::ErrorKind::Other, "nop")) } + +/// Check if config is valid. +fn check_config(_config: &Config) -> Result<(), Error> { + // Nothing to check at the moment. + Ok(()) +} diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index d7c785d17..da7e5434d 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -24,6 +24,7 @@ use futures::lock::Mutex as Futures_mutex; use std::boxed::Box; use std::fmt::Debug; use std::future::Future; +use std::io::ErrorKind; use std::pin::Pin; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -38,13 +39,12 @@ use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; use crate::rdata::AllRecordData; use octseq::Octets; -use tokio::io; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot}; use tokio::time::sleep; -/// Time to wait on a non-idle connection for the other side to send -/// a response on any outstanding query. +/// Default configuration value for the amount of time to wait on a non-idle +/// connection for the other side to send a response on any outstanding query. // Implement a simple response timer to see if the connection and the server // are alive. Set the timer when the connection goes from idle to busy. // Reset the timer each time a reply arrives. Cancel the timer when the @@ -52,7 +52,13 @@ use tokio::time::sleep; // queries as timed out and shutdown the connection. // // Note: nsd has 120 seconds, unbound has 3 seconds. -const RESPONSE_TIMEOUT: Duration = Duration::from_secs(19); +const DEF_RESPONSE_TIMEOUT: Duration = Duration::from_secs(19); + +/// Minimum configuration value for response_timeout. +const MIN_RESPONSE_TIMEOUT: Duration = Duration::from_millis(1); + +/// Maximum configuration value for response_timeout. +const MAX_RESPONSE_TIMEOUT: Duration = Duration::from_secs(600); /// Capacity of the channel that transports [ChanReq]. const DEF_CHAN_CAP: usize = 8; @@ -61,6 +67,23 @@ const DEF_CHAN_CAP: usize = 8; /// [InnerConnection::run]. const READ_REPLY_CHAN_CAP: usize = 8; +//------------ Config --------------------------------------------------------- + +/// Configuration for an octet_stream transport connection. +#[derive(Clone, Debug)] +pub struct Config { + /// Response timeout. + pub response_timeout: Duration, +} + +impl Default for Config { + fn default() -> Self { + Self { + response_timeout: DEF_RESPONSE_TIMEOUT, + } + } +} + //------------ Connection ----------------------------------------------------- #[derive(Clone, Debug)] @@ -74,8 +97,15 @@ impl Connection { /// Constructor for [Connection]. /// /// Returns a [Connection] wrapped in a [Result](io::Result). - pub fn new() -> io::Result> { - let connection = InnerConnection::new()?; + pub fn new(config: Option) -> Result, Error> { + let config = match config { + Some(config) => { + check_config(&config)?; + config + } + None => Default::default(), + }; + let connection = InnerConnection::new(config)?; Ok(Self { inner: Arc::new(connection), }) @@ -313,6 +343,9 @@ impl QueryNoCheck { /// The actual implementation of [Connection]. #[derive(Debug)] struct InnerConnection> { + /// User configuration variables. + config: Config, + /// [InnerConnection::sender] and [InnerConnection::receiver] are /// part of a single channel. /// @@ -424,9 +457,10 @@ impl + Clone> InnerConnection { /// Constructor for [InnerConnection]. /// /// This is the implementation of [Connection::new]. - pub fn new() -> io::Result> { + pub fn new(config: Config) -> Result, Error> { let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); Ok(Self { + config, sender: tx, receiver: Futures_mutex::new(Some(rx)), }) @@ -472,7 +506,7 @@ impl + Clone> InnerConnection { ConnState::Active(opt_instant) => { if let Some(instant) = opt_instant { let elapsed = instant.elapsed(); - if elapsed > RESPONSE_TIMEOUT { + if elapsed > self.config.response_timeout { Self::error( Error::StreamReadTimeout, &mut query_vec, @@ -480,7 +514,7 @@ impl + Clone> InnerConnection { status.state = ConnState::ReadTimeout; break; } - Some(RESPONSE_TIMEOUT - elapsed) + Some(self.config.response_timeout - elapsed) } else { None } @@ -513,7 +547,7 @@ impl + Clone> InnerConnection { None => // Just use the response timeout { - RESPONSE_TIMEOUT + self.config.response_timeout } }; @@ -1127,3 +1161,16 @@ fn is_answer_ignore_id< reply.question() == query.question() } } + +/// Check if config is valid. +fn check_config(config: &Config) -> Result<(), Error> { + if config.response_timeout < MIN_RESPONSE_TIMEOUT + || config.response_timeout > MAX_RESPONSE_TIMEOUT + { + return Err(Error::OctetStreamConfigError(Arc::new( + std::io::Error::new(ErrorKind::Other, "response_timeout"), + ))); + } + + Ok(()) +} diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 88665763a..8a14c922d 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -11,7 +11,7 @@ use bytes::{Bytes, BytesMut}; use std::boxed::Box; use std::fmt::{Debug, Formatter}; use std::future::Future; -use std::io; +use std::io::ErrorKind; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; @@ -26,15 +26,59 @@ use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; /// How many times do we try a new random port if we get ‘address in use.’ const RETRY_RANDOM_PORT: usize = 10; -/// Maximum number of parallel DNS query over a single UDP transport -/// connection. -const MAX_PARALLEL: usize = 100; +/// Default configuration value for the maximum number of parallel DNS query +/// over a single UDP transport connection. +const DEF_MAX_PARALLEL: usize = 100; -/// Maximum amount of time to wait for a reply. -const READ_TIMEOUT: Duration = Duration::from_secs(5); +/// Minimum configuration value for max_parallel. +const MIN_MAX_PARALLEL: usize = 1; -/// Maximum number of retries after timeouts. -const MAX_RETRIES: u8 = 5; +/// Maximum configuration value for max_parallel. +const MAX_MAX_PARALLEL: usize = 1000; + +/// Default configuration value for the maximum amount of time to wait for a +/// reply. +const DEF_READ_TIMEOUT: Duration = Duration::from_secs(5); + +/// Minimum configuration value for read_timeout. +const MIN_READ_TIMEOUT: Duration = Duration::from_millis(1); + +/// Maximum configuration value for read_timeout. +const MAX_READ_TIMEOUT: Duration = Duration::from_secs(60); + +/// Default configuration value for maximum number of retries after timeouts. +const DEF_MAX_RETRIES: u8 = 5; + +/// Minimum allowed configuration value for max_retries. +const MIN_MAX_RETRIES: u8 = 1; + +/// Maximum allowed configuration value for max_retries. +const MAX_MAX_RETRIES: u8 = 100; + +//------------ Config --------------------------------------------------------- + +/// Configuration for a UDP transport connection. +#[derive(Clone, Debug)] +pub struct Config { + /// Maximum number of parallel requests for a transport connection. + pub max_parallel: usize, + + /// Read timeout. + pub read_timeout: Duration, + + /// Maimum number of retries. + pub max_retries: u8, +} + +impl Default for Config { + fn default() -> Self { + Self { + max_parallel: DEF_MAX_PARALLEL, + read_timeout: DEF_READ_TIMEOUT, + max_retries: DEF_MAX_RETRIES, + } + } +} //------------ Connection ----------------------------------------------------- @@ -47,8 +91,18 @@ pub struct Connection { impl Connection { /// Create a new UDP transport connection. - pub fn new(remote_addr: SocketAddr) -> io::Result { - let connection = InnerConnection::new(remote_addr)?; + pub fn new( + config: Option, + remote_addr: SocketAddr, + ) -> Result { + let config = match config { + Some(config) => { + check_config(&config)?; + config + } + None => Default::default(), + }; + let connection = InnerConnection::new(config, remote_addr)?; Ok(Self { inner: Arc::new(connection), }) @@ -331,12 +385,14 @@ pub struct Query2 { impl Query2 { /// Create new Query object. fn new( + config: Config, query_msg: Message, remote_addr: SocketAddr, conn: Connection, ) -> Query2 { Query2 { get_result_fut: Box::pin(Self::get_result_impl2( + config, query_msg, remote_addr, conn, @@ -353,6 +409,7 @@ impl Query2 { /// /// This function is not cancel safe. async fn get_result_impl2( + config: Config, mut query_msg: Message, remote_addr: SocketAddr, conn: Connection, @@ -391,10 +448,10 @@ impl Query2 { let start = Instant::now(); let elapsed = start.elapsed(); - if elapsed > READ_TIMEOUT { + if elapsed > config.read_timeout { todo!(); } - let remain = READ_TIMEOUT - elapsed; + let remain = config.read_timeout - elapsed; let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here. let timeout_res = timeout( @@ -406,7 +463,7 @@ impl Query2 { .await; if timeout_res.is_err() { retries += 1; - if retries < MAX_RETRIES { + if retries < config.max_retries { continue; } return Err(Error::UdpTimeoutNoResponse); @@ -483,6 +540,9 @@ impl GetResult for Query2 { /// Actual implementation of the UDP transport connection. #[derive(Debug)] struct InnerConnection { + /// User configuration variables. + config: Config, + /// Address of the remote server. remote_addr: SocketAddr, @@ -492,10 +552,15 @@ struct InnerConnection { impl InnerConnection { /// Create new InnerConnection object. - fn new(remote_addr: SocketAddr) -> io::Result { + fn new( + config: Config, + remote_addr: SocketAddr, + ) -> Result { + let max_parallel = config.max_parallel; Ok(Self { + config, remote_addr, - semaphore: Arc::new(Semaphore::new(MAX_PARALLEL)), + semaphore: Arc::new(Semaphore::new(max_parallel)), }) } @@ -510,7 +575,12 @@ impl InnerConnection { bytes.extend_from_slice(slice); let query_msg = Message::from_octets(bytes) .expect("Message failed to parse contents of another Message"); - Ok(Query2::new(query_msg, self.remote_addr, conn)) + Ok(Query2::new( + self.config.clone(), + query_msg, + self.remote_addr, + conn, + )) } /// Return a permit for a our semaphore. @@ -522,3 +592,36 @@ impl InnerConnection { .expect("the semaphore has not been closed") } } + +//------------ Utility -------------------------------------------------------- + +/// Check if config is valid. +fn check_config(config: &Config) -> Result<(), Error> { + if config.max_parallel < MIN_MAX_PARALLEL + || config.max_parallel > MAX_MAX_PARALLEL + { + return Err(Error::UdpConfigError(Arc::new(std::io::Error::new( + ErrorKind::Other, + "max_parallel", + )))); + } + + if config.read_timeout < MIN_READ_TIMEOUT + || config.read_timeout > MAX_READ_TIMEOUT + { + return Err(Error::UdpConfigError(Arc::new(std::io::Error::new( + ErrorKind::Other, + "read_timeout", + )))); + } + + if config.max_retries < MIN_MAX_RETRIES + || config.max_retries > MAX_MAX_RETRIES + { + return Err(Error::UdpConfigError(Arc::new(std::io::Error::new( + ErrorKind::Other, + "max_retries", + )))); + } + Ok(()) +} diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index a2f4a2e29..0cceea385 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -11,7 +11,6 @@ use octseq::Octets; use std::boxed::Box; use std::fmt::Debug; use std::future::Future; -use std::io; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; @@ -23,6 +22,18 @@ use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; use crate::net::client::tcp_factory::TcpConnFactory; use crate::net::client::udp; +//------------ Config --------------------------------------------------------- + +/// Configuration for an octet_stream transport connection. +#[derive(Clone, Debug, Default)] +pub struct Config { + /// Configuration for the UDP transport. + pub udp: Option, + + /// Configuration for the multi_stream (TCP) transport. + pub multi_stream: Option, +} + //------------ Connection ----------------------------------------------------- /// DNS transport connection that first issues a query over a UDP transport and @@ -37,15 +48,25 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> Connection { /// Create a new connection. - pub fn new(remote_addr: SocketAddr) -> io::Result> { - let connection = InnerConnection::new(remote_addr)?; + pub fn new( + config: Option, + remote_addr: SocketAddr, + ) -> Result, Error> { + let config = match config { + Some(config) => { + check_config(&config)?; + config + } + None => Default::default(), + }; + let connection = InnerConnection::new(config, remote_addr)?; Ok(Self { inner: Arc::new(connection), }) } /// Worker function for a connection object. - pub async fn run(&self) -> Option<()> { + pub async fn run(&self) -> Result<(), Error> { self.inner.run().await } @@ -219,9 +240,12 @@ impl /// /// Create the UDP and TCP connections. Store the remote address because /// run needs it later. - fn new(remote_addr: SocketAddr) -> io::Result> { - let udp_conn = udp::Connection::new(remote_addr)?; - let tcp_conn = multi_stream::Connection::new()?; + fn new( + config: Config, + remote_addr: SocketAddr, + ) -> Result, Error> { + let udp_conn = udp::Connection::new(config.udp, remote_addr)?; + let tcp_conn = multi_stream::Connection::new(config.multi_stream)?; Ok(Self { remote_addr, @@ -234,7 +258,7 @@ impl /// /// Create a TCP connection factory and pass that to worker function /// of the multi_stream object. - pub async fn run(&self) -> Option<()> { + pub async fn run(&self) -> Result<(), Error> { let tcp_factory = TcpConnFactory::new(self.remote_addr); self.tcp_conn.run(tcp_factory).await } @@ -253,3 +277,9 @@ impl )) } } + +/// Check if config is valid. +fn check_config(_config: &Config) -> Result<(), Error> { + // Nothing to check at the moment. + Ok(()) +} From d3bc7067cf07e8c5b00d8e0f70521f57ef377f77 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 19 Oct 2023 15:27:48 +0200 Subject: [PATCH 059/124] Continue receiving if there is an error. Allow an empty question section in the reply if the message is truncated or if there is an error. --- src/net/client/udp.rs | 135 ++++++++++++++++++++++++++++++------------ 1 file changed, 96 insertions(+), 39 deletions(-) diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 8a14c922d..790f52247 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -8,6 +8,7 @@ // - random port use bytes::{Bytes, BytesMut}; +use octseq::Octets; use std::boxed::Box; use std::fmt::{Debug, Formatter}; use std::future::Future; @@ -19,6 +20,7 @@ use tokio::net::UdpSocket; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio::time::{timeout, Duration, Instant}; +use crate::base::iana::Rcode; use crate::base::Message; use crate::net::client::error::Error; use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; @@ -324,7 +326,7 @@ impl Query { .expect( "Message failed to parse contents of another Message", ); - if !answer.is_answer(&query_msg) { + if !is_answer(answer, &query_msg) { continue; } self.sock = None; @@ -447,50 +449,66 @@ impl Query2 { } let start = Instant::now(); - let elapsed = start.elapsed(); - if elapsed > config.read_timeout { - todo!(); - } - let remain = config.read_timeout - elapsed; - - let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here. - let timeout_res = timeout( - remain, - sock.as_ref() - .expect("socket should be present") - .recv(&mut buf), - ) - .await; - if timeout_res.is_err() { - retries += 1; - if retries < config.max_retries { + + loop { + let elapsed = start.elapsed(); + if elapsed > config.read_timeout { + // Break out of the receive loop and continue in the + // transmit loop. + break; + } + let remain = config.read_timeout - elapsed; + + let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here. + let timeout_res = timeout( + remain, + sock.as_ref() + .expect("socket should be present") + .recv(&mut buf), + ) + .await; + if timeout_res.is_err() { + retries += 1; + if retries < config.max_retries { + // Break out of the receive loop and continue in the + // transmit loop. + break; + } + return Err(Error::UdpTimeoutNoResponse); + } + let len = timeout_res + .expect("errror case is checked above") + .map_err(|e| Error::UdpReceive(Arc::new(e)))?; + buf.truncate(len); + + // We ignore garbage since there is a timer on this whole + // thing. + let answer = match Message::from_octets(buf.into()) { + // Just go back to receiving. + Ok(answer) => answer, + Err(_) => continue, + }; + + // Unfortunately we cannot pass query_msg to is_answer + // because is_answer requires Octets, which is not + // implemented by BytesMut. Make a copy. + let query_msg = Message::from_octets(query_msg.as_slice()) + .expect( + "Message failed to parse contents of another Message", + ); + if !is_answer(&answer, &query_msg) { + // Wrong answer, go back to receiving continue; } - return Err(Error::UdpTimeoutNoResponse); + return Ok(answer); } - let len = timeout_res - .expect("errror case is checked above") - .map_err(|e| Error::UdpReceive(Arc::new(e)))?; - buf.truncate(len); - - // We ignore garbage since there is a timer on this whole thing. - let answer = match Message::from_octets(buf.into()) { - Ok(answer) => answer, - Err(_) => continue, - }; - - // Unfortunately we cannot pass query_msg to is_answer - // because is_answer requires Octets, which is not - // implemented by BytesMut. Make a copy. - let query_msg = Message::from_octets(query_msg.as_slice()) - .expect( - "Message failed to parse contents of another Message", - ); - if !answer.is_answer(&query_msg) { + retries += 1; + if retries < config.max_retries { continue; } - return Ok(answer); + break; } + Err(Error::UdpTimeoutNoResponse) } /// Bind to a local UDP port. @@ -625,3 +643,42 @@ fn check_config(config: &Config) -> Result<(), Error> { } Ok(()) } + +/// Check if a message is a valid reply for a query. Allow the question section +/// to be empty if there is an error or if the reply is truncated. +fn is_answer< + QueryOcts: AsRef<[u8]> + Octets, + ReplyOcts: AsRef<[u8]> + Octets, +>( + reply: &Message, + query: &Message, +) -> bool { + let reply_header = reply.header(); + let reply_hcounts = reply.header_counts(); + + // First check qr and id + if !reply_header.qr() || reply_header.id() != query.header().id() { + return false; + } + + // If either tc is set or the result is an error, then the question + // section can be empty. In that case we require all other sections + // to be empty as well. + if (reply_header.tc() || reply_header.rcode() != Rcode::NoError) + && reply_hcounts.qdcount() == 0 + && reply_hcounts.ancount() == 0 + && reply_hcounts.nscount() == 0 + && reply_hcounts.arcount() == 0 + { + // We can accept this as a valid reply. + return true; + } + + // Remaining checks. The question section in the reply has to be the + // same as in the query. + if reply_hcounts.qdcount() != query.header_counts().qdcount() { + false + } else { + reply.question() == query.question() + } +} From 9dad322ea23351e836b2f6d9340f2420bd76de67 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 19 Oct 2023 15:29:51 +0200 Subject: [PATCH 060/124] Fix a bug where a message with the wrong ID is returned. --- src/net/client/octet_stream.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index da7e5434d..6bbe2460a 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -904,11 +904,11 @@ impl + Clone> InnerConnection { // Adding keepalive option // failed. Send the original // request. - Self::convert_query(&req.msg, reqmsg); + Self::convert_query(&mut_msg, reqmsg); } } } else { - Self::convert_query(&req.msg, reqmsg); + Self::convert_query(&mut_msg, reqmsg); } } @@ -943,15 +943,11 @@ impl + Clone> InnerConnection { let slice = msg.as_slice(); let len = slice.len(); - println!("convert_query: slice len {}, slice {:?}", len, slice); - let mut vec = Vec::with_capacity(2 + len); let len16 = len as u16; vec.extend_from_slice(&len16.to_be_bytes()); vec.extend_from_slice(slice); - println!("convert_query: vec {:?}", vec); - *reqmsg = Some(vec); } From 9f0c404ca60591c1592596d61723cc221fda47ff Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 19 Oct 2023 15:31:47 +0200 Subject: [PATCH 061/124] Allow the question section to be empty if the reply is an error. --- src/net/client/multi_stream.rs | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 962fcdbeb..172ec62a7 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -29,6 +29,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::{mpsc, oneshot}; use tokio::time::{sleep_until, Instant}; +use crate::base::iana::Rcode; use crate::base::Message; use crate::net::client::error::Error; use crate::net::client::factory::ConnFactory; @@ -737,9 +738,30 @@ fn is_answer_ignore_id< reply: &Message, query: &Message, ) -> bool { - if !reply.header().qr() - || reply.header_counts().qdcount() != query.header_counts().qdcount() + let reply_header = reply.header(); + let reply_hcounts = reply.header_counts(); + + // First check qr is set + if !reply_header.qr() { + return false; + } + + // If the result is an error, then the question + // section can be empty. In that case we require all other sections + // to be empty as well. + if reply_header.rcode() != Rcode::NoError + && reply_hcounts.qdcount() == 0 + && reply_hcounts.ancount() == 0 + && reply_hcounts.nscount() == 0 + && reply_hcounts.arcount() == 0 { + // We can accept this as a valid reply. + return true; + } + + // Remaining checks. The question section in the reply has to be the + // same as in the query. + if reply_hcounts.qdcount() != query.header_counts().qdcount() { false } else { reply.question() == query.question() From 15c0c5ccbb8e662ac021161776c7a3d5848054a0 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 19 Oct 2023 15:46:59 +0200 Subject: [PATCH 062/124] User configuration. --- src/net/client/redundant.rs | 287 +++++++++++++++++++++++++++--------- 1 file changed, 220 insertions(+), 67 deletions(-) diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index f58a017ec..50f8f61f0 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -8,7 +8,7 @@ use bytes::Bytes; use futures::stream::FuturesUnordered; use futures::StreamExt; -use octseq::OctetsBuilder; +use octseq::{Octets, OctetsBuilder}; use rand::random; @@ -16,7 +16,6 @@ use std::boxed::Box; use std::cmp::Ordering; use std::fmt::{Debug, Formatter}; use std::future::Future; -use std::io; use std::pin::Pin; use std::sync::Arc; use std::vec::Vec; @@ -24,6 +23,7 @@ use std::vec::Vec; use tokio::sync::{mpsc, oneshot, Mutex}; use tokio::time::{sleep_until, Duration, Instant}; +use crate::base::iana::OptRcode; use crate::base::wire::Composer; use crate::base::Message; use crate::net::client::error::Error; @@ -65,6 +65,21 @@ const PROBE_P: f64 = 0.05; /// When a worse connection is probed, give it a slight head start. const PROBE_RT: Duration = Duration::from_millis(1); +//------------ Config --------------------------------------------------------- + +/// User configuration variables. +#[derive(Clone, Debug, Default)] +pub struct Config { + /// Defer transport errors. + pub defer_transport_error: bool, + + /// Defer replies that report Refused. + pub defer_refused: bool, + + /// Defer replies that report ServFail. + pub defer_servfail: bool, +} + //------------ Connection ----------------------------------------------------- /// This type represents a transport connection. @@ -78,8 +93,15 @@ impl<'a, Octs: Clone + Composer + Debug + Send + Sync + 'static> Connection { /// Create a new connection. - pub fn new() -> io::Result> { - let connection = InnerConnection::new()?; + pub fn new(config: Option) -> Result, Error> { + let config = match config { + Some(config) => { + check_config(&config)?; + config + } + None => Default::default(), + }; + let connection = InnerConnection::new(config)?; //test_send(connection); Ok(Self { inner: Arc::new(connection), @@ -132,6 +154,9 @@ impl< /// This type represents an active query request. #[derive(Debug)] pub struct Query + Send> { + /// User configuration. + config: Config, + /// The state of the query state: QueryState, @@ -148,6 +173,14 @@ pub struct Query + Send> { fut_list: FuturesUnordered + Send>>>, + /// Transport error that should be reported if nothing better shows + /// up. + deferred_transport_error: Option, + + /// Reply that should be returned to the user if nothing better shows + /// up. + deferred_reply: Option>, + /// The result from one of the connectons. result: Option, Error>>, @@ -273,25 +306,18 @@ struct ConnRT { } /// Result of the futures in fut_list. -type FutListOutput = Result<(usize, Result, Error>), Error>; +type FutListOutput = (usize, Result, Error>); impl + Clone + Debug + Send + Sync + 'static> Query { /// Create a new query object. fn new( + config: Config, query_msg: Message, mut conn_rt: Vec, sender: mpsc::Sender>, ) -> Query { let conn_rt_len = conn_rt.len(); - println!("before sort:"); - for (i, item) in conn_rt.iter().enumerate().take(conn_rt_len) { - println!("{}: id {} ert {:?}", i, item.id, item.est_rt); - } conn_rt.sort_unstable_by(conn_rt_cmp); - println!("after sort:"); - for (i, item) in conn_rt.iter().enumerate().take(conn_rt_len) { - println!("{}: id {} ert {:?}", i, item.id, item.est_rt); - } // Do we want to probe a less performant upstream? if conn_rt_len > 1 && random::() < PROBE_P { @@ -300,19 +326,18 @@ impl + Clone + Debug + Send + Sync + 'static> Query { // Sort again conn_rt.sort_unstable_by(conn_rt_cmp); - println!("sort for probe :"); - for (i, item) in conn_rt.iter().enumerate().take(conn_rt_len) { - println!("{}: id {} ert {:?}", i, item.id, item.est_rt); - } } Query { + config, query_msg, //conns, conn_rt, sender, state: QueryState::Init, fut_list: FuturesUnordered::new(), + deferred_transport_error: None, + deferred_reply: None, result: None, res_index: 0, } @@ -338,19 +363,70 @@ impl + Clone + Debug + Send + Sync + 'static> Query { self.query_msg.clone(), ); self.fut_list.push(Box::pin(fut)); - println!("timeout {:?}", self.conn_rt[ind].est_rt); let timeout = Instant::now() + self.conn_rt[ind].est_rt; - tokio::select! { - res = self.fut_list.next() => { - println!("got res {:?}", res); - let res = res.expect("res should not be empty")?; + loop { + tokio::select! { + res = self.fut_list.next() => { + let res = res.expect("res should not be empty"); + match res.1 { + Err(ref err) => { + if self.config.defer_transport_error { + if self.deferred_transport_error.is_none() { + self.deferred_transport_error = Some(err.clone()); + } + if res.0 == ind { + // The current upstream finished, + // try the next one, if any. + self.state = + if ind+1 < self.conn_rt.len() { + QueryState::Probe(ind+1) + } + else + { + QueryState::Wait + }; + // Break out of receive loop + break; + } + // Just continue receiving + continue; + } + // Return error to the user. + } + Ok(ref msg) => { + if skip(msg, &self.config) { + if self.deferred_reply.is_none() { + self.deferred_reply = Some(msg.clone()); + } + if res.0 == ind { + // The current upstream finished, + // try the next one, if any. + self.state = + if ind+1 < self.conn_rt.len() { + QueryState::Probe(ind+1) + } + else + { + QueryState::Wait + }; + // Break out of receive loop + break; + } + // Just continue receiving + continue; + } + // Now we have a reply that can be + // returned to the user. + } + } self.result = Some(res.1); self.res_index= res.0; self.state = QueryState::Report(0); - continue; - } - _ = sleep_until(timeout) => { + // Break out of receive loop + break; + } + _ = sleep_until(timeout) => { // Move to the next Probe state if there // are more upstreams to try, otherwise // move to the Wait state. @@ -362,9 +438,13 @@ impl + Clone + Debug + Send + Sync + 'static> Query { { QueryState::Wait }; - continue; + // Break out of receive loop + break; + } } } + // Continue with state machine loop + continue; } QueryState::Report(ind) => { if ind >= self.conn_rt.len() @@ -382,11 +462,6 @@ impl + Clone + Debug + Send + Sync + 'static> Query { .start .expect("start time should not be empty"); let elapsed = start.elapsed(); - println!( - "expected rt was {:?}", - self.conn_rt[ind].est_rt - ); - println!("reporting duration {:?}", elapsed); let time_report = TimeReport { id: self.conn_rt[ind].id, elapsed, @@ -406,12 +481,61 @@ impl + Clone + Debug + Send + Sync + 'static> Query { continue; } QueryState::Wait => { - let res = self.fut_list.next().await; - println!("got res {:?}", res); - let res = res.expect("res should not be empty")?; - self.result = Some(res.1); - self.res_index = res.0; - self.state = QueryState::Report(0); + loop { + if self.fut_list.is_empty() { + // We have nothing left. There should be a reply or + // an error. Prefer a reply over an error. + if self.deferred_reply.is_some() { + let msg = self + .deferred_reply + .take() + .expect("just checked for Some"); + return Ok(msg); + } + if self.deferred_transport_error.is_some() { + let err = self + .deferred_transport_error + .take() + .expect("just checked for Some"); + return Err(err); + } + panic!("either deferred_reply or deferred_error should be present"); + } + let res = self.fut_list.next().await; + let res = res.expect("res should not be empty"); + match res.1 { + Err(ref err) => { + if self.config.defer_transport_error { + if self.deferred_transport_error.is_none() + { + self.deferred_transport_error = + Some(err.clone()); + } + // Just continue with the next future, or + // finish if fut_list is empty. + continue; + } + // Return error to the user. + } + Ok(ref msg) => { + if skip(msg, &self.config) { + if self.deferred_reply.is_none() { + self.deferred_reply = + Some(msg.clone()); + } + // Just continue with the next future, or + // finish if fut_list is empty. + continue; + } + // Return reply to user. + } + } + self.result = Some(res.1); + self.res_index = res.0; + self.state = QueryState::Report(0); + // Break out of loop to continue with the state machine + break; + } continue; } } @@ -444,6 +568,9 @@ impl< /// Type that actually implements the connection. struct InnerConnection { + /// User configuation. + config: Config, + /// Receive side of the channel used by the runner. receiver: Mutex>>>, @@ -455,9 +582,10 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> InnerConnection { /// Implementation of the new method. - fn new() -> io::Result> { + fn new(config: Config) -> Result, Error> { let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); Ok(Self { + config, receiver: Mutex::new(Some(rx)), sender: tx, }) @@ -502,12 +630,10 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> let _ = rt_req.tx.send(Ok(conn_rt.clone())); } ChanReq::Query(query_req) => { - println!("QueryReq for id {}", query_req.id); let opt_ind = conn_rt.iter().position(|e| e.id == query_req.id); match opt_ind { Some(ind) => { - println!("QueryReq for ind {}", ind); let query = conns[ind].query(&query_req.query_msg).await; // Don't care if send fails @@ -522,49 +648,32 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> } } ChanReq::Report(time_report) => { - println!( - "for {} time {:?}", - time_report.id, time_report.elapsed - ); let opt_ind = conn_rt.iter().position(|e| e.id == time_report.id); if let Some(ind) = opt_ind { - println!("Report for ind {}", ind); let elapsed = time_report.elapsed.as_secs_f64(); conn_stats[ind].mean += (elapsed - conn_stats[ind].mean) / SMOOTH_N; let elapsed_sq = elapsed * elapsed; conn_stats[ind].mean_sq += (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N; - println!( - "new mean {} mean_sq {}", - conn_stats[ind].mean, conn_stats[ind].mean_sq - ); let mean = conn_stats[ind].mean; let var = conn_stats[ind].mean_sq - mean * mean; let std_dev = if var < 0. { 0. } else { f64::sqrt(var) }; - println!("std dev {}", std_dev); let est_rt = mean + 3. * std_dev; conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt); - println!("new est_rt {:?}", conn_rt[ind].est_rt); } } ChanReq::Failure(time_report) => { - println!( - "failure for {} time {:?}", - time_report.id, time_report.elapsed - ); let opt_ind = conn_rt.iter().position(|e| e.id == time_report.id); if let Some(ind) = opt_ind { - println!("Failure Report for ind {}", ind); let elapsed = time_report.elapsed.as_secs_f64(); if elapsed < conn_stats[ind].mean { // Do not update the mean if a // failure took less time than the // current mean. - println!("ignoring better time"); continue; } conn_stats[ind].mean += @@ -572,18 +681,12 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> let elapsed_sq = elapsed * elapsed; conn_stats[ind].mean_sq += (elapsed_sq - conn_stats[ind].mean_sq) / SMOOTH_N; - println!( - "new mean {} mean_sq {}", - conn_stats[ind].mean, conn_stats[ind].mean_sq - ); let mean = conn_stats[ind].mean; let var = conn_stats[ind].mean_sq - mean * mean; let std_dev = if var < 0. { 0. } else { f64::sqrt(var) }; - println!("std dev {}", std_dev); let est_rt = mean + 3. * std_dev; conn_rt[ind].est_rt = Duration::from_secs_f64(est_rt); - println!("new est_rt {:?}", conn_rt[ind].est_rt); } } } @@ -614,7 +717,12 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> .await .expect("send should not fail"); let conn_rt = rx.await.expect("receive should not fail")?; - Ok(Query::new(query_msg, conn_rt, self.sender.clone())) + Ok(Query::new( + self.config.clone(), + query_msg, + conn_rt, + self.sender.clone(), + )) } } @@ -628,7 +736,7 @@ async fn start_request( id: u64, sender: mpsc::Sender>, query_msg: Message, -) -> Result<(usize, Result, Error>), Error> { +) -> (usize, Result, Error>) { let (tx, rx) = oneshot::channel(); sender .send(ChanReq::Query(QueryReq { @@ -638,13 +746,58 @@ async fn start_request( })) .await .expect("send is expected to work"); - let mut query = rx.await.expect("receive is expected to work")?; + let mut query = match rx.await.expect("receive is expected to work") { + Err(err) => return (index, Err(err)), + Ok(query) => query, + }; let reply = query.get_result().await; - Ok((index, reply)) + (index, reply) } /// Compare ConnRT elements based on estimated response time. fn conn_rt_cmp(e1: &ConnRT, e2: &ConnRT) -> Ordering { e1.est_rt.cmp(&e2.est_rt) } + +/// Return if this reply should be skipped or not. +fn skip(msg: &Message, config: &Config) -> bool { + // Check if we actually need to check. + if !config.defer_refused && !config.defer_servfail { + return false; + } + + let opt_rcode = get_opt_rcode(msg); + // OptRcode needs PartialEq + if let OptRcode::Refused = opt_rcode { + if config.defer_refused { + return true; + } + } + if let OptRcode::ServFail = opt_rcode { + if config.defer_servfail { + return true; + } + } + + false +} + +/// Get the extended rcode of a message. +fn get_opt_rcode(msg: &Message) -> OptRcode { + let opt = msg.opt(); + match opt { + Some(opt) => opt.rcode(msg.header()), + None => { + // Convert Rcode to OptRcode, this should be part of + // OptRcode + OptRcode::from_int(msg.header().rcode().to_int() as u16) + } + } +} + +/// Check if config is valid. +fn check_config(_config: &Config) -> Result<(), Error> { + // Nothing to check at the moment. + Ok(()) +} From 63d7f467ccb815970a5dc77c693066c4ce154295 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 19 Oct 2023 15:56:46 +0200 Subject: [PATCH 063/124] Update client-transports example. --- examples/client-transports.rs | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 5b982cebd..35d5cc5bd 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -41,13 +41,14 @@ async fn main() { // Create a new UDP+TCP transport connection. Pass the destination address // and port as parameter. - let udptcp_conn = udp_tcp::Connection::new(server_addr).unwrap(); + let udptcp_conn = udp_tcp::Connection::new(None, server_addr).unwrap(); // Create a clone for the run function. Start the run function on a // separate task. let conn_run = udptcp_conn.clone(); tokio::spawn(async move { - conn_run.run().await; + let res = conn_run.run().await; + println!("run exited with {:?}", res); }); // Send a query message. @@ -63,13 +64,14 @@ async fn main() { // A muli_stream transport connection sets up new TCP connections when // needed. - let tcp_conn = multi_stream::Connection::>::new().unwrap(); + let tcp_conn = multi_stream::Connection::>::new(None).unwrap(); // Start the run function as a separate task. The run function receives // the factory as a parameter. let conn_run = tcp_conn.clone(); tokio::spawn(async move { - conn_run.run(tcp_factory).await; + let res = conn_run.run(tcp_factory).await; + println!("run exited with {:?}", res); }); // Send a query message. @@ -110,12 +112,13 @@ async fn main() { TlsConnFactory::new(client_config, "dns.google", server_addr); // Again create a multi_stream transport connection. - let tls_conn = multi_stream::Connection::new().unwrap(); + let tls_conn = multi_stream::Connection::new(None).unwrap(); // Can start the run function. let conn_run = tls_conn.clone(); tokio::spawn(async move { - conn_run.run(tls_factory).await; + let res = conn_run.run(tls_factory).await; + println!("run exited with {:?}", res); }); let mut query = tls_conn.query(&msg).await.unwrap(); @@ -123,7 +126,7 @@ async fn main() { println!("TLS reply: {:?}", reply); // Create a transport connection for redundant connections. - let redun = redundant::Connection::new().unwrap(); + let redun = redundant::Connection::new(None).unwrap(); // Start the run function on a separate task. let redun_run = redun.clone(); @@ -146,7 +149,7 @@ async fn main() { // Create a new UDP transport connection. Pass the destination address // and port as parameter. This transport does not retry over TCP if the // reply is truncated. - let udp_conn = udp::Connection::new(server_addr).unwrap(); + let udp_conn = udp::Connection::new(None, server_addr).unwrap(); // Send a query message. let mut query = udp_conn.query(&msg).await.unwrap(); @@ -159,7 +162,7 @@ async fn main() { // single request or a small burst of requests. let tcp_conn = TcpStream::connect(server_addr).await.unwrap(); - let tcp = octet_stream::Connection::new().unwrap(); + let tcp = octet_stream::Connection::new(None).unwrap(); let tcp_worker = tcp.clone(); tokio::spawn(async move { From 6947f56a770ccf9aef3e1ca71999c75a729a5ce0 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 19 Oct 2023 21:08:11 +0200 Subject: [PATCH 064/124] Updated version of tokio and proc-macro2. --- Cargo.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 4b93c0c76..0026c1cfb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,9 +30,12 @@ ring = { version = "0.17", optional = true } serde = { version = "1.0.130", optional = true, features = ["derive"] } siphasher = { version = "1", optional = true } smallvec = { version = "1", optional = true } -tokio = { version = "1.0", optional = true, features = ["io-util", "macros", "net", "time", "sync" ] } +tokio = { version = "1.33", optional = true, features = ["io-util", "macros", "net", "time", "sync" ] } tokio-rustls = { version = "0.24", optional = true, features = [] } +# XXX Force proc-macro2 to at least 1.0.69 for minimal-version build +proc-macro2 = "1.0.69" + [target.'cfg(macos)'.dependencies] # specifying this overrides minimum-version mio's 0.2.69 libc dependency, which allows the build to work libc = { version = "0.2.71", default-features = false, optional = true } From ff37fc722fc68db061e6984214f38d887865e040 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 6 Nov 2023 16:45:41 +0100 Subject: [PATCH 065/124] Switch to BaseMessageBuilder. --- examples/client-transports.rs | 19 +- src/net/client/base_message_builder.rs | 40 ++++ src/net/client/bmb.rs | 166 +++++++++++++++++ src/net/client/mod.rs | 2 + src/net/client/multi_stream.rs | 132 ++++++-------- src/net/client/octet_stream.rs | 242 +++++-------------------- src/net/client/query.rs | 12 +- src/net/client/redundant.rs | 106 +++++------ src/net/client/udp.rs | 90 ++++----- src/net/client/udp_tcp.rs | 90 +++------ 10 files changed, 429 insertions(+), 470 deletions(-) create mode 100644 src/net/client/base_message_builder.rs create mode 100644 src/net/client/bmb.rs diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 35d5cc5bd..1e82ce940 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -1,9 +1,10 @@ use domain::base::Dname; use domain::base::Rtype::Aaaa; use domain::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use domain::net::client::bmb::BMB; use domain::net::client::multi_stream; use domain::net::client::octet_stream; -use domain::net::client::query::QueryMessage3; +use domain::net::client::query::QueryMessage4; use domain::net::client::redundant; use domain::net::client::tcp_factory::TcpConnFactory; use domain::net::client::tls_factory::TlsConnFactory; @@ -36,6 +37,8 @@ async fn main() { println!("request msg: {:?}", msg.as_slice()); + let bmb = BMB::new(msg); + // Destination for UDP and TCP let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); @@ -52,7 +55,7 @@ async fn main() { }); // Send a query message. - let mut query = udptcp_conn.query(&msg).await.unwrap(); + let mut query = udptcp_conn.query(&bmb).await.unwrap(); // Get the reply let reply = query.get_result().await; @@ -64,7 +67,7 @@ async fn main() { // A muli_stream transport connection sets up new TCP connections when // needed. - let tcp_conn = multi_stream::Connection::>::new(None).unwrap(); + let tcp_conn = multi_stream::Connection::new(None).unwrap(); // Start the run function as a separate task. The run function receives // the factory as a parameter. @@ -75,7 +78,7 @@ async fn main() { }); // Send a query message. - let mut query = tcp_conn.query(&msg).await.unwrap(); + let mut query = tcp_conn.query(&bmb).await.unwrap(); // Get the reply let reply = query.get_result().await; @@ -121,7 +124,7 @@ async fn main() { println!("run exited with {:?}", res); }); - let mut query = tls_conn.query(&msg).await.unwrap(); + let mut query = tls_conn.query(&bmb).await.unwrap(); let reply = query.get_result().await; println!("TLS reply: {:?}", reply); @@ -141,7 +144,7 @@ async fn main() { // Start a few queries. for _i in 1..10 { - let mut query = redun.query(&msg).await.unwrap(); + let mut query = redun.query(&bmb).await.unwrap(); let reply = query.get_result().await; println!("redundant connection reply: {:?}", reply); } @@ -152,7 +155,7 @@ async fn main() { let udp_conn = udp::Connection::new(None, server_addr).unwrap(); // Send a query message. - let mut query = udp_conn.query(&msg).await.unwrap(); + let mut query = udp_conn.query(&bmb).await.unwrap(); // Get the reply let reply = query.get_result().await; @@ -171,7 +174,7 @@ async fn main() { }); // Send a query message. - let mut query = tcp.query(&msg).await.unwrap(); + let mut query = tcp.query(&bmb).await.unwrap(); // Get the reply let reply = query.get_result().await; diff --git a/src/net/client/base_message_builder.rs b/src/net/client/base_message_builder.rs new file mode 100644 index 000000000..943246237 --- /dev/null +++ b/src/net/client/base_message_builder.rs @@ -0,0 +1,40 @@ +//! Trait for building a message by applying changes to a base message. + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +use crate::base::Header; +use crate::base::Message; +//use crate::base::message_builder::OptBuilder; +use crate::base::opt::TcpKeepalive; + +use std::boxed::Box; +use std::fmt::Debug; +use std::vec::Vec; + +#[derive(Clone, Debug)] +/// Capture the various EDNS options. +pub enum OptTypes { + /// TcpKeepalive variant + TypeTcpKeepalive(TcpKeepalive), +} + +/// A trait that allows construction of a message as a series to changes to +/// an existing message. +pub trait BaseMessageBuilder: Debug + Send + Sync { + /// Return a boxed dyn of the current object. + fn as_box_dyn(&self) -> Box; + + /// Create a message that captures the recorded changes. + fn to_message(&self) -> Message>; + + /// Create a message that captures the recorded changes and convert to + /// a Vec. + fn to_vec(&self) -> Vec; + + /// Return a reference to a mutable Header to record changes to the header. + fn header_mut(&mut self) -> &mut Header; + + /// Add an EDNS option. + fn add_opt(&mut self, opt: OptTypes); +} diff --git a/src/net/client/bmb.rs b/src/net/client/bmb.rs new file mode 100644 index 000000000..0a1392efb --- /dev/null +++ b/src/net/client/bmb.rs @@ -0,0 +1,166 @@ +//! Simple class that implement the BaseMessageBuilder trait. + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +//use bytes::BytesMut; + +//use crate::base::message_builder::OptBuilder; +use crate::base::Header; +use crate::base::Message; +use crate::base::MessageBuilder; +use crate::base::ParsedDname; +use crate::base::Rtype; +use crate::base::StaticCompressor; +use crate::dep::octseq::Octets; +use crate::net::client::base_message_builder::BaseMessageBuilder; +use crate::net::client::base_message_builder::OptTypes; +use crate::net::client::error::Error; +use crate::rdata::AllRecordData; + +use std::boxed::Box; +use std::fmt::Debug; +use std::vec::Vec; + +#[derive(Clone, Debug)] +/// Object that implements the BaseMessageBuilder trait for a Message object. +pub struct BMB> { + /// Base messages. + msg: Message, + + /// New header. + header: Header, + + /// Collection of EDNS options to add. + opts: Vec, +} + +impl + Debug + Octets> BMB { + /// Create a new BMB object. + pub fn new(msg: Message) -> Self { + let header = msg.header(); + Self { + msg, + header, + opts: Vec::new(), + } + } + + /// Create new message based on the changes to the base message. + fn to_message_impl(&self) -> Result>, Error> { + let source = &self.msg; + + let mut target = + MessageBuilder::from_target(StaticCompressor::new(Vec::new())) + .expect("Vec is expected to have enough space"); + let target_hdr = target.header_mut(); + target_hdr.set_flags(self.header.flags()); + target_hdr.set_opcode(self.header.opcode()); + target_hdr.set_rcode(self.header.rcode()); + target_hdr.set_id(self.header.id()); + + let source = source.question(); + let mut target = target.question(); + for rr in source { + let rr = rr.map_err(|_e| Error::MessageParseError)?; + target + .push(rr) + .map_err(|_e| Error::MessageBuilderPushError)?; + } + let mut source = + source.answer().map_err(|_e| Error::MessageParseError)?; + let mut target = target.answer(); + for rr in &mut source { + let rr = rr.map_err(|_e| Error::MessageParseError)?; + let rr = rr + .into_record::>>() + .map_err(|_e| Error::MessageParseError)? + .expect("record expected"); + target + .push(rr) + .map_err(|_e| Error::MessageBuilderPushError)?; + } + + let mut source = source + .next_section() + .map_err(|_e| Error::MessageParseError)? + .expect("section should be present"); + let mut target = target.authority(); + for rr in &mut source { + let rr = rr.map_err(|_e| Error::MessageParseError)?; + let rr = rr + .into_record::>>() + .map_err(|_e| Error::MessageParseError)? + .expect("record expected"); + target + .push(rr) + .map_err(|_e| Error::MessageBuilderPushError)?; + } + + let source = source + .next_section() + .map_err(|_e| Error::MessageParseError)? + .expect("section should be present"); + let mut target = target.additional(); + for rr in source { + let rr = rr.map_err(|_e| Error::MessageParseError)?; + if rr.rtype() == Rtype::Opt { + } else { + let rr = rr + .into_record::>>() + .map_err(|_e| Error::MessageParseError)? + .expect("record expected"); + target + .push(rr) + .map_err(|_e| Error::MessageBuilderPushError)?; + } + } + target + .opt(|opt| { + for o in &self.opts { + match o { + OptTypes::TypeTcpKeepalive(tka) => { + opt.tcp_keepalive(tka.timeout())? + } + } + } + Ok(()) + }) + .map_err(|_e| Error::MessageBuilderPushError)?; + + // It would be nice to use .builder() here. But that one deletes all + // section. We have to resort to .as_builder() which gives a + // reference and then .clone() + let result = target.as_builder().clone(); + let msg = Message::from_octets(result.finish().into_target()).expect( + "Message should be able to parse output from MessageBuilder", + ); + Ok(msg) + } +} + +impl + Clone + Debug + Octets + Send + Sync + 'static> + BaseMessageBuilder for BMB +{ + fn as_box_dyn(&self) -> Box { + Box::new(self.clone()) + } + + fn to_vec(&self) -> Vec { + let msg = self.to_message(); + msg.as_octets().clone() + } + + fn to_message(&self) -> Message> { + self.to_message_impl().unwrap() + } + + fn header_mut(&mut self) -> &mut Header { + &mut self.header + } + + fn add_opt(&mut self, opt: OptTypes) { + self.opts.push(opt); + println!("add_opt: after push: {:?}", self); + } +} diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index d1f7beebf..85b900dbf 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -7,6 +7,8 @@ #![doc = include_str!("../../../examples/client-transports.rs")] //! ``` +pub mod base_message_builder; +pub mod bmb; pub mod error; pub mod factory; pub mod multi_stream; diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 172ec62a7..ab3f8f69c 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -31,10 +31,11 @@ use tokio::time::{sleep_until, Instant}; use crate::base::iana::Rcode; use crate::base::Message; +use crate::net::client::base_message_builder::BaseMessageBuilder; use crate::net::client::error::Error; use crate::net::client::factory::ConnFactory; use crate::net::client::octet_stream; -use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; +use crate::net::client::query::{GetResult, QueryMessage4}; /// Capacity of the channel that transports [ChanReq]. const DEF_CHAN_CAP: usize = 8; @@ -56,18 +57,16 @@ pub struct Config { #[derive(Clone, Debug)] /// A DNS over octect streams transport. -pub struct Connection> { +pub struct Connection { /// Reference counted [InnerConnection]. - inner: Arc>, + inner: Arc>, } -impl + Clone + Debug + Octets + Send + Sync + 'static> - Connection -{ +impl Connection { /// Constructor for [Connection]. /// /// Returns a [Connection] wrapped in a [Result](io::Result). - pub fn new(config: Option) -> Result, Error> { + pub fn new(config: Option) -> Result { let config = match config { Some(config) => { check_config(&config)?; @@ -99,27 +98,13 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> /// /// This function takes a precomposed message as a parameter and /// returns a [Query] object wrapped in a [Result]. - pub async fn query_impl( - &self, - query_msg: &Message, - ) -> Result, Error> { - let (tx, rx) = oneshot::channel(); - self.inner.new_conn(None, tx).await?; - let gr = Query::new(self.clone(), query_msg, rx); - Ok(gr) - } - - /// Start a DNS request. - /// - /// This function takes a precomposed message as a parameter and - /// returns a [Query] object wrapped in a [Result]. - pub async fn query_impl3( + pub async fn query_impl4( &self, - query_msg: &Message, + query_msg: &BMB, ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.inner.new_conn(None, tx).await?; - let gr = Query::new(self.clone(), query_msg, rx); + let gr = Query::::new(self.clone(), query_msg, rx); Ok(Box::new(gr)) } @@ -132,30 +117,32 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> async fn new_conn( &self, id: u64, - tx: oneshot::Sender>, + tx: oneshot::Sender>, ) -> Result<(), Error> { self.inner.new_conn(Some(id), tx).await } } -impl - QueryMessage, Octs> for Connection +/* +impl + QueryMessage, Octs> for Connection<> { fn query<'a>( &'a self, query_msg: &'a Message, - ) -> Pin, Error>> + Send + '_>> + ) -> Pin, Error>> + Send + '_>> { return Box::pin(self.query_impl(query_msg)); } } +*/ -impl + Clone + Debug + Octets + Send + Sync + 'static> - QueryMessage3 for Connection +impl QueryMessage4 + for Connection { fn query<'a>( &'a self, - query_msg: &'a Message, + query_msg: &'a BMB, ) -> Pin< Box< dyn Future, Error>> @@ -163,7 +150,7 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> + '_, >, > { - return Box::pin(self.query_impl3(query_msg)); + return Box::pin(self.query_impl4(query_msg)); } } @@ -171,20 +158,20 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> /// This struct represent an active DNS query. #[derive(Debug)] -pub struct Query> { +pub struct Query { /// Request message. /// /// The reply message is compared with the request message to see if /// it matches the query. // query_msg: Message>, - query_msg: Message, + query_msg: BMB, /// Current state of the query. - state: QueryState, + state: QueryState, /// A multi_octet connection object is needed to request new underlying /// octet_stream transport connections. - conn: Connection, + conn: Connection, /// id of most recent connection. conn_id: u64, @@ -197,12 +184,12 @@ pub struct Query> { /// Status of a query. Used in [Query]. #[derive(Debug)] -enum QueryState> { +enum QueryState { /// Get a octet_stream transport. - GetConn(oneshot::Receiver>), + GetConn(oneshot::Receiver>), /// Start a query using the transport. - StartQuery(octet_stream::Connection), + StartQuery(octet_stream::Connection), /// Get the result of the query. GetResult(octet_stream::QueryNoCheck), @@ -218,28 +205,26 @@ enum QueryState> { } /// The reply to a NewConn request. -type ChanResp = Result, Arc>; +type ChanResp = Result, Arc>; /// Response to the DNS request sent by [InnerConnection::run] to [Query]. #[derive(Debug)] -struct ChanRespOk> { +struct ChanRespOk { /// id of this connection. id: u64, /// New octet_stream transport. - conn: octet_stream::Connection, + conn: octet_stream::Connection, } -impl + Clone + Debug + Octets + Send + Sync + 'static> - Query -{ +impl Query { /// Constructor for [Query], takes a DNS query and a receiver for the /// reply. fn new( - conn: Connection, - query_msg: &Message, - receiver: oneshot::Receiver>, - ) -> Query { + conn: Connection, + query_msg: &BMB, + receiver: oneshot::Receiver>, + ) -> Query { Self { conn, query_msg: query_msg.clone(), @@ -287,8 +272,8 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> } } QueryState::StartQuery(ref mut conn) => { - let mut msg = self.query_msg.clone(); - let query_res = conn.query_no_check(&mut msg).await; + let msg = self.query_msg.clone(); + let query_res = conn.query_no_check(&msg).await; match query_res { Err(err) => { if let Error::ConnectionClosed = err { @@ -324,10 +309,7 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> } let msg = reply.expect("error is checked before"); - let query_msg_ref: &[u8] = self.query_msg.as_ref(); - let query_msg_vec = query_msg_ref.to_vec(); - let query_msg = Message::from_octets(query_msg_vec) - .expect("how to go from MessageBuild to Message?"); + let query_msg = self.query_msg.to_message(); if !is_answer_ignore_id(&msg, &query_msg) { return Err(Error::WrongReplyForQuery); @@ -353,9 +335,7 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> } } -impl + Clone + Debug + Octets + Send + Sync + 'static> - GetResult for Query -{ +impl GetResult for Query { fn get_result( &mut self, ) -> Pin< @@ -369,7 +349,7 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> /// The actual implementation of [Connection]. #[derive(Debug)] -struct InnerConnection> { +struct InnerConnection { /// User configuration values. config: Config, @@ -377,7 +357,7 @@ struct InnerConnection> { /// part of a single channel. /// /// Used by [Query] to send requests to [InnerConnection::run]. - sender: mpsc::Sender>, + sender: mpsc::Sender>, /// receiver part of the channel. /// @@ -385,39 +365,39 @@ struct InnerConnection> { /// [InnerConnection::run]. /// The Option is to allow [InnerConnection::run] to signal that the /// connection is closed. - receiver: Futures_mutex>>>, + receiver: Futures_mutex>>>, } #[derive(Debug)] /// A request to [Connection::run] either for a new octet_stream or to /// shutdown. -struct ChanReq> { +struct ChanReq { /// A requests consists of a command. - cmd: ReqCmd, + cmd: ReqCmd, } #[derive(Debug)] /// Commands that can be requested. -enum ReqCmd> { +enum ReqCmd { /// Request for a (new) connection. /// /// The id of the previous connection (if any) is passed as well as a /// channel to send the reply. - NewConn(Option, ReplySender), + NewConn(Option, ReplySender), /// Shutdown command. Shutdown, } /// This is the type of sender in [ReqCmd]. -type ReplySender = oneshot::Sender>; +type ReplySender = oneshot::Sender>; /// Internal datastructure of [InnerConnection::run] to keep track of /// the status of the connection. // The types Status and ConnState are only used in InnerConnection -struct State3<'a, F, IO, Octs: AsRef<[u8]>> { +struct State3<'a, F, IO, BMB> { /// Underlying octet_stream connection. - conn_state: SingleConnState3, + conn_state: SingleConnState3, /// Current connection id. conn_id: u64, @@ -436,12 +416,12 @@ struct State3<'a, F, IO, Octs: AsRef<[u8]>> { } /// State of the current underlying octet_stream transport. -enum SingleConnState3> { +enum SingleConnState3 { /// No current octet_stream transport. None, /// Current octet_stream transport. - Some(octet_stream::Connection), + Some(octet_stream::Connection), /// State that deals with an error getting a new octet stream from /// a factory. @@ -465,13 +445,11 @@ struct ErrorState { timeout: Duration, } -impl + Clone + Octets + Send + Sync + 'static> - InnerConnection -{ +impl InnerConnection { /// Constructor for [InnerConnection]. /// /// This is the implementation of [Connection::new]. - pub fn new(config: Config) -> Result, Error> { + pub fn new(config: Config) -> Result { let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); Ok(Self { config, @@ -499,9 +477,9 @@ impl + Clone + Octets + Send + Sync + 'static> let opt_receiver = locked_opt_receiver.take(); opt_receiver.expect("no receiver present?") }; - let mut curr_cmd: Option> = None; + let mut curr_cmd: Option> = None; - let mut state = State3::<'a, F, IO, Octs> { + let mut state = State3::<'a, F, IO, BMB> { conn_state: SingleConnState3::None, conn_id: 0, factory, @@ -680,7 +658,7 @@ impl + Clone + Octets + Send + Sync + 'static> async fn new_conn( &self, opt_id: Option, - sender: oneshot::Sender>, + sender: oneshot::Sender>, ) -> Result<(), Error> { let req = ChanReq { cmd: ReqCmd::NewConn(opt_id, sender), diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 6bbe2460a..ebc4b47f5 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -31,12 +31,13 @@ use std::time::{Duration, Instant}; use std::vec::Vec; use crate::base::{ - opt::{AllOptData, Opt, OptRecord, TcpKeepalive}, - Message, MessageBuilder, ParsedDname, Rtype, StaticCompressor, + opt::{AllOptData, OptRecord, TcpKeepalive}, + Message, }; +use crate::net::client::base_message_builder::BaseMessageBuilder; +use crate::net::client::base_message_builder::OptTypes; use crate::net::client::error::Error; -use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; -use crate::rdata::AllRecordData; +use crate::net::client::query::{GetResult, QueryMessage4}; use octseq::Octets; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -88,16 +89,16 @@ impl Default for Config { #[derive(Clone, Debug)] /// A single DNS over octect stream connection. -pub struct Connection> { +pub struct Connection { /// Reference counted [InnerConnection]. - inner: Arc>, + inner: Arc>, } -impl Connection { +impl Connection { /// Constructor for [Connection]. /// /// Returns a [Connection] wrapped in a [Result](io::Result). - pub fn new(config: Option) -> Result, Error> { + pub fn new(config: Option) -> Result { let config = match config { Some(config) => { check_config(&config)?; @@ -126,23 +127,9 @@ impl Connection { /// /// This function takes a precomposed message as a parameter and /// returns a [Query] object wrapped in a [Result]. - async fn query_impl( + async fn query_impl4( &self, - query_msg: &Message, - ) -> Result { - let (tx, rx) = oneshot::channel(); - self.inner.query(tx, query_msg).await?; - let msg = query_msg; - Ok(Query::new(msg, rx)) - } - - /// Start a DNS request. - /// - /// This function takes a precomposed message as a parameter and - /// returns a [Query] object wrapped in a [Result]. - async fn query_impl3( - &self, - query_msg: &Message, + query_msg: &BMB, ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.inner.query(tx, query_msg).await?; @@ -156,7 +143,7 @@ impl Connection { /// match the request avoids having to keep the request around. pub async fn query_no_check( &self, - query_msg: &mut Message, + query_msg: &BMB, ) -> Result { let (tx, rx) = oneshot::channel(); self.inner.query(tx, query_msg).await?; @@ -164,24 +151,10 @@ impl Connection { } } -impl< - Octs: AsMut<[u8]> + AsRef<[u8]> + Clone + Debug + Octets + Send + Sync, - > QueryMessage for Connection -{ - fn query<'a>( - &'a self, - query_msg: &'a Message, - ) -> Pin> + Send + '_>> { - return Box::pin(self.query_impl(query_msg)); - } -} - -impl + Clone + Octets + Send + Sync> QueryMessage3 - for Connection -{ +impl QueryMessage4 for Connection { fn query<'a>( &'a self, - query_msg: &'a Message, + query_msg: &'a BMB, ) -> Pin< Box< dyn Future, Error>> @@ -189,7 +162,7 @@ impl + Clone + Octets + Send + Sync> QueryMessage3 + '_, >, > { - return Box::pin(self.query_impl3(query_msg)); + return Box::pin(self.query_impl4(query_msg)); } } @@ -223,12 +196,11 @@ enum QueryState { impl Query { /// Constructor for [Query], takes a DNS query and a receiver for the /// reply. - fn new( - query_msg: &Message, + fn new( + query_msg: &BMB, receiver: oneshot::Receiver, ) -> Query { - let msg_ref: &[u8] = query_msg.as_ref(); - let vec = msg_ref.to_vec(); + let vec = query_msg.to_vec(); let msg = Message::from_octets(vec) .expect("Message failed to parse contents of another Message"); Self { @@ -342,7 +314,7 @@ impl QueryNoCheck { /// The actual implementation of [Connection]. #[derive(Debug)] -struct InnerConnection> { +struct InnerConnection { /// User configuration variables. config: Config, @@ -350,7 +322,7 @@ struct InnerConnection> { /// part of a single channel. /// /// Used by [Query] to send requests to [InnerConnection::run]. - sender: mpsc::Sender>, + sender: mpsc::Sender>, /// receiver part of the channel. /// @@ -358,14 +330,14 @@ struct InnerConnection> { /// [InnerConnection::run]. /// The Option is to allow [InnerConnection::run] to signal that the /// connection is closed. - receiver: Futures_mutex>>>, + receiver: Futures_mutex>>>, } #[derive(Debug)] /// A request from [Query] to [Connection::run] to start a DNS request. -struct ChanReq> { +struct ChanReq { /// DNS request message - msg: Message, + msg: BMB, /// Sender to send result back to [Query] sender: ReplySender, @@ -453,11 +425,11 @@ enum ConnState { // This type could be local to InnerConnection, but I don't know how type ReaderChanReply = Message; -impl + Clone> InnerConnection { +impl InnerConnection { /// Constructor for [InnerConnection]. /// /// This is the implementation of [Connection::new]. - pub fn new(config: Config) -> Result, Error> { + pub fn new(config: Config) -> Result { let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); Ok(Self { config, @@ -650,14 +622,13 @@ impl + Clone> InnerConnection { pub async fn query( &self, sender: oneshot::Sender, - query_msg: &Message, + query_msg: &BMB, ) -> Result<(), Error> { // We should figure out how to get query_msg. - let msg_clone = query_msg.clone(); let req = ChanReq { sender, - msg: msg_clone, + msg: query_msg.clone(), }; match self.sender.send(req).await { Err(_) => @@ -825,7 +796,7 @@ impl + Clone> InnerConnection { // Note: maybe reqmsg should be a return value. fn insert_req( &self, - req: ChanReq, + mut req: ChanReq, status: &mut Status, reqmsg: &mut Option>, query_vec: &mut Queries, @@ -887,29 +858,17 @@ impl + Clone> InnerConnection { // nature of its use of sequence numbers, is far more // resilient against forgery by third parties." - let mut mut_msg = Message::from_octets(req.msg.as_slice().to_vec()) - .expect("Message failed to parse contents of another Message"); - let hdr = mut_msg.header_mut(); + let hdr = req.msg.header_mut(); hdr.set_id(ind16); if status.send_keepalive { - let res_msg = add_tcp_keepalive(&mut_msg); + let res = add_tcp_keepalive(&mut req.msg); - match res_msg { - Ok(msg) => { - Self::convert_query(&msg, reqmsg); - status.send_keepalive = false; - } - Err(_) => { - // Adding keepalive option - // failed. Send the original - // request. - Self::convert_query(&mut_msg, reqmsg); - } + if let Ok(()) = res { + status.send_keepalive = false; } - } else { - Self::convert_query(&mut_msg, reqmsg); } + Self::convert_query(&req.msg, reqmsg); } /// Take an element out of query_vec. @@ -933,20 +892,20 @@ impl + Clone> InnerConnection { /// Convert the query message to a vector. // This function should return the vector instead of storing it // through a reference. - fn convert_query>( - msg: &Message, + fn convert_query( + msg: &dyn BaseMessageBuilder, reqmsg: &mut Option>, ) { // Ideally there should be a write_all_vectored. Until there is one, // copy to a new Vec and prepend the length octets. - let slice = msg.as_slice(); + let slice = msg.to_vec(); let len = slice.len(); let mut vec = Vec::with_capacity(2 + len); let len16 = len as u16; vec.extend_from_slice(&len16.to_be_bytes()); - vec.extend_from_slice(slice); + vec.extend_from_slice(&slice); *reqmsg = Some(vec); } @@ -1018,127 +977,12 @@ impl + Clone> InnerConnection { //------------ Utility -------------------------------------------------------- -/// Add an edns-tcp-keepalive option to a MessageBuilder. -/// -/// This is surprisingly difficult. We need to copy the original message to -/// a new MessageBuilder because MessageBuilder has no support for changing the -/// opt record. -fn add_tcp_keepalive( - msg: &Message, -) -> Result>, Error> { - // We can't just insert a new option in an existing - // opt record. So we have to create new message and copy records - // from the old one. And insert our option while copying the opt - // record. - let source = msg; - - let mut target = - MessageBuilder::from_target(StaticCompressor::new(Vec::new())) - .expect("Vec is expected to have enough space"); - let source_hdr = source.header(); - let target_hdr = target.header_mut(); - target_hdr.set_flags(source_hdr.flags()); - target_hdr.set_opcode(source_hdr.opcode()); - target_hdr.set_rcode(source_hdr.rcode()); - target_hdr.set_id(source_hdr.id()); - - let source = source.question(); - let mut target = target.question(); - for rr in source { - let rr = rr.map_err(|_e| Error::MessageParseError)?; - target - .push(rr) - .map_err(|_e| Error::MessageBuilderPushError)?; - } - let mut source = - source.answer().map_err(|_e| Error::MessageParseError)?; - let mut target = target.answer(); - for rr in &mut source { - let rr = rr.map_err(|_e| Error::MessageParseError)?; - let rr = rr - .into_record::>>() - .map_err(|_e| Error::MessageParseError)? - .expect("record expected"); - target - .push(rr) - .map_err(|_e| Error::MessageBuilderPushError)?; - } - - let mut source = source - .next_section() - .map_err(|_e| Error::MessageParseError)? - .expect("section should be present"); - let mut target = target.authority(); - for rr in &mut source { - let rr = rr.map_err(|_e| Error::MessageParseError)?; - let rr = rr - .into_record::>>() - .map_err(|_e| Error::MessageParseError)? - .expect("record expected"); - target - .push(rr) - .map_err(|_e| Error::MessageBuilderPushError)?; - } - - let source = source - .next_section() - .map_err(|_e| Error::MessageParseError)? - .expect("section should be present"); - let mut target = target.additional(); - let mut found_opt_rr = false; - for rr in source { - let rr = rr.map_err(|_e| Error::MessageParseError)?; - if rr.rtype() == Rtype::Opt { - found_opt_rr = true; - let rr = rr - .into_record::>() - .map_err(|_e| Error::MessageParseError)? - .expect("record expected"); - let opt_record = OptRecord::from_record(rr); - target - .opt(|newopt| { - newopt - .set_udp_payload_size(opt_record.udp_payload_size()); - newopt.set_version(opt_record.version()); - newopt.set_dnssec_ok(opt_record.dnssec_ok()); - for option in opt_record.opt().iter::>() - { - let option = option.unwrap(); - if let AllOptData::TcpKeepalive(_) = option { - // Ignore existing TcpKeepalive - } else { - newopt.push(&option).unwrap(); - } - } - // send an empty keepalive option - newopt.tcp_keepalive(None).unwrap(); - Ok(()) - }) - .map_err(|_e| Error::MessageBuilderPushError)?; - } else { - let rr = rr - .into_record::>>() - .map_err(|_e| Error::MessageParseError)? - .expect("record expected"); - target - .push(rr) - .map_err(|_e| Error::MessageBuilderPushError)?; - } - } - if !found_opt_rr { - // send an empty keepalive option - target - .opt(|opt| opt.tcp_keepalive(None)) - .map_err(|_e| Error::MessageBuilderPushError)?; - } - - // It would be nice to use .builder() here. But that one deletes all - // section. We have to resort to .as_builder() which gives a - // reference and then .clone() - let result = target.as_builder().clone(); - let msg = Message::from_octets(result.finish().into_target()) - .expect("Message should be able to parse output from MessageBuilder"); - Ok(msg) +/// Add an edns-tcp-keepalive option to a BaseMessageBuilder. +fn add_tcp_keepalive( + msg: &mut BMB, +) -> Result<(), Error> { + msg.add_opt(OptTypes::TypeTcpKeepalive(TcpKeepalive::new(None))); + Ok(()) } /// Check if a DNS reply match the query. Ignore whether id fields match. diff --git a/src/net/client/query.rs b/src/net/client/query.rs index ba141a4f8..2e1510810 100644 --- a/src/net/client/query.rs +++ b/src/net/client/query.rs @@ -8,7 +8,6 @@ use std::boxed::Box; use std::fmt::Debug; use std::future::Future; use std::pin::Pin; -// use std::sync::Arc; use crate::base::Message; use crate::net::client::error::Error; @@ -51,6 +50,17 @@ pub trait QueryMessage3 { ) -> Pin + Send + '_>>; } +/// Trait for starting a DNS query based on a base message builder. +pub trait QueryMessage4 { + /// Query function that takes a BaseMessageBuilder type. + /// + /// This function is intended to be cancel safe. + fn query<'a>( + &'a self, + query_msg: &'a BMB, + ) -> Pin + Send + '_>>; +} + /// This type is the actual result type of the future returned by the /// query function in the QueryMessage2 trait. type QueryResultOutput = Result, Error>; diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index 50f8f61f0..267cd0414 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -8,7 +8,7 @@ use bytes::Bytes; use futures::stream::FuturesUnordered; use futures::StreamExt; -use octseq::{Octets, OctetsBuilder}; +use octseq::Octets; use rand::random; @@ -24,10 +24,9 @@ use tokio::sync::{mpsc, oneshot, Mutex}; use tokio::time::{sleep_until, Duration, Instant}; use crate::base::iana::OptRcode; -use crate::base::wire::Composer; use crate::base::Message; use crate::net::client::error::Error; -use crate::net::client::query::{GetResult, QueryMessage3}; +use crate::net::client::query::{GetResult, QueryMessage4}; /* Basic algorithm: @@ -84,16 +83,14 @@ pub struct Config { /// This type represents a transport connection. #[derive(Clone)] -pub struct Connection { +pub struct Connection { /// Reference to the actual implementation of the connection. - inner: Arc>, + inner: Arc>, } -impl<'a, Octs: Clone + Composer + Debug + Send + Sync + 'static> - Connection -{ +impl<'a, BMB: Clone + Debug + Send + Sync + 'static> Connection { /// Create a new connection. - pub fn new(config: Option) -> Result, Error> { + pub fn new(config: Option) -> Result { let config = match config { Some(config) => { check_config(&config)?; @@ -116,7 +113,7 @@ impl<'a, Octs: Clone + Composer + Debug + Send + Sync + 'static> /// Add a transport connection. pub async fn add( &self, - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, ) -> Result<(), Error> { self.inner.add(conn).await } @@ -124,20 +121,19 @@ impl<'a, Octs: Clone + Composer + Debug + Send + Sync + 'static> /// Implementation of the query function. async fn query_impl( &self, - query_msg: &Message, + query_msg: &BMB, ) -> Result, Error> { let query = self.inner.query(query_msg.clone()).await?; Ok(Box::new(query)) } } -impl< - Octs: Clone + Composer + Debug + OctetsBuilder + Send + Sync + 'static, - > QueryMessage3 for Connection +impl QueryMessage4 + for Connection { fn query<'a>( &'a self, - query_msg: &'a Message, + query_msg: &'a BMB, ) -> Pin< Box< dyn Future, Error>> @@ -153,7 +149,7 @@ impl< /// This type represents an active query request. #[derive(Debug)] -pub struct Query + Send> { +pub struct Query { /// User configuration. config: Config, @@ -161,13 +157,13 @@ pub struct Query + Send> { state: QueryState, /// The query message - query_msg: Message, + query_msg: BMB, /// List of connections identifiers and estimated response times. conn_rt: Vec, /// Channel to send requests to the run function. - sender: mpsc::Sender>, + sender: mpsc::Sender>, /// List of futures for outstanding requests. fut_list: @@ -205,15 +201,15 @@ enum QueryState { } /// The commands that can be sent to the run function. -enum ChanReq { +enum ChanReq { /// Add a connection - Add(AddReq), + Add(AddReq), /// Get the list of estimated response times for all connections GetRT(RTReq), /// Start a query - Query(QueryReq), + Query(QueryReq), /// Report how long it took to get a response Report(TimeReport), @@ -222,16 +218,16 @@ enum ChanReq { Failure(TimeReport), } -impl Debug for ChanReq { +impl Debug for ChanReq { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { f.debug_struct("ChanReq").finish() } } /// Request to add a new connection -struct AddReq { +struct AddReq { /// New connection to add - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, /// Channel to send the reply to tx: oneshot::Sender, @@ -250,18 +246,18 @@ struct RTReq /**/ { type RTReply = Result, Error>; /// Request to start a query -struct QueryReq { +struct QueryReq { /// Identifier of connection id: u64, /// Request message - query_msg: Message, + query_msg: BMB, /// Channel to send the reply to tx: oneshot::Sender, } -impl + Debug + Send> Debug for QueryReq { +impl Debug for QueryReq { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { f.debug_struct("QueryReq") .field("id", &self.id) @@ -308,14 +304,14 @@ struct ConnRT { /// Result of the futures in fut_list. type FutListOutput = (usize, Result, Error>); -impl + Clone + Debug + Send + Sync + 'static> Query { +impl Query { /// Create a new query object. fn new( config: Config, - query_msg: Message, + query_msg: BMB, mut conn_rt: Vec, - sender: mpsc::Sender>, - ) -> Query { + sender: mpsc::Sender>, + ) -> Self { let conn_rt_len = conn_rt.len(); conn_rt.sort_unstable_by(conn_rt_cmp); @@ -328,7 +324,7 @@ impl + Clone + Debug + Send + Sync + 'static> Query { conn_rt.sort_unstable_by(conn_rt_cmp); } - Query { + Self { config, query_msg, //conns, @@ -543,18 +539,7 @@ impl + Clone + Debug + Send + Sync + 'static> Query { } } -impl< - Octs: AsMut<[u8]> - + AsRef<[u8]> - + Clone - + Composer - + Debug - + OctetsBuilder - + Send - + Sync - + 'static, - > GetResult for Query -{ +impl GetResult for Query { fn get_result( &mut self, ) -> Pin< @@ -567,22 +552,20 @@ impl< //------------ InnerConnection ------------------------------------------------ /// Type that actually implements the connection. -struct InnerConnection { +struct InnerConnection { /// User configuation. config: Config, /// Receive side of the channel used by the runner. - receiver: Mutex>>>, + receiver: Mutex>>>, /// To send a request to the runner. - sender: mpsc::Sender>, + sender: mpsc::Sender>, } -impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> - InnerConnection -{ +impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { /// Implementation of the new method. - fn new(config: Config) -> Result, Error> { + fn new(config: Config) -> Result { let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); Ok(Self { config, @@ -596,7 +579,7 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> let mut next_id: u64 = 10; let mut conn_stats: Vec = Vec::new(); let mut conn_rt: Vec = Vec::new(); - let mut conns: Vec + Send + Sync>> = + let mut conns: Vec + Send + Sync>> = Vec::new(); let mut receiver = self.receiver.lock().await; @@ -696,7 +679,7 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> /// Implementation of the add method. async fn add( &self, - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, ) -> Result<(), Error> { let (tx, rx) = oneshot::channel(); self.sender @@ -707,10 +690,7 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> } /// Implementation of the query method. - async fn query( - &'a self, - query_msg: Message, - ) -> Result, Error> { + async fn query(&'a self, query_msg: BMB) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.sender .send(ChanReq::GetRT(RTReq { tx })) @@ -731,19 +711,15 @@ impl<'a, Octs: AsRef<[u8]> + Clone + Debug + Send + Sync + 'static> /// Async function to send a request and wait for the reply. /// /// This gives a single future that we can put in a list. -async fn start_request( +async fn start_request( index: usize, id: u64, - sender: mpsc::Sender>, - query_msg: Message, + sender: mpsc::Sender>, + query_msg: BMB, ) -> (usize, Result, Error>) { let (tx, rx) = oneshot::channel(); sender - .send(ChanReq::Query(QueryReq { - id, - query_msg: query_msg.clone(), - tx, - })) + .send(ChanReq::Query(QueryReq { id, query_msg, tx })) .await .expect("send is expected to work"); let mut query = match rx.await.expect("receive is expected to work") { diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 790f52247..09e2db78e 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -7,7 +7,7 @@ // - cookies // - random port -use bytes::{Bytes, BytesMut}; +use bytes::Bytes; use octseq::Octets; use std::boxed::Box; use std::fmt::{Debug, Formatter}; @@ -22,8 +22,9 @@ use tokio::time::{timeout, Duration, Instant}; use crate::base::iana::Rcode; use crate::base::Message; +use crate::net::client::base_message_builder::BaseMessageBuilder; use crate::net::client::error::Error; -use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; +use crate::net::client::query::{GetResult, QueryMessage4}; /// How many times do we try a new random port if we get ‘address in use.’ const RETRY_RANDOM_PORT: usize = 10; @@ -111,19 +112,11 @@ impl Connection { } /// Start a new DNS query. - async fn query_impl + Clone + Send>( - &self, - query_msg: &Message, - ) -> Result { - self.inner.query(query_msg, self.clone()).await - } - - /// Start a new DNS query. - async fn query_impl3< - Octs: AsRef<[u8]> + Clone + Debug + Send + 'static, + async fn query_impl4< + BMB: BaseMessageBuilder + Clone + Send + Sync + 'static, >( &self, - query_msg: &Message, + query_msg: &BMB, ) -> Result, Error> { let gr = self.inner.query(query_msg, self.clone()).await?; Ok(Box::new(gr)) @@ -135,24 +128,12 @@ impl Connection { } } -impl + Clone + Debug + Send + Sync> - QueryMessage for Connection -{ - fn query<'a>( - &'a self, - query_msg: &'a Message, - ) -> Pin> + Send + '_>> - { - return Box::pin(self.query_impl(query_msg)); - } -} - -impl + Clone + Debug + Send + Sync + 'static> - QueryMessage3 for Connection +impl + QueryMessage4 for Connection { fn query<'a>( &'a self, - query_msg: &'a Message, + query_msg: &'a BMB, ) -> Pin< Box< dyn Future, Error>> @@ -160,7 +141,7 @@ impl + Clone + Debug + Send + Sync + 'static> + '_, >, > { - return Box::pin(self.query_impl3(query_msg)); + return Box::pin(self.query_impl4(query_msg)); } } @@ -375,27 +356,27 @@ impl GetResult for Query { */ -//------------ Query2 --------------------------------------------------------- +//------------ Query4 --------------------------------------------------------- /// The state of a DNS query. -pub struct Query2 { +pub struct Query4 { /// Future that does the actual work of GetResult. get_result_fut: Pin, Error>> + Send>>, } -impl Query2 { +impl Query4 { /// Create new Query object. - fn new( + fn new( config: Config, - query_msg: Message, + query_msg: &BMB, remote_addr: SocketAddr, conn: Connection, - ) -> Query2 { - Query2 { + ) -> Self { + Self { get_result_fut: Box::pin(Self::get_result_impl2( config, - query_msg, + query_msg.clone(), remote_addr, conn, )), @@ -410,9 +391,9 @@ impl Query2 { /// Get the result of a DNS Query. /// /// This function is not cancel safe. - async fn get_result_impl2( + async fn get_result_impl2( config: Config, - mut query_msg: Message, + mut query_bmb: BMB, remote_addr: SocketAddr, conn: Connection, ) -> Result, Error> { @@ -434,8 +415,9 @@ impl Query2 { .map_err(|e| Error::UdpConnect(Arc::new(e)))?; // Set random ID in header - let header = query_msg.header_mut(); + let header = query_bmb.header_mut(); header.set_random_id(); + let query_msg = query_bmb.to_message(); let dgram = query_msg.as_slice(); let sent = sock @@ -444,7 +426,7 @@ impl Query2 { .send(dgram) .await .map_err(|e| Error::UdpSend(Arc::new(e)))?; - if sent != query_msg.as_slice().len() { + if sent != dgram.len() { return Err(Error::UdpShortSend); } @@ -489,13 +471,6 @@ impl Query2 { Err(_) => continue, }; - // Unfortunately we cannot pass query_msg to is_answer - // because is_answer requires Octets, which is not - // implemented by BytesMut. Make a copy. - let query_msg = Message::from_octets(query_msg.as_slice()) - .expect( - "Message failed to parse contents of another Message", - ); if !is_answer(&answer, &query_msg) { // Wrong answer, go back to receiving continue; @@ -537,13 +512,13 @@ impl Query2 { } } -impl Debug for Query2 { +impl Debug for Query4 { fn fmt(&self, _: &mut Formatter<'_>) -> Result<(), core::fmt::Error> { todo!() } } -impl GetResult for Query2 { +impl GetResult for Query4 { fn get_result( &mut self, ) -> Pin< @@ -583,17 +558,14 @@ impl InnerConnection { } /// Return a Query object that contains the query state. - async fn query + Clone>( + async fn query< + BMB: BaseMessageBuilder + Clone + Send + Sync + 'static, + >( &self, - query_msg: &Message, + query_msg: &BMB, conn: Connection, - ) -> Result { - let slice = query_msg.as_slice(); - let mut bytes = BytesMut::with_capacity(slice.len()); - bytes.extend_from_slice(slice); - let query_msg = Message::from_octets(bytes) - .expect("Message failed to parse contents of another Message"); - Ok(Query2::new( + ) -> Result { + Ok(Query4::new( self.config.clone(), query_msg, self.remote_addr, diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index 0cceea385..52cc5d341 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -7,7 +7,6 @@ // - handle shutdown use bytes::Bytes; -use octseq::Octets; use std::boxed::Box; use std::fmt::Debug; use std::future::Future; @@ -16,9 +15,10 @@ use std::pin::Pin; use std::sync::Arc; use crate::base::Message; +use crate::net::client::base_message_builder::BaseMessageBuilder; use crate::net::client::error::Error; use crate::net::client::multi_stream; -use crate::net::client::query::{GetResult, QueryMessage, QueryMessage3}; +use crate::net::client::query::{GetResult, QueryMessage4}; use crate::net::client::tcp_factory::TcpConnFactory; use crate::net::client::udp; @@ -39,19 +39,17 @@ pub struct Config { /// DNS transport connection that first issues a query over a UDP transport and /// falls back to TCP if the reply is truncated. #[derive(Clone)] -pub struct Connection + Debug> { +pub struct Connection { /// Reference to the real object that provides the connection. - inner: Arc>, + inner: Arc>, } -impl + Clone + Debug + Octets + Send + Sync + 'static> - Connection -{ +impl Connection { /// Create a new connection. pub fn new( config: Option, remote_addr: SocketAddr, - ) -> Result, Error> { + ) -> Result { let config = match config { Some(config) => { check_config(&config)?; @@ -70,42 +68,22 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> self.inner.run().await } - /// Start a query. - pub async fn query_impl( - &self, - query_msg: &Message, - ) -> Result, Error> { - self.inner.query(query_msg).await - } - - /// Start a query for the QueryMessage3 trait. - async fn query_impl3( + /// Start a query for the QueryMessage4 trait. + async fn query_impl4( &self, - query_msg: &Message, + query_msg: &BMB, ) -> Result, Error> { let gr = self.inner.query(query_msg).await?; Ok(Box::new(gr)) } } -impl + Clone + Debug + Octets + Send + Sync + 'static> - QueryMessage, Octs> for Connection +impl QueryMessage4 + for Connection { fn query<'a>( &'a self, - query_msg: &'a Message, - ) -> Pin, Error>> + Send + '_>> - { - return Box::pin(self.query_impl(query_msg)); - } -} - -impl + Clone + Debug + Octets + Send + Sync + 'static> - QueryMessage3 for Connection -{ - fn query<'a>( - &'a self, - query_msg: &'a Message, + query_msg: &'a BMB, ) -> Pin< Box< dyn Future, Error>> @@ -113,7 +91,7 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> + '_, >, > { - return Box::pin(self.query_impl3(query_msg)); + return Box::pin(self.query_impl4(query_msg)); } } @@ -121,15 +99,15 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> /// Object that contains the current state of a query. #[derive(Debug)] -pub struct Query + Debug> { +pub struct Query { /// Reqeust message. - query_msg: Message, + query_msg: BMB, /// UDP transport to be used. udp_conn: udp::Connection, /// TCP transport to be used. - tcp_conn: multi_stream::Connection, + tcp_conn: multi_stream::Connection, /// Current state of the query. state: QueryState, @@ -151,17 +129,15 @@ enum QueryState { GetTcpResult(Box), } -impl + Clone + Debug + Octets + Send + Sync + 'static> - Query -{ +impl Query { /// Create a new Query object. /// /// The initial state is to start with a UDP transport. fn new( - query_msg: &Message, + query_msg: &BMB, udp_conn: udp::Connection, - tcp_conn: multi_stream::Connection, - ) -> Query { + tcp_conn: multi_stream::Connection, + ) -> Query { Query { query_msg: query_msg.clone(), udp_conn, @@ -179,7 +155,7 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> QueryState::StartUdpQuery => { let msg = self.query_msg.clone(); let query = - QueryMessage3::query(&self.udp_conn, &msg).await?; + QueryMessage4::query(&self.udp_conn, &msg).await?; self.state = QueryState::GetUdpResult(query); continue; } @@ -194,7 +170,7 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> QueryState::StartTcpQuery => { let msg = self.query_msg.clone(); let query = - QueryMessage3::query(&self.tcp_conn, &msg).await?; + QueryMessage4::query(&self.tcp_conn, &msg).await?; self.state = QueryState::GetTcpResult(query); continue; } @@ -207,8 +183,8 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> } } -impl + Clone + Debug + Octets + Send + Sync + 'static> - GetResult for Query +impl GetResult + for Query { fn get_result( &mut self, @@ -222,7 +198,7 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> //------------ InnerConnection ------------------------------------------------ /// The actual connection object. -struct InnerConnection + Debug> { +struct InnerConnection { /// The remote address to connect to. remote_addr: SocketAddr, @@ -230,20 +206,15 @@ struct InnerConnection + Debug> { udp_conn: udp::Connection, /// The TCP transport connection. - tcp_conn: multi_stream::Connection, + tcp_conn: multi_stream::Connection, } -impl - InnerConnection -{ +impl InnerConnection { /// Create a new InnerConnection object. /// /// Create the UDP and TCP connections. Store the remote address because /// run needs it later. - fn new( - config: Config, - remote_addr: SocketAddr, - ) -> Result, Error> { + fn new(config: Config, remote_addr: SocketAddr) -> Result { let udp_conn = udp::Connection::new(config.udp, remote_addr)?; let tcp_conn = multi_stream::Connection::new(config.multi_stream)?; @@ -266,10 +237,7 @@ impl /// Implementation of the query function. /// /// Just create a Query object with the state it needs. - async fn query( - &self, - query_msg: &Message, - ) -> Result, Error> { + async fn query(&self, query_msg: &BMB) -> Result, Error> { Ok(Query::new( query_msg, self.udp_conn.clone(), From 5f72cdf204ecaf2e4976c5aef1beca4da3b8f25a Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 6 Nov 2023 17:38:44 +0100 Subject: [PATCH 066/124] derive Debug --- src/net/client/redundant.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index 267cd0414..88cbbd115 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -82,7 +82,7 @@ pub struct Config { //------------ Connection ----------------------------------------------------- /// This type represents a transport connection. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Connection { /// Reference to the actual implementation of the connection. inner: Arc>, @@ -552,6 +552,7 @@ impl GetResult for Query { //------------ InnerConnection ------------------------------------------------ /// Type that actually implements the connection. +#[derive(Debug)] struct InnerConnection { /// User configuation. config: Config, From f46e419bce6b0e56964b516d90f984f1ab0fac49 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 16 Nov 2023 16:06:57 +0100 Subject: [PATCH 067/124] Make sure that dropping all references to connection objects terminates the run methods. --- src/net/client/multi_stream.rs | 52 ++++++++++++++++++---------- src/net/client/octet_stream.rs | 62 +++++++++++++++++++++------------- src/net/client/redundant.rs | 30 +++++++++++----- src/net/client/udp_tcp.rs | 12 ++++--- 4 files changed, 102 insertions(+), 54 deletions(-) diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index ab3f8f69c..bf75a7669 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -8,7 +8,6 @@ use bytes::Bytes; -use futures::lock::Mutex as Futures_mutex; use futures::stream::FuturesUnordered; use futures::StreamExt; @@ -21,7 +20,7 @@ use std::fmt::Debug; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::io; @@ -84,21 +83,21 @@ impl Connection { /// /// This function has to run in the background or together with /// any calls to [query](Self::query) or [Query::get_result]. - pub async fn run< - F: ConnFactory + Send, - IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Unpin, + pub fn run< + F: ConnFactory + Send + 'static, + IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, >( &self, factory: F, - ) -> Result<(), Error> { - self.inner.run(factory).await + ) -> Pin> + Send>> { + self.inner.run(factory) } /// Start a DNS request. /// /// This function takes a precomposed message as a parameter and /// returns a [Query] object wrapped in a [Result]. - pub async fn query_impl4( + async fn query_impl4( &self, query_msg: &BMB, ) -> Result, Error> { @@ -365,7 +364,7 @@ struct InnerConnection { /// [InnerConnection::run]. /// The Option is to allow [InnerConnection::run] to signal that the /// connection is closed. - receiver: Futures_mutex>>>, + receiver: Mutex>>>, } #[derive(Debug)] @@ -454,27 +453,42 @@ impl InnerConnection { Ok(Self { config, sender: tx, - receiver: Futures_mutex::new(Some(rx)), + receiver: Mutex::new(Some(rx)), }) } /// Main execution function for [InnerConnection]. /// /// This function Gets called by [Connection::run]. - /// This function is not async cancellation safe - #[rustfmt::skip] + /// This function is not async cancellation safe. + /// Make sure the resulting future does not contain a reference to self. + pub fn run< + F: ConnFactory + Send + 'static, + IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, + >( + &self, + factory: F, + ) -> Pin> + Send>> { + let mut receiver = self.receiver.lock().unwrap(); + let opt_receiver = receiver.take(); + drop(receiver); + + Box::pin(Self::run_impl(self.config.clone(), factory, opt_receiver)) + } - pub async fn run< + /// Implementation of the run method. This function does not have + /// a reference to self. + #[rustfmt::skip] + async fn run_impl< 'a, F: ConnFactory + Send, IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Unpin, >( - &self, + config: Config, factory: F, + opt_receiver: Option>> ) -> Result<(), Error> { let mut receiver = { - let mut locked_opt_receiver = self.receiver.lock().await; - let opt_receiver = locked_opt_receiver.take(); opt_receiver.expect("no receiver present?") }; let mut curr_cmd: Option> = None; @@ -597,7 +611,7 @@ impl InnerConnection { let stream = res_conn .expect("error case is checked before"); - let conn = octet_stream::Connection::new(self.config.octet_stream.clone())?; + let conn = octet_stream::Connection::new(config.octet_stream.clone())?; let conn_run = conn.clone(); let clo = || async move { @@ -633,7 +647,9 @@ impl InnerConnection { tokio::select! { msg = recv_fut => { if msg.is_none() { - panic!("recv failed"); + // All references to the connection object have been + // dropped. Shutdown. + break; } curr_cmd = Some(msg.expect("None is checked before").cmd); } diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index ebc4b47f5..e08e211d6 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -20,13 +20,12 @@ use bytes; use bytes::{Bytes, BytesMut}; use core::convert::From; -use futures::lock::Mutex as Futures_mutex; use std::boxed::Box; use std::fmt::Debug; use std::future::Future; use std::io::ErrorKind; use std::pin::Pin; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use std::vec::Vec; @@ -94,7 +93,7 @@ pub struct Connection { inner: Arc>, } -impl Connection { +impl Connection { /// Constructor for [Connection]. /// /// Returns a [Connection] wrapped in a [Result](io::Result). @@ -116,11 +115,12 @@ impl Connection { /// /// This function has to run in the background or together with /// any calls to [query](Self::query) or [Query::get_result]. - pub async fn run( + /// Worker function for a connection object. + pub fn run( &self, io: IO, - ) -> Option<()> { - self.inner.run(io).await + ) -> Pin> + Send>> { + self.inner.run(io) } /// Start a DNS request. @@ -151,7 +151,9 @@ impl Connection { } } -impl QueryMessage4 for Connection { +impl QueryMessage4 + for Connection +{ fn query<'a>( &'a self, query_msg: &'a BMB, @@ -330,7 +332,7 @@ struct InnerConnection { /// [InnerConnection::run]. /// The Option is to allow [InnerConnection::run] to signal that the /// connection is closed. - receiver: Futures_mutex>>>, + receiver: Mutex>>>, } #[derive(Debug)] @@ -425,7 +427,7 @@ enum ConnState { // This type could be local to InnerConnection, but I don't know how type ReaderChanReply = Message; -impl InnerConnection { +impl InnerConnection { /// Constructor for [InnerConnection]. /// /// This is the implementation of [Connection::new]. @@ -434,17 +436,32 @@ impl InnerConnection { Ok(Self { config, sender: tx, - receiver: Futures_mutex::new(Some(rx)), + receiver: Mutex::new(Some(rx)), }) } + /// Run method. + /// + /// Make sure the future does not contain a reference to self. + pub fn run( + &self, + io: IO, + ) -> Pin> + Send>> { + let mut receiver = self.receiver.lock().unwrap(); + let opt_receiver = receiver.take(); + drop(receiver); + + Box::pin(Self::run_impl(self.config.clone(), io, opt_receiver)) + } + /// Main execution function for [InnerConnection]. /// /// This function Gets called by [Connection::run]. /// This function is not async cancellation safe - pub async fn run( - &self, + async fn run_impl( + config: Config, io: IO, + opt_receiver: Option>>, ) -> Option<()> { let (reply_sender, mut reply_receiver) = mpsc::channel::(READ_REPLY_CHAN_CAP); @@ -454,11 +471,7 @@ impl InnerConnection { let reader_fut = Self::reader(&mut read_stream, reply_sender); tokio::pin!(reader_fut); - let mut receiver = { - let mut locked_opt_receiver = self.receiver.lock().await; - let opt_receiver = locked_opt_receiver.take(); - opt_receiver.expect("no receiver present?") - }; + let mut receiver = { opt_receiver.expect("no receiver present?") }; let mut status = Status { state: ConnState::Active(None), @@ -478,7 +491,7 @@ impl InnerConnection { ConnState::Active(opt_instant) => { if let Some(instant) = opt_instant { let elapsed = instant.elapsed(); - if elapsed > self.config.response_timeout { + if elapsed > config.response_timeout { Self::error( Error::StreamReadTimeout, &mut query_vec, @@ -486,7 +499,7 @@ impl InnerConnection { status.state = ConnState::ReadTimeout; break; } - Some(self.config.response_timeout - elapsed) + Some(config.response_timeout - elapsed) } else { None } @@ -519,7 +532,7 @@ impl InnerConnection { None => // Just use the response timeout { - self.config.response_timeout + config.response_timeout } }; @@ -586,9 +599,13 @@ impl InnerConnection { res = recv_fut, if !do_write => { match res { Some(req) => - self.insert_req(req, &mut status, + Self::insert_req(req, &mut status, &mut reqmsg, &mut query_vec), - None => panic!("recv failed"), + None => { + // All references to the connection object have + // been dropped. Shutdown. + break; + } } } _ = sleep_fut => { @@ -795,7 +812,6 @@ impl InnerConnection { /// idle. Addend a edns-tcp-keepalive option if needed. // Note: maybe reqmsg should be a return value. fn insert_req( - &self, mut req: ChanReq, status: &mut Status, reqmsg: &mut Option>, diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index 88cbbd115..07f50f0aa 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -18,9 +18,10 @@ use std::fmt::{Debug, Formatter}; use std::future::Future; use std::pin::Pin; use std::sync::Arc; +use std::sync::Mutex; use std::vec::Vec; -use tokio::sync::{mpsc, oneshot, Mutex}; +use tokio::sync::{mpsc, oneshot}; use tokio::time::{sleep_until, Duration, Instant}; use crate::base::iana::OptRcode; @@ -106,8 +107,8 @@ impl<'a, BMB: Clone + Debug + Send + Sync + 'static> Connection { } /// Runner function for a connection. - pub async fn run(&self) { - self.inner.run().await + pub fn run(&self) -> Pin + Send>> { + self.inner.run() } /// Add a transport connection. @@ -575,22 +576,33 @@ impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { }) } + /// Run method. + /// + /// Make sure the future does not contain a reference to self. + fn run(&self) -> Pin + Send>> { + let mut receiver = self.receiver.lock().unwrap(); + let opt_receiver = receiver.take(); + drop(receiver); + + Box::pin(Self::run_impl(opt_receiver)) + } + /// Implementation of the run method. - async fn run(&self) { + async fn run_impl(opt_receiver: Option>>) { let mut next_id: u64 = 10; let mut conn_stats: Vec = Vec::new(); let mut conn_rt: Vec = Vec::new(); let mut conns: Vec + Send + Sync>> = Vec::new(); - let mut receiver = self.receiver.lock().await; - let opt_receiver = receiver.take(); - drop(receiver); let mut receiver = opt_receiver.expect("receiver should not be empty"); loop { - let req = - receiver.recv().await.expect("receiver should not fail"); + let req = match receiver.recv().await { + Some(req) => req, + None => break, // All references to connection objects are + // dropped. Shutdown. + }; match req { ChanReq::Add(add_req) => { let id = next_id; diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index 52cc5d341..ab32925f1 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -64,8 +64,10 @@ impl Connection { } /// Worker function for a connection object. - pub async fn run(&self) -> Result<(), Error> { - self.inner.run().await + pub fn run( + &self, + ) -> Pin> + Send>> { + self.inner.run() } /// Start a query for the QueryMessage4 trait. @@ -229,9 +231,11 @@ impl InnerConnection { /// /// Create a TCP connection factory and pass that to worker function /// of the multi_stream object. - pub async fn run(&self) -> Result<(), Error> { + fn run(&self) -> Pin> + Send>> { let tcp_factory = TcpConnFactory::new(self.remote_addr); - self.tcp_conn.run(tcp_factory).await + + let fut = self.tcp_conn.run(tcp_factory); + Box::pin(fut) } /// Implementation of the query function. From 7fafb1d8a9260572fabc0320686f569dfe18cd27 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 1 Dec 2023 16:18:53 +0100 Subject: [PATCH 068/124] Less debug output. --- src/net/client/bmb.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/net/client/bmb.rs b/src/net/client/bmb.rs index 0a1392efb..0c70a0b7e 100644 --- a/src/net/client/bmb.rs +++ b/src/net/client/bmb.rs @@ -161,6 +161,6 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> fn add_opt(&mut self, opt: OptTypes) { self.opts.push(opt); - println!("add_opt: after push: {:?}", self); + //println!("add_opt: after push: {:?}", self); } } From b0e30245b730c9ee31d12a522ffdf8dc1696a416 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 1 Dec 2023 16:19:48 +0100 Subject: [PATCH 069/124] Make sure run tasks terminate. Timeouts to avoid waiting forever. More comments. --- examples/client-transports.rs | 127 +++++++++++++++++++++++----------- 1 file changed, 88 insertions(+), 39 deletions(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 1e82ce940..79c3b0e56 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -13,14 +13,20 @@ use domain::net::client::udp_tcp; use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; use std::sync::Arc; +use std::time::Duration; use tokio::net::TcpStream; +use tokio::time::timeout; use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore}; #[tokio::main] async fn main() { - // Create DNS request message - // Create a message builder wrapping a compressor wrapping a stream - // target. + // Create DNS request message. It would be nice if there was an object + // that implements both MessageBuilder and BaseMEssageBuilder. Until + // that time, first create a message using MessageBuilder, then turn + // that into a Message, and create a BaseMessaBuilder based on the message. + // + // TODO: No need for StreamTarget at the moment, this should also be + // handled better. let mut msg = MessageBuilder::from_target(StaticCompressor::new( StreamTarget::new_vec(), )) @@ -30,59 +36,90 @@ async fn main() { msg.push((Dname::>::vec_from_str("example.com").unwrap(), Aaaa)) .unwrap(); + // Create a Message to pass to BMB. let msg = Message::from_octets( msg.as_target().as_target().as_dgram_slice().to_vec(), ) .unwrap(); - println!("request msg: {:?}", msg.as_slice()); - + // Transports take a BaseMEssageBuilder to be able to add options along + // the way and only flatten just before actually writing to the network. let bmb = BMB::new(msg); // Destination for UDP and TCP let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); + let octet_stream_config = octet_stream::Config { + response_timeout: Duration::from_millis(100), + }; + let multi_stream_config = multi_stream::Config { + octet_stream: Some(octet_stream_config.clone()), + }; + // Create a new UDP+TCP transport connection. Pass the destination address // and port as parameter. - let udptcp_conn = udp_tcp::Connection::new(None, server_addr).unwrap(); - - // Create a clone for the run function. Start the run function on a - // separate task. - let conn_run = udptcp_conn.clone(); + let udp_config = udp::Config { + max_parallel: 1, + read_timeout: Duration::from_millis(1000), + max_retries: 1, + }; + let udp_tcp_config = udp_tcp::Config { + udp: Some(udp_config.clone()), + multi_stream: Some(multi_stream_config.clone()), + }; + let udptcp_conn = + udp_tcp::Connection::new(Some(udp_tcp_config), server_addr).unwrap(); + + // Start the run function in a separate task. The run function will + // terminate when all references to the connection have been dropped. + // Make sure that the task does not accidentally get a reference to the + // connection. + let run_fut = udptcp_conn.run(); tokio::spawn(async move { - let res = conn_run.run().await; - println!("run exited with {:?}", res); + let res = run_fut.await; + println!("UDP+TCP run exited with {:?}", res); }); // Send a query message. let mut query = udptcp_conn.query(&bmb).await.unwrap(); // Get the reply + println!("Wating for UDP+TCP reply"); let reply = query.get_result().await; println!("UDP+TCP reply: {:?}", reply); + // The query may have a reference to the connection. Drop the query + // when it is no longer needed. + drop(query); + // Create a factory of TCP connections. Pass the destination address and // port as parameter. let tcp_factory = TcpConnFactory::new(server_addr); // A muli_stream transport connection sets up new TCP connections when // needed. - let tcp_conn = multi_stream::Connection::new(None).unwrap(); + let tcp_conn = + multi_stream::Connection::new(Some(multi_stream_config.clone())) + .unwrap(); - // Start the run function as a separate task. The run function receives + // Get a future for the run function. The run function receives // the factory as a parameter. - let conn_run = tcp_conn.clone(); + let run_fut = tcp_conn.run(tcp_factory); tokio::spawn(async move { - let res = conn_run.run(tcp_factory).await; - println!("run exited with {:?}", res); + let res = run_fut.await; + println!("multi TCP run exited with {:?}", res); }); // Send a query message. let mut query = tcp_conn.query(&bmb).await.unwrap(); - // Get the reply - let reply = query.get_result().await; - println!("TCP reply: {:?}", reply); + // Get the reply. A multi_stream connection does not have any timeout. + // Wrap get_result in a timeout. + println!("Wating for multi TCP reply"); + let reply = timeout(Duration::from_millis(500), query.get_result()).await; + println!("multi TCP reply: {:?}", reply); + + drop(query); // Some TLS boiler plate for the root certificates. let mut root_store = RootCertStore::empty(); @@ -106,35 +143,40 @@ async fn main() { // Currently the only support TLS connections are the ones that have a // valid certificate. Use a well known public resolver. - let server_addr = + let google_server_addr = SocketAddr::new(IpAddr::from_str("8.8.8.8").unwrap(), 853); // Create a new TLS connection factory. We pass the TLS config, the name of // the remote server and the destination address and port. let tls_factory = - TlsConnFactory::new(client_config, "dns.google", server_addr); + TlsConnFactory::new(client_config, "dns.google", google_server_addr); // Again create a multi_stream transport connection. - let tls_conn = multi_stream::Connection::new(None).unwrap(); + let tls_conn = + multi_stream::Connection::new(Some(multi_stream_config)).unwrap(); - // Can start the run function. - let conn_run = tls_conn.clone(); + // Start the run function. + let run_fut = tls_conn.run(tls_factory); tokio::spawn(async move { - let res = conn_run.run(tls_factory).await; - println!("run exited with {:?}", res); + let res = run_fut.await; + println!("TLS run exited with {:?}", res); }); let mut query = tls_conn.query(&bmb).await.unwrap(); - let reply = query.get_result().await; + println!("Wating for TLS reply"); + let reply = timeout(Duration::from_millis(500), query.get_result()).await; println!("TLS reply: {:?}", reply); + drop(query); + // Create a transport connection for redundant connections. let redun = redundant::Connection::new(None).unwrap(); // Start the run function on a separate task. - let redun_run = redun.clone(); + let run_fut = redun.run(); tokio::spawn(async move { - redun_run.run().await; + run_fut.await; + println!("redundant run terminated"); }); // Add the previously created transports. @@ -143,16 +185,22 @@ async fn main() { redun.add(Box::new(tls_conn)).await.unwrap(); // Start a few queries. - for _i in 1..10 { + for i in 1..10 { let mut query = redun.query(&bmb).await.unwrap(); let reply = query.get_result().await; - println!("redundant connection reply: {:?}", reply); + if i == 2 { + println!("redundant connection reply: {:?}", reply); + } } + drop(redun); + // Create a new UDP transport connection. Pass the destination address // and port as parameter. This transport does not retry over TCP if the - // reply is truncated. - let udp_conn = udp::Connection::new(None, server_addr).unwrap(); + // reply is truncated. This transport does not have a separate run + // function. + let udp_conn = + udp::Connection::new(Some(udp_config), server_addr).unwrap(); // Send a query message. let mut query = udp_conn.query(&bmb).await.unwrap(); @@ -165,12 +213,11 @@ async fn main() { // single request or a small burst of requests. let tcp_conn = TcpStream::connect(server_addr).await.unwrap(); - let tcp = octet_stream::Connection::new(None).unwrap(); - let tcp_worker = tcp.clone(); - + let tcp = octet_stream::Connection::>>::new(None).unwrap(); + let run_fut = tcp.run(tcp_conn); tokio::spawn(async move { - tcp_worker.run(tcp_conn).await; - println!("run terminated"); + run_fut.await; + println!("single TCP run terminated"); }); // Send a query message. @@ -179,4 +226,6 @@ async fn main() { // Get the reply let reply = query.get_result().await; println!("TCP reply: {:?}", reply); + + drop(tcp); } From 07b9be78f6fa0bed8b8bb26ffe3f8d3b58f675e2 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 4 Dec 2023 16:16:07 +0100 Subject: [PATCH 070/124] Handle UDP payload size. --- src/net/client/base_message_builder.rs | 3 +++ src/net/client/bmb.rs | 34 ++++++++++++++++++-------- src/net/client/udp.rs | 17 +++++++++++++ 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/src/net/client/base_message_builder.rs b/src/net/client/base_message_builder.rs index 943246237..8b7c62804 100644 --- a/src/net/client/base_message_builder.rs +++ b/src/net/client/base_message_builder.rs @@ -35,6 +35,9 @@ pub trait BaseMessageBuilder: Debug + Send + Sync { /// Return a reference to a mutable Header to record changes to the header. fn header_mut(&mut self) -> &mut Header; + /// Set the UDP payload size. + fn set_udp_payload_size(&mut self, value: u16); + /// Add an EDNS option. fn add_opt(&mut self, opt: OptTypes); } diff --git a/src/net/client/bmb.rs b/src/net/client/bmb.rs index 0c70a0b7e..4c0d55968 100644 --- a/src/net/client/bmb.rs +++ b/src/net/client/bmb.rs @@ -33,6 +33,9 @@ pub struct BMB> { /// Collection of EDNS options to add. opts: Vec, + + /// UDP payload size. + udp_payload_size: Option, } impl + Debug + Octets> BMB { @@ -43,6 +46,7 @@ impl + Debug + Octets> BMB { msg, header, opts: Vec::new(), + udp_payload_size: None, } } @@ -115,18 +119,24 @@ impl + Debug + Octets> BMB { .map_err(|_e| Error::MessageBuilderPushError)?; } } - target - .opt(|opt| { - for o in &self.opts { - match o { - OptTypes::TypeTcpKeepalive(tka) => { - opt.tcp_keepalive(tka.timeout())? + + if self.udp_payload_size.is_some() || !self.opts.is_empty() { + target + .opt(|opt| { + if let Some(size) = self.udp_payload_size { + opt.set_udp_payload_size(size) + } + for o in &self.opts { + match o { + OptTypes::TypeTcpKeepalive(tka) => { + opt.tcp_keepalive(tka.timeout())? + } } } - } - Ok(()) - }) - .map_err(|_e| Error::MessageBuilderPushError)?; + Ok(()) + }) + .map_err(|_e| Error::MessageBuilderPushError)?; + } // It would be nice to use .builder() here. But that one deletes all // section. We have to resort to .as_builder() which gives a @@ -159,6 +169,10 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> &mut self.header } + fn set_udp_payload_size(&mut self, value: u16) { + self.udp_payload_size = Some(value); + } + fn add_opt(&mut self, opt: OptTypes) { self.opts.push(opt); //println!("add_opt: after push: {:?}", self); diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 09e2db78e..9987951cf 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -58,6 +58,10 @@ const MIN_MAX_RETRIES: u8 = 1; /// Maximum allowed configuration value for max_retries. const MAX_MAX_RETRIES: u8 = 100; +/// Default UDP payload size. See draft-ietf-dnsop-avoid-fragmentation-15 +/// for discussion. +const DEF_UDP_PAYLOAD_SIZE: u16 = 1232; + //------------ Config --------------------------------------------------------- /// Configuration for a UDP transport connection. @@ -71,6 +75,10 @@ pub struct Config { /// Maimum number of retries. pub max_retries: u8, + + /// EDNS(0) UDP payload size. Set this value to None to be able to create + /// a DNS request without ENDS(0) option. + pub udp_payload_size: Option, } impl Default for Config { @@ -79,6 +87,7 @@ impl Default for Config { max_parallel: DEF_MAX_PARALLEL, read_timeout: DEF_READ_TIMEOUT, max_retries: DEF_MAX_RETRIES, + udp_payload_size: Some(DEF_UDP_PAYLOAD_SIZE), } } } @@ -372,6 +381,7 @@ impl Query4 { query_msg: &BMB, remote_addr: SocketAddr, conn: Connection, + udp_payload_size: Option, ) -> Self { Self { get_result_fut: Box::pin(Self::get_result_impl2( @@ -379,6 +389,7 @@ impl Query4 { query_msg.clone(), remote_addr, conn, + udp_payload_size, )), } } @@ -396,6 +407,7 @@ impl Query4 { mut query_bmb: BMB, remote_addr: SocketAddr, conn: Connection, + udp_payload_size: Option, ) -> Result, Error> { let recv_size = 2000; // Should be configurable. @@ -417,6 +429,10 @@ impl Query4 { // Set random ID in header let header = query_bmb.header_mut(); header.set_random_id(); + // Set UDP payload size + if let Some(size) = udp_payload_size { + query_bmb.set_udp_payload_size(size) + } let query_msg = query_bmb.to_message(); let dgram = query_msg.as_slice(); @@ -570,6 +586,7 @@ impl InnerConnection { query_msg, self.remote_addr, conn, + self.config.udp_payload_size, )) } From 92b75465c4850a0efb5ba4c08bb20b23e49ce535 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 4 Dec 2023 16:23:29 +0100 Subject: [PATCH 071/124] Simpify message building, add udp_payload_size to UDP config. --- examples/client-transports.rs | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 79c3b0e56..22c5e2412 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -1,6 +1,6 @@ use domain::base::Dname; use domain::base::Rtype::Aaaa; -use domain::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use domain::base::{Message, MessageBuilder}; use domain::net::client::bmb::BMB; use domain::net::client::multi_stream; use domain::net::client::octet_stream; @@ -24,23 +24,14 @@ async fn main() { // that implements both MessageBuilder and BaseMEssageBuilder. Until // that time, first create a message using MessageBuilder, then turn // that into a Message, and create a BaseMessaBuilder based on the message. - // - // TODO: No need for StreamTarget at the moment, this should also be - // handled better. - let mut msg = MessageBuilder::from_target(StaticCompressor::new( - StreamTarget::new_vec(), - )) - .unwrap(); + let mut msg = MessageBuilder::new_vec(); msg.header_mut().set_rd(true); let mut msg = msg.question(); msg.push((Dname::>::vec_from_str("example.com").unwrap(), Aaaa)) .unwrap(); // Create a Message to pass to BMB. - let msg = Message::from_octets( - msg.as_target().as_target().as_dgram_slice().to_vec(), - ) - .unwrap(); + let msg = Message::from_octets(msg.as_target().to_vec()).unwrap(); // Transports take a BaseMEssageBuilder to be able to add options along // the way and only flatten just before actually writing to the network. @@ -62,6 +53,7 @@ async fn main() { max_parallel: 1, read_timeout: Duration::from_millis(1000), max_retries: 1, + udp_payload_size: Some(1400), }; let udp_tcp_config = udp_tcp::Config { udp: Some(udp_config.clone()), From c121bf0b8495070764a498fe1ddb16178a5815ea Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 4 Dec 2023 16:36:40 +0100 Subject: [PATCH 072/124] Use into_message. --- examples/client-transports.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 22c5e2412..16acaf493 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -1,6 +1,6 @@ use domain::base::Dname; +use domain::base::MessageBuilder; use domain::base::Rtype::Aaaa; -use domain::base::{Message, MessageBuilder}; use domain::net::client::bmb::BMB; use domain::net::client::multi_stream; use domain::net::client::octet_stream; @@ -31,7 +31,7 @@ async fn main() { .unwrap(); // Create a Message to pass to BMB. - let msg = Message::from_octets(msg.as_target().to_vec()).unwrap(); + let msg = msg.into_message(); // Transports take a BaseMEssageBuilder to be able to add options along // the way and only flatten just before actually writing to the network. From c770a389def45abad12dbb0ba47d6f7e9ab57449 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 4 Dec 2023 16:42:23 +0100 Subject: [PATCH 073/124] Better error handling for TCP connection. --- examples/client-transports.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 16acaf493..b2e3fd25a 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -203,7 +203,16 @@ async fn main() { // Create a single TCP transport connection. This is usefull for a // single request or a small burst of requests. - let tcp_conn = TcpStream::connect(server_addr).await.unwrap(); + let tcp_conn = match TcpStream::connect(server_addr).await { + Ok(conn) => conn, + Err(err) => { + println!( + "TCP Connection to {} failed: {}, exiting", + server_addr, err + ); + return; + } + }; let tcp = octet_stream::Connection::>>::new(None).unwrap(); let run_fut = tcp.run(tcp_conn); From b64418aa352c5c1feb48772e4de4fa55f971c9b1 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Tue, 5 Dec 2023 09:06:23 +0100 Subject: [PATCH 074/124] We need 0.21.9 for rustls to pass minimal version tests --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 0026c1cfb..6bd7f72ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,7 @@ zonefile = ["bytes", "std"] ci-test = ["resolv", "resolv-sync", "sign", "std", "serde", "tsig", "validate", "zonefile"] [dev-dependencies] -rustls = { version = "0.21" } +rustls = { version = "0.21.9" } serde_test = "1.0.130" serde_yaml = "0.9" tokio = { version = "1", features = ["rt-multi-thread", "io-util", "net"] } From 33d73f81ab23b8ed6a5dbd6fdcd1b382f9f5962f Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Tue, 5 Dec 2023 10:43:29 +0100 Subject: [PATCH 075/124] Rename ConnFactory, TcpConnFactory, TlsConnFactory to ConnectionStream, TcpConnStream, TlsConnStream. --- examples/client-transports.rs | 20 +++++----- .../{factory.rs => connection_stream.rs} | 4 +- src/net/client/mod.rs | 6 +-- src/net/client/multi_stream.rs | 40 +++++++++---------- .../{tcp_factory.rs => tcp_conn_stream.rs} | 20 +++++----- .../{tls_factory.rs => tls_conn_stream.rs} | 18 ++++----- src/net/client/udp_tcp.rs | 8 ++-- 7 files changed, 58 insertions(+), 58 deletions(-) rename src/net/client/{factory.rs => connection_stream.rs} (85%) rename src/net/client/{tcp_factory.rs => tcp_conn_stream.rs} (72%) rename src/net/client/{tls_factory.rs => tls_conn_stream.rs} (86%) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index b2e3fd25a..f655c3cfb 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -6,8 +6,8 @@ use domain::net::client::multi_stream; use domain::net::client::octet_stream; use domain::net::client::query::QueryMessage4; use domain::net::client::redundant; -use domain::net::client::tcp_factory::TcpConnFactory; -use domain::net::client::tls_factory::TlsConnFactory; +use domain::net::client::tcp_conn_stream::TcpConnStream; +use domain::net::client::tls_conn_stream::TlsConnStream; use domain::net::client::udp; use domain::net::client::udp_tcp; use std::net::{IpAddr, SocketAddr}; @@ -84,9 +84,9 @@ async fn main() { // when it is no longer needed. drop(query); - // Create a factory of TCP connections. Pass the destination address and + // Create a stream of TCP connections. Pass the destination address and // port as parameter. - let tcp_factory = TcpConnFactory::new(server_addr); + let tcp_conn_stream = TcpConnStream::new(server_addr); // A muli_stream transport connection sets up new TCP connections when // needed. @@ -95,8 +95,8 @@ async fn main() { .unwrap(); // Get a future for the run function. The run function receives - // the factory as a parameter. - let run_fut = tcp_conn.run(tcp_factory); + // the connection stream as a parameter. + let run_fut = tcp_conn.run(tcp_conn_stream); tokio::spawn(async move { let res = run_fut.await; println!("multi TCP run exited with {:?}", res); @@ -138,17 +138,17 @@ async fn main() { let google_server_addr = SocketAddr::new(IpAddr::from_str("8.8.8.8").unwrap(), 853); - // Create a new TLS connection factory. We pass the TLS config, the name of + // Create a new TLS connection stream. We pass the TLS config, the name of // the remote server and the destination address and port. - let tls_factory = - TlsConnFactory::new(client_config, "dns.google", google_server_addr); + let tls_conn_stream = + TlsConnStream::new(client_config, "dns.google", google_server_addr); // Again create a multi_stream transport connection. let tls_conn = multi_stream::Connection::new(Some(multi_stream_config)).unwrap(); // Start the run function. - let run_fut = tls_conn.run(tls_factory); + let run_fut = tls_conn.run(tls_conn_stream); tokio::spawn(async move { let res = run_fut.await; println!("TLS run exited with {:?}", res); diff --git a/src/net/client/factory.rs b/src/net/client/connection_stream.rs similarity index 85% rename from src/net/client/factory.rs rename to src/net/client/connection_stream.rs index 47e5a6b41..aea49a4b1 100644 --- a/src/net/client/factory.rs +++ b/src/net/client/connection_stream.rs @@ -1,4 +1,4 @@ -//! Trait for connection factories +//! Trait for connection streams #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] @@ -9,7 +9,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; /// This trait is for creating new network connections. /// /// The IO type is the type of the resulting connection object. -pub trait ConnFactory { +pub trait ConnectionStream { /// The next method is an asynchronous function that returns a /// new connection. /// diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 85b900dbf..69355ef93 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -9,13 +9,13 @@ pub mod base_message_builder; pub mod bmb; +pub mod connection_stream; pub mod error; -pub mod factory; pub mod multi_stream; pub mod octet_stream; pub mod query; pub mod redundant; -pub mod tcp_factory; -pub mod tls_factory; +pub mod tcp_conn_stream; +pub mod tls_conn_stream; pub mod udp; pub mod udp_tcp; diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index bf75a7669..0f930c835 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -31,8 +31,8 @@ use tokio::time::{sleep_until, Instant}; use crate::base::iana::Rcode; use crate::base::Message; use crate::net::client::base_message_builder::BaseMessageBuilder; +use crate::net::client::connection_stream::ConnectionStream; use crate::net::client::error::Error; -use crate::net::client::factory::ConnFactory; use crate::net::client::octet_stream; use crate::net::client::query::{GetResult, QueryMessage4}; @@ -84,13 +84,13 @@ impl Connection { /// This function has to run in the background or together with /// any calls to [query](Self::query) or [Query::get_result]. pub fn run< - F: ConnFactory + Send + 'static, + S: ConnectionStream + Send + 'static, IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, >( &self, - factory: F, + stream: S, ) -> Pin> + Send>> { - self.inner.run(factory) + self.inner.run(stream) } /// Start a DNS request. @@ -394,15 +394,15 @@ type ReplySender = oneshot::Sender>; /// Internal datastructure of [InnerConnection::run] to keep track of /// the status of the connection. // The types Status and ConnState are only used in InnerConnection -struct State3<'a, F, IO, BMB> { +struct State3<'a, S, IO, BMB> { /// Underlying octet_stream connection. conn_state: SingleConnState3, /// Current connection id. conn_id: u64, - /// Factory for new octet streams. - factory: F, + /// Connection stream for new octet streams. + stream: S, /// Collection of futures for the async run function of the underlying /// octet_stream. @@ -423,7 +423,7 @@ enum SingleConnState3 { Some(octet_stream::Connection), /// State that deals with an error getting a new octet stream from - /// a factory. + /// a connection stream. Err(ErrorState), } @@ -463,17 +463,17 @@ impl InnerConnection { /// This function is not async cancellation safe. /// Make sure the resulting future does not contain a reference to self. pub fn run< - F: ConnFactory + Send + 'static, + S: ConnectionStream + Send + 'static, IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, >( &self, - factory: F, + stream: S, ) -> Pin> + Send>> { let mut receiver = self.receiver.lock().unwrap(); let opt_receiver = receiver.take(); drop(receiver); - Box::pin(Self::run_impl(self.config.clone(), factory, opt_receiver)) + Box::pin(Self::run_impl(self.config.clone(), stream, opt_receiver)) } /// Implementation of the run method. This function does not have @@ -481,11 +481,11 @@ impl InnerConnection { #[rustfmt::skip] async fn run_impl< 'a, - F: ConnFactory + Send, + S: ConnectionStream + Send, IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Unpin, >( config: Config, - factory: F, + stream: S, opt_receiver: Option>> ) -> Result<(), Error> { let mut receiver = { @@ -493,10 +493,10 @@ impl InnerConnection { }; let mut curr_cmd: Option> = None; - let mut state = State3::<'a, F, IO, BMB> { + let mut state = State3::<'a, S, IO, BMB> { conn_state: SingleConnState3::None, conn_id: 0, - factory, + stream, runners: FuturesUnordered::< Pin> + Send>>, >::new(), @@ -506,7 +506,7 @@ impl InnerConnection { let mut do_stream = false; let mut stream_fut: Pin< Box> + Send>, - > = Box::pin(factory_nop()); + > = Box::pin(stream_nop()); let mut opt_chan = None; loop { @@ -557,7 +557,7 @@ impl InnerConnection { _ = chan.send(resp); } else { opt_chan = Some(chan); - stream_fut = Box::pin(state.factory.next()); + stream_fut = Box::pin(state.stream.next()); do_stream = true; } } @@ -572,7 +572,7 @@ impl InnerConnection { tokio::select! { res_conn = stream_fut.as_mut() => { do_stream = false; - stream_fut = Box::pin(factory_nop()); + stream_fut = Box::pin(stream_nop()); if let Err(error) = res_conn { let error = Arc::new(error); @@ -763,8 +763,8 @@ fn is_answer_ignore_id< } /// Helper function to create an empty future that is compatible with the -/// future return by a factory. -async fn factory_nop() -> Result { +/// future returned by a connection stream. +async fn stream_nop() -> Result { Err(io::Error::new(io::ErrorKind::Other, "nop")) } diff --git a/src/net/client/tcp_factory.rs b/src/net/client/tcp_conn_stream.rs similarity index 72% rename from src/net/client/tcp_factory.rs rename to src/net/client/tcp_conn_stream.rs index 6abecdfae..4fc1becce 100644 --- a/src/net/client/tcp_factory.rs +++ b/src/net/client/tcp_conn_stream.rs @@ -1,4 +1,4 @@ -//! A factory for TCP connections +//! A stream of TCP connections #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] @@ -11,27 +11,27 @@ use std::pin::Pin; use std::task::{ready, Context, Poll}; use tokio::net::{TcpStream, ToSocketAddrs}; -use crate::net::client::factory::ConnFactory; +use crate::net::client::connection_stream::ConnectionStream; -//------------ TcpConnFactory ------------------------------------------------- +//------------ TcpConnStream -------------------------------------------------- -/// This a connection factory that produces TCP connections. -pub struct TcpConnFactory { +/// This a stream of TCP connections. +pub struct TcpConnStream { /// Remote address to connect to. addr: A, } -impl TcpConnFactory { - /// Create a new factory. +impl TcpConnStream { + /// Create a new TCP connection stream. /// /// addr is the destination address to connect to. - pub fn new(addr: A) -> TcpConnFactory { + pub fn new(addr: A) -> Self { Self { addr } } } -impl ConnFactory - for TcpConnFactory +impl ConnectionStream + for TcpConnStream { type F = Pin< Box> + Send>, diff --git a/src/net/client/tls_factory.rs b/src/net/client/tls_conn_stream.rs similarity index 86% rename from src/net/client/tls_factory.rs rename to src/net/client/tls_conn_stream.rs index 64b680bd3..46777d19f 100644 --- a/src/net/client/tls_factory.rs +++ b/src/net/client/tls_conn_stream.rs @@ -1,4 +1,4 @@ -//! A factory for TLS connections +//! A stream of TLS connections #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] @@ -15,12 +15,12 @@ use tokio_rustls::client::TlsStream; use tokio_rustls::rustls::{ClientConfig, ServerName}; use tokio_rustls::TlsConnector; -use crate::net::client::factory::ConnFactory; +use crate::net::client::connection_stream::ConnectionStream; -//------------ TlsConnFactory ------------------------------------------------- +//------------ TlsConnStream ------------------------------------------------- -/// Factory object for TLS connections -pub struct TlsConnFactory { +/// Stream of TLS connections +pub struct TlsConnStream { /// Configuration for setting up a TLS connection. client_config: Arc, @@ -31,13 +31,13 @@ pub struct TlsConnFactory { addr: A, } -impl TlsConnFactory { - /// Function to create a new TLS connection factory +impl TlsConnStream { + /// Function to create a new TLS connection stream pub fn new( client_config: Arc, server_name: &str, addr: A, - ) -> TlsConnFactory { + ) -> Self { Self { client_config, server_name: String::from(server_name), @@ -47,7 +47,7 @@ impl TlsConnFactory { } impl - ConnFactory> for TlsConnFactory + ConnectionStream> for TlsConnStream { type F = Pin< Box< diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index ab32925f1..968c5589e 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -19,7 +19,7 @@ use crate::net::client::base_message_builder::BaseMessageBuilder; use crate::net::client::error::Error; use crate::net::client::multi_stream; use crate::net::client::query::{GetResult, QueryMessage4}; -use crate::net::client::tcp_factory::TcpConnFactory; +use crate::net::client::tcp_conn_stream::TcpConnStream; use crate::net::client::udp; //------------ Config --------------------------------------------------------- @@ -229,12 +229,12 @@ impl InnerConnection { /// Implementation of the worker function. /// - /// Create a TCP connection factory and pass that to worker function + /// Create a TCP connection stream and pass that to worker function /// of the multi_stream object. fn run(&self) -> Pin> + Send>> { - let tcp_factory = TcpConnFactory::new(self.remote_addr); + let tcp_stream = TcpConnStream::new(self.remote_addr); - let fut = self.tcp_conn.run(tcp_factory); + let fut = self.tcp_conn.run(tcp_stream); Box::pin(fut) } From 2151e128d6ab788505ee2f6197b65ed98ec36de0 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 7 Dec 2023 15:45:06 +0100 Subject: [PATCH 076/124] Rename ConnectionStream to AsyncConnect --- examples/client-transports.rs | 18 ++++++++--------- ...{connection_stream.rs => async_connect.rs} | 10 +++++----- src/net/client/mod.rs | 6 +++--- src/net/client/multi_stream.rs | 10 +++++----- .../{tcp_conn_stream.rs => tcp_connect.rs} | 20 +++++++++---------- .../{tls_conn_stream.rs => tls_connect.rs} | 16 +++++++-------- src/net/client/udp_tcp.rs | 8 ++++---- 7 files changed, 44 insertions(+), 44 deletions(-) rename src/net/client/{connection_stream.rs => async_connect.rs} (59%) rename src/net/client/{tcp_conn_stream.rs => tcp_connect.rs} (72%) rename src/net/client/{tls_conn_stream.rs => tls_connect.rs} (88%) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index f655c3cfb..410004828 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -6,8 +6,8 @@ use domain::net::client::multi_stream; use domain::net::client::octet_stream; use domain::net::client::query::QueryMessage4; use domain::net::client::redundant; -use domain::net::client::tcp_conn_stream::TcpConnStream; -use domain::net::client::tls_conn_stream::TlsConnStream; +use domain::net::client::tcp_connect::TcpConnect; +use domain::net::client::tls_connect::TlsConnect; use domain::net::client::udp; use domain::net::client::udp_tcp; use std::net::{IpAddr, SocketAddr}; @@ -84,9 +84,9 @@ async fn main() { // when it is no longer needed. drop(query); - // Create a stream of TCP connections. Pass the destination address and + // Create a new TCP connections object. Pass the destination address and // port as parameter. - let tcp_conn_stream = TcpConnStream::new(server_addr); + let tcp_connect = TcpConnect::new(server_addr); // A muli_stream transport connection sets up new TCP connections when // needed. @@ -96,7 +96,7 @@ async fn main() { // Get a future for the run function. The run function receives // the connection stream as a parameter. - let run_fut = tcp_conn.run(tcp_conn_stream); + let run_fut = tcp_conn.run(tcp_connect); tokio::spawn(async move { let res = run_fut.await; println!("multi TCP run exited with {:?}", res); @@ -138,17 +138,17 @@ async fn main() { let google_server_addr = SocketAddr::new(IpAddr::from_str("8.8.8.8").unwrap(), 853); - // Create a new TLS connection stream. We pass the TLS config, the name of + // Create a new TLS connections object. We pass the TLS config, the name of // the remote server and the destination address and port. - let tls_conn_stream = - TlsConnStream::new(client_config, "dns.google", google_server_addr); + let tls_connect = + TlsConnect::new(client_config, "dns.google", google_server_addr); // Again create a multi_stream transport connection. let tls_conn = multi_stream::Connection::new(Some(multi_stream_config)).unwrap(); // Start the run function. - let run_fut = tls_conn.run(tls_conn_stream); + let run_fut = tls_conn.run(tls_connect); tokio::spawn(async move { let res = run_fut.await; println!("TLS run exited with {:?}", res); diff --git a/src/net/client/connection_stream.rs b/src/net/client/async_connect.rs similarity index 59% rename from src/net/client/connection_stream.rs rename to src/net/client/async_connect.rs index aea49a4b1..56bb9c907 100644 --- a/src/net/client/connection_stream.rs +++ b/src/net/client/async_connect.rs @@ -1,4 +1,4 @@ -//! Trait for connection streams +//! Trait for asynchronously creating connections. #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] @@ -6,18 +6,18 @@ use std::future::Future; use tokio::io::{AsyncRead, AsyncWrite}; -/// This trait is for creating new network connections. +/// This trait is for creating new network connections asynchronously. /// /// The IO type is the type of the resulting connection object. -pub trait ConnectionStream { +pub trait AsyncConnect { /// The next method is an asynchronous function that returns a /// new connection. /// - /// This method is equivalent to async fn next(&self) -> Result; + /// This method is equivalent to async fn connect(&self) -> Result; /// Associated type for the return type of next. type F: Future> + Send; /// Get the next IO connection. - fn next(&self) -> Self::F; + fn connect(&self) -> Self::F; } diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 69355ef93..d44a562ba 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -7,15 +7,15 @@ #![doc = include_str!("../../../examples/client-transports.rs")] //! ``` +pub mod async_connect; pub mod base_message_builder; pub mod bmb; -pub mod connection_stream; pub mod error; pub mod multi_stream; pub mod octet_stream; pub mod query; pub mod redundant; -pub mod tcp_conn_stream; -pub mod tls_conn_stream; +pub mod tcp_connect; +pub mod tls_connect; pub mod udp; pub mod udp_tcp; diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 0f930c835..32a960482 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -30,8 +30,8 @@ use tokio::time::{sleep_until, Instant}; use crate::base::iana::Rcode; use crate::base::Message; +use crate::net::client::async_connect::AsyncConnect; use crate::net::client::base_message_builder::BaseMessageBuilder; -use crate::net::client::connection_stream::ConnectionStream; use crate::net::client::error::Error; use crate::net::client::octet_stream; use crate::net::client::query::{GetResult, QueryMessage4}; @@ -84,7 +84,7 @@ impl Connection { /// This function has to run in the background or together with /// any calls to [query](Self::query) or [Query::get_result]. pub fn run< - S: ConnectionStream + Send + 'static, + S: AsyncConnect + Send + 'static, IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, >( &self, @@ -463,7 +463,7 @@ impl InnerConnection { /// This function is not async cancellation safe. /// Make sure the resulting future does not contain a reference to self. pub fn run< - S: ConnectionStream + Send + 'static, + S: AsyncConnect + Send + 'static, IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, >( &self, @@ -481,7 +481,7 @@ impl InnerConnection { #[rustfmt::skip] async fn run_impl< 'a, - S: ConnectionStream + Send, + S: AsyncConnect + Send, IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Unpin, >( config: Config, @@ -557,7 +557,7 @@ impl InnerConnection { _ = chan.send(resp); } else { opt_chan = Some(chan); - stream_fut = Box::pin(state.stream.next()); + stream_fut = Box::pin(state.stream.connect()); do_stream = true; } } diff --git a/src/net/client/tcp_conn_stream.rs b/src/net/client/tcp_connect.rs similarity index 72% rename from src/net/client/tcp_conn_stream.rs rename to src/net/client/tcp_connect.rs index 4fc1becce..7c68a69d3 100644 --- a/src/net/client/tcp_conn_stream.rs +++ b/src/net/client/tcp_connect.rs @@ -1,4 +1,4 @@ -//! A stream of TCP connections +//! Create new TCP connections. #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] @@ -11,18 +11,18 @@ use std::pin::Pin; use std::task::{ready, Context, Poll}; use tokio::net::{TcpStream, ToSocketAddrs}; -use crate::net::client::connection_stream::ConnectionStream; +use crate::net::client::async_connect::AsyncConnect; -//------------ TcpConnStream -------------------------------------------------- +//------------ TcpConnect -------------------------------------------------- -/// This a stream of TCP connections. -pub struct TcpConnStream { +/// Create new TCP connections. +pub struct TcpConnect { /// Remote address to connect to. addr: A, } -impl TcpConnStream { - /// Create a new TCP connection stream. +impl TcpConnect { + /// Create new TCP connections. /// /// addr is the destination address to connect to. pub fn new(addr: A) -> Self { @@ -30,14 +30,14 @@ impl TcpConnStream { } } -impl ConnectionStream - for TcpConnStream +impl AsyncConnect + for TcpConnect { type F = Pin< Box> + Send>, >; - fn next(&self) -> Self::F { + fn connect(&self) -> Self::F { Box::pin(Next { future: Box::pin(TcpStream::connect(self.addr.clone())), }) diff --git a/src/net/client/tls_conn_stream.rs b/src/net/client/tls_connect.rs similarity index 88% rename from src/net/client/tls_conn_stream.rs rename to src/net/client/tls_connect.rs index 46777d19f..52d88e6d2 100644 --- a/src/net/client/tls_conn_stream.rs +++ b/src/net/client/tls_connect.rs @@ -1,4 +1,4 @@ -//! A stream of TLS connections +//! Create new TLS connections #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] @@ -15,12 +15,12 @@ use tokio_rustls::client::TlsStream; use tokio_rustls::rustls::{ClientConfig, ServerName}; use tokio_rustls::TlsConnector; -use crate::net::client::connection_stream::ConnectionStream; +use crate::net::client::async_connect::AsyncConnect; -//------------ TlsConnStream ------------------------------------------------- +//------------ TlsConnect ----------------------------------------------------- -/// Stream of TLS connections -pub struct TlsConnStream { +/// Create new TLS connections +pub struct TlsConnect { /// Configuration for setting up a TLS connection. client_config: Arc, @@ -31,7 +31,7 @@ pub struct TlsConnStream { addr: A, } -impl TlsConnStream { +impl TlsConnect { /// Function to create a new TLS connection stream pub fn new( client_config: Arc, @@ -47,7 +47,7 @@ impl TlsConnStream { } impl - ConnectionStream> for TlsConnStream + AsyncConnect> for TlsConnect { type F = Pin< Box< @@ -56,7 +56,7 @@ impl >, >; - fn next(&self) -> Self::F { + fn connect(&self) -> Self::F { let tls_connection = TlsConnector::from(self.client_config.clone()); let server_name = match ServerName::try_from(self.server_name.as_str()) { diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index 968c5589e..39c4b2f23 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -19,7 +19,7 @@ use crate::net::client::base_message_builder::BaseMessageBuilder; use crate::net::client::error::Error; use crate::net::client::multi_stream; use crate::net::client::query::{GetResult, QueryMessage4}; -use crate::net::client::tcp_conn_stream::TcpConnStream; +use crate::net::client::tcp_connect::TcpConnect; use crate::net::client::udp; //------------ Config --------------------------------------------------------- @@ -229,12 +229,12 @@ impl InnerConnection { /// Implementation of the worker function. /// - /// Create a TCP connection stream and pass that to worker function + /// Create a TCP connect object and pass that to run function /// of the multi_stream object. fn run(&self) -> Pin> + Send>> { - let tcp_stream = TcpConnStream::new(self.remote_addr); + let tcp_connect = TcpConnect::new(self.remote_addr); - let fut = self.tcp_conn.run(tcp_stream); + let fut = self.tcp_conn.run(tcp_connect); Box::pin(fut) } From 552ab5c7e48cd8afb947f04e1958dde70f9c7b6c Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 7 Dec 2023 16:06:14 +0100 Subject: [PATCH 077/124] Move parameter IO to associated type Connection. --- src/net/client/async_connect.rs | 6 ++++-- src/net/client/multi_stream.rs | 16 ++++++++-------- src/net/client/tcp_connect.rs | 3 ++- src/net/client/tls_connect.rs | 5 +++-- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/net/client/async_connect.rs b/src/net/client/async_connect.rs index 56bb9c907..91101e6dd 100644 --- a/src/net/client/async_connect.rs +++ b/src/net/client/async_connect.rs @@ -9,14 +9,16 @@ use tokio::io::{AsyncRead, AsyncWrite}; /// This trait is for creating new network connections asynchronously. /// /// The IO type is the type of the resulting connection object. -pub trait AsyncConnect { +pub trait AsyncConnect { /// The next method is an asynchronous function that returns a /// new connection. /// /// This method is equivalent to async fn connect(&self) -> Result; + type Connection: AsyncRead + AsyncWrite + Send + Unpin; + /// Associated type for the return type of next. - type F: Future> + Send; + type F: Future> + Send; /// Get the next IO connection. fn connect(&self) -> Self::F; diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 32a960482..704506dc0 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -84,8 +84,8 @@ impl Connection { /// This function has to run in the background or together with /// any calls to [query](Self::query) or [Query::get_result]. pub fn run< - S: AsyncConnect + Send + 'static, - IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, + S: AsyncConnect + Send + 'static, + C: 'static + AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, >( &self, stream: S, @@ -463,8 +463,8 @@ impl InnerConnection { /// This function is not async cancellation safe. /// Make sure the resulting future does not contain a reference to self. pub fn run< - S: AsyncConnect + Send + 'static, - IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, + S: AsyncConnect + Send + 'static, + C: 'static + AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, >( &self, stream: S, @@ -481,8 +481,8 @@ impl InnerConnection { #[rustfmt::skip] async fn run_impl< 'a, - S: AsyncConnect + Send, - IO: 'static + AsyncRead + AsyncWrite + Debug + Send + Unpin, + S: AsyncConnect + Send, + C: 'static + AsyncRead + AsyncWrite + Debug + Send + Unpin, >( config: Config, stream: S, @@ -493,7 +493,7 @@ impl InnerConnection { }; let mut curr_cmd: Option> = None; - let mut state = State3::<'a, S, IO, BMB> { + let mut state = State3::<'a, S, C, BMB> { conn_state: SingleConnState3::None, conn_id: 0, stream, @@ -505,7 +505,7 @@ impl InnerConnection { let mut do_stream = false; let mut stream_fut: Pin< - Box> + Send>, + Box> + Send>, > = Box::pin(stream_nop()); let mut opt_chan = None; diff --git a/src/net/client/tcp_connect.rs b/src/net/client/tcp_connect.rs index 7c68a69d3..7fd5bf16f 100644 --- a/src/net/client/tcp_connect.rs +++ b/src/net/client/tcp_connect.rs @@ -30,9 +30,10 @@ impl TcpConnect { } } -impl AsyncConnect +impl AsyncConnect for TcpConnect { + type Connection = TcpStream; type F = Pin< Box> + Send>, >; diff --git a/src/net/client/tls_connect.rs b/src/net/client/tls_connect.rs index 52d88e6d2..bb6c39501 100644 --- a/src/net/client/tls_connect.rs +++ b/src/net/client/tls_connect.rs @@ -46,9 +46,10 @@ impl TlsConnect { } } -impl - AsyncConnect> for TlsConnect +impl AsyncConnect + for TlsConnect { + type Connection = TlsStream; type F = Pin< Box< dyn Future, std::io::Error>> From f186a79b50e0617e9b27f9763ae4e90421431171 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 7 Dec 2023 16:09:38 +0100 Subject: [PATCH 078/124] Remove type constraints for Connection. --- src/net/client/async_connect.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/net/client/async_connect.rs b/src/net/client/async_connect.rs index 91101e6dd..547173fef 100644 --- a/src/net/client/async_connect.rs +++ b/src/net/client/async_connect.rs @@ -4,7 +4,6 @@ #![warn(clippy::missing_docs_in_private_items)] use std::future::Future; -use tokio::io::{AsyncRead, AsyncWrite}; /// This trait is for creating new network connections asynchronously. /// @@ -15,7 +14,7 @@ pub trait AsyncConnect { /// /// This method is equivalent to async fn connect(&self) -> Result; - type Connection: AsyncRead + AsyncWrite + Send + Unpin; + type Connection; /// Associated type for the return type of next. type F: Future> + Send; From 53803b5ac4a596edb9afd9f4d7669e2daedbe700 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 7 Dec 2023 16:52:01 +0100 Subject: [PATCH 079/124] Rename BaseMessageBuilder, BMB to ComposeRequest, RequestMessage. --- examples/client-transports.rs | 18 ++-- ..._message_builder.rs => compose_request.rs} | 12 +-- src/net/client/mod.rs | 4 +- src/net/client/multi_stream.rs | 92 +++++++++---------- src/net/client/octet_stream.rs | 51 +++++----- src/net/client/{bmb.rs => request_message.rs} | 21 ++--- src/net/client/udp.rs | 26 +++--- src/net/client/udp_tcp.rs | 28 +++--- 8 files changed, 119 insertions(+), 133 deletions(-) rename src/net/client/{base_message_builder.rs => compose_request.rs} (74%) rename src/net/client/{bmb.rs => request_message.rs} (90%) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 410004828..72768249f 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -1,11 +1,11 @@ use domain::base::Dname; use domain::base::MessageBuilder; use domain::base::Rtype::Aaaa; -use domain::net::client::bmb::BMB; use domain::net::client::multi_stream; use domain::net::client::octet_stream; use domain::net::client::query::QueryMessage4; use domain::net::client::redundant; +use domain::net::client::request_message::RequestMessage; use domain::net::client::tcp_connect::TcpConnect; use domain::net::client::tls_connect::TlsConnect; use domain::net::client::udp; @@ -35,7 +35,7 @@ async fn main() { // Transports take a BaseMEssageBuilder to be able to add options along // the way and only flatten just before actually writing to the network. - let bmb = BMB::new(msg); + let req = RequestMessage::new(msg); // Destination for UDP and TCP let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); @@ -73,7 +73,7 @@ async fn main() { }); // Send a query message. - let mut query = udptcp_conn.query(&bmb).await.unwrap(); + let mut query = udptcp_conn.query(&req).await.unwrap(); // Get the reply println!("Wating for UDP+TCP reply"); @@ -103,7 +103,7 @@ async fn main() { }); // Send a query message. - let mut query = tcp_conn.query(&bmb).await.unwrap(); + let mut query = tcp_conn.query(&req).await.unwrap(); // Get the reply. A multi_stream connection does not have any timeout. // Wrap get_result in a timeout. @@ -154,7 +154,7 @@ async fn main() { println!("TLS run exited with {:?}", res); }); - let mut query = tls_conn.query(&bmb).await.unwrap(); + let mut query = tls_conn.query(&req).await.unwrap(); println!("Wating for TLS reply"); let reply = timeout(Duration::from_millis(500), query.get_result()).await; println!("TLS reply: {:?}", reply); @@ -178,7 +178,7 @@ async fn main() { // Start a few queries. for i in 1..10 { - let mut query = redun.query(&bmb).await.unwrap(); + let mut query = redun.query(&req).await.unwrap(); let reply = query.get_result().await; if i == 2 { println!("redundant connection reply: {:?}", reply); @@ -195,7 +195,7 @@ async fn main() { udp::Connection::new(Some(udp_config), server_addr).unwrap(); // Send a query message. - let mut query = udp_conn.query(&bmb).await.unwrap(); + let mut query = udp_conn.query(&req).await.unwrap(); // Get the reply let reply = query.get_result().await; @@ -214,7 +214,7 @@ async fn main() { } }; - let tcp = octet_stream::Connection::>>::new(None).unwrap(); + let tcp = octet_stream::Connection::new(None).unwrap(); let run_fut = tcp.run(tcp_conn); tokio::spawn(async move { run_fut.await; @@ -222,7 +222,7 @@ async fn main() { }); // Send a query message. - let mut query = tcp.query(&bmb).await.unwrap(); + let mut query = tcp.query(&req).await.unwrap(); // Get the reply let reply = query.get_result().await; diff --git a/src/net/client/base_message_builder.rs b/src/net/client/compose_request.rs similarity index 74% rename from src/net/client/base_message_builder.rs rename to src/net/client/compose_request.rs index 8b7c62804..9bd040479 100644 --- a/src/net/client/base_message_builder.rs +++ b/src/net/client/compose_request.rs @@ -1,12 +1,11 @@ -//! Trait for building a message by applying changes to a base message. +//! Trait for composing a request by applying limited changes. #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] +use crate::base::opt::TcpKeepalive; use crate::base::Header; use crate::base::Message; -//use crate::base::message_builder::OptBuilder; -use crate::base::opt::TcpKeepalive; use std::boxed::Box; use std::fmt::Debug; @@ -19,11 +18,10 @@ pub enum OptTypes { TypeTcpKeepalive(TcpKeepalive), } -/// A trait that allows construction of a message as a series to changes to -/// an existing message. -pub trait BaseMessageBuilder: Debug + Send + Sync { +/// A trait that allows composing a request as a series. +pub trait ComposeRequest: Debug + Send + Sync { /// Return a boxed dyn of the current object. - fn as_box_dyn(&self) -> Box; + fn as_box_dyn(&self) -> Box; /// Create a message that captures the recorded changes. fn to_message(&self) -> Message>; diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index d44a562ba..5db95a4da 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -8,13 +8,13 @@ //! ``` pub mod async_connect; -pub mod base_message_builder; -pub mod bmb; +pub mod compose_request; pub mod error; pub mod multi_stream; pub mod octet_stream; pub mod query; pub mod redundant; +pub mod request_message; pub mod tcp_connect; pub mod tls_connect; pub mod udp; diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 704506dc0..5d1eb847f 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -31,7 +31,7 @@ use tokio::time::{sleep_until, Instant}; use crate::base::iana::Rcode; use crate::base::Message; use crate::net::client::async_connect::AsyncConnect; -use crate::net::client::base_message_builder::BaseMessageBuilder; +use crate::net::client::compose_request::ComposeRequest; use crate::net::client::error::Error; use crate::net::client::octet_stream; use crate::net::client::query::{GetResult, QueryMessage4}; @@ -56,12 +56,12 @@ pub struct Config { #[derive(Clone, Debug)] /// A DNS over octect streams transport. -pub struct Connection { +pub struct Connection { /// Reference counted [InnerConnection]. - inner: Arc>, + inner: Arc>, } -impl Connection { +impl Connection { /// Constructor for [Connection]. /// /// Returns a [Connection] wrapped in a [Result](io::Result). @@ -99,11 +99,11 @@ impl Connection { /// returns a [Query] object wrapped in a [Result]. async fn query_impl4( &self, - query_msg: &BMB, + query_msg: &CR, ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.inner.new_conn(None, tx).await?; - let gr = Query::::new(self.clone(), query_msg, rx); + let gr = Query::::new(self.clone(), query_msg, rx); Ok(Box::new(gr)) } @@ -116,32 +116,32 @@ impl Connection { async fn new_conn( &self, id: u64, - tx: oneshot::Sender>, + tx: oneshot::Sender>, ) -> Result<(), Error> { self.inner.new_conn(Some(id), tx).await } } /* -impl - QueryMessage, Octs> for Connection<> +impl + QueryMessage, Octs> for Connection<> { fn query<'a>( &'a self, query_msg: &'a Message, - ) -> Pin, Error>> + Send + '_>> + ) -> Pin, Error>> + Send + '_>> { return Box::pin(self.query_impl(query_msg)); } } */ -impl QueryMessage4 - for Connection +impl QueryMessage4 + for Connection { fn query<'a>( &'a self, - query_msg: &'a BMB, + query_msg: &'a CR, ) -> Pin< Box< dyn Future, Error>> @@ -157,20 +157,20 @@ impl QueryMessage4 /// This struct represent an active DNS query. #[derive(Debug)] -pub struct Query { +pub struct Query { /// Request message. /// /// The reply message is compared with the request message to see if /// it matches the query. // query_msg: Message>, - query_msg: BMB, + query_msg: CR, /// Current state of the query. - state: QueryState, + state: QueryState, /// A multi_octet connection object is needed to request new underlying /// octet_stream transport connections. - conn: Connection, + conn: Connection, /// id of most recent connection. conn_id: u64, @@ -183,12 +183,12 @@ pub struct Query { /// Status of a query. Used in [Query]. #[derive(Debug)] -enum QueryState { +enum QueryState { /// Get a octet_stream transport. - GetConn(oneshot::Receiver>), + GetConn(oneshot::Receiver>), /// Start a query using the transport. - StartQuery(octet_stream::Connection), + StartQuery(octet_stream::Connection), /// Get the result of the query. GetResult(octet_stream::QueryNoCheck), @@ -204,26 +204,26 @@ enum QueryState { } /// The reply to a NewConn request. -type ChanResp = Result, Arc>; +type ChanResp = Result, Arc>; /// Response to the DNS request sent by [InnerConnection::run] to [Query]. #[derive(Debug)] -struct ChanRespOk { +struct ChanRespOk { /// id of this connection. id: u64, /// New octet_stream transport. - conn: octet_stream::Connection, + conn: octet_stream::Connection, } -impl Query { +impl Query { /// Constructor for [Query], takes a DNS query and a receiver for the /// reply. fn new( - conn: Connection, - query_msg: &BMB, - receiver: oneshot::Receiver>, - ) -> Query { + conn: Connection, + query_msg: &CR, + receiver: oneshot::Receiver>, + ) -> Query { Self { conn, query_msg: query_msg.clone(), @@ -334,7 +334,7 @@ impl Query { } } -impl GetResult for Query { +impl GetResult for Query { fn get_result( &mut self, ) -> Pin< @@ -348,7 +348,7 @@ impl GetResult for Query { /// The actual implementation of [Connection]. #[derive(Debug)] -struct InnerConnection { +struct InnerConnection { /// User configuration values. config: Config, @@ -356,7 +356,7 @@ struct InnerConnection { /// part of a single channel. /// /// Used by [Query] to send requests to [InnerConnection::run]. - sender: mpsc::Sender>, + sender: mpsc::Sender>, /// receiver part of the channel. /// @@ -364,39 +364,39 @@ struct InnerConnection { /// [InnerConnection::run]. /// The Option is to allow [InnerConnection::run] to signal that the /// connection is closed. - receiver: Mutex>>>, + receiver: Mutex>>>, } #[derive(Debug)] /// A request to [Connection::run] either for a new octet_stream or to /// shutdown. -struct ChanReq { +struct ChanReq { /// A requests consists of a command. - cmd: ReqCmd, + cmd: ReqCmd, } #[derive(Debug)] /// Commands that can be requested. -enum ReqCmd { +enum ReqCmd { /// Request for a (new) connection. /// /// The id of the previous connection (if any) is passed as well as a /// channel to send the reply. - NewConn(Option, ReplySender), + NewConn(Option, ReplySender), /// Shutdown command. Shutdown, } /// This is the type of sender in [ReqCmd]. -type ReplySender = oneshot::Sender>; +type ReplySender = oneshot::Sender>; /// Internal datastructure of [InnerConnection::run] to keep track of /// the status of the connection. // The types Status and ConnState are only used in InnerConnection -struct State3<'a, S, IO, BMB> { +struct State3<'a, S, IO, CR> { /// Underlying octet_stream connection. - conn_state: SingleConnState3, + conn_state: SingleConnState3, /// Current connection id. conn_id: u64, @@ -415,12 +415,12 @@ struct State3<'a, S, IO, BMB> { } /// State of the current underlying octet_stream transport. -enum SingleConnState3 { +enum SingleConnState3 { /// No current octet_stream transport. None, /// Current octet_stream transport. - Some(octet_stream::Connection), + Some(octet_stream::Connection), /// State that deals with an error getting a new octet stream from /// a connection stream. @@ -444,7 +444,7 @@ struct ErrorState { timeout: Duration, } -impl InnerConnection { +impl InnerConnection { /// Constructor for [InnerConnection]. /// /// This is the implementation of [Connection::new]. @@ -486,14 +486,14 @@ impl InnerConnection { >( config: Config, stream: S, - opt_receiver: Option>> + opt_receiver: Option>> ) -> Result<(), Error> { let mut receiver = { opt_receiver.expect("no receiver present?") }; - let mut curr_cmd: Option> = None; + let mut curr_cmd: Option> = None; - let mut state = State3::<'a, S, C, BMB> { + let mut state = State3::<'a, S, C, CR> { conn_state: SingleConnState3::None, conn_id: 0, stream, @@ -674,7 +674,7 @@ impl InnerConnection { async fn new_conn( &self, opt_id: Option, - sender: oneshot::Sender>, + sender: oneshot::Sender>, ) -> Result<(), Error> { let req = ChanReq { cmd: ReqCmd::NewConn(opt_id, sender), diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index e08e211d6..77b0f4cfe 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -33,8 +33,8 @@ use crate::base::{ opt::{AllOptData, OptRecord, TcpKeepalive}, Message, }; -use crate::net::client::base_message_builder::BaseMessageBuilder; -use crate::net::client::base_message_builder::OptTypes; +use crate::net::client::compose_request::ComposeRequest; +use crate::net::client::compose_request::OptTypes; use crate::net::client::error::Error; use crate::net::client::query::{GetResult, QueryMessage4}; use octseq::Octets; @@ -88,12 +88,12 @@ impl Default for Config { #[derive(Clone, Debug)] /// A single DNS over octect stream connection. -pub struct Connection { +pub struct Connection { /// Reference counted [InnerConnection]. - inner: Arc>, + inner: Arc>, } -impl Connection { +impl Connection { /// Constructor for [Connection]. /// /// Returns a [Connection] wrapped in a [Result](io::Result). @@ -129,7 +129,7 @@ impl Connection { /// returns a [Query] object wrapped in a [Result]. async fn query_impl4( &self, - query_msg: &BMB, + query_msg: &CR, ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.inner.query(tx, query_msg).await?; @@ -143,7 +143,7 @@ impl Connection { /// match the request avoids having to keep the request around. pub async fn query_no_check( &self, - query_msg: &BMB, + query_msg: &CR, ) -> Result { let (tx, rx) = oneshot::channel(); self.inner.query(tx, query_msg).await?; @@ -151,12 +151,12 @@ impl Connection { } } -impl QueryMessage4 - for Connection +impl QueryMessage4 + for Connection { fn query<'a>( &'a self, - query_msg: &'a BMB, + query_msg: &'a CR, ) -> Pin< Box< dyn Future, Error>> @@ -198,8 +198,8 @@ enum QueryState { impl Query { /// Constructor for [Query], takes a DNS query and a receiver for the /// reply. - fn new( - query_msg: &BMB, + fn new( + query_msg: &CR, receiver: oneshot::Receiver, ) -> Query { let vec = query_msg.to_vec(); @@ -316,7 +316,7 @@ impl QueryNoCheck { /// The actual implementation of [Connection]. #[derive(Debug)] -struct InnerConnection { +struct InnerConnection { /// User configuration variables. config: Config, @@ -324,7 +324,7 @@ struct InnerConnection { /// part of a single channel. /// /// Used by [Query] to send requests to [InnerConnection::run]. - sender: mpsc::Sender>, + sender: mpsc::Sender>, /// receiver part of the channel. /// @@ -332,14 +332,14 @@ struct InnerConnection { /// [InnerConnection::run]. /// The Option is to allow [InnerConnection::run] to signal that the /// connection is closed. - receiver: Mutex>>>, + receiver: Mutex>>>, } #[derive(Debug)] /// A request from [Query] to [Connection::run] to start a DNS request. -struct ChanReq { +struct ChanReq { /// DNS request message - msg: BMB, + msg: CR, /// Sender to send result back to [Query] sender: ReplySender, @@ -427,7 +427,7 @@ enum ConnState { // This type could be local to InnerConnection, but I don't know how type ReaderChanReply = Message; -impl InnerConnection { +impl InnerConnection { /// Constructor for [InnerConnection]. /// /// This is the implementation of [Connection::new]. @@ -461,7 +461,7 @@ impl InnerConnection { async fn run_impl( config: Config, io: IO, - opt_receiver: Option>>, + opt_receiver: Option>>, ) -> Option<()> { let (reply_sender, mut reply_receiver) = mpsc::channel::(READ_REPLY_CHAN_CAP); @@ -639,7 +639,7 @@ impl InnerConnection { pub async fn query( &self, sender: oneshot::Sender, - query_msg: &BMB, + query_msg: &CR, ) -> Result<(), Error> { // We should figure out how to get query_msg. @@ -812,7 +812,7 @@ impl InnerConnection { /// idle. Addend a edns-tcp-keepalive option if needed. // Note: maybe reqmsg should be a return value. fn insert_req( - mut req: ChanReq, + mut req: ChanReq, status: &mut Status, reqmsg: &mut Option>, query_vec: &mut Queries, @@ -908,10 +908,7 @@ impl InnerConnection { /// Convert the query message to a vector. // This function should return the vector instead of storing it // through a reference. - fn convert_query( - msg: &dyn BaseMessageBuilder, - reqmsg: &mut Option>, - ) { + fn convert_query(msg: &dyn ComposeRequest, reqmsg: &mut Option>) { // Ideally there should be a write_all_vectored. Until there is one, // copy to a new Vec and prepend the length octets. @@ -994,9 +991,7 @@ impl InnerConnection { //------------ Utility -------------------------------------------------------- /// Add an edns-tcp-keepalive option to a BaseMessageBuilder. -fn add_tcp_keepalive( - msg: &mut BMB, -) -> Result<(), Error> { +fn add_tcp_keepalive(msg: &mut CR) -> Result<(), Error> { msg.add_opt(OptTypes::TypeTcpKeepalive(TcpKeepalive::new(None))); Ok(()) } diff --git a/src/net/client/bmb.rs b/src/net/client/request_message.rs similarity index 90% rename from src/net/client/bmb.rs rename to src/net/client/request_message.rs index 4c0d55968..a3c5f20bb 100644 --- a/src/net/client/bmb.rs +++ b/src/net/client/request_message.rs @@ -1,11 +1,8 @@ -//! Simple class that implement the BaseMessageBuilder trait. +//! Simple object that implements the ComposeRequest trait. #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] -//use bytes::BytesMut; - -//use crate::base::message_builder::OptBuilder; use crate::base::Header; use crate::base::Message; use crate::base::MessageBuilder; @@ -13,8 +10,8 @@ use crate::base::ParsedDname; use crate::base::Rtype; use crate::base::StaticCompressor; use crate::dep::octseq::Octets; -use crate::net::client::base_message_builder::BaseMessageBuilder; -use crate::net::client::base_message_builder::OptTypes; +use crate::net::client::compose_request::ComposeRequest; +use crate::net::client::compose_request::OptTypes; use crate::net::client::error::Error; use crate::rdata::AllRecordData; @@ -23,9 +20,9 @@ use std::fmt::Debug; use std::vec::Vec; #[derive(Clone, Debug)] -/// Object that implements the BaseMessageBuilder trait for a Message object. -pub struct BMB> { - /// Base messages. +/// Object that implements the ComposeRequest trait for a Message object. +pub struct RequestMessage> { + /// Base message. msg: Message, /// New header. @@ -38,7 +35,7 @@ pub struct BMB> { udp_payload_size: Option, } -impl + Debug + Octets> BMB { +impl + Debug + Octets> RequestMessage { /// Create a new BMB object. pub fn new(msg: Message) -> Self { let header = msg.header(); @@ -150,9 +147,9 @@ impl + Debug + Octets> BMB { } impl + Clone + Debug + Octets + Send + Sync + 'static> - BaseMessageBuilder for BMB + ComposeRequest for RequestMessage { - fn as_box_dyn(&self) -> Box { + fn as_box_dyn(&self) -> Box { Box::new(self.clone()) } diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 9987951cf..02c5dfbba 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -22,7 +22,7 @@ use tokio::time::{timeout, Duration, Instant}; use crate::base::iana::Rcode; use crate::base::Message; -use crate::net::client::base_message_builder::BaseMessageBuilder; +use crate::net::client::compose_request::ComposeRequest; use crate::net::client::error::Error; use crate::net::client::query::{GetResult, QueryMessage4}; @@ -122,10 +122,10 @@ impl Connection { /// Start a new DNS query. async fn query_impl4< - BMB: BaseMessageBuilder + Clone + Send + Sync + 'static, + CR: ComposeRequest + Clone + Send + Sync + 'static, >( &self, - query_msg: &BMB, + query_msg: &CR, ) -> Result, Error> { let gr = self.inner.query(query_msg, self.clone()).await?; Ok(Box::new(gr)) @@ -137,12 +137,12 @@ impl Connection { } } -impl - QueryMessage4 for Connection +impl QueryMessage4 + for Connection { fn query<'a>( &'a self, - query_msg: &'a BMB, + query_msg: &'a CR, ) -> Pin< Box< dyn Future, Error>> @@ -376,9 +376,9 @@ pub struct Query4 { impl Query4 { /// Create new Query object. - fn new( + fn new( config: Config, - query_msg: &BMB, + query_msg: &CR, remote_addr: SocketAddr, conn: Connection, udp_payload_size: Option, @@ -402,9 +402,9 @@ impl Query4 { /// Get the result of a DNS Query. /// /// This function is not cancel safe. - async fn get_result_impl2( + async fn get_result_impl2( config: Config, - mut query_bmb: BMB, + mut query_bmb: CR, remote_addr: SocketAddr, conn: Connection, udp_payload_size: Option, @@ -574,11 +574,9 @@ impl InnerConnection { } /// Return a Query object that contains the query state. - async fn query< - BMB: BaseMessageBuilder + Clone + Send + Sync + 'static, - >( + async fn query( &self, - query_msg: &BMB, + query_msg: &CR, conn: Connection, ) -> Result { Ok(Query4::new( diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index 39c4b2f23..e992c5278 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -15,7 +15,7 @@ use std::pin::Pin; use std::sync::Arc; use crate::base::Message; -use crate::net::client::base_message_builder::BaseMessageBuilder; +use crate::net::client::compose_request::ComposeRequest; use crate::net::client::error::Error; use crate::net::client::multi_stream; use crate::net::client::query::{GetResult, QueryMessage4}; @@ -44,7 +44,7 @@ pub struct Connection { inner: Arc>, } -impl Connection { +impl Connection { /// Create a new connection. pub fn new( config: Option, @@ -73,19 +73,19 @@ impl Connection { /// Start a query for the QueryMessage4 trait. async fn query_impl4( &self, - query_msg: &BMB, + query_msg: &CR, ) -> Result, Error> { let gr = self.inner.query(query_msg).await?; Ok(Box::new(gr)) } } -impl QueryMessage4 - for Connection +impl QueryMessage4 + for Connection { fn query<'a>( &'a self, - query_msg: &'a BMB, + query_msg: &'a CR, ) -> Pin< Box< dyn Future, Error>> @@ -131,15 +131,15 @@ enum QueryState { GetTcpResult(Box), } -impl Query { +impl Query { /// Create a new Query object. /// /// The initial state is to start with a UDP transport. fn new( - query_msg: &BMB, + query_msg: &CR, udp_conn: udp::Connection, - tcp_conn: multi_stream::Connection, - ) -> Query { + tcp_conn: multi_stream::Connection, + ) -> Query { Query { query_msg: query_msg.clone(), udp_conn, @@ -185,9 +185,7 @@ impl Query { } } -impl GetResult - for Query -{ +impl GetResult for Query { fn get_result( &mut self, ) -> Pin< @@ -211,7 +209,7 @@ struct InnerConnection { tcp_conn: multi_stream::Connection, } -impl InnerConnection { +impl InnerConnection { /// Create a new InnerConnection object. /// /// Create the UDP and TCP connections. Store the remote address because @@ -241,7 +239,7 @@ impl InnerConnection { /// Implementation of the query function. /// /// Just create a Query object with the state it needs. - async fn query(&self, query_msg: &BMB) -> Result, Error> { + async fn query(&self, query_msg: &CR) -> Result, Error> { Ok(Query::new( query_msg, self.udp_conn.clone(), From a04da47eb5d7e97a5f95cf8da61abf25a380363e Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 8 Dec 2023 11:49:47 +0100 Subject: [PATCH 080/124] Rename QueryMessage4, GetResult to Request, GetResponse. --- examples/client-transports.rs | 36 +++++----- src/net/client/mod.rs | 2 +- src/net/client/multi_stream.rs | 79 +++++++++------------- src/net/client/octet_stream.rs | 71 ++++++++++---------- src/net/client/query.rs | 78 ---------------------- src/net/client/redundant.rs | 118 ++++++++++++++++++--------------- src/net/client/request.rs | 40 +++++++++++ src/net/client/udp.rs | 80 +++++++++++----------- src/net/client/udp_tcp.rs | 118 ++++++++++++++++----------------- 9 files changed, 287 insertions(+), 335 deletions(-) delete mode 100644 src/net/client/query.rs create mode 100644 src/net/client/request.rs diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 72768249f..706a251df 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -3,8 +3,8 @@ use domain::base::MessageBuilder; use domain::base::Rtype::Aaaa; use domain::net::client::multi_stream; use domain::net::client::octet_stream; -use domain::net::client::query::QueryMessage4; use domain::net::client::redundant; +use domain::net::client::request::Request; use domain::net::client::request_message::RequestMessage; use domain::net::client::tcp_connect::TcpConnect; use domain::net::client::tls_connect::TlsConnect; @@ -73,16 +73,16 @@ async fn main() { }); // Send a query message. - let mut query = udptcp_conn.query(&req).await.unwrap(); + let mut request = udptcp_conn.request(&req).await.unwrap(); // Get the reply println!("Wating for UDP+TCP reply"); - let reply = query.get_result().await; + let reply = request.get_response().await; println!("UDP+TCP reply: {:?}", reply); // The query may have a reference to the connection. Drop the query // when it is no longer needed. - drop(query); + drop(request); // Create a new TCP connections object. Pass the destination address and // port as parameter. @@ -103,15 +103,16 @@ async fn main() { }); // Send a query message. - let mut query = tcp_conn.query(&req).await.unwrap(); + let mut request = tcp_conn.request(&req).await.unwrap(); // Get the reply. A multi_stream connection does not have any timeout. // Wrap get_result in a timeout. println!("Wating for multi TCP reply"); - let reply = timeout(Duration::from_millis(500), query.get_result()).await; + let reply = + timeout(Duration::from_millis(500), request.get_response()).await; println!("multi TCP reply: {:?}", reply); - drop(query); + drop(request); // Some TLS boiler plate for the root certificates. let mut root_store = RootCertStore::empty(); @@ -154,12 +155,13 @@ async fn main() { println!("TLS run exited with {:?}", res); }); - let mut query = tls_conn.query(&req).await.unwrap(); + let mut request = tls_conn.request(&req).await.unwrap(); println!("Wating for TLS reply"); - let reply = timeout(Duration::from_millis(500), query.get_result()).await; + let reply = + timeout(Duration::from_millis(500), request.get_response()).await; println!("TLS reply: {:?}", reply); - drop(query); + drop(request); // Create a transport connection for redundant connections. let redun = redundant::Connection::new(None).unwrap(); @@ -178,8 +180,8 @@ async fn main() { // Start a few queries. for i in 1..10 { - let mut query = redun.query(&req).await.unwrap(); - let reply = query.get_result().await; + let mut request = redun.request(&req).await.unwrap(); + let reply = request.get_response().await; if i == 2 { println!("redundant connection reply: {:?}", reply); } @@ -195,10 +197,10 @@ async fn main() { udp::Connection::new(Some(udp_config), server_addr).unwrap(); // Send a query message. - let mut query = udp_conn.query(&req).await.unwrap(); + let mut request = udp_conn.request(&req).await.unwrap(); // Get the reply - let reply = query.get_result().await; + let reply = request.get_response().await; println!("UDP reply: {:?}", reply); // Create a single TCP transport connection. This is usefull for a @@ -221,11 +223,11 @@ async fn main() { println!("single TCP run terminated"); }); - // Send a query message. - let mut query = tcp.query(&req).await.unwrap(); + // Send a request message. + let mut request = tcp.request(&req).await.unwrap(); // Get the reply - let reply = query.get_result().await; + let reply = request.get_response().await; println!("TCP reply: {:?}", reply); drop(tcp); diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 5db95a4da..cd2823415 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -12,8 +12,8 @@ pub mod compose_request; pub mod error; pub mod multi_stream; pub mod octet_stream; -pub mod query; pub mod redundant; +pub mod request; pub mod request_message; pub mod tcp_connect; pub mod tls_connect; diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 5d1eb847f..ff32c06ca 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -34,7 +34,7 @@ use crate::net::client::async_connect::AsyncConnect; use crate::net::client::compose_request::ComposeRequest; use crate::net::client::error::Error; use crate::net::client::octet_stream; -use crate::net::client::query::{GetResult, QueryMessage4}; +use crate::net::client::request::{GetResponse, Request}; /// Capacity of the channel that transports [ChanReq]. const DEF_CHAN_CAP: usize = 8; @@ -82,7 +82,7 @@ impl Connection { /// Main execution function for [Connection]. /// /// This function has to run in the background or together with - /// any calls to [query](Self::query) or [Query::get_result]. + /// any calls to [query](Self::query) or [ReqResp::get_response]. pub fn run< S: AsyncConnect + Send + 'static, C: 'static + AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, @@ -96,14 +96,14 @@ impl Connection { /// Start a DNS request. /// /// This function takes a precomposed message as a parameter and - /// returns a [Query] object wrapped in a [Result]. - async fn query_impl4( + /// returns a [ReqResp] object wrapped in a [Result]. + async fn query_impl( &self, query_msg: &CR, - ) -> Result, Error> { + ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.inner.new_conn(None, tx).await?; - let gr = Query::::new(self.clone(), query_msg, rx); + let gr = ReqResp::::new(self.clone(), query_msg, rx); Ok(Box::new(gr)) } @@ -122,48 +122,32 @@ impl Connection { } } -/* -impl - QueryMessage, Octs> for Connection<> -{ - fn query<'a>( +impl Request for Connection { + fn request<'a>( &'a self, - query_msg: &'a Message, - ) -> Pin, Error>> + Send + '_>> - { - return Box::pin(self.query_impl(query_msg)); - } -} -*/ - -impl QueryMessage4 - for Connection -{ - fn query<'a>( - &'a self, - query_msg: &'a CR, + request_msg: &'a CR, ) -> Pin< Box< - dyn Future, Error>> + dyn Future, Error>> + Send + '_, >, > { - return Box::pin(self.query_impl4(query_msg)); + return Box::pin(self.query_impl(request_msg)); } } -//------------ Query ---------------------------------------------------------- +//------------ ReqResp -------------------------------------------------------- -/// This struct represent an active DNS query. +/// This struct represent an active DNS request. #[derive(Debug)] -pub struct Query { +pub struct ReqResp { /// Request message. /// /// The reply message is compared with the request message to see if /// it matches the query. // query_msg: Message>, - query_msg: CR, + request_msg: CR, /// Current state of the query. state: QueryState, @@ -216,29 +200,30 @@ struct ChanRespOk { conn: octet_stream::Connection, } -impl Query { - /// Constructor for [Query], takes a DNS query and a receiver for the +impl ReqResp { + /// Constructor for [ReqResp], takes a DNS request and a receiver for the /// reply. fn new( conn: Connection, - query_msg: &CR, + request_msg: &CR, receiver: oneshot::Receiver>, - ) -> Query { + ) -> ReqResp { Self { conn, - query_msg: query_msg.clone(), + request_msg: request_msg.clone(), state: QueryState::GetConn(receiver), conn_id: 0, - //imm_retry_count: 0, delayed_retry_count: 0, } } - /// Get the result of a DNS query. + /// Get the result of a DNS request. /// - /// This function returns the reply to a DNS query wrapped in a + /// This function returns the reply to a DNS request wrapped in a /// [Result]. - pub async fn get_result_impl(&mut self) -> Result, Error> { + pub async fn get_response_impl( + &mut self, + ) -> Result, Error> { loop { match self.state { QueryState::GetConn(ref mut receiver) => { @@ -271,7 +256,7 @@ impl Query { } } QueryState::StartQuery(ref mut conn) => { - let msg = self.query_msg.clone(); + let msg = self.request_msg.clone(); let query_res = conn.query_no_check(&msg).await; match query_res { Err(err) => { @@ -308,9 +293,9 @@ impl Query { } let msg = reply.expect("error is checked before"); - let query_msg = self.query_msg.to_message(); + let request_msg = self.request_msg.to_message(); - if !is_answer_ignore_id(&msg, &query_msg) { + if !is_answer_ignore_id(&msg, &request_msg) { return Err(Error::WrongReplyForQuery); } return Ok(msg); @@ -334,13 +319,13 @@ impl Query { } } -impl GetResult for Query { - fn get_result( +impl GetResponse for ReqResp { + fn get_response( &mut self, ) -> Pin< Box, Error>> + Send + '_>, > { - Box::pin(self.get_result_impl()) + Box::pin(self.get_response_impl()) } } @@ -355,7 +340,7 @@ struct InnerConnection { /// [InnerConnection::sender] and [InnerConnection::receiver] are /// part of a single channel. /// - /// Used by [Query] to send requests to [InnerConnection::run]. + /// Used by [ReqResp] to send requests to [InnerConnection::run]. sender: mpsc::Sender>, /// receiver part of the channel. diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 77b0f4cfe..5526471db 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -36,7 +36,7 @@ use crate::base::{ use crate::net::client::compose_request::ComposeRequest; use crate::net::client::compose_request::OptTypes; use crate::net::client::error::Error; -use crate::net::client::query::{GetResult, QueryMessage4}; +use crate::net::client::request::{GetResponse, Request}; use octseq::Octets; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -126,15 +126,14 @@ impl Connection { /// Start a DNS request. /// /// This function takes a precomposed message as a parameter and - /// returns a [Query] object wrapped in a [Result]. - async fn query_impl4( + /// returns a [ReqRepl] object wrapped in a [Result]. + async fn request_impl( &self, - query_msg: &CR, - ) -> Result, Error> { + request_msg: &CR, + ) -> Result, Error> { let (tx, rx) = oneshot::channel(); - self.inner.query(tx, query_msg).await?; - let msg = query_msg; - Ok(Box::new(Query::new(msg, rx))) + self.inner.request(tx, request_msg).await?; + Ok(Box::new(ReqResp::new(request_msg, rx))) } /// Start a DNS request but do not check if the reply matches the request. @@ -146,38 +145,36 @@ impl Connection { query_msg: &CR, ) -> Result { let (tx, rx) = oneshot::channel(); - self.inner.query(tx, query_msg).await?; + self.inner.request(tx, query_msg).await?; Ok(QueryNoCheck::new(rx)) } } -impl QueryMessage4 - for Connection -{ - fn query<'a>( +impl Request for Connection { + fn request<'a>( &'a self, - query_msg: &'a CR, + request_msg: &'a CR, ) -> Pin< Box< - dyn Future, Error>> + dyn Future, Error>> + Send + '_, >, > { - return Box::pin(self.query_impl4(query_msg)); + return Box::pin(self.request_impl(request_msg)); } } -//------------ Query ---------------------------------------------------------- +//------------ ReqResp -------------------------------------------------------- -/// This struct represent an active DNS query. +/// This struct represent an active DNS request. #[derive(Debug)] -pub struct Query { +pub struct ReqResp { /// Request message. /// /// The reply message is compared with the request message to see if /// it matches the query. - query_msg: Message>, + request_msg: Message>, /// Current state of the query. state: QueryState, @@ -195,27 +192,29 @@ enum QueryState { Done, } -impl Query { +impl ReqResp { /// Constructor for [Query], takes a DNS query and a receiver for the /// reply. fn new( - query_msg: &CR, + request_msg: &CR, receiver: oneshot::Receiver, - ) -> Query { - let vec = query_msg.to_vec(); + ) -> ReqResp { + let vec = request_msg.to_vec(); let msg = Message::from_octets(vec) .expect("Message failed to parse contents of another Message"); Self { - query_msg: msg, + request_msg: msg, state: QueryState::Busy(receiver), } } - /// Get the result of a DNS query. + /// Get the result of a DNS request. /// - /// This function returns the reply to a DNS query wrapped in a + /// This function returns the reply to a DNS request wrapped in a /// [Result]. - pub async fn get_result_impl(&mut self) -> Result, Error> { + pub async fn get_response_impl( + &mut self, + ) -> Result, Error> { match self.state { QueryState::Busy(ref mut receiver) => { let res = receiver.await; @@ -236,7 +235,7 @@ impl Query { let resp = res.expect("error case is checked already"); let msg = resp.reply; - if !is_answer_ignore_id(&msg, &self.query_msg) { + if !is_answer_ignore_id(&msg, &self.request_msg) { return Err(Error::WrongReplyForQuery); } Ok(msg) @@ -248,13 +247,13 @@ impl Query { } } -impl GetResult for Query { - fn get_result( +impl GetResponse for ReqResp { + fn get_response( &mut self, ) -> Pin< Box, Error>> + Send + '_>, > { - Box::pin(self.get_result_impl()) + Box::pin(self.get_response_impl()) } } @@ -636,16 +635,14 @@ impl InnerConnection { } /// This function sends a DNS request to [InnerConnection::run]. - pub async fn query( + pub async fn request( &self, sender: oneshot::Sender, - query_msg: &CR, + request_msg: &CR, ) -> Result<(), Error> { - // We should figure out how to get query_msg. - let req = ChanReq { sender, - msg: query_msg.clone(), + msg: request_msg.clone(), }; match self.sender.send(req).await { Err(_) => diff --git a/src/net/client/query.rs b/src/net/client/query.rs deleted file mode 100644 index 2e1510810..000000000 --- a/src/net/client/query.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! Traits for query transports - -#![warn(missing_docs)] -#![warn(clippy::missing_docs_in_private_items)] - -use bytes::Bytes; -use std::boxed::Box; -use std::fmt::Debug; -use std::future::Future; -use std::pin::Pin; - -use crate::base::Message; -use crate::net::client::error::Error; - -/// Trait for starting a DNS query based on a message. -pub trait QueryMessage { - /// Query function that takes a message builder type. - /// - /// This function is intended to be cancel safe. - fn query<'a>( - &'a self, - query_msg: &'a Message, - ) -> Pin> + Send + '_>>; -} - -/* -// This trait is replaced with QueryMessage3 -/// Trait for starting a DNS query based on a message. -pub trait QueryMessage2 { - /// Query function that takes a message builder type. - /// - /// This function is intended to be cancel safe. - fn query<'a>( - &'a self, - query_msg: &'a mut MessageBuilder< - StaticCompressor>, - >, - ) -> Pin + Send + '_>>; -} -*/ - -/// Trait for starting a DNS query based on a message. -pub trait QueryMessage3 { - /// Query function that takes a message type. - /// - /// This function is intended to be cancel safe. - fn query<'a>( - &'a self, - query_msg: &'a Message, - ) -> Pin + Send + '_>>; -} - -/// Trait for starting a DNS query based on a base message builder. -pub trait QueryMessage4 { - /// Query function that takes a BaseMessageBuilder type. - /// - /// This function is intended to be cancel safe. - fn query<'a>( - &'a self, - query_msg: &'a BMB, - ) -> Pin + Send + '_>>; -} - -/// This type is the actual result type of the future returned by the -/// query function in the QueryMessage2 trait. -type QueryResultOutput = Result, Error>; - -/// Trait for getting the result of a DNS query. -pub trait GetResult: Debug { - /// Get the result of a DNS query. - /// - /// This function is intended to be cancel safe. - fn get_result( - &mut self, - ) -> Pin< - Box, Error>> + Send + '_>, - >; -} diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index 07f50f0aa..417ec3872 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -27,7 +27,7 @@ use tokio::time::{sleep_until, Duration, Instant}; use crate::base::iana::OptRcode; use crate::base::Message; use crate::net::client::error::Error; -use crate::net::client::query::{GetResult, QueryMessage4}; +use crate::net::client::request::{GetResponse, Request}; /* Basic algorithm: @@ -114,51 +114,51 @@ impl<'a, BMB: Clone + Debug + Send + Sync + 'static> Connection { /// Add a transport connection. pub async fn add( &self, - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, ) -> Result<(), Error> { self.inner.add(conn).await } - /// Implementation of the query function. - async fn query_impl( + /// Implementation of the request function. + async fn request_impl( &self, - query_msg: &BMB, - ) -> Result, Error> { - let query = self.inner.query(query_msg.clone()).await?; - Ok(Box::new(query)) + request_msg: &BMB, + ) -> Result, Error> { + let request = self.inner.request(request_msg.clone()).await?; + Ok(Box::new(request)) } } -impl QueryMessage4 +impl Request for Connection { - fn query<'a>( + fn request<'a>( &'a self, - query_msg: &'a BMB, + request_msg: &'a BMB, ) -> Pin< Box< - dyn Future, Error>> + dyn Future, Error>> + Send + '_, >, > { - return Box::pin(self.query_impl(query_msg)); + return Box::pin(self.request_impl(request_msg)); } } -//------------ Query ---------------------------------------------------------- +//------------ ReqResp -------------------------------------------------------- /// This type represents an active query request. #[derive(Debug)] -pub struct Query { +pub struct ReqResp { /// User configuration. config: Config, /// The state of the query state: QueryState, - /// The query message - query_msg: BMB, + /// The reuqest message + request_msg: BMB, /// List of connections identifiers and estimated response times. conn_rt: Vec, @@ -210,7 +210,7 @@ enum ChanReq { GetRT(RTReq), /// Start a query - Query(QueryReq), + Query(RequestReq), /// Report how long it took to get a response Report(TimeReport), @@ -228,7 +228,7 @@ impl Debug for ChanReq { /// Request to add a new connection struct AddReq { /// New connection to add - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, /// Channel to send the reply to tx: oneshot::Sender, @@ -246,29 +246,29 @@ struct RTReq /**/ { /// Reply to a RT request type RTReply = Result, Error>; -/// Request to start a query -struct QueryReq { +/// Request to start a request +struct RequestReq { /// Identifier of connection id: u64, /// Request message - query_msg: BMB, + request_msg: BMB, /// Channel to send the reply to - tx: oneshot::Sender, + tx: oneshot::Sender, } -impl Debug for QueryReq { +impl Debug for RequestReq { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - f.debug_struct("QueryReq") + f.debug_struct("RequestReq") .field("id", &self.id) - .field("query_msg", &self.query_msg) + .field("request_msg", &self.request_msg) .finish() } } -/// Reply to a query request. -type QueryReply = Result, Error>; +/// Reply to a request request. +type RequestReply = Result, Error>; /// Report the amount of time until success or failure. #[derive(Debug)] @@ -305,11 +305,11 @@ struct ConnRT { /// Result of the futures in fut_list. type FutListOutput = (usize, Result, Error>); -impl Query { +impl ReqResp { /// Create a new query object. fn new( config: Config, - query_msg: BMB, + request_msg: BMB, mut conn_rt: Vec, sender: mpsc::Sender>, ) -> Self { @@ -327,8 +327,7 @@ impl Query { Self { config, - query_msg, - //conns, + request_msg, conn_rt, sender, state: QueryState::Init, @@ -340,8 +339,8 @@ impl Query { } } - /// Implementation of get_result. - async fn get_result_impl(&mut self) -> Result, Error> { + /// Implementation of get_response. + async fn get_response_impl(&mut self) -> Result, Error> { loop { match self.state { QueryState::Init => { @@ -357,7 +356,7 @@ impl Query { ind, self.conn_rt[ind].id, self.sender.clone(), - self.query_msg.clone(), + self.request_msg.clone(), ); self.fut_list.push(Box::pin(fut)); let timeout = Instant::now() + self.conn_rt[ind].est_rt; @@ -540,13 +539,15 @@ impl Query { } } -impl GetResult for Query { - fn get_result( +impl GetResponse + for ReqResp +{ + fn get_response( &mut self, ) -> Pin< Box, Error>> + Send + '_>, > { - Box::pin(self.get_result_impl()) + Box::pin(self.get_response_impl()) } } @@ -592,8 +593,7 @@ impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { let mut next_id: u64 = 10; let mut conn_stats: Vec = Vec::new(); let mut conn_rt: Vec = Vec::new(); - let mut conns: Vec + Send + Sync>> = - Vec::new(); + let mut conns: Vec + Send + Sync>> = Vec::new(); let mut receiver = opt_receiver.expect("receiver should not be empty"); @@ -625,19 +625,20 @@ impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { // Don't care if send fails let _ = rt_req.tx.send(Ok(conn_rt.clone())); } - ChanReq::Query(query_req) => { + ChanReq::Query(request_req) => { let opt_ind = - conn_rt.iter().position(|e| e.id == query_req.id); + conn_rt.iter().position(|e| e.id == request_req.id); match opt_ind { Some(ind) => { - let query = - conns[ind].query(&query_req.query_msg).await; + let query = conns[ind] + .request(&request_req.request_msg) + .await; // Don't care if send fails - let _ = query_req.tx.send(query); + let _ = request_req.tx.send(query); } None => { // Don't care if send fails - let _ = query_req + let _ = request_req .tx .send(Err(Error::RedundantTransportNotFound)); } @@ -692,7 +693,7 @@ impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { /// Implementation of the add method. async fn add( &self, - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, ) -> Result<(), Error> { let (tx, rx) = oneshot::channel(); self.sender @@ -703,16 +704,19 @@ impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { } /// Implementation of the query method. - async fn query(&'a self, query_msg: BMB) -> Result, Error> { + async fn request( + &'a self, + request_msg: BMB, + ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.sender .send(ChanReq::GetRT(RTReq { tx })) .await .expect("send should not fail"); let conn_rt = rx.await.expect("receive should not fail")?; - Ok(Query::new( + Ok(ReqResp::new( self.config.clone(), - query_msg, + request_msg, conn_rt, self.sender.clone(), )) @@ -728,18 +732,22 @@ async fn start_request( index: usize, id: u64, sender: mpsc::Sender>, - query_msg: BMB, + request_msg: BMB, ) -> (usize, Result, Error>) { let (tx, rx) = oneshot::channel(); sender - .send(ChanReq::Query(QueryReq { id, query_msg, tx })) + .send(ChanReq::Query(RequestReq { + id, + request_msg, + tx, + })) .await .expect("send is expected to work"); - let mut query = match rx.await.expect("receive is expected to work") { + let mut request = match rx.await.expect("receive is expected to work") { Err(err) => return (index, Err(err)), - Ok(query) => query, + Ok(request) => request, }; - let reply = query.get_result().await; + let reply = request.get_response().await; (index, reply) } diff --git a/src/net/client/request.rs b/src/net/client/request.rs new file mode 100644 index 000000000..e590c3c03 --- /dev/null +++ b/src/net/client/request.rs @@ -0,0 +1,40 @@ +//! Traits for request/response transports + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +use bytes::Bytes; +use std::boxed::Box; +use std::fmt::Debug; +use std::future::Future; +use std::pin::Pin; + +use crate::base::Message; +use crate::net::client::error::Error; + +/// Trait for starting a DNS request based on a request composer. +pub trait Request { + /// Request function that takes a ComposeRequest type. + /// + /// This function is intended to be cancel safe. + fn request<'a>( + &'a self, + request_msg: &'a CR, + ) -> Pin + Send + '_>>; +} + +/// This type is the actual result type of the future returned by the +/// request function in the Request trait. +type RequestResultOutput = Result, Error>; + +/// Trait for getting the result of a DNS query. +pub trait GetResponse: Debug { + /// Get the result of a DNS request. + /// + /// This function is intended to be cancel safe. + fn get_response( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + >; +} diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 02c5dfbba..771647358 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -24,7 +24,7 @@ use crate::base::iana::Rcode; use crate::base::Message; use crate::net::client::compose_request::ComposeRequest; use crate::net::client::error::Error; -use crate::net::client::query::{GetResult, QueryMessage4}; +use crate::net::client::request::{GetResponse, Request}; /// How many times do we try a new random port if we get ‘address in use.’ const RETRY_RANDOM_PORT: usize = 10; @@ -120,14 +120,14 @@ impl Connection { }) } - /// Start a new DNS query. - async fn query_impl4< + /// Start a new DNS request. + async fn request_impl< CR: ComposeRequest + Clone + Send + Sync + 'static, >( &self, - query_msg: &CR, - ) -> Result, Error> { - let gr = self.inner.query(query_msg, self.clone()).await?; + request_msg: &CR, + ) -> Result, Error> { + let gr = self.inner.request(request_msg, self.clone()).await?; Ok(Box::new(gr)) } @@ -137,20 +137,20 @@ impl Connection { } } -impl QueryMessage4 +impl Request for Connection { - fn query<'a>( + fn request<'a>( &'a self, - query_msg: &'a CR, + request_msg: &'a CR, ) -> Pin< Box< - dyn Future, Error>> + dyn Future, Error>> + Send + '_, >, > { - return Box::pin(self.query_impl4(query_msg)); + return Box::pin(self.request_impl(request_msg)); } } @@ -365,28 +365,28 @@ impl GetResult for Query { */ -//------------ Query4 --------------------------------------------------------- +//------------ ReqResp -------------------------------------------------------- -/// The state of a DNS query. -pub struct Query4 { - /// Future that does the actual work of GetResult. - get_result_fut: +/// The state of a DNS request. +pub struct ReqResp { + /// Future that does the actual work of GetResponse. + get_response_fut: Pin, Error>> + Send>>, } -impl Query4 { - /// Create new Query object. +impl ReqResp { + /// Create new ReqResp object. fn new( config: Config, - query_msg: &CR, + request_msg: &CR, remote_addr: SocketAddr, conn: Connection, udp_payload_size: Option, ) -> Self { Self { - get_result_fut: Box::pin(Self::get_result_impl2( + get_response_fut: Box::pin(Self::get_response_impl2( config, - query_msg.clone(), + request_msg.clone(), remote_addr, conn, udp_payload_size, @@ -395,16 +395,16 @@ impl Query4 { } /// Async function that waits for the future stored in Query to complete. - async fn get_result_impl(&mut self) -> Result, Error> { - (&mut self.get_result_fut).await + async fn get_response_impl(&mut self) -> Result, Error> { + (&mut self.get_response_fut).await } - /// Get the result of a DNS Query. + /// Get the response of a DNS request. /// /// This function is not cancel safe. - async fn get_result_impl2( + async fn get_response_impl2( config: Config, - mut query_bmb: CR, + mut request_bmb: CR, remote_addr: SocketAddr, conn: Connection, udp_payload_size: Option, @@ -427,14 +427,14 @@ impl Query4 { .map_err(|e| Error::UdpConnect(Arc::new(e)))?; // Set random ID in header - let header = query_bmb.header_mut(); + let header = request_bmb.header_mut(); header.set_random_id(); // Set UDP payload size if let Some(size) = udp_payload_size { - query_bmb.set_udp_payload_size(size) + request_bmb.set_udp_payload_size(size) } - let query_msg = query_bmb.to_message(); - let dgram = query_msg.as_slice(); + let request_msg = request_bmb.to_message(); + let dgram = request_msg.as_slice(); let sent = sock .as_ref() @@ -487,7 +487,7 @@ impl Query4 { Err(_) => continue, }; - if !is_answer(&answer, &query_msg) { + if !is_answer(&answer, &request_msg) { // Wrong answer, go back to receiving continue; } @@ -528,19 +528,19 @@ impl Query4 { } } -impl Debug for Query4 { +impl Debug for ReqResp { fn fmt(&self, _: &mut Formatter<'_>) -> Result<(), core::fmt::Error> { todo!() } } -impl GetResult for Query4 { - fn get_result( +impl GetResponse for ReqResp { + fn get_response( &mut self, ) -> Pin< Box, Error>> + Send + '_>, > { - Box::pin(self.get_result_impl()) + Box::pin(self.get_response_impl()) } } @@ -574,14 +574,14 @@ impl InnerConnection { } /// Return a Query object that contains the query state. - async fn query( + async fn request( &self, - query_msg: &CR, + request_msg: &CR, conn: Connection, - ) -> Result { - Ok(Query4::new( + ) -> Result { + Ok(ReqResp::new( self.config.clone(), - query_msg, + request_msg, self.remote_addr, conn, self.config.udp_payload_size, diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index e992c5278..e4fd61990 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -18,7 +18,7 @@ use crate::base::Message; use crate::net::client::compose_request::ComposeRequest; use crate::net::client::error::Error; use crate::net::client::multi_stream; -use crate::net::client::query::{GetResult, QueryMessage4}; +use crate::net::client::request::{GetResponse, Request}; use crate::net::client::tcp_connect::TcpConnect; use crate::net::client::udp; @@ -70,40 +70,38 @@ impl Connection { self.inner.run() } - /// Start a query for the QueryMessage4 trait. - async fn query_impl4( + /// Start a request for the Request trait. + async fn request_impl( &self, - query_msg: &CR, - ) -> Result, Error> { - let gr = self.inner.query(query_msg).await?; + request_msg: &CR, + ) -> Result, Error> { + let gr = self.inner.request(request_msg).await?; Ok(Box::new(gr)) } } -impl QueryMessage4 - for Connection -{ - fn query<'a>( +impl Request for Connection { + fn request<'a>( &'a self, - query_msg: &'a CR, + request_msg: &'a CR, ) -> Pin< Box< - dyn Future, Error>> + dyn Future, Error>> + Send + '_, >, > { - return Box::pin(self.query_impl4(query_msg)); + return Box::pin(self.request_impl(request_msg)); } } -//------------ Query ---------------------------------------------------------- +//------------ ReqResp -------------------------------------------------------- /// Object that contains the current state of a query. #[derive(Debug)] -pub struct Query { +pub struct ReqResp { /// Reqeust message. - query_msg: BMB, + request_msg: BMB, /// UDP transport to be used. udp_conn: udp::Connection, @@ -111,87 +109,87 @@ pub struct Query { /// TCP transport to be used. tcp_conn: multi_stream::Connection, - /// Current state of the query. + /// Current state of the request. state: QueryState, } /// Status of the query. #[derive(Debug)] enum QueryState { - /// Start a query over the UDP transport. - StartUdpQuery, + /// Start a request over the UDP transport. + StartUdpRequest, - /// Get the result from the UDP transport. - GetUdpResult(Box), + /// Get the response from the UDP transport. + GetUdpResponse(Box), - /// Start a query over the TCP transport. - StartTcpQuery, + /// Start a request over the TCP transport. + StartTcpRequest, - /// Get the result from the TCP transport. - GetTcpResult(Box), + /// Get the response from the TCP transport. + GetTcpResponse(Box), } -impl Query { - /// Create a new Query object. +impl ReqResp { + /// Create a new ReqResp object. /// /// The initial state is to start with a UDP transport. fn new( - query_msg: &CR, + request_msg: &CR, udp_conn: udp::Connection, tcp_conn: multi_stream::Connection, - ) -> Query { - Query { - query_msg: query_msg.clone(), + ) -> ReqResp { + Self { + request_msg: request_msg.clone(), udp_conn, tcp_conn, - state: QueryState::StartUdpQuery, + state: QueryState::StartUdpRequest, } } - /// Get the result of a DNS query. + /// Get the response of a DNS request. /// /// This function is cancel safe. - async fn get_result_impl(&mut self) -> Result, Error> { + async fn get_response_impl(&mut self) -> Result, Error> { loop { match &mut self.state { - QueryState::StartUdpQuery => { - let msg = self.query_msg.clone(); - let query = - QueryMessage4::query(&self.udp_conn, &msg).await?; - self.state = QueryState::GetUdpResult(query); + QueryState::StartUdpRequest => { + let msg = self.request_msg.clone(); + let request = self.udp_conn.request(&msg).await?; + self.state = QueryState::GetUdpResponse(request); continue; } - QueryState::GetUdpResult(ref mut query) => { - let reply = query.get_result().await?; - if reply.header().tc() { - self.state = QueryState::StartTcpQuery; + QueryState::GetUdpResponse(ref mut request) => { + let response = request.get_response().await?; + if response.header().tc() { + self.state = QueryState::StartTcpRequest; continue; } - return Ok(reply); + return Ok(response); } - QueryState::StartTcpQuery => { - let msg = self.query_msg.clone(); - let query = - QueryMessage4::query(&self.tcp_conn, &msg).await?; - self.state = QueryState::GetTcpResult(query); + QueryState::StartTcpRequest => { + let msg = self.request_msg.clone(); + let request = self.tcp_conn.request(&msg).await?; + self.state = QueryState::GetTcpResponse(request); continue; } - QueryState::GetTcpResult(ref mut query) => { - let reply = query.get_result().await?; - return Ok(reply); + QueryState::GetTcpResponse(ref mut query) => { + let response = query.get_response().await?; + return Ok(response); } } } } } -impl GetResult for Query { - fn get_result( +impl GetResponse + for ReqResp +{ + fn get_response( &mut self, ) -> Pin< Box, Error>> + Send + '_>, > { - Box::pin(self.get_result_impl()) + Box::pin(self.get_response_impl()) } } @@ -236,12 +234,12 @@ impl InnerConnection { Box::pin(fut) } - /// Implementation of the query function. + /// Implementation of the request function. /// - /// Just create a Query object with the state it needs. - async fn query(&self, query_msg: &CR) -> Result, Error> { - Ok(Query::new( - query_msg, + /// Just create a ReqResp object with the state it needs. + async fn request(&self, request_msg: &CR) -> Result, Error> { + Ok(ReqResp::new( + request_msg, self.udp_conn.clone(), self.tcp_conn.clone(), )) From 1e9487d1be46936904b558ff8cf52427d9b20a54 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 8 Dec 2023 15:19:51 +0100 Subject: [PATCH 081/124] Add remark about the need for associated types in the future. --- src/net/client/request.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/net/client/request.rs b/src/net/client/request.rs index e590c3c03..d0cd3b79b 100644 --- a/src/net/client/request.rs +++ b/src/net/client/request.rs @@ -13,6 +13,9 @@ use crate::base::Message; use crate::net::client::error::Error; /// Trait for starting a DNS request based on a request composer. +/// +/// In the future, the return type of request should become an associated type. +/// However, the use of 'dyn Request' in redundant currently prevents that. pub trait Request { /// Request function that takes a ComposeRequest type. /// @@ -28,6 +31,9 @@ pub trait Request { type RequestResultOutput = Result, Error>; /// Trait for getting the result of a DNS query. +/// +/// In the future, the return type of get_response should become an associated +/// type. However, too many uses of 'dyn GetResponse' currently prevent that. pub trait GetResponse: Debug { /// Get the result of a DNS request. /// From 8147c8b70f659fd63c789a9aaf51a538a14519d9 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 20 Nov 2023 10:03:26 +0100 Subject: [PATCH 082/124] Use crate::net::client transports. --- Cargo.toml | 4 +- src/resolv/stub/mod.rs | 511 +++++++++-------------------------------- 2 files changed, 112 insertions(+), 403 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6bd7f72ae..c5afece9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ ring = { version = "0.17", optional = true } serde = { version = "1.0.130", optional = true, features = ["derive"] } siphasher = { version = "1", optional = true } smallvec = { version = "1", optional = true } -tokio = { version = "1.33", optional = true, features = ["io-util", "macros", "net", "time", "sync" ] } +tokio = { version = "1.33", optional = true, features = ["io-util", "macros", "net", "time", "sync", "rt-multi-thread" ] } tokio-rustls = { version = "0.24", optional = true, features = [] } # XXX Force proc-macro2 to at least 1.0.69 for minimal-version build @@ -45,7 +45,7 @@ default = ["std", "rand"] bytes = ["dep:bytes", "octseq/bytes"] heapless = ["dep:heapless", "octseq/heapless"] interop = ["bytes", "ring"] -resolv = ["bytes", "futures", "smallvec", "std", "tokio", "libc", "rand"] +resolv = ["bytes", "futures", "net", "smallvec", "std", "tokio", "libc", "rand"] resolv-sync = ["resolv", "tokio/rt"] serde = ["dep:serde", "octseq/serde"] sign = ["std"] diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index c6e5a59ff..506f26e34 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -19,6 +19,13 @@ use crate::base::message_builder::{ }; use crate::base::name::{ToDname, ToRelativeDname}; use crate::base::question::Question; +use crate::net::client::base_message_builder::BaseMessageBuilder; +use crate::net::client::bmb; +use crate::net::client::multi_stream; +use crate::net::client::query::QueryMessage4; +use crate::net::client::redundant; +use crate::net::client::tcp_factory::TcpConnFactory; +use crate::net::client::udp_tcp; use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs}; use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts}; use crate::resolv::lookup::srv::{lookup_srv, FoundSrvs, SrvError}; @@ -26,18 +33,17 @@ use crate::resolv::resolver::{Resolver, SearchNames}; use bytes::Bytes; use octseq::array::Array; use std::boxed::Box; +use std::fmt::Debug; use std::future::Future; -use std::net::{IpAddr, SocketAddr}; +use std::net::IpAddr; use std::pin::Pin; -use std::slice::SliceIndex; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::vec::Vec; use std::{io, ops}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpStream, UdpSocket}; #[cfg(feature = "resolv-sync")] use tokio::runtime; +use tokio::sync::Mutex; use tokio::time::timeout; //------------ Sub-modules --------------------------------------------------- @@ -46,9 +52,6 @@ pub mod conf; //------------ Module Configuration ------------------------------------------ -/// How many times do we try a new random port if we get ‘address in use.’ -const RETRY_RANDOM_PORT: usize = 10; - //------------ StubResolver -------------------------------------------------- /// A DNS stub resolver. @@ -70,16 +73,14 @@ const RETRY_RANDOM_PORT: usize = 10; /// [`query()`]: #method.query /// [`run()`]: #method.run /// [`run_with_conf()`]: #method.run_with_conf -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct StubResolver { - /// Preferred servers. - preferred: ServerList, - - /// Streaming servers. - stream: ServerList, + transport: Mutex>>>>, /// Resolver options. options: ResolvOptions, + + servers: Vec, } impl StubResolver { @@ -91,11 +92,10 @@ impl StubResolver { /// Creates a new resolver using the given configuraiton. pub fn from_conf(conf: ResolvConf) -> Self { StubResolver { - preferred: ServerList::from_conf(&conf, |s| { - s.transport.is_preferred() - }), - stream: ServerList::from_conf(&conf, |s| s.transport.is_stream()), + transport: None.into(), options: conf.options, + + servers: conf.servers, } } @@ -118,6 +118,75 @@ impl StubResolver { ) -> Result { Query::new(self)?.run(message).await } + + async fn setup_transport< + BMB: Clone + Debug + BaseMessageBuilder + Send + Sync + 'static, + >( + &self, + ) -> redundant::Connection { + // Create a redundant transport and fill it with the right transports + let redun = redundant::Connection::new(None).unwrap(); + + // Start the run function on a separate task. + let redun_run = redun.clone(); + tokio::spawn(async move { + redun_run.run().await; + }); + + // We have 3 modes of operation: use_vc: only use TCP, ign_tc: only + // UDP no fallback to TCP, and normal with is UDP falling back to TCP. + if self.options.use_vc { + for s in &self.servers { + if let Transport::Tcp = s.transport { + let tcp_factory = TcpConnFactory::new(s.addr); + let tcp_conn = + multi_stream::Connection::new(None).unwrap(); + // TODO: How do we handle this? + // Create a clone for the run function. Start the run function on a + // separate task. + let conn_run = tcp_conn.clone(); + tokio::spawn(async move { + let res = conn_run.run(tcp_factory).await; + println!("run exited with {:?}", res); + }); + redun.add(Box::new(tcp_conn)).await.unwrap(); + } + } + } else if self.options.ign_tc { + todo!(); + } else { + for s in &self.servers { + if let Transport::Udp = s.transport { + let udptcp_conn = + udp_tcp::Connection::new(None, s.addr).unwrap(); + // TODO: How do we handle this? + // Create a clone for the run function. Start the run function on a + // separate task. + let conn_run = udptcp_conn.clone(); + tokio::spawn(async move { + let res = conn_run.run().await; + println!("run exited with {:?}", res); + }); + redun.add(Box::new(udptcp_conn)).await.unwrap(); + } + } + } + + redun + } + + async fn get_transport( + &self, + ) -> redundant::Connection>> { + let mut opt_transport = self.transport.lock().await; + + if opt_transport.is_none() { + let transport = self.setup_transport().await; + *opt_transport = Some(transport); + } + + (*opt_transport).as_ref().unwrap().clone() + } } impl StubResolver { @@ -238,14 +307,7 @@ pub struct Query<'a> { /// The resolver whose configuration we are using. resolver: &'a StubResolver, - /// Are we still in the preferred server list or have gone streaming? - preferred: bool, - - /// The number of attempts, starting with zero. - attempt: usize, - - /// The index in the server list we currently trying. - counter: ServerListCounter, + edns: Arc, /// The preferred error to return. /// @@ -259,23 +321,9 @@ pub struct Query<'a> { impl<'a> Query<'a> { pub fn new(resolver: &'a StubResolver) -> Result { - let (preferred, counter) = - if resolver.options().use_vc || resolver.preferred.is_empty() { - if resolver.stream.is_empty() { - return Err(io::Error::new( - io::ErrorKind::NotFound, - "no servers available", - )); - } - (false, resolver.stream.counter(resolver.options().rotate)) - } else { - (true, resolver.preferred.counter(resolver.options().rotate)) - }; Ok(Query { resolver, - preferred, - attempt: 0, - counter, + edns: Arc::new(AtomicBool::new(true)), error: Err(io::Error::new( io::ErrorKind::TimedOut, "all timed out", @@ -291,26 +339,14 @@ impl<'a> Query<'a> { match self.run_query(&mut message).await { Ok(answer) => { if answer.header().rcode() == Rcode::FormErr - && self.current_server().does_edns() + && self.does_edns() { // FORMERR with EDNS: turn off EDNS and try again. - self.current_server().disable_edns(); + self.disable_edns(); continue; } else if answer.header().rcode() == Rcode::ServFail { // SERVFAIL: go to next server. self.update_error_servfail(answer); - } else if answer.header().tc() - && self.preferred - && !self.resolver.options().ign_tc - { - // Truncated. If we can, switch to stream transports - // and try again. Otherwise return the truncated - // answer. - if self.switch_to_stream() { - continue; - } else { - return Ok(answer); - } } else { // I guess we have an answer ... return Ok(answer); @@ -318,9 +354,7 @@ impl<'a> Query<'a> { } Err(err) => self.update_error(err), } - if !self.next_server() { - return self.error; - } + return self.error; } } @@ -339,18 +373,21 @@ impl<'a> Query<'a> { &mut self, message: &mut QueryMessage, ) -> Result { - let server = self.current_server(); - server.prepare_message(message); - server.query(message).await - } + let msg = Message::from_octets( + message.as_target().as_dgram_slice().to_vec(), + ) + .unwrap(); - fn current_server(&self) -> &ServerInfo { - let list = if self.preferred { - &self.resolver.preferred - } else { - &self.resolver.stream - }; - self.counter.info(list) + let bmb = bmb::BMB::new(msg); + + let transport = self.resolver.get_transport().await; + let mut gr_fut = transport.query(&bmb).await.unwrap(); + let reply = + timeout(self.resolver.options.timeout, gr_fut.get_result()) + .await + .unwrap() + .unwrap(); + Ok(Answer { message: reply }) } fn update_error(&mut self, err: io::Error) { @@ -366,34 +403,12 @@ impl<'a> Query<'a> { self.error = Ok(answer) } - fn switch_to_stream(&mut self) -> bool { - if !self.preferred { - // We already did this. - return false; - } - self.preferred = false; - self.attempt = 0; - self.counter = - self.resolver.stream.counter(self.resolver.options().rotate); - true + pub fn does_edns(&self) -> bool { + self.edns.load(Ordering::Relaxed) } - fn next_server(&mut self) -> bool { - if self.counter.next() { - return true; - } - self.attempt += 1; - if self.attempt >= self.resolver.options().attempts { - return false; - } - self.counter = if self.preferred { - self.resolver - .preferred - .counter(self.resolver.options().rotate) - } else { - self.resolver.stream.counter(self.resolver.options().rotate) - }; - true + pub fn disable_edns(&self) { + self.edns.store(false, Ordering::Relaxed); } } @@ -451,312 +466,6 @@ impl AsRef> for Answer { } } -//------------ ServerInfo ---------------------------------------------------- - -#[derive(Clone, Debug)] -struct ServerInfo { - /// The basic server configuration. - conf: ServerConf, - - /// Whether this server supports EDNS. - /// - /// We start out with assuming it does and unset it if we get a FORMERR. - edns: Arc, -} - -impl ServerInfo { - pub fn does_edns(&self) -> bool { - self.edns.load(Ordering::Relaxed) - } - - pub fn disable_edns(&self) { - self.edns.store(false, Ordering::Relaxed); - } - - pub fn prepare_message(&self, query: &mut QueryMessage) { - query.rewind(); - if self.does_edns() { - query - .opt(|opt| { - opt.set_udp_payload_size(self.conf.udp_payload_size); - Ok(()) - }) - .unwrap(); - } - } - - pub async fn query( - &self, - query: &QueryMessage, - ) -> Result { - let res = match self.conf.transport { - Transport::Udp => { - timeout( - self.conf.request_timeout, - Self::udp_query( - query, - self.conf.addr, - self.conf.recv_size, - ), - ) - .await - } - Transport::Tcp => { - timeout( - self.conf.request_timeout, - Self::tcp_query(query, self.conf.addr), - ) - .await - } - }; - match res { - Ok(Ok(answer)) => Ok(answer), - Ok(Err(err)) => Err(err), - Err(_) => Err(io::Error::new( - io::ErrorKind::TimedOut, - "request timed out", - )), - } - } - - pub async fn tcp_query( - query: &QueryMessage, - addr: SocketAddr, - ) -> Result { - let mut sock = TcpStream::connect(&addr).await?; - sock.write_all(query.as_target().as_stream_slice()).await?; - - // This loop can be infinite because we have a timeout on this whole - // thing, anyway. - loop { - let mut buf = Vec::new(); - let len = sock.read_u16().await? as u64; - AsyncReadExt::take(&mut sock, len) - .read_to_end(&mut buf) - .await?; - if let Ok(answer) = Message::from_octets(buf.into()) { - if answer.is_answer(&query.as_message()) { - return Ok(answer.into()); - } - // else try with the next message. - } else { - return Err(io::Error::new( - io::ErrorKind::Other, - "short buf", - )); - } - } - } - - pub async fn udp_query( - query: &QueryMessage, - addr: SocketAddr, - recv_size: usize, - ) -> Result { - let sock = Self::udp_bind(addr.is_ipv4()).await?; - sock.connect(addr).await?; - let sent = sock.send(query.as_target().as_dgram_slice()).await?; - if sent != query.as_target().as_dgram_slice().len() { - return Err(io::Error::new( - io::ErrorKind::Other, - "short UDP send", - )); - } - loop { - let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here. - let len = sock.recv(&mut buf).await?; - buf.truncate(len); - - // We ignore garbage since there is a timer on this whole thing. - let answer = match Message::from_octets(buf.into()) { - Ok(answer) => answer, - Err(_) => continue, - }; - if !answer.is_answer(&query.as_message()) { - continue; - } - return Ok(answer.into()); - } - } - - async fn udp_bind(v4: bool) -> Result { - let mut i = 0; - loop { - let local: SocketAddr = if v4 { - ([0u8; 4], 0).into() - } else { - ([0u16; 8], 0).into() - }; - match UdpSocket::bind(&local).await { - Ok(sock) => return Ok(sock), - Err(err) => { - if i == RETRY_RANDOM_PORT { - return Err(err); - } else { - i += 1 - } - } - } - } - } -} - -impl From for ServerInfo { - fn from(conf: ServerConf) -> Self { - ServerInfo { - conf, - edns: Arc::new(AtomicBool::new(true)), - } - } -} - -impl<'a> From<&'a ServerConf> for ServerInfo { - fn from(conf: &'a ServerConf) -> Self { - conf.clone().into() - } -} - -//------------ ServerList ---------------------------------------------------- - -#[derive(Clone, Debug)] -struct ServerList { - /// The actual list of servers. - servers: Vec, - - /// Where to start accessing the list. - /// - /// In rotate mode, this value will always keep growing and will have to - /// be used modulo `servers`’s length. - /// - /// When it eventually wraps around the end of usize’s range, there will - /// be a jump in rotation. Since that will happen only oh-so-often, we - /// accept that in favour of simpler code. - start: Arc, -} - -impl ServerList { - pub fn from_conf(conf: &ResolvConf, filter: F) -> Self - where - F: Fn(&ServerConf) -> bool, - { - ServerList { - servers: { - conf.servers - .iter() - .filter(|f| filter(f)) - .map(Into::into) - .collect() - }, - start: Arc::new(AtomicUsize::new(0)), - } - } - - pub fn is_empty(&self) -> bool { - self.servers.is_empty() - } - - pub fn counter(&self, rotate: bool) -> ServerListCounter { - let res = ServerListCounter::new(self); - if rotate { - self.rotate() - } - res - } - - pub fn iter(&self) -> ServerListIter { - ServerListIter::new(self) - } - - pub fn rotate(&self) { - self.start.fetch_add(1, Ordering::SeqCst); - } -} - -impl<'a> IntoIterator for &'a ServerList { - type Item = &'a ServerInfo; - type IntoIter = ServerListIter<'a>; - - fn into_iter(self) -> Self::IntoIter { - self.iter() - } -} - -impl> ops::Index for ServerList { - type Output = >::Output; - - fn index(&self, index: I) -> &>::Output { - self.servers.index(index) - } -} - -//------------ ServerListCounter --------------------------------------------- - -#[derive(Clone, Debug)] -struct ServerListCounter { - cur: usize, - end: usize, -} - -impl ServerListCounter { - fn new(list: &ServerList) -> Self { - if list.servers.is_empty() { - return ServerListCounter { cur: 0, end: 0 }; - } - - // We modulo the start value here to prevent hick-ups towards the - // end of usize’s range. - let start = list.start.load(Ordering::Relaxed) % list.servers.len(); - ServerListCounter { - cur: start, - end: start + list.servers.len(), - } - } - - #[allow(clippy::should_implement_trait)] - pub fn next(&mut self) -> bool { - let next = self.cur + 1; - if next < self.end { - self.cur = next; - true - } else { - false - } - } - - pub fn info<'a>(&self, list: &'a ServerList) -> &'a ServerInfo { - &list[self.cur % list.servers.len()] - } -} - -//------------ ServerListIter ------------------------------------------------ - -#[derive(Clone, Debug)] -struct ServerListIter<'a> { - servers: &'a ServerList, - counter: ServerListCounter, -} - -impl<'a> ServerListIter<'a> { - fn new(list: &'a ServerList) -> Self { - ServerListIter { - servers: list, - counter: ServerListCounter::new(list), - } - } -} - -impl<'a> Iterator for ServerListIter<'a> { - type Item = &'a ServerInfo; - - fn next(&mut self) -> Option { - if self.counter.next() { - Some(self.counter.info(self.servers)) - } else { - None - } - } -} - //------------ SearchIter ---------------------------------------------------- #[derive(Clone, Debug)] From a4f04ea6d1f2884688f7a66e0a468220b3bf49fb Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 17 Nov 2023 11:59:08 +0100 Subject: [PATCH 083/124] Reduce the number of tasks we need. --- src/resolv/stub/mod.rs | 55 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 8 deletions(-) diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index 506f26e34..b65f68e5a 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -31,6 +31,7 @@ use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts}; use crate::resolv::lookup::srv::{lookup_srv, FoundSrvs, SrvError}; use crate::resolv::resolver::{Resolver, SearchNames}; use bytes::Bytes; +use futures::stream::{FuturesUnordered, StreamExt}; use octseq::array::Array; use std::boxed::Box; use std::fmt::Debug; @@ -128,11 +129,22 @@ impl StubResolver { let redun = redundant::Connection::new(None).unwrap(); // Start the run function on a separate task. - let redun_run = redun.clone(); + let redun_run_fut = redun.run(); + + // It would be nice to have just one task. However redun.run() has to + // execute before we can call redun.add(). However, we need to know + // the type of the elements we add to FuturesUnordered. For the moment + // we have two tasks. tokio::spawn(async move { - redun_run.run().await; + redun_run_fut.await; }); + let fut_list_tcp = FuturesUnordered::new(); + let fut_list_udp_tcp = FuturesUnordered::new(); + + // Start the tasks with empty base transports. We need redun to be + // running before we can add transports. + // We have 3 modes of operation: use_vc: only use TCP, ign_tc: only // UDP no fallback to TCP, and normal with is UDP falling back to TCP. if self.options.use_vc { @@ -145,9 +157,10 @@ impl StubResolver { // Create a clone for the run function. Start the run function on a // separate task. let conn_run = tcp_conn.clone(); - tokio::spawn(async move { - let res = conn_run.run(tcp_factory).await; - println!("run exited with {:?}", res); + fut_list_tcp.push(async move { + let fut = conn_run.run(tcp_factory); + drop(conn_run); + let _res = fut.await; }); redun.add(Box::new(tcp_conn)).await.unwrap(); } @@ -163,15 +176,20 @@ impl StubResolver { // Create a clone for the run function. Start the run function on a // separate task. let conn_run = udptcp_conn.clone(); - tokio::spawn(async move { - let res = conn_run.run().await; - println!("run exited with {:?}", res); + fut_list_udp_tcp.push(async move { + let fut = conn_run.run(); + drop(conn_run); + let _res = fut.await; }); redun.add(Box::new(udptcp_conn)).await.unwrap(); } } } + tokio::spawn(async move { + run(fut_list_tcp, fut_list_udp_tcp).await; + }); + redun } @@ -189,6 +207,27 @@ impl StubResolver { } } +async fn run( + mut fut_list_tcp: FuturesUnordered, + mut fut_list_udp_tcp: FuturesUnordered, +) { + loop { + let tcp_empty = fut_list_tcp.is_empty(); + let udp_tcp_empty = fut_list_udp_tcp.is_empty(); + if tcp_empty && udp_tcp_empty { + break; + } + tokio::select! { + _ = fut_list_tcp.next(), if !tcp_empty => { + // Nothing to do + } + _ = fut_list_udp_tcp.next(), if !udp_tcp_empty => { + // Nothing to do + } + } + } +} + impl StubResolver { pub async fn lookup_addr( &self, From effbd2af1f3664f7dd6dd98353e90462dbc5f3e9 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Tue, 5 Dec 2023 10:53:22 +0100 Subject: [PATCH 084/124] Factory is renamed to connection stream --- src/resolv/stub/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index b65f68e5a..5e0eb355b 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -24,7 +24,7 @@ use crate::net::client::bmb; use crate::net::client::multi_stream; use crate::net::client::query::QueryMessage4; use crate::net::client::redundant; -use crate::net::client::tcp_factory::TcpConnFactory; +use crate::net::client::tcp_conn_stream::TcpConnStream; use crate::net::client::udp_tcp; use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs}; use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts}; @@ -150,7 +150,7 @@ impl StubResolver { if self.options.use_vc { for s in &self.servers { if let Transport::Tcp = s.transport { - let tcp_factory = TcpConnFactory::new(s.addr); + let tcp_conn_stream = TcpConnStream::new(s.addr); let tcp_conn = multi_stream::Connection::new(None).unwrap(); // TODO: How do we handle this? @@ -158,7 +158,7 @@ impl StubResolver { // separate task. let conn_run = tcp_conn.clone(); fut_list_tcp.push(async move { - let fut = conn_run.run(tcp_factory); + let fut = conn_run.run(tcp_conn_stream); drop(conn_run); let _res = fut.await; }); From 6a5350c3acb0d437082e26139f19bd7b4eb3be0e Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 8 Dec 2023 15:41:30 +0100 Subject: [PATCH 085/124] Adapt to renaming in net/client --- src/resolv/stub/mod.rs | 44 +++++++++++++++++------------------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index 5e0eb355b..db8130369 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -19,12 +19,12 @@ use crate::base::message_builder::{ }; use crate::base::name::{ToDname, ToRelativeDname}; use crate::base::question::Question; -use crate::net::client::base_message_builder::BaseMessageBuilder; -use crate::net::client::bmb; +use crate::net::client::compose_request::ComposeRequest; use crate::net::client::multi_stream; -use crate::net::client::query::QueryMessage4; use crate::net::client::redundant; -use crate::net::client::tcp_conn_stream::TcpConnStream; +use crate::net::client::request::Request; +use crate::net::client::request_message::RequestMessage; +use crate::net::client::tcp_connect::TcpConnect; use crate::net::client::udp_tcp; use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs}; use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts}; @@ -76,7 +76,7 @@ pub mod conf; /// [`run_with_conf()`]: #method.run_with_conf #[derive(Debug)] pub struct StubResolver { - transport: Mutex>>>>, + transport: Mutex>>>>, /// Resolver options. options: ResolvOptions, @@ -121,10 +121,10 @@ impl StubResolver { } async fn setup_transport< - BMB: Clone + Debug + BaseMessageBuilder + Send + Sync + 'static, + CR: Clone + Debug + ComposeRequest + Send + Sync + 'static, >( &self, - ) -> redundant::Connection { + ) -> redundant::Connection { // Create a redundant transport and fill it with the right transports let redun = redundant::Connection::new(None).unwrap(); @@ -150,17 +150,13 @@ impl StubResolver { if self.options.use_vc { for s in &self.servers { if let Transport::Tcp = s.transport { - let tcp_conn_stream = TcpConnStream::new(s.addr); + let tcp_connect = TcpConnect::new(s.addr); let tcp_conn = multi_stream::Connection::new(None).unwrap(); - // TODO: How do we handle this? - // Create a clone for the run function. Start the run function on a - // separate task. - let conn_run = tcp_conn.clone(); + // Start the run function on a separate task. + let run_fut = tcp_conn.run(tcp_connect); fut_list_tcp.push(async move { - let fut = conn_run.run(tcp_conn_stream); - drop(conn_run); - let _res = fut.await; + let _res = run_fut.await; }); redun.add(Box::new(tcp_conn)).await.unwrap(); } @@ -172,14 +168,10 @@ impl StubResolver { if let Transport::Udp = s.transport { let udptcp_conn = udp_tcp::Connection::new(None, s.addr).unwrap(); - // TODO: How do we handle this? - // Create a clone for the run function. Start the run function on a - // separate task. - let conn_run = udptcp_conn.clone(); + // Start the run function on a separate task. + let run_fut = udptcp_conn.run(); fut_list_udp_tcp.push(async move { - let fut = conn_run.run(); - drop(conn_run); - let _res = fut.await; + let _res = run_fut.await; }); redun.add(Box::new(udptcp_conn)).await.unwrap(); } @@ -195,7 +187,7 @@ impl StubResolver { async fn get_transport( &self, - ) -> redundant::Connection>> { + ) -> redundant::Connection>> { let mut opt_transport = self.transport.lock().await; if opt_transport.is_none() { @@ -417,12 +409,12 @@ impl<'a> Query<'a> { ) .unwrap(); - let bmb = bmb::BMB::new(msg); + let request_msg = RequestMessage::new(msg); let transport = self.resolver.get_transport().await; - let mut gr_fut = transport.query(&bmb).await.unwrap(); + let mut gr_fut = transport.request(&request_msg).await.unwrap(); let reply = - timeout(self.resolver.options.timeout, gr_fut.get_result()) + timeout(self.resolver.options.timeout, gr_fut.get_response()) .await .unwrap() .unwrap(); From 34ecd446a4ff460011c686ddb1ef7b666ca1b948 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 8 Dec 2023 15:42:55 +0100 Subject: [PATCH 086/124] Get rid of ign_tc. --- src/resolv/stub/conf.rs | 9 --------- src/resolv/stub/mod.rs | 2 -- 2 files changed, 11 deletions(-) diff --git a/src/resolv/stub/conf.rs b/src/resolv/stub/conf.rs index b40be6139..24554e487 100644 --- a/src/resolv/stub/conf.rs +++ b/src/resolv/stub/conf.rs @@ -66,11 +66,6 @@ pub struct ResolvOptions { /// it is supposed to mean. pub primary: bool, - /// Ignore trunactions errors, don’t retry with TCP. - /// - /// This option is implemented by the query. - pub ign_tc: bool, - /// Set the recursion desired bit in queries. /// /// Enabled by default. @@ -186,7 +181,6 @@ impl Default for ResolvOptions { aa_only: false, use_vc: false, primary: false, - ign_tc: false, stay_open: false, use_inet6: false, rotate: false, @@ -556,9 +550,6 @@ impl fmt::Display for ResolvConf { if self.options.primary { options.push("primary".into()) } - if self.options.ign_tc { - options.push("ign-tc".into()) - } if !self.options.recurse { options.push("no-recurse".into()) } diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index db8130369..23a9d10eb 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -161,8 +161,6 @@ impl StubResolver { redun.add(Box::new(tcp_conn)).await.unwrap(); } } - } else if self.options.ign_tc { - todo!(); } else { for s in &self.servers { if let Transport::Udp = s.transport { From 565ea77fe88667651b0d8ab4f503a22a421fa586 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Wed, 20 Dec 2023 12:01:13 +0100 Subject: [PATCH 087/124] Fix features and imports. --- Cargo.toml | 4 ++-- src/net/client/multi_stream.rs | 4 ++-- src/net/client/redundant.rs | 4 ++-- src/net/client/request_message.rs | 3 ++- src/net/client/tcp_connect.rs | 2 +- src/net/client/tls_connect.rs | 2 +- src/resolv/stub/mod.rs | 2 +- 7 files changed, 11 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e3df5779b..7264e3646 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,13 +45,13 @@ default = ["std", "rand"] bytes = ["dep:bytes", "octseq/bytes"] heapless = ["dep:heapless", "octseq/heapless"] interop = ["bytes", "ring"] -resolv = ["bytes", "futures-util", "net", "smallvec", "std", "tokio", "libc", "rand"] +resolv = ["net", "smallvec", "std", "rand"] resolv-sync = ["resolv", "tokio/rt"] serde = ["dep:serde", "octseq/serde"] sign = ["std"] smallvec = ["dep:smallvec", "octseq/smallvec"] std = [] -net = ["bytes", "futures", "std", "tokio", "tokio-rustls"] +net = ["bytes", "futures-util", "std", "tokio", "tokio-rustls"] tsig = ["bytes", "ring", "smallvec"] validate = ["std", "ring"] zonefile = ["bytes", "std"] diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index ff32c06ca..cdb45e725 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -8,8 +8,8 @@ use bytes::Bytes; -use futures::stream::FuturesUnordered; -use futures::StreamExt; +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; use octseq::Octets; diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index 417ec3872..1f84fe8f3 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -5,8 +5,8 @@ use bytes::Bytes; -use futures::stream::FuturesUnordered; -use futures::StreamExt; +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; use octseq::Octets; diff --git a/src/net/client/request_message.rs b/src/net/client/request_message.rs index a3c5f20bb..0fe5714ae 100644 --- a/src/net/client/request_message.rs +++ b/src/net/client/request_message.rs @@ -37,7 +37,8 @@ pub struct RequestMessage> { impl + Debug + Octets> RequestMessage { /// Create a new BMB object. - pub fn new(msg: Message) -> Self { + pub fn new(msg: impl Into>) -> Self { + let msg = msg.into(); let header = msg.header(); Self { msg, diff --git a/src/net/client/tcp_connect.rs b/src/net/client/tcp_connect.rs index 7fd5bf16f..c49cb5d73 100644 --- a/src/net/client/tcp_connect.rs +++ b/src/net/client/tcp_connect.rs @@ -4,7 +4,7 @@ #![warn(clippy::missing_docs_in_private_items)] use core::ops::DerefMut; -use futures::Future; +use core::future::Future; use std::boxed::Box; use std::fmt::Debug; use std::pin::Pin; diff --git a/src/net/client/tls_connect.rs b/src/net/client/tls_connect.rs index bb6c39501..9f45f72b4 100644 --- a/src/net/client/tls_connect.rs +++ b/src/net/client/tls_connect.rs @@ -4,7 +4,7 @@ #![warn(clippy::missing_docs_in_private_items)] use core::ops::DerefMut; -use futures::Future; +use core::future::Future; use std::boxed::Box; use std::pin::Pin; use std::string::String; diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index 23a9d10eb..8dc27d49d 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -31,7 +31,7 @@ use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts}; use crate::resolv::lookup::srv::{lookup_srv, FoundSrvs, SrvError}; use crate::resolv::resolver::{Resolver, SearchNames}; use bytes::Bytes; -use futures::stream::{FuturesUnordered, StreamExt}; +use futures_util::stream::{FuturesUnordered, StreamExt}; use octseq::array::Array; use std::boxed::Box; use std::fmt::Debug; From d6f4491ceedc39d1abe5c715207bbed6c908ffa2 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Wed, 20 Dec 2023 12:16:28 +0100 Subject: [PATCH 088/124] Fix format. --- src/net/client/tcp_connect.rs | 2 +- src/net/client/tls_connect.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/net/client/tcp_connect.rs b/src/net/client/tcp_connect.rs index c49cb5d73..e4ee597c1 100644 --- a/src/net/client/tcp_connect.rs +++ b/src/net/client/tcp_connect.rs @@ -3,8 +3,8 @@ #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] -use core::ops::DerefMut; use core::future::Future; +use core::ops::DerefMut; use std::boxed::Box; use std::fmt::Debug; use std::pin::Pin; diff --git a/src/net/client/tls_connect.rs b/src/net/client/tls_connect.rs index 9f45f72b4..5509a680d 100644 --- a/src/net/client/tls_connect.rs +++ b/src/net/client/tls_connect.rs @@ -3,8 +3,8 @@ #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] -use core::ops::DerefMut; use core::future::Future; +use core::ops::DerefMut; use std::boxed::Box; use std::pin::Pin; use std::string::String; From 5f225191b7c3b78cd27202ad7df147aa9da6eca6 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Thu, 21 Dec 2023 12:18:14 +0100 Subject: [PATCH 089/124] Move all transport protocol stuff into a single module. --- examples/client-transports.rs | 31 +++---- examples/readzone.rs | 2 +- src/base/message_builder.rs | 92 +++++++++++++------ src/net/client/async_connect.rs | 24 ----- src/net/client/mod.rs | 8 +- src/net/client/multi_stream.rs | 2 +- src/net/client/protocol/connect.rs | 20 ++++ src/net/client/protocol/mod.rs | 9 ++ .../{tcp_connect.rs => protocol/tcp.rs} | 6 +- .../{tls_connect.rs => protocol/tls.rs} | 6 +- src/net/client/udp_tcp.rs | 2 +- src/net/mod.rs | 2 +- src/resolv/stub/mod.rs | 2 +- 13 files changed, 125 insertions(+), 81 deletions(-) delete mode 100644 src/net/client/async_connect.rs create mode 100644 src/net/client/protocol/connect.rs create mode 100644 src/net/client/protocol/mod.rs rename src/net/client/{tcp_connect.rs => protocol/tcp.rs} (94%) rename src/net/client/{tls_connect.rs => protocol/tls.rs} (96%) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 706a251df..476ea3658 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -1,13 +1,13 @@ +/// Using the `domain::net::client` module for sending a query. use domain::base::Dname; use domain::base::MessageBuilder; use domain::base::Rtype::Aaaa; use domain::net::client::multi_stream; use domain::net::client::octet_stream; +use domain::net::client::protocol::{TcpConnect, TlsConnect}; use domain::net::client::redundant; use domain::net::client::request::Request; use domain::net::client::request_message::RequestMessage; -use domain::net::client::tcp_connect::TcpConnect; -use domain::net::client::tls_connect::TlsConnect; use domain::net::client::udp; use domain::net::client::udp_tcp; use std::net::{IpAddr, SocketAddr}; @@ -20,31 +20,28 @@ use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore}; #[tokio::main] async fn main() { - // Create DNS request message. It would be nice if there was an object - // that implements both MessageBuilder and BaseMEssageBuilder. Until - // that time, first create a message using MessageBuilder, then turn - // that into a Message, and create a BaseMessaBuilder based on the message. + // Create DNS request message. + // + // Transports currently take a `RequestMessage` as their input to be able + // to add options along the way. + // + // In the future, it will also be possible to pass in a message or message + // builder directly as input but for now it needs to be converted into a + // `RequestMessage` manually. let mut msg = MessageBuilder::new_vec(); msg.header_mut().set_rd(true); let mut msg = msg.question(); - msg.push((Dname::>::vec_from_str("example.com").unwrap(), Aaaa)) + msg.push((Dname::vec_from_str("example.com").unwrap(), Aaaa)) .unwrap(); - - // Create a Message to pass to BMB. - let msg = msg.into_message(); - - // Transports take a BaseMEssageBuilder to be able to add options along - // the way and only flatten just before actually writing to the network. let req = RequestMessage::new(msg); // Destination for UDP and TCP let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); - let octet_stream_config = octet_stream::Config { - response_timeout: Duration::from_millis(100), - }; let multi_stream_config = multi_stream::Config { - octet_stream: Some(octet_stream_config.clone()), + octet_stream: Some(octet_stream::Config { + response_timeout: Duration::from_millis(100), + }), }; // Create a new UDP+TCP transport connection. Pass the destination address diff --git a/examples/readzone.rs b/examples/readzone.rs index c854eade0..e07cafc89 100644 --- a/examples/readzone.rs +++ b/examples/readzone.rs @@ -15,7 +15,7 @@ fn main() { start.elapsed().unwrap().as_secs_f32() ); let mut i = 0; - while let Some(_) = zone.next_entry().unwrap() { + while zone.next_entry().unwrap().is_some() { i += 1; if i % 100_000_000 == 0 { eprintln!( diff --git a/src/base/message_builder.rs b/src/base/message_builder.rs index 6702bf427..c38ce2ddc 100644 --- a/src/base/message_builder.rs +++ b/src/base/message_builder.rs @@ -300,14 +300,16 @@ impl> MessageBuilder { /// # Conversions /// -impl MessageBuilder { +impl MessageBuilder { /// Converts the message builder into a message builder /// /// This is a no-op. pub fn builder(self) -> MessageBuilder { self } +} +impl MessageBuilder { /// Converts the message builder into a question builder. pub fn question(self) -> QuestionBuilder { QuestionBuilder::new(self) @@ -340,15 +342,14 @@ impl MessageBuilder { pub fn finish(self) -> Target { self.target } +} +impl MessageBuilder { /// Converts the builder into a message. /// /// The method will return a message atop whatever octets sequence the /// builder’s octets builder converts into. - pub fn into_message(self) -> Message<::Octets> - where - Target: FreezeBuilder, - { + pub fn into_message(self) -> Message { unsafe { Message::from_octets_unchecked(self.target.freeze()) } } } @@ -448,6 +449,15 @@ where } } +impl From> for Message +where + Target: FreezeBuilder, +{ + fn from(src: MessageBuilder) -> Self { + src.into_message() + } +} + //--- AsRef // // XXX Should we deref down to target? @@ -554,14 +564,16 @@ impl QuestionBuilder { } } -impl QuestionBuilder { +impl QuestionBuilder { /// Converts the question builder into a question builder. /// /// In other words, doesn’t do anything. pub fn question(self) -> QuestionBuilder { self } +} +impl QuestionBuilder { /// Converts the question builder into an answer builder. pub fn answer(self) -> AnswerBuilder { AnswerBuilder::new(self.builder) @@ -587,15 +599,14 @@ impl QuestionBuilder { pub fn finish(self) -> Target { self.builder.finish() } +} +impl QuestionBuilder { /// Converts the question builder into the final message. /// /// The method will return a message atop whatever octets sequence the /// builder’s octets builder converts into. - pub fn into_message(self) -> Message - where - Target: FreezeBuilder, - { + pub fn into_message(self) -> Message { self.builder.into_message() } } @@ -650,6 +661,15 @@ where } } +impl From> for Message +where + Target: FreezeBuilder, +{ + fn from(src: QuestionBuilder) -> Self { + src.into_message() + } +} + //--- Deref, DerefMut, AsRef, and AsMut impl Deref for QuestionBuilder { @@ -831,15 +851,14 @@ impl AnswerBuilder { pub fn finish(self) -> Target { self.builder.finish() } +} +impl AnswerBuilder { /// Converts the answer builder into the final message. /// /// The method will return a message atop whatever octets sequence the /// builder’s octets builder converts into. - pub fn into_message(self) -> Message - where - Target: FreezeBuilder, - { + pub fn into_message(self) -> Message { self.builder.into_message() } } @@ -894,6 +913,15 @@ where } } +impl From> for Message +where + Target: FreezeBuilder, +{ + fn from(src: AnswerBuilder) -> Self { + src.into_message() + } +} + //--- Deref, DerefMut, AsRef, and AsMut impl Deref for AnswerBuilder { @@ -1055,9 +1083,7 @@ impl AuthorityBuilder { self.rewind(); self.answer } -} -impl AuthorityBuilder { /// Converts the authority builder into an authority builder. /// /// This is identical to the identity function. @@ -1076,15 +1102,14 @@ impl AuthorityBuilder { pub fn finish(self) -> Target { self.answer.finish() } +} +impl AuthorityBuilder { /// Converts the authority builder into the final message. /// /// The method will return a message atop whatever octets sequence the /// builder’s octets builder converts into. - pub fn into_message(self) -> Message - where - Target: FreezeBuilder, - { + pub fn into_message(self) -> Message { self.answer.into_message() } } @@ -1139,6 +1164,15 @@ where } } +impl From> for Message +where + Target: FreezeBuilder, +{ + fn from(src: AuthorityBuilder) -> Self { + src.into_message() + } +} + //--- Deref, DerefMut, AsRef, and AsMut impl Deref for AuthorityBuilder { @@ -1336,9 +1370,7 @@ impl AdditionalBuilder { self.rewind(); self.authority } -} -impl AdditionalBuilder { /// Converts the additional builder into an additional builder. /// /// In other words, does absolutely nothing. @@ -1350,15 +1382,14 @@ impl AdditionalBuilder { pub fn finish(self) -> Target { self.authority.finish() } +} +impl AdditionalBuilder { /// Converts the additional builder into the final message. /// /// The method will return a message atop whatever octets sequence the /// builder’s octets builder converts into. - pub fn into_message(self) -> Message - where - Target: FreezeBuilder, - { + pub fn into_message(self) -> Message { self.authority.into_message() } } @@ -1413,6 +1444,15 @@ where } } +impl From> for Message +where + Target: FreezeBuilder, +{ + fn from(src: AdditionalBuilder) -> Self { + src.into_message() + } +} + //--- Deref, DerefMut, AsRef, and AsMut impl Deref for AdditionalBuilder { diff --git a/src/net/client/async_connect.rs b/src/net/client/async_connect.rs deleted file mode 100644 index 547173fef..000000000 --- a/src/net/client/async_connect.rs +++ /dev/null @@ -1,24 +0,0 @@ -//! Trait for asynchronously creating connections. - -#![warn(missing_docs)] -#![warn(clippy::missing_docs_in_private_items)] - -use std::future::Future; - -/// This trait is for creating new network connections asynchronously. -/// -/// The IO type is the type of the resulting connection object. -pub trait AsyncConnect { - /// The next method is an asynchronous function that returns a - /// new connection. - /// - /// This method is equivalent to async fn connect(&self) -> Result; - - type Connection; - - /// Associated type for the return type of next. - type F: Future> + Send; - - /// Get the next IO connection. - fn connect(&self) -> Self::F; -} diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index cd2823415..e3bb151f3 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -7,7 +7,11 @@ #![doc = include_str!("../../../examples/client-transports.rs")] //! ``` -pub mod async_connect; +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +pub mod protocol; + pub mod compose_request; pub mod error; pub mod multi_stream; @@ -15,7 +19,5 @@ pub mod octet_stream; pub mod redundant; pub mod request; pub mod request_message; -pub mod tcp_connect; -pub mod tls_connect; pub mod udp; pub mod udp_tcp; diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index cdb45e725..8e348b794 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -30,10 +30,10 @@ use tokio::time::{sleep_until, Instant}; use crate::base::iana::Rcode; use crate::base::Message; -use crate::net::client::async_connect::AsyncConnect; use crate::net::client::compose_request::ComposeRequest; use crate::net::client::error::Error; use crate::net::client::octet_stream; +use crate::net::client::protocol::AsyncConnect; use crate::net::client::request::{GetResponse, Request}; /// Capacity of the channel that transports [ChanReq]. diff --git a/src/net/client/protocol/connect.rs b/src/net/client/protocol/connect.rs new file mode 100644 index 000000000..cd655b4cc --- /dev/null +++ b/src/net/client/protocol/connect.rs @@ -0,0 +1,20 @@ +//! Asynchronously establishing a connection. + +use std::future::Future; +use std::io; + +//------------ AsyncConnect -------------------------------------------------- + +/// Establish a connection asynchronously. +/// +/// +pub trait AsyncConnect { + /// The type of an established connection. + type Connection; + + /// The future establishing the connection. + type Fut: Future> + Send; + + /// Returns a future that establishing a connection. + fn connect(&self) -> Self::Fut; +} diff --git a/src/net/client/protocol/mod.rs b/src/net/client/protocol/mod.rs new file mode 100644 index 000000000..5e8f2bd54 --- /dev/null +++ b/src/net/client/protocol/mod.rs @@ -0,0 +1,9 @@ +//! Underlying transport protocols. + +pub use self::connect::AsyncConnect; +pub use self::tcp::TcpConnect; +pub use self::tls::TlsConnect; + +mod connect; +mod tcp; +mod tls; diff --git a/src/net/client/tcp_connect.rs b/src/net/client/protocol/tcp.rs similarity index 94% rename from src/net/client/tcp_connect.rs rename to src/net/client/protocol/tcp.rs index e4ee597c1..b00cfc2fd 100644 --- a/src/net/client/tcp_connect.rs +++ b/src/net/client/protocol/tcp.rs @@ -11,7 +11,7 @@ use std::pin::Pin; use std::task::{ready, Context, Poll}; use tokio::net::{TcpStream, ToSocketAddrs}; -use crate::net::client::async_connect::AsyncConnect; +use crate::net::client::protocol::AsyncConnect; //------------ TcpConnect -------------------------------------------------- @@ -34,11 +34,11 @@ impl AsyncConnect for TcpConnect { type Connection = TcpStream; - type F = Pin< + type Fut = Pin< Box> + Send>, >; - fn connect(&self) -> Self::F { + fn connect(&self) -> Self::Fut { Box::pin(Next { future: Box::pin(TcpStream::connect(self.addr.clone())), }) diff --git a/src/net/client/tls_connect.rs b/src/net/client/protocol/tls.rs similarity index 96% rename from src/net/client/tls_connect.rs rename to src/net/client/protocol/tls.rs index 5509a680d..5e2aa074b 100644 --- a/src/net/client/tls_connect.rs +++ b/src/net/client/protocol/tls.rs @@ -15,7 +15,7 @@ use tokio_rustls::client::TlsStream; use tokio_rustls::rustls::{ClientConfig, ServerName}; use tokio_rustls::TlsConnector; -use crate::net::client::async_connect::AsyncConnect; +use crate::net::client::protocol::AsyncConnect; //------------ TlsConnect ----------------------------------------------------- @@ -50,14 +50,14 @@ impl AsyncConnect for TlsConnect { type Connection = TlsStream; - type F = Pin< + type Fut = Pin< Box< dyn Future, std::io::Error>> + Send, >, >; - fn connect(&self) -> Self::F { + fn connect(&self) -> Self::Fut { let tls_connection = TlsConnector::from(self.client_config.clone()); let server_name = match ServerName::try_from(self.server_name.as_str()) { diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index e4fd61990..b6aa8299b 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -18,8 +18,8 @@ use crate::base::Message; use crate::net::client::compose_request::ComposeRequest; use crate::net::client::error::Error; use crate::net::client::multi_stream; +use crate::net::client::protocol::TcpConnect; use crate::net::client::request::{GetResponse, Request}; -use crate::net::client::tcp_connect::TcpConnect; use crate::net::client::udp; //------------ Config --------------------------------------------------------- diff --git a/src/net/mod.rs b/src/net/mod.rs index ebc56aacd..5eb5b6195 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -1,4 +1,4 @@ -//! DNS transport protocols +//! Sending and receiving DNS messages #![cfg(feature = "net")] #![cfg_attr(docsrs, doc(cfg(feature = "net")))] diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index 8dc27d49d..c7399c09f 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -21,10 +21,10 @@ use crate::base::name::{ToDname, ToRelativeDname}; use crate::base::question::Question; use crate::net::client::compose_request::ComposeRequest; use crate::net::client::multi_stream; +use crate::net::client::protocol::TcpConnect; use crate::net::client::redundant; use crate::net::client::request::Request; use crate::net::client::request_message::RequestMessage; -use crate::net::client::tcp_connect::TcpConnect; use crate::net::client::udp_tcp; use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs}; use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts}; From 01056b69ce44ca0547a2a453691a8386c1ef9ff9 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Thu, 21 Dec 2023 17:06:52 +0100 Subject: [PATCH 090/124] Move things around. Have RequestMessage keep an OptRecord. --- examples/client-transports.rs | 21 +- src/base/opt/mod.rs | 167 +++++++++++++- src/net/client/compose_request.rs | 41 ---- src/net/client/error.rs | 163 ------------- src/net/client/mod.rs | 6 +- src/net/client/multi_stream.rs | 6 +- src/net/client/octet_stream.rs | 11 +- src/net/client/protocol.rs | 116 ++++++++++ src/net/client/protocol/connect.rs | 20 -- src/net/client/protocol/mod.rs | 9 - src/net/client/protocol/tcp.rs | 70 ------ src/net/client/protocol/tls.rs | 117 ---------- src/net/client/redundant.rs | 3 +- src/net/client/request.rs | 355 ++++++++++++++++++++++++++++- src/net/client/request_message.rs | 178 --------------- src/net/client/udp.rs | 6 +- src/net/client/udp_tcp.rs | 6 +- src/resolv/stub/mod.rs | 4 +- 18 files changed, 656 insertions(+), 643 deletions(-) delete mode 100644 src/net/client/compose_request.rs delete mode 100644 src/net/client/error.rs create mode 100644 src/net/client/protocol.rs delete mode 100644 src/net/client/protocol/connect.rs delete mode 100644 src/net/client/protocol/mod.rs delete mode 100644 src/net/client/protocol/tcp.rs delete mode 100644 src/net/client/protocol/tls.rs delete mode 100644 src/net/client/request_message.rs diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 476ea3658..6b5d582ce 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -6,13 +6,11 @@ use domain::net::client::multi_stream; use domain::net::client::octet_stream; use domain::net::client::protocol::{TcpConnect, TlsConnect}; use domain::net::client::redundant; -use domain::net::client::request::Request; -use domain::net::client::request_message::RequestMessage; +use domain::net::client::request::{Request, RequestMessage}; use domain::net::client::udp; use domain::net::client::udp_tcp; use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; -use std::sync::Arc; use std::time::Duration; use tokio::net::TcpStream; use tokio::time::timeout; @@ -124,12 +122,10 @@ async fn main() { )); // TLS config - let client_config = Arc::new( - ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_no_client_auth(), - ); + let client_config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); // Currently the only support TLS connections are the ones that have a // valid certificate. Use a well known public resolver. @@ -138,8 +134,11 @@ async fn main() { // Create a new TLS connections object. We pass the TLS config, the name of // the remote server and the destination address and port. - let tls_connect = - TlsConnect::new(client_config, "dns.google", google_server_addr); + let tls_connect = TlsConnect::new( + client_config, + "dns.google".try_into().unwrap(), + google_server_addr, + ); // Again create a multi_stream transport connection. let tls_conn = diff --git a/src/base/opt/mod.rs b/src/base/opt/mod.rs index cb6a6ce0c..5557ce94e 100644 --- a/src/base/opt/mod.rs +++ b/src/base/opt/mod.rs @@ -40,17 +40,17 @@ opt_types! { //============ Module Content ================================================ use super::header::Header; -use super::iana::{OptRcode, OptionCode, Rtype}; +use super::iana::{Class, OptRcode, OptionCode, Rtype}; use super::name::{Dname, ToDname}; use super::rdata::{ComposeRecordData, ParseRecordData, RecordData}; -use super::record::Record; -use super::wire::{Composer, FormError, ParseError}; +use super::record::{Record, Ttl}; +use super::wire::{Compose, Composer, FormError, ParseError}; use crate::utils::base16; use core::cmp::Ordering; use core::convert::TryInto; use core::marker::PhantomData; use core::{fmt, hash, mem}; -use octseq::builder::{OctetsBuilder, ShortBuf}; +use octseq::builder::{EmptyBuilder, OctetsBuilder, ShortBuf}; use octseq::octets::{Octets, OctetsFrom}; use octseq::parse::Parser; @@ -76,6 +76,15 @@ pub struct Opt { octets: Octs, } +impl Opt { + /// Creates empty OPT record data. + pub fn empty() -> Self { + Self { + octets: Octs::empty(), + } + } +} + impl> Opt { /// Creates OPT record data from an octets sequence. /// @@ -83,7 +92,18 @@ impl> Opt { /// options. It does not check whether the options themselves are valid. pub fn from_octets(octets: Octs) -> Result { Opt::check_slice(octets.as_ref())?; - Ok(Opt { octets }) + Ok(unsafe { Self::from_octets_unchecked(octets) }) + } + + /// Creates OPT record data from octets without checking. + /// + /// # Safety + /// + /// The caller needs to ensure that the slice contains correctly encoded + /// OPT record data. The data of the options themselves does not need to + /// be correct. + unsafe fn from_octets_unchecked(octets: Octs) -> Self { + Self { octets } } /// Parses OPT record data from the beginning of a parser. @@ -128,6 +148,12 @@ impl Opt<[u8]> { } } +impl + ?Sized> Opt { + pub fn for_slice_ref(&self) -> Opt<&[u8]> { + unsafe { Opt::from_octets_unchecked(self.octets.as_ref()) } + } +} + impl + ?Sized> Opt { /// Returns the length of the OPT record data. pub fn len(&self) -> usize { @@ -163,6 +189,44 @@ impl + ?Sized> Opt { } } +impl Opt { + /// Appends a new option to the OPT data. + pub fn push( + &mut self, + option: &Opt, + ) -> Result<(), BuildDataError> { + self.push_raw_option(option.code(), option.compose_len(), |target| { + option.compose_option(target) + }) + } + + /// Appends a raw option to the OPT data. + /// + /// The method will append an option with the given option code. The data + /// of the option will be written via the closure `op`. + pub fn push_raw_option( + &mut self, + code: OptionCode, + option_len: u16, + op: F, + ) -> Result<(), BuildDataError> + where + F: FnOnce(&mut Octs) -> Result<(), Octs::AppendError>, + { + LongOptData::check_len( + self.octets + .as_ref() + .len() + .saturating_add(usize::from(option_len)), + )?; + + code.compose(&mut self.octets)?; + option_len.compose(&mut self.octets)?; + op(&mut self.octets)?; + Ok(()) + } +} + //--- OctetsFrom impl OctetsFrom> for Opt @@ -283,7 +347,8 @@ impl + ?Sized> fmt::Debug for Opt { /// /// The OPT record reappropriates the record header for encoding some /// basic information. This type provides access to this information. It -/// consists of the record header accept for its `rdlen` field. +/// consists of the record header with the exception of the fiinal `rdlen` +/// field. /// /// This is so that `OptBuilder` can safely deref to this type. /// @@ -440,6 +505,23 @@ impl OptRecord { } } + /// Converts the OPT record into a regular record. + pub fn as_record(&self) -> Record<&'static Dname<[u8]>, Opt<&[u8]>> + where + Octs: AsRef<[u8]>, + { + Record::new( + Dname::root_slice(), + Class::Int(self.udp_payload_size), + Ttl::from_secs( + u32::from(self.ext_rcode) << 24 + | u32::from(self.version) << 16 + | u32::from(self.flags), + ), + self.data.for_slice_ref(), + ) + } + /// Returns the UDP payload size. /// /// Through this field a sender of a message can signal the maximum size @@ -451,6 +533,11 @@ impl OptRecord { self.udp_payload_size } + /// Sets the UDP payload size. + pub fn set_udp_payload_size(&mut self, value: u16) { + self.udp_payload_size = value + } + /// Returns the extended rcode. /// /// Some of the bits of the rcode are stored in the regular message @@ -484,6 +571,44 @@ impl OptRecord { } } +impl OptRecord { + /// Appends a new option to the OPT data. + pub fn push( + &mut self, + option: &Opt, + ) -> Result<(), BuildDataError> { + self.data.push(option) + } + + /// Appends a raw option to the OPT data. + /// + /// The method will append an option with the given option code. The data + /// of the option will be written via the closure `op`. + pub fn push_raw_option( + &mut self, + code: OptionCode, + option_len: u16, + op: F, + ) -> Result<(), BuildDataError> + where + F: FnOnce(&mut Octs) -> Result<(), Octs::AppendError>, + { + self.data.push_raw_option(code, option_len, op) + } +} + +impl Default for OptRecord { + fn default() -> Self { + Self { + udp_payload_size: 0, + ext_rcode: 0, + version: 0, + flags: 0, + data: Opt::empty(), + } + } +} + //--- From impl From>> for OptRecord { @@ -521,6 +646,20 @@ impl AsRef> for OptRecord { } } +//--- Debug + +impl> fmt::Debug for OptRecord { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("OptRecord") + .field("udp_payload_size", &self.udp_payload_size) + .field("ext_rcord", &self.ext_rcode) + .field("version", &self.version) + .field("flags", &self.flags) + .field("data", &self.data) + .finish() + } +} + //------------ OptionHeader -------------------------------------------------- /// The header of an OPT option. @@ -859,13 +998,27 @@ impl std::error::Error for LongOptData {} /// An error happened while constructing an SVCB value. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum BuildDataError { - /// The value would exceed the allow length of a value. + /// The value would exceed the allowed length of a value. LongOptData, /// The underlying octets builder ran out of buffer space. ShortBuf, } +impl BuildDataError { + /// Converts the error into a `LongOptData` error for ‘endless’ buffers. + /// + /// # Panics + /// + /// This method will panic if the error is of the `ShortBuf` variant. + pub fn unlimited_buf(self) -> LongOptData { + match self { + Self::LongOptData => LongOptData(()), + Self::ShortBuf => panic!("ShortBuf on unlimited buffer"), + } + } +} + impl From for BuildDataError { fn from(_: LongOptData) -> Self { Self::LongOptData diff --git a/src/net/client/compose_request.rs b/src/net/client/compose_request.rs deleted file mode 100644 index 9bd040479..000000000 --- a/src/net/client/compose_request.rs +++ /dev/null @@ -1,41 +0,0 @@ -//! Trait for composing a request by applying limited changes. - -#![warn(missing_docs)] -#![warn(clippy::missing_docs_in_private_items)] - -use crate::base::opt::TcpKeepalive; -use crate::base::Header; -use crate::base::Message; - -use std::boxed::Box; -use std::fmt::Debug; -use std::vec::Vec; - -#[derive(Clone, Debug)] -/// Capture the various EDNS options. -pub enum OptTypes { - /// TcpKeepalive variant - TypeTcpKeepalive(TcpKeepalive), -} - -/// A trait that allows composing a request as a series. -pub trait ComposeRequest: Debug + Send + Sync { - /// Return a boxed dyn of the current object. - fn as_box_dyn(&self) -> Box; - - /// Create a message that captures the recorded changes. - fn to_message(&self) -> Message>; - - /// Create a message that captures the recorded changes and convert to - /// a Vec. - fn to_vec(&self) -> Vec; - - /// Return a reference to a mutable Header to record changes to the header. - fn header_mut(&mut self) -> &mut Header; - - /// Set the UDP payload size. - fn set_udp_payload_size(&mut self, value: u16); - - /// Add an EDNS option. - fn add_opt(&mut self, opt: OptTypes); -} diff --git a/src/net/client/error.rs b/src/net/client/error.rs deleted file mode 100644 index 2aeb5e68c..000000000 --- a/src/net/client/error.rs +++ /dev/null @@ -1,163 +0,0 @@ -//! Error type for client transports. - -#![warn(missing_docs)] -#![warn(clippy::missing_docs_in_private_items)] - -use std::error; -use std::fmt::{Display, Formatter}; -use std::sync::Arc; - -/// Error type for client transports. -#[derive(Clone, Debug)] -pub enum Error { - /// Connection was already closed. - ConnectionClosed, - - /// PushError from MessageBuilder. - MessageBuilderPushError, - - /// ParseError from Message. - MessageParseError, - - /// octet_stream configuration error. - OctetStreamConfigError(Arc), - - /// Underlying transport not found in redundant connection - RedundantTransportNotFound, - - /// Octet sequence too short to be a valid DNS message. - ShortMessage, - - /// Stream transport closed because it was idle (for too long). - StreamIdleTimeout, - - /// Error receiving a reply. - StreamReceiveError, - - /// Reading from stream gave an error. - StreamReadError(Arc), - - /// Reading from stream took too long. - StreamReadTimeout, - - /// Too many outstand queries on a single stream transport. - StreamTooManyOutstandingQueries, - - /// Writing to a stream gave an error. - StreamWriteError(Arc), - - /// Reading for a stream ended unexpectedly. - StreamUnexpectedEndOfData, - - /// Binding a UDP socket gave an error. - UdpBind(Arc), - - /// UDP configuration error. - UdpConfigError(Arc), - - /// Connecting a UDP socket gave an error. - UdpConnect(Arc), - - /// Receiving from a UDP socket gave an error. - UdpReceive(Arc), - - /// Sending over a UDP socket gaven an error. - UdpSend(Arc), - - /// Sending over a UDP socket gave a partial result. - UdpShortSend, - - /// Timeout receiving a response over a UDP socket. - UdpTimeoutNoResponse, - - /// Reply does not match the query. - WrongReplyForQuery, - - /// No transport available to transmit request. - NoTransportAvailable, -} - -impl Display for Error { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - match self { - Error::ConnectionClosed => write!(f, "connection closed"), - Error::MessageBuilderPushError => { - write!(f, "PushError from MessageBuilder") - } - Error::MessageParseError => write!(f, "ParseError from Message"), - Error::OctetStreamConfigError(_) => write!(f, "bad config value"), - Error::RedundantTransportNotFound => write!( - f, - "Underlying transport not found in redundant connection" - ), - Error::ShortMessage => { - write!(f, "octet sequence to short to be a valid message") - } - Error::StreamIdleTimeout => { - write!(f, "stream was idle for too long") - } - Error::StreamReceiveError => write!(f, "error receiving a reply"), - Error::StreamReadError(_) => { - write!(f, "error reading from stream") - } - Error::StreamReadTimeout => { - write!(f, "timeout reading from stream") - } - Error::StreamTooManyOutstandingQueries => { - write!(f, "too many outstanding queries on stream") - } - Error::StreamWriteError(_) => { - write!(f, "error writing to stream") - } - Error::StreamUnexpectedEndOfData => { - write!(f, "unexpected end of data") - } - Error::UdpBind(_) => write!(f, "error binding UDP socket"), - Error::UdpConfigError(_) => write!(f, "bad config value"), - Error::UdpConnect(_) => write!(f, "error connecting UDP socket"), - Error::UdpReceive(_) => { - write!(f, "error receiving from UDP socket") - } - Error::UdpSend(_) => write!(f, "error sending to UDP socket"), - Error::UdpShortSend => write!(f, "partial sent to UDP socket"), - Error::UdpTimeoutNoResponse => { - write!(f, "timeout waiting for response") - } - Error::WrongReplyForQuery => { - write!(f, "reply does not match query") - } - Error::NoTransportAvailable => { - write!(f, "no transport available") - } - } - } -} - -impl error::Error for Error { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - match self { - Error::ConnectionClosed => None, - Error::MessageBuilderPushError => None, - Error::MessageParseError => None, - Error::OctetStreamConfigError(e) => Some(e), - Error::RedundantTransportNotFound => None, - Error::ShortMessage => None, - Error::StreamIdleTimeout => None, - Error::StreamReceiveError => None, - Error::StreamReadError(e) => Some(e), - Error::StreamReadTimeout => None, - Error::StreamTooManyOutstandingQueries => None, - Error::StreamWriteError(e) => Some(e), - Error::StreamUnexpectedEndOfData => None, - Error::UdpBind(e) => Some(e), - Error::UdpConfigError(e) => Some(e), - Error::UdpConnect(e) => Some(e), - Error::UdpReceive(e) => Some(e), - Error::UdpSend(e) => Some(e), - Error::UdpShortSend => None, - Error::UdpTimeoutNoResponse => None, - Error::WrongReplyForQuery => None, - Error::NoTransportAvailable => None, - } - } -} diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index e3bb151f3..4c0e7b1ee 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -10,14 +10,10 @@ #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] -pub mod protocol; - -pub mod compose_request; -pub mod error; pub mod multi_stream; pub mod octet_stream; +pub mod protocol; pub mod redundant; pub mod request; -pub mod request_message; pub mod udp; pub mod udp_tcp; diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 8e348b794..1bdfb3241 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -30,11 +30,11 @@ use tokio::time::{sleep_until, Instant}; use crate::base::iana::Rcode; use crate::base::Message; -use crate::net::client::compose_request::ComposeRequest; -use crate::net::client::error::Error; use crate::net::client::octet_stream; use crate::net::client::protocol::AsyncConnect; -use crate::net::client::request::{GetResponse, Request}; +use crate::net::client::request::{ + ComposeRequest, Error, GetResponse, Request, +}; /// Capacity of the channel that transports [ChanReq]. const DEF_CHAN_CAP: usize = 8; diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 5526471db..1fe109540 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -33,10 +33,9 @@ use crate::base::{ opt::{AllOptData, OptRecord, TcpKeepalive}, Message, }; -use crate::net::client::compose_request::ComposeRequest; -use crate::net::client::compose_request::OptTypes; -use crate::net::client::error::Error; -use crate::net::client::request::{GetResponse, Request}; +use crate::net::client::request::{ + ComposeRequest, Error, GetResponse, Request, +}; use octseq::Octets; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -905,7 +904,7 @@ impl InnerConnection { /// Convert the query message to a vector. // This function should return the vector instead of storing it // through a reference. - fn convert_query(msg: &dyn ComposeRequest, reqmsg: &mut Option>) { + fn convert_query(msg: &CR, reqmsg: &mut Option>) { // Ideally there should be a write_all_vectored. Until there is one, // copy to a new Vec and prepend the length octets. @@ -989,7 +988,7 @@ impl InnerConnection { /// Add an edns-tcp-keepalive option to a BaseMessageBuilder. fn add_tcp_keepalive(msg: &mut CR) -> Result<(), Error> { - msg.add_opt(OptTypes::TypeTcpKeepalive(TcpKeepalive::new(None))); + msg.add_opt(&TcpKeepalive::new(None))?; Ok(()) } diff --git a/src/net/client/protocol.rs b/src/net/client/protocol.rs new file mode 100644 index 000000000..531a314c8 --- /dev/null +++ b/src/net/client/protocol.rs @@ -0,0 +1,116 @@ +//! Underlying transport protocols. + +use core::future::Future; +use core::pin::Pin; +use std::boxed::Box; +use std::io; +use std::sync::Arc; +use tokio::net::{TcpStream, ToSocketAddrs}; +use tokio_rustls::client::TlsStream; +use tokio_rustls::rustls::{ClientConfig, ServerName}; +use tokio_rustls::TlsConnector; + +//------------ AsyncConnect -------------------------------------------------- + +/// Establish a connection asynchronously. +/// +/// +pub trait AsyncConnect { + /// The type of an established connection. + type Connection; + + /// The future establishing the connection. + type Fut: Future> + Send; + + /// Returns a future that establishing a connection. + fn connect(&self) -> Self::Fut; +} + +//------------ TcpConnect -------------------------------------------------- + +/// Create new TCP connections. +#[derive(Clone, Copy, Debug)] +pub struct TcpConnect { + /// Remote address to connect to. + addr: Addr, +} + +impl TcpConnect { + /// Create new TCP connections. + /// + /// addr is the destination address to connect to. + pub fn new(addr: Addr) -> Self { + Self { addr } + } +} + +impl AsyncConnect for TcpConnect +where + Addr: ToSocketAddrs + Clone + Send + 'static, +{ + type Connection = TcpStream; + type Fut = Pin< + Box< + dyn Future> + + Send, + >, + >; + + fn connect(&self) -> Self::Fut { + Box::pin(TcpStream::connect(self.addr.clone())) + } +} + +//------------ TlsConnect ----------------------------------------------------- + +/// Create new TLS connections +#[derive(Clone, Debug)] +pub struct TlsConnect { + /// Configuration for setting up a TLS connection. + client_config: Arc, + + /// Server name for certificate verification. + server_name: ServerName, + + /// Remote address to connect to. + addr: Addr, +} + +impl TlsConnect { + /// Function to create a new TLS connection stream + pub fn new( + client_config: impl Into>, + server_name: ServerName, + addr: Addr, + ) -> Self { + Self { + client_config: client_config.into(), + server_name, + addr, + } + } +} + +impl AsyncConnect for TlsConnect +where + Addr: ToSocketAddrs + Clone + Send + 'static, +{ + type Connection = TlsStream; + type Fut = Pin< + Box< + dyn Future> + + Send, + >, + >; + + fn connect(&self) -> Self::Fut { + let tls_connection = TlsConnector::from(self.client_config.clone()); + let server_name = self.server_name.clone(); + let addr = self.addr.clone(); + Box::pin(async { + let box_connection = Box::new(tls_connection); + let tcp = TcpStream::connect(addr).await?; + box_connection.connect(server_name, tcp).await + }) + } +} diff --git a/src/net/client/protocol/connect.rs b/src/net/client/protocol/connect.rs deleted file mode 100644 index cd655b4cc..000000000 --- a/src/net/client/protocol/connect.rs +++ /dev/null @@ -1,20 +0,0 @@ -//! Asynchronously establishing a connection. - -use std::future::Future; -use std::io; - -//------------ AsyncConnect -------------------------------------------------- - -/// Establish a connection asynchronously. -/// -/// -pub trait AsyncConnect { - /// The type of an established connection. - type Connection; - - /// The future establishing the connection. - type Fut: Future> + Send; - - /// Returns a future that establishing a connection. - fn connect(&self) -> Self::Fut; -} diff --git a/src/net/client/protocol/mod.rs b/src/net/client/protocol/mod.rs deleted file mode 100644 index 5e8f2bd54..000000000 --- a/src/net/client/protocol/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -//! Underlying transport protocols. - -pub use self::connect::AsyncConnect; -pub use self::tcp::TcpConnect; -pub use self::tls::TlsConnect; - -mod connect; -mod tcp; -mod tls; diff --git a/src/net/client/protocol/tcp.rs b/src/net/client/protocol/tcp.rs deleted file mode 100644 index b00cfc2fd..000000000 --- a/src/net/client/protocol/tcp.rs +++ /dev/null @@ -1,70 +0,0 @@ -//! Create new TCP connections. - -#![warn(missing_docs)] -#![warn(clippy::missing_docs_in_private_items)] - -use core::future::Future; -use core::ops::DerefMut; -use std::boxed::Box; -use std::fmt::Debug; -use std::pin::Pin; -use std::task::{ready, Context, Poll}; -use tokio::net::{TcpStream, ToSocketAddrs}; - -use crate::net::client::protocol::AsyncConnect; - -//------------ TcpConnect -------------------------------------------------- - -/// Create new TCP connections. -pub struct TcpConnect { - /// Remote address to connect to. - addr: A, -} - -impl TcpConnect { - /// Create new TCP connections. - /// - /// addr is the destination address to connect to. - pub fn new(addr: A) -> Self { - Self { addr } - } -} - -impl AsyncConnect - for TcpConnect -{ - type Connection = TcpStream; - type Fut = Pin< - Box> + Send>, - >; - - fn connect(&self) -> Self::Fut { - Box::pin(Next { - future: Box::pin(TcpStream::connect(self.addr.clone())), - }) - } -} - -//------------ Next ----------------------------------------------------------- - -/// This is an internal structure that provides the future for a new -/// connection. -pub struct Next { - /// Future for creating a new TCP connection. - future: Pin< - Box> + Send>, - >, -} - -impl Future for Next { - type Output = Result; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let me = self.deref_mut(); - let io = ready!(me.future.as_mut().poll(cx))?; - Poll::Ready(Ok(io)) - } -} diff --git a/src/net/client/protocol/tls.rs b/src/net/client/protocol/tls.rs deleted file mode 100644 index 5e2aa074b..000000000 --- a/src/net/client/protocol/tls.rs +++ /dev/null @@ -1,117 +0,0 @@ -//! Create new TLS connections - -#![warn(missing_docs)] -#![warn(clippy::missing_docs_in_private_items)] - -use core::future::Future; -use core::ops::DerefMut; -use std::boxed::Box; -use std::pin::Pin; -use std::string::String; -use std::sync::Arc; -use std::task::{ready, Context, Poll}; -use tokio::net::{TcpStream, ToSocketAddrs}; -use tokio_rustls::client::TlsStream; -use tokio_rustls::rustls::{ClientConfig, ServerName}; -use tokio_rustls::TlsConnector; - -use crate::net::client::protocol::AsyncConnect; - -//------------ TlsConnect ----------------------------------------------------- - -/// Create new TLS connections -pub struct TlsConnect { - /// Configuration for setting up a TLS connection. - client_config: Arc, - - /// Server name for certificate verification. - server_name: String, - - /// Remote address to connect to. - addr: A, -} - -impl TlsConnect { - /// Function to create a new TLS connection stream - pub fn new( - client_config: Arc, - server_name: &str, - addr: A, - ) -> Self { - Self { - client_config, - server_name: String::from(server_name), - addr, - } - } -} - -impl AsyncConnect - for TlsConnect -{ - type Connection = TlsStream; - type Fut = Pin< - Box< - dyn Future, std::io::Error>> - + Send, - >, - >; - - fn connect(&self) -> Self::Fut { - let tls_connection = TlsConnector::from(self.client_config.clone()); - let server_name = - match ServerName::try_from(self.server_name.as_str()) { - Err(_) => { - return Box::pin(error_helper(std::io::Error::new( - std::io::ErrorKind::Other, - "invalid DNS name", - ))); - } - Ok(res) => res, - }; - let addr = self.addr.clone(); - Box::pin(Next { - future: Box::pin(async { - let box_connection = Box::new(tls_connection); - let tcp = TcpStream::connect(addr).await?; - box_connection.connect(server_name, tcp).await - }), - }) - } -} - -//------------ Next ----------------------------------------------------------- - -/// Internal structure that contains the future for creating a new -/// TLS connection. -pub struct Next { - /// Future for creating a new TLS connection. - future: Pin< - Box< - dyn Future, std::io::Error>> - + Send, - >, - >, -} - -impl Future for Next { - type Output = Result, std::io::Error>; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, std::io::Error>> { - let me = self.deref_mut(); - let io = ready!(me.future.as_mut().poll(cx))?; - Poll::Ready(Ok(io)) - } -} - -//------------ Utility -------------------------------------------------------- - -/// Helper to return an error as an async function. -async fn error_helper( - err: std::io::Error, -) -> Result, std::io::Error> { - Err(err) -} diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index 1f84fe8f3..2b2a3de6c 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -26,8 +26,7 @@ use tokio::time::{sleep_until, Duration, Instant}; use crate::base::iana::OptRcode; use crate::base::Message; -use crate::net::client::error::Error; -use crate::net::client::request::{GetResponse, Request}; +use crate::net::client::request::{Error, GetResponse, Request}; /* Basic algorithm: diff --git a/src/net/client/request.rs b/src/net/client/request.rs index d0cd3b79b..7ee7ab0a3 100644 --- a/src/net/client/request.rs +++ b/src/net/client/request.rs @@ -3,14 +3,46 @@ #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] +use crate::base::opt::{ComposeOptData, LongOptData, OptRecord}; +use crate::base::{ + Header, Message, MessageBuilder, ParsedDname, Rtype, StaticCompressor, +}; +use crate::rdata::AllRecordData; use bytes::Bytes; +use octseq::Octets; use std::boxed::Box; use std::fmt::Debug; use std::future::Future; use std::pin::Pin; +use std::sync::Arc; +use std::vec::Vec; +use std::{error, fmt}; -use crate::base::Message; -use crate::net::client::error::Error; +//------------ ComposeRequest ------------------------------------------------ + +/// A trait that allows composing a request as a series. +pub trait ComposeRequest: Debug + Send + Sync { + /// Create a message that captures the recorded changes. + fn to_message(&self) -> Message>; + + /// Create a message that captures the recorded changes and convert to + /// a Vec. + fn to_vec(&self) -> Vec; + + /// Return a reference to a mutable Header to record changes to the header. + fn header_mut(&mut self) -> &mut Header; + + /// Set the UDP payload size. + fn set_udp_payload_size(&mut self, value: u16); + + /// Add an EDNS option. + fn add_opt( + &mut self, + opt: &impl ComposeOptData, + ) -> Result<(), LongOptData>; +} + +//------------ Request ------------------------------------------------------- /// Trait for starting a DNS request based on a request composer. /// @@ -30,6 +62,8 @@ pub trait Request { /// request function in the Request trait. type RequestResultOutput = Result, Error>; +//------------ GetResponse --------------------------------------------------- + /// Trait for getting the result of a DNS query. /// /// In the future, the return type of get_response should become an associated @@ -44,3 +78,320 @@ pub trait GetResponse: Debug { Box, Error>> + Send + '_>, >; } + +//------------ RequestMessage ------------------------------------------------ + +/// Object that implements the ComposeRequest trait for a Message object. +#[derive(Clone, Debug)] +pub struct RequestMessage> { + /// Base message. + msg: Message, + + /// New header. + header: Header, + + /// The OPT record to add if required. + opt: Option>>, +} + +impl + Debug + Octets> RequestMessage { + /// Create a new BMB object. + pub fn new(msg: impl Into>) -> Self { + let msg = msg.into(); + let header = msg.header(); + Self { + msg, + header, + opt: None, + } + } + + /// Returns a mutable reference to the OPT record. + /// + /// Adds one if necessary. + fn opt_mut(&mut self) -> &mut OptRecord> { + self.opt.get_or_insert_with(Default::default) + } + + /// Create new message based on the changes to the base message. + fn to_message_impl(&self) -> Result>, Error> { + let source = &self.msg; + + let mut target = + MessageBuilder::from_target(StaticCompressor::new(Vec::new())) + .expect("Vec is expected to have enough space"); + let target_hdr = target.header_mut(); + target_hdr.set_flags(self.header.flags()); + target_hdr.set_opcode(self.header.opcode()); + target_hdr.set_rcode(self.header.rcode()); + target_hdr.set_id(self.header.id()); + + let source = source.question(); + let mut target = target.question(); + for rr in source { + let rr = rr.map_err(|_e| Error::MessageParseError)?; + target + .push(rr) + .map_err(|_e| Error::MessageBuilderPushError)?; + } + let mut source = + source.answer().map_err(|_e| Error::MessageParseError)?; + let mut target = target.answer(); + for rr in &mut source { + let rr = rr.map_err(|_e| Error::MessageParseError)?; + let rr = rr + .into_record::>>() + .map_err(|_e| Error::MessageParseError)? + .expect("record expected"); + target + .push(rr) + .map_err(|_e| Error::MessageBuilderPushError)?; + } + + let mut source = source + .next_section() + .map_err(|_e| Error::MessageParseError)? + .expect("section should be present"); + let mut target = target.authority(); + for rr in &mut source { + let rr = rr.map_err(|_e| Error::MessageParseError)?; + let rr = rr + .into_record::>>() + .map_err(|_e| Error::MessageParseError)? + .expect("record expected"); + target + .push(rr) + .map_err(|_e| Error::MessageBuilderPushError)?; + } + + let source = source + .next_section() + .map_err(|_e| Error::MessageParseError)? + .expect("section should be present"); + let mut target = target.additional(); + for rr in source { + let rr = rr.map_err(|_e| Error::MessageParseError)?; + if rr.rtype() == Rtype::Opt { + } else { + let rr = rr + .into_record::>>() + .map_err(|_e| Error::MessageParseError)? + .expect("record expected"); + target + .push(rr) + .map_err(|_e| Error::MessageBuilderPushError)?; + } + } + + if let Some(opt) = self.opt.as_ref() { + target + .push(opt.as_record()) + .map_err(|_| Error::MessageBuilderPushError)?; + } + + // It would be nice to use .builder() here. But that one deletes all + // section. We have to resort to .as_builder() which gives a + // reference and then .clone() + let result = target.as_builder().clone(); + let msg = Message::from_octets(result.finish().into_target()).expect( + "Message should be able to parse output from MessageBuilder", + ); + Ok(msg) + } +} + +impl + Clone + Debug + Octets + Send + Sync + 'static> + ComposeRequest for RequestMessage +{ + fn to_vec(&self) -> Vec { + let msg = self.to_message(); + msg.as_octets().clone() + } + + fn to_message(&self) -> Message> { + self.to_message_impl().unwrap() + } + + fn header_mut(&mut self) -> &mut Header { + &mut self.header + } + + fn set_udp_payload_size(&mut self, value: u16) { + self.opt_mut().set_udp_payload_size(value); + } + + fn add_opt( + &mut self, + opt: &impl ComposeOptData, + ) -> Result<(), LongOptData> { + self.opt_mut().push(opt).map_err(|e| e.unlimited_buf()) + } +} + +//------------ Error --------------------------------------------------------- + +/// Error type for client transports. +#[derive(Clone, Debug)] +pub enum Error { + /// Connection was already closed. + ConnectionClosed, + + /// The OPT record has become too long. + OptTooLong, + + /// PushError from MessageBuilder. + MessageBuilderPushError, + + /// ParseError from Message. + MessageParseError, + + /// octet_stream configuration error. + OctetStreamConfigError(Arc), + + /// Underlying transport not found in redundant connection + RedundantTransportNotFound, + + /// Octet sequence too short to be a valid DNS message. + ShortMessage, + + /// Stream transport closed because it was idle (for too long). + StreamIdleTimeout, + + /// Error receiving a reply. + StreamReceiveError, + + /// Reading from stream gave an error. + StreamReadError(Arc), + + /// Reading from stream took too long. + StreamReadTimeout, + + /// Too many outstand queries on a single stream transport. + StreamTooManyOutstandingQueries, + + /// Writing to a stream gave an error. + StreamWriteError(Arc), + + /// Reading for a stream ended unexpectedly. + StreamUnexpectedEndOfData, + + /// Binding a UDP socket gave an error. + UdpBind(Arc), + + /// UDP configuration error. + UdpConfigError(Arc), + + /// Connecting a UDP socket gave an error. + UdpConnect(Arc), + + /// Receiving from a UDP socket gave an error. + UdpReceive(Arc), + + /// Sending over a UDP socket gaven an error. + UdpSend(Arc), + + /// Sending over a UDP socket gave a partial result. + UdpShortSend, + + /// Timeout receiving a response over a UDP socket. + UdpTimeoutNoResponse, + + /// Reply does not match the query. + WrongReplyForQuery, + + /// No transport available to transmit request. + NoTransportAvailable, +} + +impl From for Error { + fn from(_: LongOptData) -> Self { + Self::OptTooLong + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::ConnectionClosed => write!(f, "connection closed"), + Error::OptTooLong => write!(f, "OPT record is too long"), + Error::MessageBuilderPushError => { + write!(f, "PushError from MessageBuilder") + } + Error::MessageParseError => write!(f, "ParseError from Message"), + Error::OctetStreamConfigError(_) => write!(f, "bad config value"), + Error::RedundantTransportNotFound => write!( + f, + "Underlying transport not found in redundant connection" + ), + Error::ShortMessage => { + write!(f, "octet sequence to short to be a valid message") + } + Error::StreamIdleTimeout => { + write!(f, "stream was idle for too long") + } + Error::StreamReceiveError => write!(f, "error receiving a reply"), + Error::StreamReadError(_) => { + write!(f, "error reading from stream") + } + Error::StreamReadTimeout => { + write!(f, "timeout reading from stream") + } + Error::StreamTooManyOutstandingQueries => { + write!(f, "too many outstanding queries on stream") + } + Error::StreamWriteError(_) => { + write!(f, "error writing to stream") + } + Error::StreamUnexpectedEndOfData => { + write!(f, "unexpected end of data") + } + Error::UdpBind(_) => write!(f, "error binding UDP socket"), + Error::UdpConfigError(_) => write!(f, "bad config value"), + Error::UdpConnect(_) => write!(f, "error connecting UDP socket"), + Error::UdpReceive(_) => { + write!(f, "error receiving from UDP socket") + } + Error::UdpSend(_) => write!(f, "error sending to UDP socket"), + Error::UdpShortSend => write!(f, "partial sent to UDP socket"), + Error::UdpTimeoutNoResponse => { + write!(f, "timeout waiting for response") + } + Error::WrongReplyForQuery => { + write!(f, "reply does not match query") + } + Error::NoTransportAvailable => { + write!(f, "no transport available") + } + } + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match self { + Error::ConnectionClosed => None, + Error::OptTooLong => None, + Error::MessageBuilderPushError => None, + Error::MessageParseError => None, + Error::OctetStreamConfigError(e) => Some(e), + Error::RedundantTransportNotFound => None, + Error::ShortMessage => None, + Error::StreamIdleTimeout => None, + Error::StreamReceiveError => None, + Error::StreamReadError(e) => Some(e), + Error::StreamReadTimeout => None, + Error::StreamTooManyOutstandingQueries => None, + Error::StreamWriteError(e) => Some(e), + Error::StreamUnexpectedEndOfData => None, + Error::UdpBind(e) => Some(e), + Error::UdpConfigError(e) => Some(e), + Error::UdpConnect(e) => Some(e), + Error::UdpReceive(e) => Some(e), + Error::UdpSend(e) => Some(e), + Error::UdpShortSend => None, + Error::UdpTimeoutNoResponse => None, + Error::WrongReplyForQuery => None, + Error::NoTransportAvailable => None, + } + } +} diff --git a/src/net/client/request_message.rs b/src/net/client/request_message.rs deleted file mode 100644 index 0fe5714ae..000000000 --- a/src/net/client/request_message.rs +++ /dev/null @@ -1,178 +0,0 @@ -//! Simple object that implements the ComposeRequest trait. - -#![warn(missing_docs)] -#![warn(clippy::missing_docs_in_private_items)] - -use crate::base::Header; -use crate::base::Message; -use crate::base::MessageBuilder; -use crate::base::ParsedDname; -use crate::base::Rtype; -use crate::base::StaticCompressor; -use crate::dep::octseq::Octets; -use crate::net::client::compose_request::ComposeRequest; -use crate::net::client::compose_request::OptTypes; -use crate::net::client::error::Error; -use crate::rdata::AllRecordData; - -use std::boxed::Box; -use std::fmt::Debug; -use std::vec::Vec; - -#[derive(Clone, Debug)] -/// Object that implements the ComposeRequest trait for a Message object. -pub struct RequestMessage> { - /// Base message. - msg: Message, - - /// New header. - header: Header, - - /// Collection of EDNS options to add. - opts: Vec, - - /// UDP payload size. - udp_payload_size: Option, -} - -impl + Debug + Octets> RequestMessage { - /// Create a new BMB object. - pub fn new(msg: impl Into>) -> Self { - let msg = msg.into(); - let header = msg.header(); - Self { - msg, - header, - opts: Vec::new(), - udp_payload_size: None, - } - } - - /// Create new message based on the changes to the base message. - fn to_message_impl(&self) -> Result>, Error> { - let source = &self.msg; - - let mut target = - MessageBuilder::from_target(StaticCompressor::new(Vec::new())) - .expect("Vec is expected to have enough space"); - let target_hdr = target.header_mut(); - target_hdr.set_flags(self.header.flags()); - target_hdr.set_opcode(self.header.opcode()); - target_hdr.set_rcode(self.header.rcode()); - target_hdr.set_id(self.header.id()); - - let source = source.question(); - let mut target = target.question(); - for rr in source { - let rr = rr.map_err(|_e| Error::MessageParseError)?; - target - .push(rr) - .map_err(|_e| Error::MessageBuilderPushError)?; - } - let mut source = - source.answer().map_err(|_e| Error::MessageParseError)?; - let mut target = target.answer(); - for rr in &mut source { - let rr = rr.map_err(|_e| Error::MessageParseError)?; - let rr = rr - .into_record::>>() - .map_err(|_e| Error::MessageParseError)? - .expect("record expected"); - target - .push(rr) - .map_err(|_e| Error::MessageBuilderPushError)?; - } - - let mut source = source - .next_section() - .map_err(|_e| Error::MessageParseError)? - .expect("section should be present"); - let mut target = target.authority(); - for rr in &mut source { - let rr = rr.map_err(|_e| Error::MessageParseError)?; - let rr = rr - .into_record::>>() - .map_err(|_e| Error::MessageParseError)? - .expect("record expected"); - target - .push(rr) - .map_err(|_e| Error::MessageBuilderPushError)?; - } - - let source = source - .next_section() - .map_err(|_e| Error::MessageParseError)? - .expect("section should be present"); - let mut target = target.additional(); - for rr in source { - let rr = rr.map_err(|_e| Error::MessageParseError)?; - if rr.rtype() == Rtype::Opt { - } else { - let rr = rr - .into_record::>>() - .map_err(|_e| Error::MessageParseError)? - .expect("record expected"); - target - .push(rr) - .map_err(|_e| Error::MessageBuilderPushError)?; - } - } - - if self.udp_payload_size.is_some() || !self.opts.is_empty() { - target - .opt(|opt| { - if let Some(size) = self.udp_payload_size { - opt.set_udp_payload_size(size) - } - for o in &self.opts { - match o { - OptTypes::TypeTcpKeepalive(tka) => { - opt.tcp_keepalive(tka.timeout())? - } - } - } - Ok(()) - }) - .map_err(|_e| Error::MessageBuilderPushError)?; - } - - // It would be nice to use .builder() here. But that one deletes all - // section. We have to resort to .as_builder() which gives a - // reference and then .clone() - let result = target.as_builder().clone(); - let msg = Message::from_octets(result.finish().into_target()).expect( - "Message should be able to parse output from MessageBuilder", - ); - Ok(msg) - } -} - -impl + Clone + Debug + Octets + Send + Sync + 'static> - ComposeRequest for RequestMessage -{ - fn as_box_dyn(&self) -> Box { - Box::new(self.clone()) - } - - fn to_vec(&self) -> Vec { - let msg = self.to_message(); - msg.as_octets().clone() - } - - fn to_message(&self) -> Message> { - self.to_message_impl().unwrap() - } - - fn header_mut(&mut self) -> &mut Header { - &mut self.header - } - - fn set_udp_payload_size(&mut self, value: u16) { - self.udp_payload_size = Some(value); - } - - fn add_opt(&mut self, opt: OptTypes) { - self.opts.push(opt); - //println!("add_opt: after push: {:?}", self); - } -} diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 771647358..4ad2b2b92 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -22,9 +22,9 @@ use tokio::time::{timeout, Duration, Instant}; use crate::base::iana::Rcode; use crate::base::Message; -use crate::net::client::compose_request::ComposeRequest; -use crate::net::client::error::Error; -use crate::net::client::request::{GetResponse, Request}; +use crate::net::client::request::{ + ComposeRequest, Error, GetResponse, Request, +}; /// How many times do we try a new random port if we get ‘address in use.’ const RETRY_RANDOM_PORT: usize = 10; diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index b6aa8299b..7e98a112f 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -15,11 +15,11 @@ use std::pin::Pin; use std::sync::Arc; use crate::base::Message; -use crate::net::client::compose_request::ComposeRequest; -use crate::net::client::error::Error; use crate::net::client::multi_stream; use crate::net::client::protocol::TcpConnect; -use crate::net::client::request::{GetResponse, Request}; +use crate::net::client::request::{ + ComposeRequest, Error, GetResponse, Request, +}; use crate::net::client::udp; //------------ Config --------------------------------------------------------- diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index c7399c09f..563a3bcde 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -19,12 +19,10 @@ use crate::base::message_builder::{ }; use crate::base::name::{ToDname, ToRelativeDname}; use crate::base::question::Question; -use crate::net::client::compose_request::ComposeRequest; use crate::net::client::multi_stream; use crate::net::client::protocol::TcpConnect; use crate::net::client::redundant; -use crate::net::client::request::Request; -use crate::net::client::request_message::RequestMessage; +use crate::net::client::request::{ComposeRequest, Request, RequestMessage}; use crate::net::client::udp_tcp; use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs}; use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts}; From a895038b98c1ef3854046daf15092312d789e09d Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Thu, 21 Dec 2023 17:28:13 +0100 Subject: [PATCH 091/124] Rename Request to SendRequest. --- examples/client-transports.rs | 14 +++++++------- src/net/client/multi_stream.rs | 8 +++++--- src/net/client/octet_stream.rs | 8 +++++--- src/net/client/redundant.rs | 17 +++++++++-------- src/net/client/request.rs | 6 +++--- src/net/client/udp.rs | 6 +++--- src/net/client/udp_tcp.rs | 12 +++++++----- src/resolv/stub/mod.rs | 6 ++++-- 8 files changed, 43 insertions(+), 34 deletions(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 6b5d582ce..a0d8469c5 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -6,7 +6,7 @@ use domain::net::client::multi_stream; use domain::net::client::octet_stream; use domain::net::client::protocol::{TcpConnect, TlsConnect}; use domain::net::client::redundant; -use domain::net::client::request::{Request, RequestMessage}; +use domain::net::client::request::{RequestMessage, SendRequest}; use domain::net::client::udp; use domain::net::client::udp_tcp; use std::net::{IpAddr, SocketAddr}; @@ -68,7 +68,7 @@ async fn main() { }); // Send a query message. - let mut request = udptcp_conn.request(&req).await.unwrap(); + let mut request = udptcp_conn.send_request(&req).await.unwrap(); // Get the reply println!("Wating for UDP+TCP reply"); @@ -98,7 +98,7 @@ async fn main() { }); // Send a query message. - let mut request = tcp_conn.request(&req).await.unwrap(); + let mut request = tcp_conn.send_request(&req).await.unwrap(); // Get the reply. A multi_stream connection does not have any timeout. // Wrap get_result in a timeout. @@ -151,7 +151,7 @@ async fn main() { println!("TLS run exited with {:?}", res); }); - let mut request = tls_conn.request(&req).await.unwrap(); + let mut request = tls_conn.send_request(&req).await.unwrap(); println!("Wating for TLS reply"); let reply = timeout(Duration::from_millis(500), request.get_response()).await; @@ -176,7 +176,7 @@ async fn main() { // Start a few queries. for i in 1..10 { - let mut request = redun.request(&req).await.unwrap(); + let mut request = redun.send_request(&req).await.unwrap(); let reply = request.get_response().await; if i == 2 { println!("redundant connection reply: {:?}", reply); @@ -193,7 +193,7 @@ async fn main() { udp::Connection::new(Some(udp_config), server_addr).unwrap(); // Send a query message. - let mut request = udp_conn.request(&req).await.unwrap(); + let mut request = udp_conn.send_request(&req).await.unwrap(); // Get the reply let reply = request.get_response().await; @@ -220,7 +220,7 @@ async fn main() { }); // Send a request message. - let mut request = tcp.request(&req).await.unwrap(); + let mut request = tcp.send_request(&req).await.unwrap(); // Get the reply let reply = request.get_response().await; diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 1bdfb3241..f523e4bf6 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -33,7 +33,7 @@ use crate::base::Message; use crate::net::client::octet_stream; use crate::net::client::protocol::AsyncConnect; use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, Request, + ComposeRequest, Error, GetResponse, SendRequest, }; /// Capacity of the channel that transports [ChanReq]. @@ -122,8 +122,10 @@ impl Connection { } } -impl Request for Connection { - fn request<'a>( +impl SendRequest + for Connection +{ + fn send_request<'a>( &'a self, request_msg: &'a CR, ) -> Pin< diff --git a/src/net/client/octet_stream.rs b/src/net/client/octet_stream.rs index 1fe109540..2c64983ed 100644 --- a/src/net/client/octet_stream.rs +++ b/src/net/client/octet_stream.rs @@ -34,7 +34,7 @@ use crate::base::{ Message, }; use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, Request, + ComposeRequest, Error, GetResponse, SendRequest, }; use octseq::Octets; @@ -149,8 +149,10 @@ impl Connection { } } -impl Request for Connection { - fn request<'a>( +impl SendRequest + for Connection +{ + fn send_request<'a>( &'a self, request_msg: &'a CR, ) -> Pin< diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index 2b2a3de6c..c172f5d63 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -26,7 +26,7 @@ use tokio::time::{sleep_until, Duration, Instant}; use crate::base::iana::OptRcode; use crate::base::Message; -use crate::net::client::request::{Error, GetResponse, Request}; +use crate::net::client::request::{Error, GetResponse, SendRequest}; /* Basic algorithm: @@ -113,7 +113,7 @@ impl<'a, BMB: Clone + Debug + Send + Sync + 'static> Connection { /// Add a transport connection. pub async fn add( &self, - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, ) -> Result<(), Error> { self.inner.add(conn).await } @@ -128,10 +128,10 @@ impl<'a, BMB: Clone + Debug + Send + Sync + 'static> Connection { } } -impl Request +impl SendRequest for Connection { - fn request<'a>( + fn send_request<'a>( &'a self, request_msg: &'a BMB, ) -> Pin< @@ -227,7 +227,7 @@ impl Debug for ChanReq { /// Request to add a new connection struct AddReq { /// New connection to add - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, /// Channel to send the reply to tx: oneshot::Sender, @@ -592,7 +592,8 @@ impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { let mut next_id: u64 = 10; let mut conn_stats: Vec = Vec::new(); let mut conn_rt: Vec = Vec::new(); - let mut conns: Vec + Send + Sync>> = Vec::new(); + let mut conns: Vec + Send + Sync>> = + Vec::new(); let mut receiver = opt_receiver.expect("receiver should not be empty"); @@ -630,7 +631,7 @@ impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { match opt_ind { Some(ind) => { let query = conns[ind] - .request(&request_req.request_msg) + .send_request(&request_req.request_msg) .await; // Don't care if send fails let _ = request_req.tx.send(query); @@ -692,7 +693,7 @@ impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { /// Implementation of the add method. async fn add( &self, - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, ) -> Result<(), Error> { let (tx, rx) = oneshot::channel(); self.sender diff --git a/src/net/client/request.rs b/src/net/client/request.rs index 7ee7ab0a3..dbeff8f2e 100644 --- a/src/net/client/request.rs +++ b/src/net/client/request.rs @@ -1,4 +1,4 @@ -//! Traits for request/response transports +//! Requests. #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] @@ -48,11 +48,11 @@ pub trait ComposeRequest: Debug + Send + Sync { /// /// In the future, the return type of request should become an associated type. /// However, the use of 'dyn Request' in redundant currently prevents that. -pub trait Request { +pub trait SendRequest { /// Request function that takes a ComposeRequest type. /// /// This function is intended to be cancel safe. - fn request<'a>( + fn send_request<'a>( &'a self, request_msg: &'a CR, ) -> Pin + Send + '_>>; diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs index 4ad2b2b92..635335462 100644 --- a/src/net/client/udp.rs +++ b/src/net/client/udp.rs @@ -23,7 +23,7 @@ use tokio::time::{timeout, Duration, Instant}; use crate::base::iana::Rcode; use crate::base::Message; use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, Request, + ComposeRequest, Error, GetResponse, SendRequest, }; /// How many times do we try a new random port if we get ‘address in use.’ @@ -137,10 +137,10 @@ impl Connection { } } -impl Request +impl SendRequest for Connection { - fn request<'a>( + fn send_request<'a>( &'a self, request_msg: &'a CR, ) -> Pin< diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs index 7e98a112f..f11e3cc37 100644 --- a/src/net/client/udp_tcp.rs +++ b/src/net/client/udp_tcp.rs @@ -18,7 +18,7 @@ use crate::base::Message; use crate::net::client::multi_stream; use crate::net::client::protocol::TcpConnect; use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, Request, + ComposeRequest, Error, GetResponse, SendRequest, }; use crate::net::client::udp; @@ -80,8 +80,10 @@ impl Connection { } } -impl Request for Connection { - fn request<'a>( +impl SendRequest + for Connection +{ + fn send_request<'a>( &'a self, request_msg: &'a CR, ) -> Pin< @@ -154,7 +156,7 @@ impl ReqResp { match &mut self.state { QueryState::StartUdpRequest => { let msg = self.request_msg.clone(); - let request = self.udp_conn.request(&msg).await?; + let request = self.udp_conn.send_request(&msg).await?; self.state = QueryState::GetUdpResponse(request); continue; } @@ -168,7 +170,7 @@ impl ReqResp { } QueryState::StartTcpRequest => { let msg = self.request_msg.clone(); - let request = self.tcp_conn.request(&msg).await?; + let request = self.tcp_conn.send_request(&msg).await?; self.state = QueryState::GetTcpResponse(request); continue; } diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index 563a3bcde..606132b70 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -22,7 +22,9 @@ use crate::base::question::Question; use crate::net::client::multi_stream; use crate::net::client::protocol::TcpConnect; use crate::net::client::redundant; -use crate::net::client::request::{ComposeRequest, Request, RequestMessage}; +use crate::net::client::request::{ + ComposeRequest, RequestMessage, SendRequest, +}; use crate::net::client::udp_tcp; use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs}; use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts}; @@ -408,7 +410,7 @@ impl<'a> Query<'a> { let request_msg = RequestMessage::new(msg); let transport = self.resolver.get_transport().await; - let mut gr_fut = transport.request(&request_msg).await.unwrap(); + let mut gr_fut = transport.send_request(&request_msg).await.unwrap(); let reply = timeout(self.resolver.options.timeout, gr_fut.get_response()) .await From 36ce917d8d4173541ffbe8201dd4a06bd6f8dcd6 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Thu, 28 Dec 2023 13:36:20 +0100 Subject: [PATCH 092/124] Introduce unstable features. --- Cargo.toml | 3 +++ src/lib.rs | 34 +++++++++++++++++++++++++++++++--- src/net/client/mod.rs | 7 ++++--- src/net/client/request.rs | 2 +- src/net/mod.rs | 12 +++++++++++- 5 files changed, 50 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7264e3646..e434306e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,9 @@ tsig = ["bytes", "ring", "smallvec"] validate = ["std", "ring"] zonefile = ["bytes", "std"] +# Unstable features +unstable-client-transport = [] + # This feature should include all features that the CI should include for a # test run. Which is everything except interop. ci-test = ["resolv", "resolv-sync", "sign", "std", "serde", "tsig", "validate", "zonefile"] diff --git a/src/lib.rs b/src/lib.rs index 8184d2d4e..e628d5e7f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,9 @@ //! //! Currently, there are the following modules: //! +#![cfg_attr(feature = "net", doc = "* [net]:")] +#![cfg_attr(not(feature = "net"), doc = "* net:")] +//! Sending and receiving DNS message. #![cfg_attr(feature = "resolv", doc = "* [resolv]:")] #![cfg_attr(not(feature = "resolv"), doc = "* resolv:")] //! An asynchronous DNS resolver based on the @@ -48,12 +51,13 @@ //! Finally, the [dep] module contains re-exports of some important //! dependencies to help avoid issues with multiple versions of a crate. //! -//! # Reference of Feature Flags +//! # Reference of feature flags //! -//! The following is the complete list of the feature flags available. +//! The following is the complete list of the feature flags with the +//! exception of unstable features which are described below. //! //! * `bytes`: Enables using the types `Bytes` and `BytesMut` from the -//! [bytes](https://github.com/tokio-rs/bytes) crate as octet sequences. +//! [bytes](https://github.com/tokio-rs/bytes) crate as octet sequences. //! * `chrono`: Adds the [chrono](https://github.com/chronotope/chrono) //! crate as a dependency. This adds support for generating serial numbers //! from time stamps. @@ -104,6 +108,30 @@ #![cfg_attr(feature = "zonefile", doc = " [zonefile]")] #![cfg_attr(not(feature = "zonefile"), doc = " zonefile")] //! module and currently also enables the `bytes` and `std` features. +//! +//! # Unstable features +//! +//! When adding new functionality to the crate, practical experience is +//! necessary to arrive at a good, user friendly design. Unstable features +//! allow adding and rapidly changing new code without having to release +//! versions allowing breaking changes all the time. If you use unstable +//! features, it is best to specify a concrete version as a dependency in +//! `Cargo.toml` using the `=` operator, e.g.: +//! +//! ``` +//! [dependencies] +//! domain = "=0.9.3" +//! ``` +//! +//! Currently, the following unstable features exist: +//! +//! * `unstable-client-transport`: sending and receiving DNS messages from +//! a client perspective; primarily the `net::client` module. +//! +//! Note: Some functionality is currently informally marked as +//! “experimental” since it was introduced before adoption of the concept +//! of unstable features. These will follow proper Semver practice but may +//! significant changes in releases with breakting changes. #![no_std] #![allow(renamed_and_removed_lints)] diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 4c0e7b1ee..d2033e7d7 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -1,12 +1,13 @@ -//! DNS transport protocols -#![cfg(feature = "net")] -#![cfg_attr(docsrs, doc(cfg(feature = "net")))] +//! Sending requests and receiving responses. //! # Example with various transport connections //! ``` #![doc = include_str!("../../../examples/client-transports.rs")] //! ``` +#![cfg(feature = "unstable-client-transport")] +#![cfg_attr(docsrs, doc(cfg(feature = "unstable-client-transport")))] + #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] diff --git a/src/net/client/request.rs b/src/net/client/request.rs index dbeff8f2e..81b11f509 100644 --- a/src/net/client/request.rs +++ b/src/net/client/request.rs @@ -1,4 +1,4 @@ -//! Requests. +//! Constructing and sending requests. #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] diff --git a/src/net/mod.rs b/src/net/mod.rs index 5eb5b6195..049aed754 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -1,4 +1,14 @@ -//! Sending and receiving DNS messages +//! Sending and receiving DNS messages. +//! +//! This module provides types, traits, and function for sending and receiving +//! DNS messages. +//! +//! Currently, the module only provides the unstable +#![cfg_attr(feature = "unstable-client-transport", doc = " [`client`]")] +#![cfg_attr(not(feature = "unstable-client-transport"), doc = " `client`")] +//! sub-module intended for sending requests and receiving responses to them. +#![cfg_attr(not(feature = "unstable-client-transport"), doc = " The `unstable-client-transport` feature is necessary to enable this module.")] +//! #![cfg(feature = "net")] #![cfg_attr(docsrs, doc(cfg(feature = "net")))] From 167c0e8d27afde7ff5adcb65f89a872c03e8ea10 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 22 Dec 2023 16:09:44 +0100 Subject: [PATCH 093/124] Initial tests for net::client --- Cargo.toml | 1 + test-data/basic.rpl | 161 +++ tests/net-client.rs | 117 +++ tests/net/deckard/client.rs | 91 ++ tests/net/deckard/connect.rs | 34 + tests/net/deckard/connection.rs | 102 ++ tests/net/deckard/matches.rs | 305 ++++++ tests/net/deckard/mod.rs | 7 + tests/net/deckard/parse_deckard.rs | 609 +++++++++++ tests/net/deckard/parse_query.rs | 1501 ++++++++++++++++++++++++++++ tests/net/deckard/server.rs | 133 +++ tests/net/mod.rs | 1 + 12 files changed, 3062 insertions(+) create mode 100644 test-data/basic.rpl create mode 100644 tests/net-client.rs create mode 100644 tests/net/deckard/client.rs create mode 100644 tests/net/deckard/connect.rs create mode 100644 tests/net/deckard/connection.rs create mode 100644 tests/net/deckard/matches.rs create mode 100644 tests/net/deckard/mod.rs create mode 100644 tests/net/deckard/parse_deckard.rs create mode 100644 tests/net/deckard/parse_query.rs create mode 100644 tests/net/deckard/server.rs create mode 100644 tests/net/mod.rs diff --git a/Cargo.toml b/Cargo.toml index e434306e3..db1f602b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,7 @@ rustls = { version = "0.21.9" } serde_test = "1.0.130" serde_yaml = "0.9" tokio = { version = "1", features = ["rt-multi-thread", "io-util", "net"] } +tokio-test = "0.4" webpki-roots = { version = "0.25" } [package.metadata.docs.rs] diff --git a/test-data/basic.rpl b/test-data/basic.rpl new file mode 100644 index 000000000..72f453fe0 --- /dev/null +++ b/test-data/basic.rpl @@ -0,0 +1,161 @@ +do-ip6: no + +; config options +; target-fetch-policy: "3 2 1 0 0" +; name: "." + stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET. +CONFIG_END + +SCENARIO_BEGIN Test iterator with NS falsely declaring referral answer as authoritative. + +; K.ROOT-SERVERS.NET. +RANGE_BEGIN 0 100 + ADDRESS 193.0.14.129 +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +. IN NS +SECTION ANSWER +. IN NS K.ROOT-SERVERS.NET. +SECTION ADDITIONAL +K.ROOT-SERVERS.NET. IN A 193.0.14.129 +ENTRY_END + +; net. +ENTRY_BEGIN +MATCH opcode qname +ADJUST copy_id copy_query +REPLY QR NOERROR +SECTION QUESTION +net. IN NS +SECTION AUTHORITY +. IN SOA . . 0 0 0 0 0 +ENTRY_END + +; root-servers.net. +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +root-servers.net. IN NS +SECTION ANSWER +root-servers.net. IN NS k.root-servers.net. +SECTION ADDITIONAL +k.root-servers.net. IN A 193.0.14.129 +ENTRY_END + +ENTRY_BEGIN +MATCH opcode qname +ADJUST copy_id copy_query +REPLY QR NOERROR +SECTION QUESTION +root-servers.net. IN A +SECTION AUTHORITY +root-servers.net. IN SOA . . 0 0 0 0 0 +ENTRY_END + +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +k.root-servers.net. IN A +SECTION ANSWER +k.root-servers.net. IN A 193.0.14.129 +SECTION ADDITIONAL +ENTRY_END + +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +k.root-servers.net. IN AAAA +SECTION AUTHORITY +root-servers.net. IN SOA . . 0 0 0 0 0 +ENTRY_END + +; gtld-servers.net. +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +gtld-servers.net. IN NS +SECTION ANSWER +gtld-servers.net. IN NS a.gtld-servers.net. +SECTION ADDITIONAL +a.gtld-servers.net. IN A 192.5.6.30 +ENTRY_END + +ENTRY_BEGIN +MATCH opcode qname +ADJUST copy_id copy_query +REPLY QR NOERROR +SECTION QUESTION +gtld-servers.net. IN A +SECTION AUTHORITY +gtld-servers.net. IN SOA . . 0 0 0 0 0 +ENTRY_END + +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +a.gtld-servers.net. IN A +SECTION ANSWER +a.gtld-servers.net. IN A 192.5.6.30 +SECTION ADDITIONAL +ENTRY_END + +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +a.gtld-servers.net. IN AAAA +SECTION AUTHORITY +gtld-servers.net. IN SOA . . 0 0 0 0 0 +ENTRY_END + +RANGE_END + +; a.gtld-servers.net. +RANGE_BEGIN 0 100 + ADDRESS 192.5.6.30 + +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id copy_query +REPLY QR RD NOERROR +SECTION QUESTION +example.com. IN A +SECTION ANSWER +example.com. IN A 93.184.216.34 +ENTRY_END + +RANGE_END + +STEP 1 QUERY +ENTRY_BEGIN +REPLY RD +SECTION QUESTION +example.com. IN A +ENTRY_END + +; recursion happens here. +STEP 10 CHECK_ANSWER +ENTRY_BEGIN +MATCH all +REPLY QR RD RA NOERROR +SECTION QUESTION +example.com. IN A +SECTION ANSWER +example.com. IN A 93.184.216.34 +ENTRY_END + +SCENARIO_END diff --git a/tests/net-client.rs b/tests/net-client.rs new file mode 100644 index 000000000..bda152c00 --- /dev/null +++ b/tests/net-client.rs @@ -0,0 +1,117 @@ +#![cfg(feature = "net")] +mod net; + +use crate::net::deckard::client::do_client; +use crate::net::deckard::client::CurrStepValue; +use crate::net::deckard::connect::Connect; +use crate::net::deckard::connection::Connection; +use crate::net::deckard::parse_deckard::parse_file; +use domain::net::client::multi_stream; +use domain::net::client::octet_stream; +use domain::net::client::redundant; +use std::fs::File; +use std::net::IpAddr; +use std::net::SocketAddr; +use std::str::FromStr; +use std::sync::Arc; +use tokio::net::TcpStream; +use tokio_test; + +const TEST_FILE: &str = "test-data/basic.rpl"; + +#[test] +fn single() { + tokio_test::block_on(async { + let file = File::open(TEST_FILE).unwrap(); + let deckard = parse_file(file); + + let step_value = Arc::new(CurrStepValue::new()); + let conn = Connection::new(deckard.clone(), step_value.clone()); + let octstr = octet_stream::Connection::new(None).unwrap(); + let run_fut = octstr.run(conn); + tokio::spawn(async move { + run_fut.await; + }); + + do_client(&deckard, octstr, &step_value).await; + }); +} + +#[test] +fn multi() { + tokio_test::block_on(async { + let file = File::open(TEST_FILE).unwrap(); + let deckard = parse_file(file); + + let step_value = Arc::new(CurrStepValue::new()); + let multi_conn = Connect::new(deckard.clone(), step_value.clone()); + let ms = multi_stream::Connection::new(None).unwrap(); + let run_fut = ms.run(multi_conn); + tokio::spawn(async move { + run_fut.await.unwrap(); + println!("multi conn run terminated"); + }); + + do_client(&deckard, ms.clone(), &step_value).await; + }); +} + +#[test] +fn redundant() { + tokio_test::block_on(async { + let file = File::open(TEST_FILE).unwrap(); + let deckard = parse_file(file); + + let step_value = Arc::new(CurrStepValue::new()); + let multi_conn = Connect::new(deckard.clone(), step_value.clone()); + let ms = multi_stream::Connection::new(None).unwrap(); + let run_fut = ms.run(multi_conn); + tokio::spawn(async move { + run_fut.await.unwrap(); + println!("multi conn run terminated"); + }); + + // Redundant add previous connection. + let redun = redundant::Connection::new(None).unwrap(); + let run_fut = redun.run(); + tokio::spawn(async move { + run_fut.await; + println!("redundant conn run terminated"); + }); + redun.add(Box::new(ms.clone())).await.unwrap(); + + do_client(&deckard, redun, &step_value).await; + }); +} + +#[test] +#[ignore] +// Connect directly to the internet. Disabled by default. +fn tcp() { + tokio_test::block_on(async { + let file = File::open(TEST_FILE).unwrap(); + let deckard = parse_file(file); + + let server_addr = + SocketAddr::new(IpAddr::from_str("9.9.9.9").unwrap(), 53); + + let tcp_conn = match TcpStream::connect(server_addr).await { + Ok(conn) => conn, + Err(err) => { + println!( + "TCP Connection to {server_addr} failed: {err}, exiting" + ); + return; + } + }; + + let tcp = octet_stream::Connection::new(None).unwrap(); + let run_fut = tcp.run(tcp_conn); + tokio::spawn(async move { + run_fut.await; + println!("single TCP run terminated"); + }); + + do_client(&deckard, tcp, &CurrStepValue::new()).await; + }); +} diff --git a/tests/net/deckard/client.rs b/tests/net/deckard/client.rs new file mode 100644 index 000000000..80a75d91a --- /dev/null +++ b/tests/net/deckard/client.rs @@ -0,0 +1,91 @@ +use crate::net::deckard::matches::match_msg; +use crate::net::deckard::parse_deckard::{Deckard, Entry, Reply, StepType}; +use crate::net::deckard::parse_query; +use bytes::Bytes; + +use domain::base::{Message, MessageBuilder}; +use domain::net::client::request::RequestMessage; +use domain::net::client::request::SendRequest; +use std::sync::Mutex; + +pub async fn do_client>>>( + deckard: &Deckard, + request: R, + step_value: &CurrStepValue, +) { + let mut resp: Option> = None; + + // Assume steps are in order. Maybe we need to define that. + for step in &deckard.scenario.steps { + step_value.set(step.step_value); + match step.step_type { + StepType::Query => { + let reqmsg = entry2reqmsg(step.entry.as_ref().unwrap()); + let mut req = request.send_request(&reqmsg).await.unwrap(); + resp = Some(req.get_response().await.unwrap()); + } + StepType::CheckAnswer => { + let answer = resp.take().unwrap(); + if !match_msg(step.entry.as_ref().unwrap(), &answer, true) { + panic!("reply failed"); + } + } + StepType::TimePasses + | StepType::Traffic + | StepType::CheckTempfile + | StepType::Assign => todo!(), + } + } + println!("Done"); +} + +fn entry2reqmsg(entry: &Entry) -> RequestMessage> { + let sections = entry.sections.as_ref().unwrap(); + let mut msg = MessageBuilder::new_vec().question(); + for q in §ions.question { + let question = match q { + parse_query::Entry::QueryRecord(question) => question, + _ => todo!(), + }; + msg.push(question).unwrap(); + } + let msg = msg.answer(); + for _a in §ions.answer { + todo!(); + } + let msg = msg.authority(); + for _a in §ions.authority { + todo!(); + } + let mut msg = msg.additional(); + for _a in §ions.additional { + todo!(); + } + let reply: Reply = match &entry.reply { + Some(reply) => reply.clone(), + None => Default::default(), + }; + if reply.rd { + msg.header_mut().set_rd(true); + } + let msg = msg.into_message(); + RequestMessage::new(msg) +} + +#[derive(Debug)] +pub struct CurrStepValue { + v: Mutex, +} + +impl CurrStepValue { + pub fn new() -> Self { + Self { v: 0.into() } + } + fn set(&self, v: u64) { + let mut self_v = self.v.lock().unwrap(); + *self_v = v; + } + pub fn get(&self) -> u64 { + *(self.v.lock().unwrap()) + } +} diff --git a/tests/net/deckard/connect.rs b/tests/net/deckard/connect.rs new file mode 100644 index 000000000..287710290 --- /dev/null +++ b/tests/net/deckard/connect.rs @@ -0,0 +1,34 @@ +use crate::net::deckard::client::CurrStepValue; +use crate::net::deckard::connection::Connection; +use crate::net::deckard::parse_deckard::Deckard; +use domain::net::client::protocol::AsyncConnect; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +pub struct Connect { + deckard: Deckard, + step_value: Arc, +} + +impl Connect { + pub fn new(deckard: Deckard, step_value: Arc) -> Connect { + Self { + deckard, + step_value, + } + } +} + +impl AsyncConnect for Connect { + type Connection = Connection; + type Fut = Pin< + Box> + Send>, + >; + + fn connect(&self) -> Self::Fut { + let deckard = self.deckard.clone(); + let step_value = self.step_value.clone(); + Box::pin(async move { Ok(Connection::new(deckard, step_value)) }) + } +} diff --git a/tests/net/deckard/connection.rs b/tests/net/deckard/connection.rs new file mode 100644 index 000000000..ff459141b --- /dev/null +++ b/tests/net/deckard/connection.rs @@ -0,0 +1,102 @@ +use crate::net::deckard::client::CurrStepValue; +use crate::net::deckard::parse_deckard::Deckard; +use crate::net::deckard::server::do_server; +use domain::base::Message; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Context; +use std::task::Poll; +use std::task::Waker; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::io::ReadBuf; + +#[derive(Debug)] +pub struct Connection { + deckard: Deckard, + step_value: Arc, + waker: Option, + reply: Option>>, + send_body: bool, +} + +impl Connection { + pub fn new( + deckard: Deckard, + step_value: Arc, + ) -> Connection { + Self { + deckard, + step_value, + waker: None, + reply: None, + send_body: false, + } + } +} + +impl AsyncRead for Connection { + fn poll_read( + mut self: Pin<&mut Self>, + context: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.reply.is_some() { + let slice = self.reply.as_ref().unwrap().as_slice(); + let len = slice.len(); + if self.send_body { + buf.put_slice(slice); + self.reply = None; + return Poll::Ready(Ok(())); + } else { + buf.put_slice(&(len as u16).to_be_bytes()); + self.send_body = true; + return Poll::Ready(Ok(())); + } + } + self.reply = None; + self.send_body = false; + self.waker = Some(context.waker().clone()); + Poll::Pending + } +} + +impl AsyncWrite for Connection { + fn poll_write( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let buflen = buf.len(); + let mut len_str: [u8; 2] = [0; 2]; + len_str.copy_from_slice(&buf[0..2]); + let len = u16::from_be_bytes(len_str) as usize; + if buflen != 2 + len { + panic!("expecting one complete message per write"); + } + let msg = Message::from_octets(buf[2..].to_vec()).unwrap(); + let opt_reply = do_server(&msg, &self.deckard, &self.step_value); + if opt_reply.is_some() { + // Do we need to support more than one reply? + self.reply = opt_reply; + let opt_waker = self.waker.take(); + if let Some(waker) = opt_waker { + waker.wake(); + } + } + Poll::Ready(Ok(buflen)) + } + fn poll_flush( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + todo!() + } + fn poll_shutdown( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + // Do we need to do anything here? + Poll::Ready(Ok(())) + } +} diff --git a/tests/net/deckard/matches.rs b/tests/net/deckard/matches.rs new file mode 100644 index 000000000..1da791495 --- /dev/null +++ b/tests/net/deckard/matches.rs @@ -0,0 +1,305 @@ +use crate::net::deckard::parse_deckard::{Entry, Matches, Reply}; +use crate::net::deckard::parse_query; +use domain::base::iana::Opcode; +use domain::base::iana::OptRcode; +use domain::base::iana::Rtype; +use domain::base::Message; +use domain::base::ParsedDname; +use domain::base::QuestionSection; +use domain::base::RecordSection; +use domain::dep::octseq::Octets; +use domain::rdata::ZoneRecordData; +use domain::zonefile::inplace::Entry as ZonefileEntry; +//use std::fmt::Debug; + +pub fn match_msg<'a, Octs: AsRef<[u8]> + Clone + Octets + 'a>( + entry: &Entry, + msg: &'a Message, + verbose: bool, +) -> bool +where + ::Range<'a>: Clone, +{ + let sections = entry.sections.as_ref().unwrap(); + + let mut matches: Matches = match &entry.matches { + Some(matches) => matches.clone(), + None => Default::default(), + }; + + let reply: Reply = match &entry.reply { + Some(reply) => reply.clone(), + None => Default::default(), + }; + + if matches.all { + matches.opcode = true; + matches.qtype = true; + matches.qname = true; + matches.flags = true; + matches.rcode = true; + matches.answer = true; + matches.authority = true; + matches.additional = true; + } + + if matches.question { + matches.qtype = true; + matches.qname = true; + } + + if matches.additional { + let mut arcount = msg.header_counts().arcount(); + if msg.opt().is_some() { + arcount -= 1; + } + if !match_section( + sections.additional.clone(), + msg.additional().unwrap(), + arcount, + verbose, + ) { + if verbose { + println!("match_msg: additional section does not match"); + } + return false; + } + } + if matches.answer + && !match_section( + sections.answer.clone(), + msg.answer().unwrap(), + msg.header_counts().ancount(), + verbose, + ) + { + if verbose { + todo!(); + } + return false; + } + if matches.authority + && !match_section( + sections.authority.clone(), + msg.authority().unwrap(), + msg.header_counts().nscount(), + verbose, + ) + { + if verbose { + todo!(); + } + return false; + } + if matches.fl_do { + todo!(); + } + if matches.flags { + let header = msg.header(); + if reply.qr != header.qr() { + if verbose { + todo!(); + } + return false; + } + if reply.aa != header.aa() { + if verbose { + println!( + "match_msg: AA does not match, got {}, expected {}", + header.aa(), + reply.aa + ); + } + return false; + } + if reply.tc != header.tc() { + if verbose { + todo!(); + } + return false; + } + if reply.rd != header.rd() { + if verbose { + println!( + "match_msg: RD does not match, got {}, expected {}", + header.aa(), + reply.aa + ); + } + return false; + } + if reply.ad != header.ad() { + if verbose { + todo!(); + } + return false; + } + if reply.cd != header.cd() { + if verbose { + todo!(); + } + return false; + } + } + if matches.opcode { + // Not clear what that means. JUst check if it is Query + if msg.header().opcode() != Opcode::Query { + if verbose { + todo!(); + } + return false; + } + } + if (matches.qname || matches.qtype) + && !match_question( + sections.question.clone(), + msg.question(), + matches.qname, + matches.qtype, + ) + { + if verbose { + println!("match_msg: question section does not match"); + } + return false; + } + if matches.rcode { + let msg_rcode = + get_opt_rcode(&Message::from_octets(msg.as_slice()).unwrap()); + if reply.noerror { + if let OptRcode::NoError = msg_rcode { + // Okay + } else { + if verbose { + todo!(); + } + return false; + } + } else { + println!("reply {reply:?}"); + panic!("no rcode to match?"); + } + } + if matches.subdomain { + todo!() + } + if matches.tcp { + todo!() + } + if matches.ttl { + todo!() + } + if matches.udp { + todo!() + } + + // All checks passed! + true +} + +fn match_section< + 'a, + Octs: Clone + Octets = Octs2> + 'a, + Octs2: AsRef<[u8]> + Clone, +>( + mut match_section: Vec, + msg_section: RecordSection<'a, Octs>, + msg_count: u16, + verbose: bool, +) -> bool { + if match_section.len() != msg_count.into() { + if verbose { + todo!(); + } + return false; + } + 'outer: for msg_rr in msg_section { + let msg_rr = msg_rr.unwrap(); + if msg_rr.rtype() == Rtype::Opt { + continue; + } + for (index, mat_rr) in match_section.iter().enumerate() { + // Remove outer Record + let mat_rr = if let ZonefileEntry::Record(record) = mat_rr { + record + } else { + panic!("include not expected"); + }; + if msg_rr.owner() != mat_rr.owner() { + continue; + } + if msg_rr.class() != mat_rr.class() { + continue; + } + if msg_rr.rtype() != mat_rr.rtype() { + continue; + } + let msg_rdata = msg_rr + .clone() + .into_record::>>() + .unwrap() + .unwrap(); + if msg_rdata.data() != mat_rr.data() { + continue; + } + + // Found one. Delete this entry + match_section.swap_remove(index); + continue 'outer; + } + // Nothing matches + if verbose { + println!( + "no match for record {} {} {}", + msg_rr.owner(), + msg_rr.class(), + msg_rr.rtype() + ); + } + return false; + } + // All entries in the reply were matched. + true +} + +fn match_question( + match_section: Vec, + msg_section: QuestionSection<'_, Octs>, + match_qname: bool, + match_qtype: bool, +) -> bool { + if match_section.is_empty() { + // Nothing to match. + return true; + } + for msg_rr in msg_section { + let msg_rr = msg_rr.unwrap(); + let mat_rr = if let parse_query::Entry::QueryRecord(record) = + &match_section[0] + { + record + } else { + panic!("include not expected"); + }; + if match_qname && msg_rr.qname() != mat_rr.qname() { + return false; + } + if match_qtype && msg_rr.qtype() != mat_rr.qtype() { + return false; + } + } + // All entries in the reply were matched. + true +} + +fn get_opt_rcode(msg: &Message) -> OptRcode { + let opt = msg.opt(); + match opt { + Some(opt) => opt.rcode(msg.header()), + None => { + // Convert Rcode to OptRcode, this should be part of + // OptRcode + OptRcode::from_int(msg.header().rcode().to_int() as u16) + } + } +} diff --git a/tests/net/deckard/mod.rs b/tests/net/deckard/mod.rs new file mode 100644 index 000000000..c3eb548f1 --- /dev/null +++ b/tests/net/deckard/mod.rs @@ -0,0 +1,7 @@ +pub mod client; +pub mod connect; +pub mod connection; +mod matches; +pub mod parse_deckard; +mod parse_query; +mod server; diff --git a/tests/net/deckard/parse_deckard.rs b/tests/net/deckard/parse_deckard.rs new file mode 100644 index 000000000..b7fe2fb54 --- /dev/null +++ b/tests/net/deckard/parse_deckard.rs @@ -0,0 +1,609 @@ +use std::default::Default; +use std::fmt::Debug; +use std::io::{self, BufRead, Read}; +use std::net::IpAddr; + +use crate::net::deckard::parse_query; +use crate::net::deckard::parse_query::Zonefile as QueryZonefile; +use domain::zonefile::inplace::Entry as ZonefileEntry; +use domain::zonefile::inplace::Zonefile; + +const CONFIG_END: &str = "CONFIG_END"; +const SCENARIO_BEGIN: &str = "SCENARIO_BEGIN"; +const SCENARIO_END: &str = "SCENARIO_END"; +const RANGE_BEGIN: &str = "RANGE_BEGIN"; +const RANGE_END: &str = "RANGE_END"; +const ADDRESS: &str = "ADDRESS"; +const ENTRY_BEGIN: &str = "ENTRY_BEGIN"; +const ENTRY_END: &str = "ENTRY_END"; +const MATCH: &str = "MATCH"; +const ADJUST: &str = "ADJUST"; +const REPLY: &str = "REPLY"; +const SECTION: &str = "SECTION"; +const QUESTION: &str = "QUESTION"; +const ANSWER: &str = "ANSWER"; +const AUTHORITY: &str = "AUTHORITY"; +const ADDITIONAL: &str = "ADDITIONAL"; +const STEP: &str = "STEP"; +const STEP_TYPE_QUERY: &str = "QUERY"; +const STEP_TYPE_CHECK_ANSWER: &str = "CHECK_ANSWER"; +const STEP_TYPE_TIME_PASSES: &str = "TIME_PASSES"; +const STEP_TYPE_TRAFFIC: &str = "TRAFFIC"; +const STEP_TYPE_CHECK_TEMPFILE: &str = "CHECK_TEMPFILE"; +const STEP_TYPE_ASSIGN: &str = "ASSIGN"; + +enum Section { + Question, + Answer, + Authority, + Additional, +} + +#[derive(Clone, Debug)] +pub enum StepType { + Query, + CheckAnswer, + TimePasses, + Traffic, + CheckTempfile, + Assign, +} + +#[derive(Clone, Debug, Default)] +pub struct Config { + lines: Vec, +} + +#[derive(Clone, Debug)] +pub struct Deckard { + pub config: Config, + pub scenario: Scenario, +} + +pub fn parse_file(file: F) -> Deckard { + let mut lines = io::BufReader::new(file).lines(); + Deckard { + config: parse_config(&mut lines), + scenario: parse_scenario(&mut lines), + } +} + +fn parse_config>>( + l: &mut Lines, +) -> Config { + let mut config: Config = Default::default(); + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + if clean_line == CONFIG_END { + break; + } + config.lines.push(clean_line.to_string()); + } + config +} + +#[derive(Clone, Debug, Default)] +pub struct Scenario { + pub ranges: Vec, + pub steps: Vec, +} + +pub fn parse_scenario< + Lines: Iterator>, +>( + l: &mut Lines, +) -> Scenario { + let mut scenario: Scenario = Default::default(); + // Find SCENARIO_BEGIN + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == SCENARIO_BEGIN { + break; + } + println!("parse_scenario: garbage line {clean_line:?}"); + panic!("bad line"); + } + + // Find RANGE_BEGIN, STEP, or SCENARIO_END + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == RANGE_BEGIN { + scenario.ranges.push(parse_range(tokens, l)); + continue; + } + if token == STEP { + scenario.steps.push(parse_step(tokens, l)); + continue; + } + if token == SCENARIO_END { + break; + } + todo!(); + } + scenario +} + +#[derive(Clone, Debug, Default)] +pub struct Range { + pub start_value: u64, + pub end_value: u64, + addr: Option, + pub entry: Vec, +} + +fn parse_range>>( + mut tokens: LineTokens<'_>, + l: &mut Lines, +) -> Range { + let mut range: Range = Range { + start_value: tokens.next().unwrap().parse::().unwrap(), + end_value: tokens.next().unwrap().parse::().unwrap(), + ..Default::default() + }; + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == ADDRESS { + let addr_str = tokens.next().unwrap(); + range.addr = Some(addr_str.parse().unwrap()); + continue; + } + if token == ENTRY_BEGIN { + range.entry.push(parse_entry(l)); + continue; + } + if token == RANGE_END { + break; + } + todo!(); + } + //println!("parse_range: {:?}", range); + range +} + +#[derive(Clone, Debug)] +pub struct Step { + pub step_value: u64, + pub step_type: StepType, + pub entry: Option, +} + +fn parse_step>>( + mut tokens: LineTokens<'_>, + l: &mut Lines, +) -> Step { + let step_value = tokens.next().unwrap().parse::().unwrap(); + let step_type_str = tokens.next().unwrap(); + let step_type = if step_type_str == STEP_TYPE_QUERY { + StepType::Query + } else if step_type_str == STEP_TYPE_CHECK_ANSWER { + StepType::CheckAnswer + } else if step_type_str == STEP_TYPE_TIME_PASSES { + StepType::TimePasses + } else if step_type_str == STEP_TYPE_TRAFFIC { + StepType::Traffic + } else if step_type_str == STEP_TYPE_CHECK_TEMPFILE { + StepType::CheckTempfile + } else if step_type_str == STEP_TYPE_ASSIGN { + StepType::Assign + } else { + todo!(); + }; + let mut step = Step { + step_value, + step_type, + entry: None, + }; + + match step.step_type { + StepType::Query => (), // Continue with entry + StepType::CheckAnswer => (), // Continue with entry + StepType::TimePasses => { + println!("parse_step: should handle TIME_PASSES"); + return step; + } + StepType::Traffic => { + println!("parse_step: should handle TRAFFIC"); + return step; + } + StepType::CheckTempfile => { + println!("parse_step: should handle CHECK_TEMPFILE"); + return step; + } + StepType::Assign => { + println!("parse_step: should handle ASSIGN"); + return step; + } + } + + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == ENTRY_BEGIN { + step.entry = Some(parse_entry(l)); + //println!("parse_step: {:?}", step); + return step; + } + todo!(); + } +} + +#[derive(Clone, Debug, Default)] +pub struct Entry { + pub matches: Option, + pub adjust: Option, + pub reply: Option, + pub sections: Option, +} + +fn parse_entry>>( + l: &mut Lines, +) -> Entry { + let mut entry = Entry { + matches: None, + adjust: None, + reply: None, + sections: None, + }; + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == MATCH { + entry.matches = Some(parse_match(tokens)); + continue; + } + if token == ADJUST { + entry.adjust = Some(parse_adjust(tokens)); + continue; + } + if token == REPLY { + entry.reply = Some(parse_reply(tokens)); + continue; + } + if token == SECTION { + let (sections, line) = parse_section(tokens, l); + //println!("parse_entry: sections {:?}", sections); + entry.sections = Some(sections); + let clean_line = get_clean_line(line.as_ref()); + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == ENTRY_END { + break; + } + todo!(); + } + if token == ENTRY_END { + break; + } + todo!(); + } + entry +} + +#[derive(Clone, Debug)] +pub struct Sections { + pub question: Vec, + pub answer: Vec, + pub authority: Vec, + pub additional: Vec, +} + +fn parse_section>>( + mut tokens: LineTokens<'_>, + l: &mut Lines, +) -> (Sections, String) { + let mut sections = Sections { + question: Vec::new(), + answer: Vec::new(), + authority: Vec::new(), + additional: Vec::new(), + }; + let next = tokens.next().unwrap(); + let mut section = if next == QUESTION { + Section::Question + } else { + panic!("Bad section {next}"); + }; + // Should extract which section + loop { + let line = l.next().unwrap().unwrap(); + let clean_line = get_clean_line(line.as_ref()); + if clean_line.is_none() { + continue; + } + let clean_line = clean_line.unwrap(); + let mut tokens = LineTokens::new(clean_line); + let token = tokens.next().unwrap(); + if token == SECTION { + let next = tokens.next().unwrap(); + section = if next == QUESTION { + Section::Question + } else if next == ANSWER { + Section::Answer + } else if next == AUTHORITY { + Section::Authority + } else if next == ADDITIONAL { + Section::Additional + } else { + panic!("Bad section {next}"); + }; + continue; + } + if token == ENTRY_END { + return (sections, line); + } + + match section { + Section::Question => { + let mut zonefile = QueryZonefile::new(); + zonefile.extend_from_slice(clean_line.as_ref()); + zonefile.extend_from_slice(b"\n"); + let e = zonefile.next_entry().unwrap(); + sections.question.push(e.unwrap()); + } + Section::Answer | Section::Authority | Section::Additional => { + let mut zonefile = Zonefile::new(); + zonefile.extend_from_slice(b"$ORIGIN .\n"); + zonefile.extend_from_slice(b"ignore 3600 in ns ignore\n"); + zonefile.extend_from_slice(clean_line.as_ref()); + zonefile.extend_from_slice(b"\n"); + let _e = zonefile.next_entry().unwrap(); + let e = zonefile.next_entry().unwrap(); + + let e = e.unwrap(); + match section { + Section::Question => panic!("should not be here"), + Section::Answer => sections.answer.push(e), + Section::Authority => sections.authority.push(e), + Section::Additional => sections.additional.push(e), + } + } + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct Matches { + pub additional: bool, + pub all: bool, + pub answer: bool, + pub authority: bool, + pub fl_do: bool, + pub flags: bool, + pub opcode: bool, + pub qname: bool, + pub qtype: bool, + pub question: bool, + pub rcode: bool, + pub subdomain: bool, + pub tcp: bool, + pub ttl: bool, + pub udp: bool, +} + +fn parse_match(mut tokens: LineTokens<'_>) -> Matches { + let mut matches: Matches = Default::default(); + + loop { + let token = match tokens.next() { + None => return matches, + Some(token) => token, + }; + + if token == "all" { + matches.all = true; + } else if token == "DO" { + matches.fl_do = true; + } else if token == "opcode" { + matches.opcode = true; + } else if token == "qname" { + matches.qname = true; + } else if token == "question" { + matches.question = true; + } else if token == "qtype" { + matches.qtype = true; + } else if token == "subdomain" { + matches.subdomain = true; + } else if token == "TCP" { + matches.tcp = true; + } else if token == "ttl" { + matches.ttl = true; + } else if token == "UDP" { + matches.tcp = true; + } else { + println!("should handle match {token:?}"); + todo!(); + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct Adjust { + pub copy_id: bool, + pub copy_query: bool, +} + +fn parse_adjust(mut tokens: LineTokens<'_>) -> Adjust { + let mut adjust: Adjust = Default::default(); + + loop { + let token = match tokens.next() { + None => return adjust, + Some(token) => token, + }; + + if token == "copy_id" { + adjust.copy_id = true; + } else if token == "copy_query" { + adjust.copy_query = true; + } else { + println!("should handle adjust {token:?}"); + todo!(); + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct Reply { + pub aa: bool, + pub ad: bool, + pub cd: bool, + pub fl_do: bool, + pub formerr: bool, + pub noerror: bool, + pub nxdomain: bool, + pub qr: bool, + pub ra: bool, + pub rd: bool, + pub refused: bool, + pub servfail: bool, + pub tc: bool, + pub yxdomain: bool, +} + +fn parse_reply(mut tokens: LineTokens<'_>) -> Reply { + let mut reply: Reply = Default::default(); + + loop { + let token = match tokens.next() { + None => return reply, + Some(token) => token, + }; + + if token == "AA" { + reply.aa = true; + } else if token == "AD" { + reply.ad = true; + } else if token == "CD" { + reply.cd = true; + } else if token == "DO" { + reply.fl_do = true; + } else if token == "FORMERR" { + reply.formerr = true; + } else if token == "NOERROR" { + reply.noerror = true; + } else if token == "NXDOMAIN" { + reply.nxdomain = true; + } else if token == "QR" { + reply.qr = true; + } else if token == "RA" { + reply.ra = true; + } else if token == "RD" { + reply.rd = true; + } else if token == "REFUSED" { + reply.refused = true; + } else if token == "SERVFAIL" { + reply.servfail = true; + } else if token == "TC" { + reply.tc = true; + } else if token == "YXDOMAIN" { + reply.yxdomain = true; + } else { + println!("should handle reply {token:?}"); + todo!(); + } + } +} + +fn get_clean_line(line: &str) -> Option<&str> { + //println!("get clean line for {:?}", line); + let opt_comment = line.find(';'); + let line = if let Some(index) = opt_comment { + &line[0..index] + } else { + line + }; + let trimmed = line.trim(); + + //println!("line after trim() {:?}", trimmed); + + if trimmed.is_empty() { + None + } else { + Some(trimmed) + } +} + +struct LineTokens<'a> { + str: &'a str, + curr_index: usize, +} + +impl<'a> LineTokens<'a> { + fn new(str: &'a str) -> Self { + Self { str, curr_index: 0 } + } +} + +impl<'a> Iterator for LineTokens<'a> { + type Item = &'a str; + fn next(&mut self) -> Option { + let cur_str = &self.str[self.curr_index..]; + + if cur_str.is_empty() { + return None; + } + + // Assume cur_str starts with a token + for (index, char) in cur_str.char_indices() { + if !char.is_whitespace() { + continue; + } + let start_index = self.curr_index; + let end_index = start_index + index; + + let space_str = &self.str[end_index..]; + + for (index, char) in space_str.char_indices() { + if char.is_whitespace() { + continue; + } + + self.curr_index = end_index + index; + return Some(&self.str[start_index..end_index]); + } + + todo!(); + } + self.curr_index = self.str.len(); + Some(cur_str) + } +} diff --git a/tests/net/deckard/parse_query.rs b/tests/net/deckard/parse_query.rs new file mode 100644 index 000000000..6bf22e4f2 --- /dev/null +++ b/tests/net/deckard/parse_query.rs @@ -0,0 +1,1501 @@ +//! A zonefile scanner keeping data in place. +//! +//! The zonefile scanner provided by this module reads the entire zonefile +//! into memory and tries as much as possible to modify re-use this memory +//! when scanning data. It uses the `Bytes` family of types for safely +//! storing, manipulating, and returning the data and thus requires the +//! `bytes` feature to be enabled. +//! +//! This may or may not be a good strategy. It was primarily implemented to +//! see that the [`Scan`] trait is powerful enough to build such an +//! implementation. +// #![cfg(feature = "bytes")] +// #![cfg_attr(docsrs, doc(cfg(feature = "bytes")))] + +use bytes::buf::UninitSlice; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use core::str::FromStr; +use core::{fmt, str}; +use domain::base::charstr::CharStr; +use domain::base::iana::{Class, Rtype}; +use domain::base::name::{Chain, Dname, RelativeDname, ToDname}; +use domain::base::scan::{ + BadSymbol, ConvertSymbols, EntrySymbol, Scan, Scanner, ScannerError, + Symbol, SymbolOctetsError, +}; +use domain::base::Question; +use domain::base::Ttl; +use domain::dep::octseq::str::Str; + +//------------ Type Aliases -------------------------------------------------- + +/// The type used for scanned domain names. +pub type ScannedDname = Chain, Dname>; + +/// The type used for scanned records. + +pub type ScannedQueryRecord = Question; + +/// The type used for scanned strings. +pub type ScannedString = Str; + +//------------ Zonefile ------------------------------------------------------ + +/// A zonefile to be scanned. +/// +/// A value of this types holds data to be scanned in memory and allows +/// fetching entries by acting as an iterator. +/// +/// The type implements the `bytes::BufMut` trait for appending data directly +/// into the memory buffer. The function [`load`][Self::load] can be used to +/// create a value directly from a reader. +/// +/// Once data has been added, you can simply iterate over the value to +/// get entries. The [`next_entry`][Self::next_entry] method provides an +/// alternative with a more question mark friendly signature. +#[derive(Clone, Debug)] +pub struct Zonefile { + /// This is where we keep the data of the next entry. + buf: SourceBuf, + + /// The current origin. + origin: Option>, + + /// The last owner. + last_owner: Option, + + /// The last TTL. + last_ttl: Option, + + /// The last class. + last_class: Option, +} + +impl Zonefile { + /// Creates a new, empty value. + pub fn new() -> Self { + Self::with_buf(SourceBuf::with_empty_buf(BytesMut::new())) + } + + /// Creates a new, empty value with the given capacity. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_buf(SourceBuf::with_empty_buf(BytesMut::with_capacity( + capacity + 1, + ))) + } + + /// Creates a new value using the given buffer. + fn with_buf(buf: SourceBuf) -> Self { + Zonefile { + buf, + origin: Some(Dname::root_bytes()), + last_owner: None, + last_ttl: Some(Ttl::ZERO), + last_class: None, + } + } +} + +impl Default for Zonefile { + fn default() -> Self { + Self::new() + } +} + +impl<'a> From<&'a str> for Zonefile { + fn from(src: &'a str) -> Self { + Self::from(src.as_bytes()) + } +} + +impl<'a> From<&'a [u8]> for Zonefile { + fn from(src: &'a [u8]) -> Self { + let mut res = Self::with_capacity(src.len() + 1); + res.extend_from_slice(src); + res + } +} + +impl Zonefile { + /// Appends the given slice to the end of the buffer. + pub fn extend_from_slice(&mut self, slice: &[u8]) { + self.buf.buf.extend_from_slice(slice) + } +} + +unsafe impl BufMut for Zonefile { + fn remaining_mut(&self) -> usize { + self.buf.buf.remaining_mut() + } + + unsafe fn advance_mut(&mut self, cnt: usize) { + self.buf.buf.advance_mut(cnt); + } + + fn chunk_mut(&mut self) -> &mut UninitSlice { + self.buf.buf.chunk_mut() + } +} + +impl Zonefile { + /// Returns the next entry in the zonefile. + /// + /// Returns `Ok(None)` if the end of the file has been reached. Returns + /// an error if scanning the next entry failed. + /// + /// This method is identical to the `next` method of the iterator + /// implementation but has the return type transposed for easier use + /// with the question mark operator. + pub fn next_entry(&mut self) -> Result, Error> { + loop { + match EntryScanner::new(self)?.scan_entry()? { + ScannedEntry::Entry(entry) => return Ok(Some(entry)), + ScannedEntry::Origin(origin) => self.origin = Some(origin), + ScannedEntry::Ttl(ttl) => self.last_ttl = Some(ttl), + ScannedEntry::Empty => {} + ScannedEntry::Eof => return Ok(None), + } + } + } + + /// Returns the origin name of the zonefile. + fn get_origin(&self) -> Result, EntryError> { + self.origin + .as_ref() + .cloned() + .ok_or_else(EntryError::missing_origin) + } +} + +impl Iterator for Zonefile { + type Item = Result; + + fn next(&mut self) -> Option { + self.next_entry().transpose() + } +} + +//------------ Entry --------------------------------------------------------- + +/// An entry of a zonefile. +#[derive(Clone, Debug)] +pub enum Entry { + /// A DNS record. + QueryRecord(ScannedQueryRecord), + + /// An include directive. + /// + /// When this entry is encountered, the referenced file should be scanned + /// next. If `origin` is given, this file should be scanned with it as the + /// initial origin name, + Include { + /// The path to the file to be included. + path: ScannedString, + + /// The initial origin name of the included file, if provided. + origin: Option>, + }, +} + +//------------ ScannedEntry -------------------------------------------------- + +/// A raw scanned entry of a zonefile. +/// +/// This includes all the entry types that we can handle internally and don’t +/// have to bubble up to the user. +#[derive(Clone, Debug)] +#[allow(clippy::large_enum_variant)] +enum ScannedEntry { + /// An entry that should be handed to the user. + Entry(Entry), + + /// An `$ORIGIN` directive changing the origin name. + Origin(Dname), + + /// A `$TTL` directive changing the default TTL if it isn’t given. + Ttl(Ttl), + + /// An empty entry. + Empty, + + /// The end of file was reached. + Eof, +} + +//------------ EntryScanner -------------------------------------------------- + +/// The entry scanner for a zonefile. +/// +/// A value of this type is created for each entry. It implements the +/// [`Scanner`] interface. +#[derive(Debug)] +struct EntryScanner<'a> { + /// The zonefile we are working on. + zonefile: &'a mut Zonefile, +} + +impl<'a> EntryScanner<'a> { + /// Creates a new entry scanner using the given zonefile. + fn new(zonefile: &'a mut Zonefile) -> Result { + Ok(EntryScanner { zonefile }) + } + + /// Scans a single entry from the zone file. + fn scan_entry(&mut self) -> Result { + self._scan_entry() + .map_err(|err| self.zonefile.buf.error(err)) + } + + /// Scans a single entry from the zone file. + /// + /// This is identical to `scan_entry` but with a more convenient error + /// type. + fn _scan_entry(&mut self) -> Result { + self.zonefile.buf.next_item()?; + match self.zonefile.buf.cat { + ItemCat::None => Ok(ScannedEntry::Eof), + ItemCat::LineFeed => Ok(ScannedEntry::Empty), + ItemCat::Unquoted | ItemCat::Quoted => { + if self.zonefile.buf.has_space { + // Indented entry: a record with the last owner as the + // owner. + self.scan_owner_record( + match self.zonefile.last_owner.as_ref() { + Some(owner) => owner.clone(), + None => { + return Err(EntryError::missing_last_owner()) + } + }, + false, + ) + } else if self.zonefile.buf.peek_symbol() + == Some(Symbol::Char('$')) + { + self.scan_control() + } else if self.zonefile.buf.skip_at_token()? { + self.scan_at_record() + } else { + self.scan_record() + } + } + } + } + + /// Scans a regular record. + fn scan_record(&mut self) -> Result { + let owner = ScannedDname::scan(self)?; + self.scan_owner_record(owner, true) + } + + /// Scans a regular record with an owner name of `@`. + fn scan_at_record(&mut self) -> Result { + let owner = RelativeDname::empty_bytes() + .chain(match self.zonefile.origin.as_ref().cloned() { + Some(origin) => origin, + None => return Err(EntryError::missing_origin()), + }) + .unwrap(); // Chaining an empty name will always work. + self.scan_owner_record(owner, true) + } + + /// Scans a regular record with an explicit owner name. + fn scan_owner_record( + &mut self, + owner: ScannedDname, + new_owner: bool, + ) -> Result { + let (class, qtype) = self.scan_qcr()?; + + if new_owner { + self.zonefile.last_owner = Some(owner.clone()); + } + + let class = match class { + Some(class) => { + self.zonefile.last_class = Some(class); + class + } + None => match self.zonefile.last_class { + Some(class) => class, + None => return Err(EntryError::missing_last_class()), + }, + }; + + self.zonefile.buf.require_line_feed()?; + + Ok(ScannedEntry::Entry(Entry::QueryRecord(Question::new( + owner, qtype, class, + )))) + } + + /// Scans the class, and type portions of a query record. + fn scan_qcr(&mut self) -> Result<(Option, Rtype), EntryError> { + // Possible options are: + // + // [] [] + // [] [] + + enum Ctr { + Class(Class), + Qtype(Rtype), + } + + let first = self.scan_ascii_str(|s| { + if let Ok(qtype) = Rtype::from_str(s) { + Ok(Ctr::Qtype(qtype)) + } else if let Ok(class) = Class::from_str(s) { + Ok(Ctr::Class(class)) + } else { + Err(EntryError::expected_qtype()) + } + })?; + + match first { + Ctr::Class(class) => { + // We have a class. Now there may be a qtype. + let qtype = self.scan_ascii_str(|s| { + if let Ok(qtype) = Rtype::from_str(s) { + Ok(qtype) + } else { + Err(EntryError::expected_qtype()) + } + })?; + + Ok((Some(class), qtype)) + } + Ctr::Qtype(qtype) => Ok((None, qtype)), + } + } + + /// Scans a control directive. + fn scan_control(&mut self) -> Result { + let ctrl = self.scan_string()?; + if ctrl.eq_ignore_ascii_case("$ORIGIN") { + let origin = self.scan_dname()?.to_dname().unwrap(); + self.zonefile.buf.require_line_feed()?; + Ok(ScannedEntry::Origin(origin)) + } else if ctrl.eq_ignore_ascii_case("$INCLUDE") { + let path = self.scan_string()?; + let origin = if !self.zonefile.buf.is_line_feed() { + Some(self.scan_dname()?.to_dname().unwrap()) + } else { + None + }; + self.zonefile.buf.require_line_feed()?; + Ok(ScannedEntry::Entry(Entry::Include { path, origin })) + } else if ctrl.eq_ignore_ascii_case("$TTL") { + let ttl = u32::scan(self)?; + self.zonefile.buf.require_line_feed()?; + Ok(ScannedEntry::Ttl(Ttl::from_secs(ttl))) + } else { + Err(EntryError::unknown_control()) + } + } +} + +impl<'a> Scanner for EntryScanner<'a> { + type Octets = Bytes; + type OctetsBuilder = BytesMut; + type Dname = ScannedDname; + type Error = EntryError; + + fn has_space(&self) -> bool { + self.zonefile.buf.has_space + } + + fn continues(&mut self) -> bool { + !matches!(self.zonefile.buf.cat, ItemCat::None | ItemCat::LineFeed) + } + + fn scan_symbols(&mut self, mut op: F) -> Result<(), Self::Error> + where + F: FnMut(Symbol) -> Result<(), Self::Error>, + { + self.zonefile.buf.require_token()?; + while let Some(sym) = self.zonefile.buf.next_symbol()? { + op(sym)?; + } + self.zonefile.buf.next_item() + } + + fn scan_entry_symbols(&mut self, mut op: F) -> Result<(), Self::Error> + where + F: FnMut(EntrySymbol) -> Result<(), Self::Error>, + { + loop { + self.zonefile.buf.require_token()?; + while let Some(sym) = self.zonefile.buf.next_symbol()? { + op(sym.into())?; + } + op(EntrySymbol::EndOfToken)?; + self.zonefile.buf.next_item()?; + if self.zonefile.buf.is_line_feed() { + break; + } + } + Ok(()) + } + + fn convert_token>( + &mut self, + mut convert: C, + ) -> Result { + let mut write = 0; + let mut builder = None; + self.convert_one_token(&mut convert, &mut write, &mut builder)?; + if let Some(data) = convert.process_tail()? { + self.append_data(data, &mut write, &mut builder); + } + match builder { + Some(builder) => Ok(builder.freeze()), + None => Ok(self.zonefile.buf.split_to(write).freeze()), + } + } + + fn convert_entry>( + &mut self, + mut convert: C, + ) -> Result { + let mut write = 0; + let mut builder = None; + loop { + self.convert_one_token(&mut convert, &mut write, &mut builder)?; + if self.zonefile.buf.is_line_feed() { + break; + } + } + if let Some(data) = convert.process_tail()? { + self.append_data(data, &mut write, &mut builder); + } + match builder { + Some(builder) => Ok(builder.freeze()), + None => Ok(self.zonefile.buf.split_to(write).freeze()), + } + } + + fn scan_octets(&mut self) -> Result { + self.zonefile.buf.require_token()?; + + // The result will never be longer than the encoded form, so we can + // trim off everything to the left already. + self.zonefile.buf.trim_to(self.zonefile.buf.start); + + // Skip over symbols that don’t need converting at the beginning. + while self.zonefile.buf.next_ascii_symbol()?.is_some() {} + + // If we aren’t done yet, we have escaped characters to replace. + let mut write = self.zonefile.buf.start; + while let Some(sym) = self.zonefile.buf.next_symbol()? { + self.zonefile.buf.buf[write] = sym.into_octet()?; + write += 1; + } + + // Done. `write` marks the end. + self.zonefile.buf.next_item()?; + Ok(self.zonefile.buf.split_to(write).freeze()) + } + + fn scan_ascii_str(&mut self, op: F) -> Result + where + F: FnOnce(&str) -> Result, + { + self.zonefile.buf.require_token()?; + + // The result will never be longer than the encoded form, so we can + // trim off everything to the left already. + self.zonefile.buf.trim_to(self.zonefile.buf.start); + let mut write = 0; + + // Skip over symbols that don’t need converting at the beginning. + while self.zonefile.buf.next_ascii_symbol()?.is_some() { + write += 1; + } + + // If we not reached the end of the token, we have escaped characters + // to replace. + if !matches!(self.zonefile.buf.cat, ItemCat::None) { + while let Some(sym) = self.zonefile.buf.next_symbol()? { + self.zonefile.buf.buf[write] = sym.into_ascii()?; + write += 1; + } + } + + // Done. `write` marks the end. Process via op and return. + let res = op(unsafe { + str::from_utf8_unchecked(&self.zonefile.buf.buf[..write]) + })?; + self.zonefile.buf.next_item()?; + Ok(res) + } + + fn scan_dname(&mut self) -> Result { + // Because the labels in a domain name have their content preceeded + // by the length octet, an unescaped domain name can be almost as is + // if we have one extra octet to the left. Luckily, we always do + // (SourceBuf makes sure of it). + self.zonefile.buf.require_token()?; + + // Let’s prepare everything. We cut off the bits we don’t need with + // the result that the buffer’s start will be 1 and we set `write` + // to be 0, i.e., the start of the buffer. This also means that write + // will contain the length of the domain name assembled so far, so we + // can easily check if it has gotten too long. + assert!(self.zonefile.buf.start > 0, "missing token prefix space"); + self.zonefile.buf.trim_to(self.zonefile.buf.start - 1); + let mut write = 0; + + // Now convert label by label. + loop { + let start = write; + match self.convert_label(&mut write)? { + None => { + // End of token right after a dot, so this is an absolute + // name. Unless we have not done anything yet, then we + // have an empty domain name which is just the origin. + self.zonefile.buf.next_item()?; + if start == 0 { + return RelativeDname::empty_bytes() + .chain(self.zonefile.get_origin()?) + .map_err(|_| EntryError::bad_dname()); + } else { + return unsafe { + RelativeDname::from_octets_unchecked( + self.zonefile.buf.split_to(write).freeze(), + ) + .chain(Dname::root()) + .map_err(|_| EntryError::bad_dname()) + }; + } + } + Some(true) => { + // Last symbol was a dot: check length and continue. + if write > 254 { + return Err(EntryError::bad_dname()); + } + } + Some(false) => { + // Reached end of token. This means we have a relative + // dname. + self.zonefile.buf.next_item()?; + return unsafe { + RelativeDname::from_octets_unchecked( + self.zonefile.buf.split_to(write).freeze(), + ) + .chain(self.zonefile.get_origin()?) + .map_err(|_| EntryError::bad_dname()) + }; + } + } + } + } + + fn scan_charstr(&mut self) -> Result, Self::Error> { + self.scan_octets().and_then(|octets| { + CharStr::from_octets(octets) + .map_err(|_| EntryError::bad_charstr()) + }) + } + + fn scan_string(&mut self) -> Result, Self::Error> { + self.zonefile.buf.require_token()?; + + // The result will never be longer than the encoded form, so we can + // trim off everything to the left already. + self.zonefile.buf.trim_to(self.zonefile.buf.start); + + // Skip over symbols that don’t need converting at the beginning. + while self.zonefile.buf.next_char_symbol()?.is_some() {} + + // If we aren’t done yet, we have escaped characters to replace. + let mut write = self.zonefile.buf.start; + while let Some(sym) = self.zonefile.buf.next_symbol()? { + write += sym + .into_char()? + .encode_utf8( + &mut self.zonefile.buf.buf + [write..self.zonefile.buf.start], + ) + .len(); + } + + // Done. `write` marks the end. + self.zonefile.buf.next_item()?; + Ok(unsafe { + Str::from_utf8_unchecked( + self.zonefile.buf.split_to(write).freeze(), + ) + }) + } + + fn scan_charstr_entry(&mut self) -> Result { + // Because char-strings are never longer than their representation + // format, we can definitely do this in place. Specifically, we move + // the content around in such a way that by the end we have the result + // in the space of buf before buf.start. + + // Reminder: char-string are one length byte followed by that many + // content bytes. We use the byte just before self.read as the length + // byte of the first char-string. This way, if there is only one and + // it isn’t escaped, we don’t need to move anything at all. + + // Let’s prepare everything. We cut off the bits we don’t need with + // the result that the buffer’s start will be 1 and we set `write` + // to be 0, i.e., the start of the buffer. This also means that write + // will contain the length of the domain name assembled so far, so we + // can easily check if it has gotten too long. + assert!(self.zonefile.buf.start > 0, "missing token prefix space"); + self.zonefile.buf.trim_to(self.zonefile.buf.start - 1); + let mut write = 0; + + // Now convert token by token. + loop { + self.convert_charstr(&mut write)?; + if self.zonefile.buf.is_line_feed() { + break; + } + } + + Ok(self.zonefile.buf.split_to(write).freeze()) + } + + fn scan_opt_unknown_marker(&mut self) -> Result { + self.zonefile.buf.skip_unknown_marker() + } + + fn octets_builder(&mut self) -> Result { + Ok(BytesMut::new()) + } +} + +impl<'a> EntryScanner<'a> { + /// Converts a single token using a token converter. + fn convert_one_token< + S: From, + C: ConvertSymbols, + >( + &mut self, + convert: &mut C, + write: &mut usize, + builder: &mut Option, + ) -> Result<(), EntryError> { + self.zonefile.buf.require_token()?; + while let Some(sym) = self.zonefile.buf.next_symbol()? { + if let Some(data) = convert.process_symbol(sym.into())? { + self.append_data(data, write, builder); + } + } + self.zonefile.buf.next_item() + } + + /// Appends output data. + /// + /// If the data fits into the portion of the buffer before the current + /// read positiion, puts it there. Otherwise creates a new builder. If + /// it created a new builder or if one was passed in via `builder`, + /// appends the data to that. + fn append_data( + &mut self, + data: &[u8], + write: &mut usize, + builder: &mut Option, + ) { + if let Some(builder) = builder.as_mut() { + builder.extend_from_slice(data); + return; + } + + let new_write = *write + data.len(); + if new_write > self.zonefile.buf.start { + let mut new_builder = BytesMut::with_capacity(new_write); + new_builder.extend_from_slice(&self.zonefile.buf.buf[..*write]); + new_builder.extend_from_slice(data); + *builder = Some(new_builder); + } else { + self.zonefile.buf.buf[*write..new_write].copy_from_slice(data); + *write = new_write; + } + } + + /// Converts a single label of a domain name. + /// + /// The next symbol of the buffer should be the first symbol of the + /// label’s content. The method reads symbols from the buffer and + /// constructs a single label complete with length octets starting at + /// `write`. + /// + /// If it reaches the end of the token before making a label, returns + /// `None`. Otherwise returns whether it encountered a dot at the end of + /// the label. I.e., `Some(true)` means a dot was read as the last symbol + /// and `Some(false)` means the end of token was encountered right after + /// the label. + fn convert_label( + &mut self, + write: &mut usize, + ) -> Result, EntryError> { + let start = *write; + *write += 1; + let latest = *write + 64; // If write goes here, the label is too long + if *write == self.zonefile.buf.start { + // Reading and writing position is equal, so we don’t need to + // convert char symbols. Read char symbols until the end of label + // or an escape sequence. + loop { + match self.zonefile.buf.next_ascii_symbol()? { + Some(b'.') => { + // We found an unescaped dot, ie., end of label. + // Update the length octet and return. + self.zonefile.buf.buf[start] = + (*write - start - 1) as u8; + return Ok(Some(true)); + } + Some(_) => { + // A char symbol. Just increase the write index. + *write += 1; + if *write >= latest { + return Err(EntryError::bad_dname()); + } + } + None => { + // Either we got an escape sequence or we reached the + // end of the token. Break out of the loop and decide + // below. + break; + } + } + } + } + + // Now we need to process the label with potential escape sequences. + loop { + match self.zonefile.buf.next_symbol()? { + None => { + // We reached the end of the token. + if *write > start + 1 { + self.zonefile.buf.buf[start] = + (*write - start - 1) as u8; + return Ok(Some(false)); + } else { + return Ok(None); + } + } + Some(Symbol::Char('.')) => { + // We found an unescaped dot, ie., end of label. + // Update the length octet and return. + self.zonefile.buf.buf[start] = (*write - start - 1) as u8; + return Ok(Some(true)); + } + Some(sym) => { + // Any other symbol: Decode it and proceed to the next + // route. + self.zonefile.buf.buf[*write] = sym.into_octet()?; + *write += 1; + if *write >= latest { + return Err(EntryError::bad_dname()); + } + } + } + } + } + + /// Converts a character string. + fn convert_charstr( + &mut self, + write: &mut usize, + ) -> Result<(), EntryError> { + let start = *write; + *write += 1; + let latest = *write + 255; // If write goes here, charstr is too long + if *write == self.zonefile.buf.start { + // Reading and writing position is equal, so we don’t need to + // convert char symbols. Read char symbols until the end of label + // or an escape sequence. + while self.zonefile.buf.next_ascii_symbol()?.is_some() { + *write += 1; + if *write >= latest { + return Err(EntryError::bad_charstr()); + } + } + } + + // Now we need to process the charstr with potential escape sequences. + loop { + match self.zonefile.buf.next_symbol()? { + None => { + self.zonefile.buf.next_item()?; + self.zonefile.buf.buf[start] = (*write - start - 1) as u8; + return Ok(()); + } + Some(sym) => { + self.zonefile.buf.buf[*write] = sym.into_octet()?; + *write += 1; + if *write >= latest { + return Err(EntryError::bad_charstr()); + } + } + } + } + } +} + +//------------ SourceBuf ----------------------------------------------------- + +/// The buffer to read data from and also into if possible. +#[derive(Clone, Debug)] +struct SourceBuf { + /// The underlying ‘real’ buffer. + /// + /// This buffer contains the data we still need to process. This contains + /// the white space and other octets just before the start of the next + /// token as well since that can be used as extra space for in-place + /// manipulations. + buf: BytesMut, + + /// Where in `buf` is the next symbol to read. + start: usize, + + /// The category of the current item. + cat: ItemCat, + + /// Is the token preceeded by white space? + has_space: bool, + + /// How many unclosed opening parentheses did we see at `start`? + parens: usize, + + /// The line number of the current line. + line_num: usize, + + /// The position of the first character of the current line. + /// + /// This may be negative if we cut off bits of the current line. + line_start: isize, +} + +impl SourceBuf { + /// Create a new empty buffer. + /// + /// Assumes that `buf` is empty. Adds a single byte to the buffer which + /// we would need for parsing if the first token is a domain name. + fn with_empty_buf(mut buf: BytesMut) -> Self { + buf.put_u8(0); + SourceBuf { + buf, + start: 1, + cat: ItemCat::None, + has_space: false, + parens: 0, + line_num: 1, + line_start: 1, + } + } + + /// Enriches an entry error with position information. + fn error(&self, err: EntryError) -> Error { + Error { + err, + line: self.line_num, + col: ((self.start as isize) + 1 - self.line_start) as usize, + } + } + + /// Checks whether the current item is a token. + fn require_token(&self) -> Result<(), EntryError> { + match self.cat { + ItemCat::None => Err(EntryError::short_buf()), + ItemCat::LineFeed => Err(EntryError::end_of_entry()), + ItemCat::Quoted | ItemCat::Unquoted => Ok(()), + } + } + + /// Returns whether the current item is a line feed. + fn is_line_feed(&self) -> bool { + matches!(self.cat, ItemCat::LineFeed) + } + + /// Requires that we have reached a line feed. + fn require_line_feed(&self) -> Result<(), EntryError> { + if self.is_line_feed() { + Ok(()) + } else { + Err(EntryError::trailing_tokens()) + } + } + + /// Returns the next symbol but doesn’t advance the buffer. + /// + /// Returns `None` if the current item is a line feed or end-of-file + /// or if we have reached the end of token or if it is not a valid symbol. + fn peek_symbol(&self) -> Option { + match self.cat { + ItemCat::None | ItemCat::LineFeed => None, + ItemCat::Unquoted => { + let sym = + match Symbol::from_slice_index(&self.buf, self.start) { + Ok(Some((sym, _))) => sym, + Ok(None) | Err(_) => return None, + }; + + if sym.is_word_char() { + Some(sym) + } else { + None + } + } + ItemCat::Quoted => { + let sym = + match Symbol::from_slice_index(&self.buf, self.start) { + Ok(Some((sym, _))) => sym, + Ok(None) | Err(_) => return None, + }; + + if sym == Symbol::Char('"') { + None + } else { + Some(sym) + } + } + } + } + + /// Skips over the current token if it contains only an `@` symbol. + /// + /// Returns whether it did skip the token. + fn skip_at_token(&mut self) -> Result { + if self.peek_symbol() != Some(Symbol::Char('@')) { + return Ok(false); + } + + let (sym, sym_end) = + match Symbol::from_slice_index(&self.buf, self.start + 1) { + Ok(Some((sym, sym_end))) => (sym, sym_end), + Ok(None) => return Err(EntryError::short_buf()), + Err(err) => return Err(EntryError::bad_symbol(err)), + }; + + match self.cat { + ItemCat::None | ItemCat::LineFeed => unreachable!(), + ItemCat::Unquoted => { + if !sym.is_word_char() { + self.start += 1; + self.cat = ItemCat::None; + self.next_item()?; + Ok(true) + } else { + Ok(false) + } + } + ItemCat::Quoted => { + if sym == Symbol::Char('"') { + self.start = sym_end; + self.cat = ItemCat::None; + self.next_item()?; + Ok(true) + } else { + Ok(false) + } + } + } + } + + /// Skips over the unknown marker token. + /// + /// Returns whether it didskip the token. + fn skip_unknown_marker(&mut self) -> Result { + if !matches!(self.cat, ItemCat::Unquoted) { + return Ok(false); + } + + let (sym, sym_end) = + match Symbol::from_slice_index(&self.buf, self.start) { + Ok(Some(some)) => some, + _ => return Ok(false), + }; + + if sym != Symbol::SimpleEscape(b'#') { + return Ok(false); + } + + let (sym, sym_end) = + match Symbol::from_slice_index(&self.buf, sym_end) { + Ok(Some(some)) => some, + _ => return Ok(false), + }; + if sym.is_word_char() { + return Ok(false); + } + + self.start = sym_end; + self.cat = ItemCat::None; + self.next_item()?; + Ok(true) + } + + /// Returns the next symbol of the current token. + /// + /// Returns `None` if the current item is a line feed or end-of-file + /// or if we have reached the end of token. + /// + /// If it returns `Some(_)`, advances `self.start` to the start of the + /// next symbol. + fn next_symbol(&mut self) -> Result, EntryError> { + self._next_symbol(|sym| Ok(Some(sym))) + } + + /// Returns the next symbol if it is an unescaped ASCII symbol. + /// + /// Returns `None` if the symbol is escaped or not a printable ASCII + /// character or `self.next_symbol` would return `None`. + /// + /// If it returns `Some(_)`, advances `self.start` to the start of the + /// next symbol. + #[allow(clippy::manual_range_contains)] // Hard disagree. + fn next_ascii_symbol(&mut self) -> Result, EntryError> { + if matches!(self.cat, ItemCat::None | ItemCat::LineFeed) { + return Ok(None); + } + + let ch = match self.buf.get(self.start) { + Some(ch) => *ch, + None => return Ok(None), + }; + + match self.cat { + ItemCat::Unquoted => { + if ch < 0x21 + || ch > 0x7F + || ch == b'"' + || ch == b'(' + || ch == b')' + || ch == b';' + || ch == b'\\' + { + return Ok(None); + } + } + ItemCat::Quoted => { + if ch == b'"' { + self.start += 1; + self.cat = ItemCat::None; + return Ok(None); + } else if ch < 0x21 || ch > 0x7F || ch == b'\\' { + return Ok(None); + } + } + _ => unreachable!(), + } + self.start += 1; + Ok(Some(ch)) + } + + /// Returns the next symbol if it is unescaped. + /// + /// Returns `None` if the symbol is escaped or `self.next_symbol` would + /// return `None`. + /// + /// If it returns `Some(_)`, advances `self.start` to the start of the + /// next symbol. + fn next_char_symbol(&mut self) -> Result, EntryError> { + self._next_symbol(|sym| { + if let Symbol::Char(ch) = sym { + Ok(Some(ch)) + } else { + Ok(None) + } + }) + } + + /// Internal helper for `next_symbol` and friends. + /// + /// This only exists so we don’t have to copy and paste the fiddely part + /// of the logic. It behaves like `next_symbol` but provides an option + /// for the called to decide whether they want the symbol or not. + #[inline] + fn _next_symbol(&mut self, want: F) -> Result, EntryError> + where + F: Fn(Symbol) -> Result, EntryError>, + { + match self.cat { + ItemCat::None | ItemCat::LineFeed => Ok(None), + ItemCat::Unquoted => { + let (sym, sym_end) = + match Symbol::from_slice_index(&self.buf, self.start) { + Ok(Some((sym, sym_end))) => (sym, sym_end), + Ok(None) => return Err(EntryError::short_buf()), + Err(err) => return Err(EntryError::bad_symbol(err)), + }; + + if !sym.is_word_char() { + self.cat = ItemCat::None; + Ok(None) + } else { + match want(sym)? { + Some(some) => { + self.start = sym_end; + Ok(Some(some)) + } + None => Ok(None), + } + } + } + ItemCat::Quoted => { + let (sym, sym_end) = + match Symbol::from_slice_index(&self.buf, self.start) { + Ok(Some((sym, sym_end))) => (sym, sym_end), + Ok(None) => return Err(EntryError::short_buf()), + Err(err) => return Err(EntryError::bad_symbol(err)), + }; + + let res = match want(sym)? { + Some(some) => some, + None => return Ok(None), + }; + + if sym == Symbol::Char('"') { + self.start = sym_end; + self.cat = ItemCat::None; + Ok(None) + } else { + self.start = sym_end; + if sym == Symbol::Char('\n') { + self.line_num += 1; + self.line_start = self.start as isize; + } + Ok(Some(res)) + } + } + } + } + + /// Prepares the next item. + /// + /// # Panics + /// + /// This method must only ever by called if the current item is + /// not a token or if the current token has been read all the way to the + /// end. The latter is true if [`Self::next_symbol`] has returned + /// `Ok(None)` at least once. + /// + /// If the current item is a token and has not been read all the way to + /// the end, the method will panic to maintain consistency of the data. + fn next_item(&mut self) -> Result<(), EntryError> { + assert!( + matches!(self.cat, ItemCat::None | ItemCat::LineFeed), + "token not completely read ({:?} at {}:{})", + self.cat, + self.line_num, + ((self.start as isize) + 1 - self.line_start) as usize, + ); + + self.has_space = false; + + loop { + let ch = match self.buf.get(self.start) { + Some(&ch) => ch, + None => { + self.cat = ItemCat::None; + return Ok(()); + } + }; + + // Skip and mark actual white space. + if matches!(ch, b' ' | b'\t' | b'\r') { + self.has_space = true; + self.start += 1; + } + // CR: ignore for compatibility with Windows-style line endings. + else if ch == b'\r' { + self.start += 1; + } + // Opening parenthesis: increase group level. + else if ch == b'(' { + self.parens += 1; + self.start += 1; + } + // Closing parenthesis: decrease group level or error out. + else if ch == b')' { + if self.parens > 0 { + self.parens -= 1; + self.start += 1; + } else { + return Err(EntryError::unbalanced_parens()); + } + } + // Semicolon: comment -- skip to line end. + else if ch == b';' { + self.start += 1; + while let Some(true) = + self.buf.get(self.start).map(|ch| *ch != b'\n') + { + self.start += 1; + } + // Next iteration deals with the LF. + } + // Line end: skip over it. Ignore if we are inside a paren group. + else if ch == b'\n' { + self.start += 1; + self.line_num += 1; + self.line_start = self.start as isize; + if self.parens == 0 { + self.cat = ItemCat::LineFeed; + break; + } + } + // Double quote: quoted token + else if ch == b'"' { + self.start += 1; + self.cat = ItemCat::Quoted; + break; + } + // Else: unquoted token + else { + self.cat = ItemCat::Unquoted; + break; + } + } + Ok(()) + } + + /// Splits off the beginning of the buffer up to the given index. + /// + /// # Panics + /// + /// The method panics if `at` is greater than `self.start`. + fn split_to(&mut self, at: usize) -> BytesMut { + assert!(at <= self.start); + let res = self.buf.split_to(at); + self.start -= at; + self.line_start -= at as isize; + res + } + + /// Splits off the beginning of the buffer but doesn’t return it. + /// + /// # Panics + /// + /// The method panics if `at` is greater than `self.start`. + fn trim_to(&mut self, at: usize) { + assert!(at <= self.start); + self.buf.advance(at); + self.start -= at; + self.line_start -= at as isize; + } +} + +//------------ ItemCat ------------------------------------------------------- + +/// The category of the current item in a source buffer. +#[allow(dead_code)] // XXX +#[derive(Clone, Copy, Debug)] +enum ItemCat { + /// We don’t currently have an item. + /// + /// This is used to indicate that we have reached the end of a token or + /// that we have reached the end of the buffer. + // + // XXX: We might need a separate category for EOF. But let’s see if we + // can get away with mixing this up, first. + None, + + /// An unquoted normal token. + /// + /// This is a token that did not start with a double quote and will end + /// at the next white space. + Unquoted, + + /// A quoted normal token. + /// + /// This is a token that did start with a double quote and will end at + /// the next unescaped double quote. + /// + /// Note that the start position of the buffer indicates the first + /// character that is part of the content, i.e., the position right after + /// the opening double quote. + Quoted, + + /// A line feed. + /// + /// This is an empty token. The start position is right after the actual + /// line feed. + LineFeed, +} + +//------------ EntryError ---------------------------------------------------- + +/// An error returned by the entry scanner. +#[derive(Debug)] +struct EntryError(&'static str); + +impl EntryError { + fn bad_symbol(_err: SymbolOctetsError) -> Self { + EntryError("bad symbol") + } + + fn bad_charstr() -> Self { + EntryError("bad charstr") + } + + fn bad_dname() -> Self { + EntryError("bad dname") + } + + fn unbalanced_parens() -> Self { + EntryError("unbalanced parens") + } + + fn missing_last_owner() -> Self { + EntryError("missing last owner") + } + + fn missing_last_class() -> Self { + EntryError("missing last class") + } + + fn missing_origin() -> Self { + EntryError("missing origin") + } + + fn expected_qtype() -> Self { + EntryError("expected qtype") + } + + fn unknown_control() -> Self { + EntryError("unknown control") + } +} + +impl ScannerError for EntryError { + fn custom(msg: &'static str) -> Self { + EntryError(msg) + } + + fn end_of_entry() -> Self { + Self("unexpected end of entry") + } + + fn short_buf() -> Self { + Self("short buffer") + } + + fn trailing_tokens() -> Self { + Self("trailing tokens") + } +} + +impl From for EntryError { + fn from(_: SymbolOctetsError) -> Self { + EntryError("symbol octets error") + } +} + +impl From for EntryError { + fn from(_: BadSymbol) -> Self { + EntryError("bad symbol") + } +} + +impl fmt::Display for EntryError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(self.0.as_ref()) + } +} + +//#[cfg(feature = "std")] +impl std::error::Error for EntryError {} + +//------------ Error --------------------------------------------------------- + +#[derive(Debug)] +pub struct Error { + err: EntryError, + line: usize, + col: usize, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}:{}: {}", self.line, self.col, self.err) + } +} + +//#[cfg(feature = "std")] +impl std::error::Error for Error {} + +//============ Tests ========================================================= + +/* +#[cfg(test)] +#[cfg(feature = "std")] +mod test { + use super::*; + use std::vec::Vec; + + fn with_entry(s: &str, op: impl FnOnce(EntryScanner)) { + let mut zone = Zonefile::with_capacity(s.len()); + zone.extend_from_slice(s.as_bytes()); + let entry = EntryScanner::new(&mut zone).unwrap(); + entry.zonefile.buf.next_item().unwrap(); + op(entry) + } + + #[test] + fn scan_symbols() { + fn test(zone: &str, tok: impl AsRef<[u8]>) { + with_entry(zone, |mut entry| { + let mut tok = tok.as_ref(); + entry + .scan_symbols(|sym| { + let sym = sym.into_octet().unwrap(); + assert_eq!(sym, tok[0]); + tok = &tok[1..]; + Ok(()) + }) + .unwrap(); + }); + } + + test(" unquoted\n", b"unquoted"); + test(" unquoted ", b"unquoted"); + test("unquoted ", b"unquoted"); + test("unqu\\oted ", b"unquoted"); + test("unqu\\111ted ", b"unquoted"); + test(" \"quoted\"\n", b"quoted"); + test(" \"quoted\" ", b"quoted"); + test("\"quoted\" ", b"quoted"); + } + + #[derive(serde::Deserialize)] + #[allow(clippy::type_complexity)] + struct TestCase { + origin: Dname, + zonefile: std::string::String, + result: Vec, ZoneRecordData>>>, + } + + impl TestCase { + fn test(yaml: &str) { + let case = serde_yaml::from_str::(yaml).unwrap(); + let mut input = case.zonefile.as_bytes(); + let mut zone = Zonefile::load(&mut input).unwrap(); + zone.set_origin(case.origin); + let mut result = case.result.as_slice(); + while let Some(entry) = zone.next_entry().unwrap() { + match entry { + Entry::Record(record) => { + let (first, tail) = result.split_first().unwrap(); + assert_eq!(first, &record); + result = tail; + } + _ => panic!(), + } + } + } + } + + #[test] + fn test_data() { + TestCase::test(include_str!("../../test-data/zonefiles/basic.yaml")); + TestCase::test(include_str!("../../test-data/zonefiles/escape.yaml")); + TestCase::test(include_str!("../../test-data/zonefiles/unknown.yaml")); + } +} +*/ diff --git a/tests/net/deckard/server.rs b/tests/net/deckard/server.rs new file mode 100644 index 000000000..f8eecf760 --- /dev/null +++ b/tests/net/deckard/server.rs @@ -0,0 +1,133 @@ +use crate::net::deckard::client::CurrStepValue; +use crate::net::deckard::matches::match_msg; +use crate::net::deckard::parse_deckard; +use crate::net::deckard::parse_deckard::{Adjust, Deckard, Reply}; +use crate::net::deckard::parse_query; +use domain::base::iana::rcode::Rcode; +use domain::base::{Message, MessageBuilder}; +use domain::dep::octseq::Octets; +use domain::zonefile::inplace::Entry as ZonefileEntry; + +pub fn do_server<'a, Oct: Clone + Octets + 'a>( + msg: &'a Message, + deckard: &Deckard, + step_value: &CurrStepValue, +) -> Option>> +where + ::Range<'a>: Clone, +{ + let ranges = &deckard.scenario.ranges; + let step = step_value.get(); + for range in ranges { + if step < range.start_value || step > range.end_value { + continue; + } + for entry in &range.entry { + if !match_msg(entry, msg, false) { + continue; + } + let reply = do_adjust(entry, msg); + return Some(reply); + } + } + todo!(); +} + +fn do_adjust( + entry: &parse_deckard::Entry, + reqmsg: &Message, +) -> Message> { + let sections = entry.sections.as_ref().unwrap(); + let adjust: Adjust = match &entry.adjust { + Some(adjust) => adjust.clone(), + None => Default::default(), + }; + let mut msg = MessageBuilder::new_vec().question(); + if adjust.copy_query { + for q in reqmsg.question() { + msg.push(q.unwrap()).unwrap(); + } + } else { + for q in §ions.question { + let question = match q { + parse_query::Entry::QueryRecord(question) => question, + _ => todo!(), + }; + msg.push(question).unwrap(); + } + } + let mut msg = msg.answer(); + for a in §ions.answer { + let rec = if let ZonefileEntry::Record(record) = a { + record + } else { + panic!("include not expected") + }; + msg.push(rec).unwrap(); + } + let mut msg = msg.authority(); + for a in §ions.authority { + let rec = if let ZonefileEntry::Record(record) = a { + record + } else { + panic!("include not expected") + }; + msg.push(rec).unwrap(); + } + let mut msg = msg.additional(); + for _a in §ions.additional { + todo!(); + } + let reply: Reply = match &entry.reply { + Some(reply) => reply.clone(), + None => Default::default(), + }; + if reply.aa { + msg.header_mut().set_aa(true); + } + if reply.ad { + todo!() + } + if reply.cd { + todo!() + } + if reply.fl_do { + todo!() + } + if reply.formerr { + todo!() + } + if reply.noerror { + msg.header_mut().set_rcode(Rcode::NoError); + } + if reply.nxdomain { + todo!() + } + if reply.qr { + msg.header_mut().set_qr(true); + } + if reply.ra { + todo!() + } + if reply.rd { + msg.header_mut().set_rd(true); + } + if reply.refused { + todo!() + } + if reply.servfail { + todo!() + } + if reply.tc { + todo!() + } + if reply.yxdomain { + todo!() + } + if adjust.copy_id { + msg.header_mut().set_id(reqmsg.header().id()); + } else { + todo!(); + } + msg.into_message() +} diff --git a/tests/net/mod.rs b/tests/net/mod.rs new file mode 100644 index 000000000..4e7b62367 --- /dev/null +++ b/tests/net/mod.rs @@ -0,0 +1 @@ +pub mod deckard; From dd4e2360208e16c3413c7295d7dcc0acb518be4e Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Wed, 27 Dec 2023 16:20:26 +0100 Subject: [PATCH 094/124] Datagram transports. This abstracts from UDP. --- examples/client-transports.rs | 27 ++- src/net/client/dgram.rs | 436 ++++++++++++++++++++++++++++++++++ src/net/client/mod.rs | 1 + src/net/client/protocol.rs | 147 +++++++++++- 4 files changed, 608 insertions(+), 3 deletions(-) create mode 100644 src/net/client/dgram.rs diff --git a/examples/client-transports.rs b/examples/client-transports.rs index a0d8469c5..475c18721 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -2,9 +2,10 @@ use domain::base::Dname; use domain::base::MessageBuilder; use domain::base::Rtype::Aaaa; +use domain::net::client::dgram; use domain::net::client::multi_stream; use domain::net::client::octet_stream; -use domain::net::client::protocol::{TcpConnect, TlsConnect}; +use domain::net::client::protocol::{TcpConnect, TlsConnect, UdpConnect}; use domain::net::client::redundant; use domain::net::client::request::{RequestMessage, SendRequest}; use domain::net::client::udp; @@ -34,7 +35,8 @@ async fn main() { let req = RequestMessage::new(msg); // Destination for UDP and TCP - let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); + let server_addr = + SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); let multi_stream_config = multi_stream::Config { octet_stream: Some(octet_stream::Config { @@ -199,6 +201,27 @@ async fn main() { let reply = request.get_response().await; println!("UDP reply: {:?}", reply); + // Create a new datagram transport connection. Pass the destination address + // and port as parameter. This transport does not retry over TCP if the + // reply is truncated. This transport does not have a separate run + // function. + let dgram_config = dgram::Config { + max_parallel: 1, + read_timeout: Duration::from_millis(1000), + max_retries: 1, + udp_payload_size: Some(1400), + }; + let udp_connect = UdpConnect::new(server_addr); + let dgram_conn = + dgram::Connection::new(Some(dgram_config), udp_connect).unwrap(); + + // Send a query message. + let mut request = dgram_conn.send_request(&req).await.unwrap(); + + // Get the reply + let reply = request.get_response().await; + println!("Dgram reply: {:?}", reply); + // Create a single TCP transport connection. This is usefull for a // single request or a small burst of requests. let tcp_conn = match TcpStream::connect(server_addr).await { diff --git a/src/net/client/dgram.rs b/src/net/client/dgram.rs new file mode 100644 index 000000000..dcce66304 --- /dev/null +++ b/src/net/client/dgram.rs @@ -0,0 +1,436 @@ +//! A DNS over datagram transport + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +// To do: +// - cookies + +use bytes::Bytes; +use octseq::Octets; +use std::boxed::Box; +use std::fmt::{Debug, Formatter}; +use std::future::Future; +use std::io::ErrorKind; +use std::pin::Pin; +use std::sync::Arc; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio::time::{timeout, Duration, Instant}; + +use crate::base::iana::Rcode; +use crate::base::Message; +use crate::net::client::protocol::{ + AsyncConnect, AsyncDgramRecv, AsyncDgramSend, +}; +use crate::net::client::request::{ + ComposeRequest, Error, GetResponse, SendRequest, +}; + +/// Default configuration value for the maximum number of parallel DNS query +/// over a single datagram transport connection. +const DEF_MAX_PARALLEL: usize = 100; + +/// Minimum configuration value for max_parallel. +const MIN_MAX_PARALLEL: usize = 1; + +/// Maximum configuration value for max_parallel. +const MAX_MAX_PARALLEL: usize = 1000; + +/// Default configuration value for the maximum amount of time to wait for a +/// reply. +const DEF_READ_TIMEOUT: Duration = Duration::from_secs(5); + +/// Minimum configuration value for read_timeout. +const MIN_READ_TIMEOUT: Duration = Duration::from_millis(1); + +/// Maximum configuration value for read_timeout. +const MAX_READ_TIMEOUT: Duration = Duration::from_secs(60); + +/// Default configuration value for maximum number of retries after timeouts. +const DEF_MAX_RETRIES: u8 = 5; + +/// Minimum allowed configuration value for max_retries. +const MIN_MAX_RETRIES: u8 = 1; + +/// Maximum allowed configuration value for max_retries. +const MAX_MAX_RETRIES: u8 = 100; + +/// Default UDP payload size. See draft-ietf-dnsop-avoid-fragmentation-15 +/// for discussion. +const DEF_UDP_PAYLOAD_SIZE: u16 = 1232; + +//------------ Config --------------------------------------------------------- + +/// Configuration for a datagram transport connection. +#[derive(Clone, Debug)] +pub struct Config { + /// Maximum number of parallel requests for a transport connection. + pub max_parallel: usize, + + /// Read timeout. + pub read_timeout: Duration, + + /// Maimum number of retries. + pub max_retries: u8, + + /// EDNS(0) UDP payload size. Set this value to None to be able to create + /// a DNS request without ENDS(0) option. + pub udp_payload_size: Option, +} + +impl Default for Config { + fn default() -> Self { + Self { + max_parallel: DEF_MAX_PARALLEL, + read_timeout: DEF_READ_TIMEOUT, + max_retries: DEF_MAX_RETRIES, + udp_payload_size: Some(DEF_UDP_PAYLOAD_SIZE), + } + } +} + +//------------ Connection ----------------------------------------------------- + +/// A datagram transport connection. +#[derive(Clone, Debug)] +pub struct Connection { + /// Reference to the actual connection object. + inner: Arc>, +} + +impl< + S: AsyncConnect + Clone + Send + Sync + 'static, + C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + 'static, + > Connection +{ + /// Create a new datagram transport connection. + pub fn new( + config: Option, + connect: S, + ) -> Result, Error> { + let config = match config { + Some(config) => { + check_config(&config)?; + config + } + None => Default::default(), + }; + let connection = InnerConnection::new(config, connect)?; + Ok(Self { + inner: Arc::new(connection), + }) + } + + /// Start a new DNS request. + async fn request_impl< + CR: ComposeRequest + Clone + Send + Sync + 'static, + >( + &self, + request_msg: &CR, + ) -> Result, Error> { + let gr = self.inner.request(request_msg, self.clone()).await?; + Ok(Box::new(gr)) + } + + /// Get a permit from the semaphore to start using a socket. + async fn get_permit(&self) -> OwnedSemaphorePermit { + self.inner.get_permit().await + } +} + +impl< + S: AsyncConnect + Clone + Send + Sync + 'static, + C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + 'static, + CR: ComposeRequest + Clone + Send + Sync + 'static, + > SendRequest for Connection +{ + fn send_request<'a>( + &'a self, + request_msg: &'a CR, + ) -> Pin< + Box< + dyn Future, Error>> + + Send + + '_, + >, + > { + return Box::pin(self.request_impl(request_msg)); + } +} + +//------------ ReqResp -------------------------------------------------------- + +/// The state of a DNS request. +pub struct ReqResp { + /// Future that does the actual work of GetResponse. + get_response_fut: + Pin, Error>> + Send>>, +} + +impl ReqResp { + /// Create new ReqResp object. + fn new< + S: AsyncConnect + Clone + Send + Sync + 'static, + C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + 'static, + CR: ComposeRequest + Clone + Send + Sync + 'static, + >( + config: Config, + request_msg: &CR, + conn: Connection, + udp_payload_size: Option, + connect: S, + ) -> Self { + Self { + get_response_fut: Box::pin(Self::get_response_impl2( + config, + request_msg.clone(), + conn, + udp_payload_size, + connect, + )), + } + } + + /// Async function that waits for the future stored in Query to complete. + async fn get_response_impl(&mut self) -> Result, Error> { + (&mut self.get_response_fut).await + } + + /// Get the response of a DNS request. + /// + /// This function is not cancel safe. + async fn get_response_impl2< + S: AsyncConnect + Clone + Send + Sync + 'static, + C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + 'static, + CR: ComposeRequest, + >( + config: Config, + mut request_bmb: CR, + conn: Connection, + udp_payload_size: Option, + connect: S, + ) -> Result, Error> { + let recv_size = 2000; // Should be configurable. + + let mut retries: u8 = 0; + + // We need to get past the semaphore that limits the + // number of concurrent sockets we can use. + let _permit = conn.get_permit().await; + + loop { + let sock = connect + .connect() + .await + .map_err(|e| Error::UdpConnect(Arc::new(e)))?; + + // Set random ID in header + let header = request_bmb.header_mut(); + header.set_random_id(); + // Set UDP payload size + if let Some(size) = udp_payload_size { + request_bmb.set_udp_payload_size(size) + } + let request_msg = request_bmb.to_message(); + let dgram = request_msg.as_slice(); + + let sent = sock + .send(dgram) + .await + .map_err(|e| Error::UdpSend(Arc::new(e)))?; + if sent != dgram.len() { + return Err(Error::UdpShortSend); + } + + let start = Instant::now(); + + loop { + let elapsed = start.elapsed(); + if elapsed > config.read_timeout { + // Break out of the receive loop and continue in the + // transmit loop. + break; + } + let remain = config.read_timeout - elapsed; + + let buf = vec![0; recv_size]; // XXX use uninit'ed mem here. + let timeout_res = timeout(remain, sock.recv(buf)).await; + if timeout_res.is_err() { + retries += 1; + if retries < config.max_retries { + // Break out of the receive loop and continue in the + // transmit loop. + break; + } + return Err(Error::UdpTimeoutNoResponse); + } + let buf = timeout_res + .expect("errror case is checked above") + .map_err(|e| Error::UdpReceive(Arc::new(e)))?; + + // We ignore garbage since there is a timer on this whole + // thing. + let answer = match Message::from_octets(buf.into()) { + // Just go back to receiving. + Ok(answer) => answer, + Err(_) => continue, + }; + + if !is_answer(&answer, &request_msg) { + // Wrong answer, go back to receiving + continue; + } + return Ok(answer); + } + retries += 1; + if retries < config.max_retries { + continue; + } + break; + } + Err(Error::UdpTimeoutNoResponse) + } +} + +impl Debug for ReqResp { + fn fmt(&self, _: &mut Formatter<'_>) -> Result<(), core::fmt::Error> { + todo!() + } +} + +impl GetResponse for ReqResp { + fn get_response( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + > { + Box::pin(self.get_response_impl()) + } +} + +//------------ InnerConnection ------------------------------------------------ + +/// Actual implementation of the datagram transport connection. +#[derive(Debug)] +struct InnerConnection { + /// User configuration variables. + config: Config, + + /// Connections to datagram sockets. + connect: S, + + /// Semaphore to limit access to UDP sockets. + semaphore: Arc, +} + +impl< + S: AsyncConnect + Clone + Send + Sync + 'static, + C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + 'static, + > InnerConnection +{ + /// Create new InnerConnection object. + fn new(config: Config, connect: S) -> Result, Error> { + let max_parallel = config.max_parallel; + Ok(Self { + config, + connect, + semaphore: Arc::new(Semaphore::new(max_parallel)), + }) + } + + /// Return a Query object that contains the query state. + async fn request( + &self, + request_msg: &CR, + conn: Connection, + ) -> Result { + Ok(ReqResp::new( + self.config.clone(), + request_msg, + conn, + self.config.udp_payload_size, + self.connect.clone(), + )) + } + + /// Return a permit for a our semaphore. + async fn get_permit(&self) -> OwnedSemaphorePermit { + self.semaphore + .clone() + .acquire_owned() + .await + .expect("the semaphore has not been closed") + } +} + +//------------ Utility -------------------------------------------------------- + +/// Check if config is valid. +fn check_config(config: &Config) -> Result<(), Error> { + if config.max_parallel < MIN_MAX_PARALLEL + || config.max_parallel > MAX_MAX_PARALLEL + { + return Err(Error::UdpConfigError(Arc::new(std::io::Error::new( + ErrorKind::Other, + "max_parallel", + )))); + } + + if config.read_timeout < MIN_READ_TIMEOUT + || config.read_timeout > MAX_READ_TIMEOUT + { + return Err(Error::UdpConfigError(Arc::new(std::io::Error::new( + ErrorKind::Other, + "read_timeout", + )))); + } + + if config.max_retries < MIN_MAX_RETRIES + || config.max_retries > MAX_MAX_RETRIES + { + return Err(Error::UdpConfigError(Arc::new(std::io::Error::new( + ErrorKind::Other, + "max_retries", + )))); + } + Ok(()) +} + +/// Check if a message is a valid reply for a query. Allow the question section +/// to be empty if there is an error or if the reply is truncated. +fn is_answer< + QueryOcts: AsRef<[u8]> + Octets, + ReplyOcts: AsRef<[u8]> + Octets, +>( + reply: &Message, + query: &Message, +) -> bool { + let reply_header = reply.header(); + let reply_hcounts = reply.header_counts(); + + // First check qr and id + if !reply_header.qr() || reply_header.id() != query.header().id() { + return false; + } + + // If either tc is set or the result is an error, then the question + // section can be empty. In that case we require all other sections + // to be empty as well. + if (reply_header.tc() || reply_header.rcode() != Rcode::NoError) + && reply_hcounts.qdcount() == 0 + && reply_hcounts.ancount() == 0 + && reply_hcounts.nscount() == 0 + && reply_hcounts.arcount() == 0 + { + // We can accept this as a valid reply. + return true; + } + + // Remaining checks. The question section in the reply has to be the + // same as in the query. + if reply_hcounts.qdcount() != query.header_counts().qdcount() { + false + } else { + reply.question() == query.question() + } +} diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index d2033e7d7..dfbd72087 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -11,6 +11,7 @@ #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] +pub mod dgram; pub mod multi_stream; pub mod octet_stream; pub mod protocol; diff --git a/src/net/client/protocol.rs b/src/net/client/protocol.rs index 531a314c8..566ea6f81 100644 --- a/src/net/client/protocol.rs +++ b/src/net/client/protocol.rs @@ -4,12 +4,17 @@ use core::future::Future; use core::pin::Pin; use std::boxed::Box; use std::io; +use std::net::SocketAddr; use std::sync::Arc; -use tokio::net::{TcpStream, ToSocketAddrs}; +use std::vec::Vec; +use tokio::net::{TcpStream, ToSocketAddrs, UdpSocket}; use tokio_rustls::client::TlsStream; use tokio_rustls::rustls::{ClientConfig, ServerName}; use tokio_rustls::TlsConnector; +/// How many times do we try a new random port if we get ‘address in use.’ +const RETRY_RANDOM_PORT: usize = 10; + //------------ AsyncConnect -------------------------------------------------- /// Establish a connection asynchronously. @@ -114,3 +119,143 @@ where }) } } + +//------------ AsyncDgramRecv ------------------------------------------------- + +/// Receive a datagram packet asynchronously. +/// +/// +pub trait AsyncDgramRecv { + /// The future performing the receive operation. + type Fut: Future, io::Error>> + Send; + + /// Returns a future that performs the receive operation. + fn recv(&self, buf: Vec) -> Self::Fut; +} + +//------------ AsyncDgramSend ------------------------------------------------- + +/// Send a datagram packet asynchronously. +/// +/// +pub trait AsyncDgramSend { + /// The future performing the send operation. + type Fut: Future> + Send; + + /// Returns a future that performs the send operation. + fn send(&self, buf: &[u8]) -> Self::Fut; +} + +//------------ UdpConnect -------------------------------------------------- + +/// Create new TCP connections. +#[derive(Clone, Copy, Debug)] +pub struct UdpConnect { + /// Remote address to connect to. + addr: SocketAddr, +} + +impl UdpConnect { + /// Create new UDP connections. + /// + /// addr is the destination address to connect to. + pub fn new(addr: SocketAddr) -> Self { + Self { addr } + } +} + +impl AsyncConnect for UdpConnect { + type Connection = UdpDgram; + type Fut = Pin< + Box< + dyn Future> + + Send, + >, + >; + + fn connect(&self) -> Self::Fut { + Box::pin(UdpDgram::new(self.addr)) + } +} + +/// A single UDP 'connection' +pub struct UdpDgram { + /// Underlying UDP socket + sock: Arc, +} + +impl UdpDgram { + /// Create a new UdpDgram object. + async fn new(addr: SocketAddr) -> Result { + let sock = Self::udp_bind(addr.is_ipv4()).await?; + sock.connect(addr).await?; + Ok(Self { + sock: Arc::new(sock), + }) + } + /// Bind to a local UDP port. + /// + /// This should explicitly pick a random number in a suitable range of + /// ports. + async fn udp_bind(v4: bool) -> Result { + let mut i = 0; + loop { + let local: SocketAddr = if v4 { + ([0u8; 4], 0).into() + } else { + ([0u16; 8], 0).into() + }; + match UdpSocket::bind(&local).await { + Ok(sock) => return Ok(sock), + Err(err) => { + if i == RETRY_RANDOM_PORT { + return Err(err); + } else { + i += 1 + } + } + } + } + } +} + +impl AsyncDgramRecv for UdpDgram { + type Fut = + Pin, io::Error>> + Send>>; + fn recv(&self, mut buf: Vec) -> Self::Fut { + let sock = self.sock.clone(); + Box::pin(async move { + let len = sock.recv(&mut buf).await?; + buf.truncate(len); + Ok(buf) + }) + } +} + +impl AsyncDgramSend for UdpDgram { + type Fut = Pin> + Send>>; + fn send(&self, buf: &[u8]) -> Self::Fut { + let sock = self.sock.clone(); + let buf = buf.to_vec(); + Box::pin(async move { sock.send(&buf).await }) + } +} + +/* +struct Sender { + sock: Arc, + buf: Vec +} + +impl Sender { + fn new() -> Self { Self } +} + +impl Future for Sender { + type Output = Result; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> + Poll { + self.sock.poll_send(cx, &self.buf) + } +} +*/ From 5f01c4f399cfaa95a830de4806d5200f486febf5 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 28 Dec 2023 10:53:29 +0100 Subject: [PATCH 095/124] Test for dgram. --- examples/client-transports.rs | 3 +- tests/net-client.rs | 16 ++++++ tests/net/deckard/dgram.rs | 96 +++++++++++++++++++++++++++++++++++ tests/net/deckard/mod.rs | 1 + 4 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 tests/net/deckard/dgram.rs diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 475c18721..265dcbe77 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -35,8 +35,7 @@ async fn main() { let req = RequestMessage::new(msg); // Destination for UDP and TCP - let server_addr = - SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); + let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); let multi_stream_config = multi_stream::Config { octet_stream: Some(octet_stream::Config { diff --git a/tests/net-client.rs b/tests/net-client.rs index bda152c00..297a686e2 100644 --- a/tests/net-client.rs +++ b/tests/net-client.rs @@ -5,7 +5,9 @@ use crate::net::deckard::client::do_client; use crate::net::deckard::client::CurrStepValue; use crate::net::deckard::connect::Connect; use crate::net::deckard::connection::Connection; +use crate::net::deckard::dgram::Dgram; use crate::net::deckard::parse_deckard::parse_file; +use domain::net::client::dgram; use domain::net::client::multi_stream; use domain::net::client::octet_stream; use domain::net::client::redundant; @@ -19,6 +21,20 @@ use tokio_test; const TEST_FILE: &str = "test-data/basic.rpl"; +#[test] +fn dgram() { + tokio_test::block_on(async { + let file = File::open(TEST_FILE).unwrap(); + let deckard = parse_file(file); + + let step_value = Arc::new(CurrStepValue::new()); + let conn = Dgram::new(deckard.clone(), step_value.clone()); + let octstr = dgram::Connection::new(None, conn).unwrap(); + + do_client(&deckard, octstr, &step_value).await; + }); +} + #[test] fn single() { tokio_test::block_on(async { diff --git a/tests/net/deckard/dgram.rs b/tests/net/deckard/dgram.rs new file mode 100644 index 000000000..e14e501c7 --- /dev/null +++ b/tests/net/deckard/dgram.rs @@ -0,0 +1,96 @@ +//! Provide server-side of datagram protocols + +use crate::net::deckard::client::CurrStepValue; +use crate::net::deckard::parse_deckard::Deckard; +use crate::net::deckard::server::do_server; +use domain::base::Message; +use domain::net::client::protocol::{ + AsyncConnect, AsyncDgramRecv, AsyncDgramSend, +}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; + +#[derive(Clone)] +pub struct Dgram { + deckard: Deckard, + step_value: Arc, +} + +impl Dgram { + pub fn new(deckard: Deckard, step_value: Arc) -> Self { + Self { + deckard, + step_value, + } + } +} + +impl AsyncConnect for Dgram { + type Connection = DgramConnection; + type Fut = Pin< + Box< + dyn Future> + + Send, + >, + >; + fn connect(&self) -> Self::Fut { + let deckard = self.deckard.clone(); + let step_value = self.step_value.clone(); + Box::pin(async move { Ok(DgramConnection::new(deckard, step_value)) }) + } +} + +pub struct DgramConnection { + deckard: Deckard, + step_value: Arc, + + sender: mpsc::Sender>>, + receiver: Arc>>>>, +} + +impl DgramConnection { + fn new(deckard: Deckard, step_value: Arc) -> Self { + let (sender, receiver) = mpsc::channel(2); + Self { + deckard, + step_value, + sender, + receiver: Arc::new(Mutex::new(receiver)), + } + } +} +impl AsyncDgramRecv for DgramConnection { + type Fut = + Pin, std::io::Error>> + Send>>; + fn recv(&self, buf: Vec) -> Self::Fut { + let arc_m_rec = self.receiver.clone(); + Box::pin(async move { + let mut rec = arc_m_rec.lock().await; + let msg = (*rec).recv().await.unwrap(); + let msg_octets = msg.into_octets(); + if msg_octets.len() > buf.len() { + panic!("test returned reply that is bigger than buffer"); + } + Ok(msg_octets) + }) + } +} + +impl AsyncDgramSend for DgramConnection { + type Fut = + Pin> + Send>>; + fn send(&self, buf: &[u8]) -> Self::Fut { + let msg = Message::from_octets(buf).unwrap(); + let opt_reply = do_server(&msg, &self.deckard, &self.step_value); + let sender = self.sender.clone(); + let len = buf.len(); + Box::pin(async move { + if opt_reply.is_some() { + sender.send(opt_reply.unwrap()).await.unwrap(); + } + Ok(len) + }) + } +} diff --git a/tests/net/deckard/mod.rs b/tests/net/deckard/mod.rs index c3eb548f1..c61857ce9 100644 --- a/tests/net/deckard/mod.rs +++ b/tests/net/deckard/mod.rs @@ -1,6 +1,7 @@ pub mod client; pub mod connect; pub mod connection; +pub mod dgram; mod matches; pub mod parse_deckard; mod parse_query; From 4e6f4cddef0a0c28eff117eea6f766eb2990311c Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 28 Dec 2023 15:11:12 +0100 Subject: [PATCH 096/124] Introduce dgram_stream to replace udp_tcp --- examples/client-transports.rs | 25 +-- src/net/client/dgram_stream.rs | 290 +++++++++++++++++++++++++++++++++ src/net/client/mod.rs | 2 +- 3 files changed, 305 insertions(+), 12 deletions(-) create mode 100644 src/net/client/dgram_stream.rs diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 265dcbe77..88833436d 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -3,13 +3,13 @@ use domain::base::Dname; use domain::base::MessageBuilder; use domain::base::Rtype::Aaaa; use domain::net::client::dgram; +use domain::net::client::dgram_stream; use domain::net::client::multi_stream; use domain::net::client::octet_stream; use domain::net::client::protocol::{TcpConnect, TlsConnect, UdpConnect}; use domain::net::client::redundant; use domain::net::client::request::{RequestMessage, SendRequest}; use domain::net::client::udp; -use domain::net::client::udp_tcp; use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; use std::time::Duration; @@ -51,18 +51,27 @@ async fn main() { max_retries: 1, udp_payload_size: Some(1400), }; - let udp_tcp_config = udp_tcp::Config { - udp: Some(udp_config.clone()), + let dgram_config = dgram::Config { + max_parallel: 1, + read_timeout: Duration::from_millis(1000), + max_retries: 1, + udp_payload_size: Some(1400), + }; + let dgram_stream_config = dgram_stream::Config { + dgram: Some(dgram_config.clone()), multi_stream: Some(multi_stream_config.clone()), }; + let udp_connect = UdpConnect::new(server_addr); + let tcp_connect = TcpConnect::new(server_addr); let udptcp_conn = - udp_tcp::Connection::new(Some(udp_tcp_config), server_addr).unwrap(); + dgram_stream::Connection::new(Some(dgram_stream_config), udp_connect) + .unwrap(); // Start the run function in a separate task. The run function will // terminate when all references to the connection have been dropped. // Make sure that the task does not accidentally get a reference to the // connection. - let run_fut = udptcp_conn.run(); + let run_fut = udptcp_conn.run(tcp_connect); tokio::spawn(async move { let res = run_fut.await; println!("UDP+TCP run exited with {:?}", res); @@ -204,12 +213,6 @@ async fn main() { // and port as parameter. This transport does not retry over TCP if the // reply is truncated. This transport does not have a separate run // function. - let dgram_config = dgram::Config { - max_parallel: 1, - read_timeout: Duration::from_millis(1000), - max_retries: 1, - udp_payload_size: Some(1400), - }; let udp_connect = UdpConnect::new(server_addr); let dgram_conn = dgram::Connection::new(Some(dgram_config), udp_connect).unwrap(); diff --git a/src/net/client/dgram_stream.rs b/src/net/client/dgram_stream.rs new file mode 100644 index 000000000..89b96249d --- /dev/null +++ b/src/net/client/dgram_stream.rs @@ -0,0 +1,290 @@ +//! A UDP transport that falls back to TCP if the reply is truncated + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +// To do: +// - handle shutdown + +use bytes::Bytes; +use std::boxed::Box; +use std::fmt::Debug; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use crate::base::Message; +use crate::net::client::dgram; +use crate::net::client::multi_stream; +use crate::net::client::protocol::{ + AsyncConnect, AsyncDgramRecv, AsyncDgramSend, +}; +use crate::net::client::request::{ + ComposeRequest, Error, GetResponse, SendRequest, +}; + +use tokio::io::{AsyncRead, AsyncWrite}; + +//------------ Config --------------------------------------------------------- + +/// Configuration for an octet_stream transport connection. +#[derive(Clone, Debug, Default)] +pub struct Config { + /// Configuration for the UDP transport. + pub dgram: Option, + + /// Configuration for the multi_stream (TCP) transport. + pub multi_stream: Option, +} + +//------------ Connection ----------------------------------------------------- + +/// DNS transport connection that first issues a query over a UDP transport and +/// falls back to TCP if the reply is truncated. +#[derive(Clone)] +pub struct Connection { + /// Reference to the real object that provides the connection. + inner: Arc>, +} + +impl< + S: AsyncConnect + Clone + Debug + Send + Sync + 'static, + CR: ComposeRequest + Clone + 'static, + > Connection +where + S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync, +{ + /// Create a new connection. + pub fn new( + config: Option, + dgram_connect: S, + ) -> Result { + let config = match config { + Some(config) => { + check_config(&config)?; + config + } + None => Default::default(), + }; + let connection = InnerConnection::new(config, dgram_connect)?; + Ok(Self { + inner: Arc::new(connection), + }) + } + + /// Worker function for a connection object. + pub fn run( + &self, + stream_connect: SC, + ) -> Pin> + Send>> + where + SC::Connection: AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, + { + self.inner.run(stream_connect) + } + + /// Start a request for the Request trait. + async fn request_impl( + &self, + request_msg: &CR, + ) -> Result, Error> { + let gr = self.inner.request(request_msg).await?; + Ok(Box::new(gr)) + } +} + +impl< + S: AsyncConnect + Clone + Debug + Send + Sync + 'static, + CR: ComposeRequest + Clone + 'static, + > SendRequest for Connection +where + S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync, +{ + fn send_request<'a>( + &'a self, + request_msg: &'a CR, + ) -> Pin< + Box< + dyn Future, Error>> + + Send + + '_, + >, + > { + return Box::pin(self.request_impl(request_msg)); + } +} + +//------------ ReqResp -------------------------------------------------------- + +/// Object that contains the current state of a query. +#[derive(Debug)] +pub struct ReqResp { + /// Reqeust message. + request_msg: BMB, + + /// UDP transport to be used. + udp_conn: dgram::Connection, + + /// TCP transport to be used. + tcp_conn: multi_stream::Connection, + + /// Current state of the request. + state: QueryState, +} + +/// Status of the query. +#[derive(Debug)] +enum QueryState { + /// Start a request over the UDP transport. + StartUdpRequest, + + /// Get the response from the UDP transport. + GetUdpResponse(Box), + + /// Start a request over the TCP transport. + StartTcpRequest, + + /// Get the response from the TCP transport. + GetTcpResponse(Box), +} + +impl< + S: AsyncConnect + Clone + Send + Sync + 'static, + CR: ComposeRequest + Clone + 'static, + > ReqResp +{ + /// Create a new ReqResp object. + /// + /// The initial state is to start with a UDP transport. + fn new( + request_msg: &CR, + udp_conn: dgram::Connection, + tcp_conn: multi_stream::Connection, + ) -> ReqResp { + Self { + request_msg: request_msg.clone(), + udp_conn, + tcp_conn, + state: QueryState::StartUdpRequest, + } + } + + /// Get the response of a DNS request. + /// + /// This function is cancel safe. + async fn get_response_impl(&mut self) -> Result, Error> + where + S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync, + { + loop { + match &mut self.state { + QueryState::StartUdpRequest => { + let msg = self.request_msg.clone(); + let request = self.udp_conn.send_request(&msg).await?; + self.state = QueryState::GetUdpResponse(request); + continue; + } + QueryState::GetUdpResponse(ref mut request) => { + let response = request.get_response().await?; + if response.header().tc() { + self.state = QueryState::StartTcpRequest; + continue; + } + return Ok(response); + } + QueryState::StartTcpRequest => { + let msg = self.request_msg.clone(); + let request = self.tcp_conn.send_request(&msg).await?; + self.state = QueryState::GetTcpResponse(request); + continue; + } + QueryState::GetTcpResponse(ref mut query) => { + let response = query.get_response().await?; + return Ok(response); + } + } + } + } +} + +impl< + S: AsyncConnect + Clone + Debug + Send + Sync + 'static, + CR: ComposeRequest + Clone + Debug + 'static, + > GetResponse for ReqResp +where + S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync, +{ + fn get_response( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + > { + Box::pin(self.get_response_impl()) + } +} + +//------------ InnerConnection ------------------------------------------------ + +/// The actual connection object. +struct InnerConnection { + /// The UDP transport connection. + udp_conn: dgram::Connection, + + /// The TCP transport connection. + tcp_conn: multi_stream::Connection, +} + +impl< + S: AsyncConnect + Clone + Send + Sync + 'static, + CR: ComposeRequest + Clone + 'static, + > InnerConnection +{ + /// Create a new InnerConnection object. + /// + /// Create the UDP and TCP connections. Store the remote address because + /// run needs it later. + fn new(config: Config, dgram_connect: S) -> Result + where + S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync, + { + let udp_conn = dgram::Connection::new(config.dgram, dgram_connect)?; + let tcp_conn = multi_stream::Connection::new(config.multi_stream)?; + + Ok(Self { udp_conn, tcp_conn }) + } + + /// Implementation of the worker function. + /// + /// Create a TCP connect object and pass that to run function + /// of the multi_stream object. + fn run( + &self, + stream_connect: SC, + ) -> Pin> + Send>> + where + SC::Connection: AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, + { + let fut = self.tcp_conn.run(stream_connect); + Box::pin(fut) + } + + /// Implementation of the request function. + /// + /// Just create a ReqResp object with the state it needs. + async fn request( + &self, + request_msg: &CR, + ) -> Result, Error> { + Ok(ReqResp::new( + request_msg, + self.udp_conn.clone(), + self.tcp_conn.clone(), + )) + } +} + +/// Check if config is valid. +fn check_config(_config: &Config) -> Result<(), Error> { + // Nothing to check at the moment. + Ok(()) +} diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index dfbd72087..8f7a38748 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -12,10 +12,10 @@ #![warn(clippy::missing_docs_in_private_items)] pub mod dgram; +pub mod dgram_stream; pub mod multi_stream; pub mod octet_stream; pub mod protocol; pub mod redundant; pub mod request; pub mod udp; -pub mod udp_tcp; From f2bbdf7d6fcee9151a930cf10fd5807b8f3669d6 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 28 Dec 2023 15:17:59 +0100 Subject: [PATCH 097/124] Test for dgram_stream. --- tests/net-client.rs | 21 +++++++++++++++++++++ tests/net/deckard/dgram.rs | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/net-client.rs b/tests/net-client.rs index 297a686e2..494d0f431 100644 --- a/tests/net-client.rs +++ b/tests/net-client.rs @@ -8,6 +8,7 @@ use crate::net::deckard::connection::Connection; use crate::net::deckard::dgram::Dgram; use crate::net::deckard::parse_deckard::parse_file; use domain::net::client::dgram; +use domain::net::client::dgram_stream; use domain::net::client::multi_stream; use domain::net::client::octet_stream; use domain::net::client::redundant; @@ -72,6 +73,26 @@ fn multi() { }); } +#[test] +fn dgram_stream() { + tokio_test::block_on(async { + let file = File::open(TEST_FILE).unwrap(); + let deckard = parse_file(file); + + let step_value = Arc::new(CurrStepValue::new()); + let conn = Dgram::new(deckard.clone(), step_value.clone()); + let multi_conn = Connect::new(deckard.clone(), step_value.clone()); + let ds = dgram_stream::Connection::new(None, conn).unwrap(); + let run_fut = ds.run(multi_conn); + tokio::spawn(async move { + run_fut.await.unwrap(); + println!("dgram_stream conn run terminated"); + }); + + do_client(&deckard, ds, &step_value).await; + }); +} + #[test] fn redundant() { tokio_test::block_on(async { diff --git a/tests/net/deckard/dgram.rs b/tests/net/deckard/dgram.rs index e14e501c7..c902e5473 100644 --- a/tests/net/deckard/dgram.rs +++ b/tests/net/deckard/dgram.rs @@ -12,7 +12,7 @@ use std::pin::Pin; use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Dgram { deckard: Deckard, step_value: Arc, From b42c6b62210f760e4c7f32b1d56a8032012850be Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 28 Dec 2023 15:28:39 +0100 Subject: [PATCH 098/124] Switch to dgram_stream --- src/resolv/stub/mod.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index 606132b70..a4e2ad43d 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -20,12 +20,12 @@ use crate::base::message_builder::{ use crate::base::name::{ToDname, ToRelativeDname}; use crate::base::question::Question; use crate::net::client::multi_stream; -use crate::net::client::protocol::TcpConnect; +use crate::net::client::protocol::{TcpConnect, UdpConnect}; use crate::net::client::redundant; use crate::net::client::request::{ ComposeRequest, RequestMessage, SendRequest, }; -use crate::net::client::udp_tcp; +use crate::net::client::dgram_stream; use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs}; use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts}; use crate::resolv::lookup::srv::{lookup_srv, FoundSrvs, SrvError}; @@ -164,10 +164,12 @@ impl StubResolver { } else { for s in &self.servers { if let Transport::Udp = s.transport { + let udp_connect = UdpConnect::new(s.addr); + let tcp_connect = TcpConnect::new(s.addr); let udptcp_conn = - udp_tcp::Connection::new(None, s.addr).unwrap(); + dgram_stream::Connection::new(None, udp_connect).unwrap(); // Start the run function on a separate task. - let run_fut = udptcp_conn.run(); + let run_fut = udptcp_conn.run(tcp_connect); fut_list_udp_tcp.push(async move { let _res = run_fut.await; }); From 240bdfc9be47527af2acce40c475dac420f130fd Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 28 Dec 2023 15:29:32 +0100 Subject: [PATCH 099/124] Remove udp and udp_tcp. --- examples/client-transports.rs | 21 -- src/net/client/mod.rs | 1 - src/net/client/udp.rs | 671 ---------------------------------- src/net/client/udp_tcp.rs | 255 ------------- 4 files changed, 948 deletions(-) delete mode 100644 src/net/client/udp.rs delete mode 100644 src/net/client/udp_tcp.rs diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 88833436d..1cb22b9c2 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -9,7 +9,6 @@ use domain::net::client::octet_stream; use domain::net::client::protocol::{TcpConnect, TlsConnect, UdpConnect}; use domain::net::client::redundant; use domain::net::client::request::{RequestMessage, SendRequest}; -use domain::net::client::udp; use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; use std::time::Duration; @@ -45,12 +44,6 @@ async fn main() { // Create a new UDP+TCP transport connection. Pass the destination address // and port as parameter. - let udp_config = udp::Config { - max_parallel: 1, - read_timeout: Duration::from_millis(1000), - max_retries: 1, - udp_payload_size: Some(1400), - }; let dgram_config = dgram::Config { max_parallel: 1, read_timeout: Duration::from_millis(1000), @@ -195,20 +188,6 @@ async fn main() { drop(redun); - // Create a new UDP transport connection. Pass the destination address - // and port as parameter. This transport does not retry over TCP if the - // reply is truncated. This transport does not have a separate run - // function. - let udp_conn = - udp::Connection::new(Some(udp_config), server_addr).unwrap(); - - // Send a query message. - let mut request = udp_conn.send_request(&req).await.unwrap(); - - // Get the reply - let reply = request.get_response().await; - println!("UDP reply: {:?}", reply); - // Create a new datagram transport connection. Pass the destination address // and port as parameter. This transport does not retry over TCP if the // reply is truncated. This transport does not have a separate run diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 8f7a38748..3720af80b 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -18,4 +18,3 @@ pub mod octet_stream; pub mod protocol; pub mod redundant; pub mod request; -pub mod udp; diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs deleted file mode 100644 index 635335462..000000000 --- a/src/net/client/udp.rs +++ /dev/null @@ -1,671 +0,0 @@ -//! A DNS over UDP transport - -#![warn(missing_docs)] -#![warn(clippy::missing_docs_in_private_items)] - -// To do: -// - cookies -// - random port - -use bytes::Bytes; -use octseq::Octets; -use std::boxed::Box; -use std::fmt::{Debug, Formatter}; -use std::future::Future; -use std::io::ErrorKind; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::Arc; -use tokio::net::UdpSocket; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; -use tokio::time::{timeout, Duration, Instant}; - -use crate::base::iana::Rcode; -use crate::base::Message; -use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, SendRequest, -}; - -/// How many times do we try a new random port if we get ‘address in use.’ -const RETRY_RANDOM_PORT: usize = 10; - -/// Default configuration value for the maximum number of parallel DNS query -/// over a single UDP transport connection. -const DEF_MAX_PARALLEL: usize = 100; - -/// Minimum configuration value for max_parallel. -const MIN_MAX_PARALLEL: usize = 1; - -/// Maximum configuration value for max_parallel. -const MAX_MAX_PARALLEL: usize = 1000; - -/// Default configuration value for the maximum amount of time to wait for a -/// reply. -const DEF_READ_TIMEOUT: Duration = Duration::from_secs(5); - -/// Minimum configuration value for read_timeout. -const MIN_READ_TIMEOUT: Duration = Duration::from_millis(1); - -/// Maximum configuration value for read_timeout. -const MAX_READ_TIMEOUT: Duration = Duration::from_secs(60); - -/// Default configuration value for maximum number of retries after timeouts. -const DEF_MAX_RETRIES: u8 = 5; - -/// Minimum allowed configuration value for max_retries. -const MIN_MAX_RETRIES: u8 = 1; - -/// Maximum allowed configuration value for max_retries. -const MAX_MAX_RETRIES: u8 = 100; - -/// Default UDP payload size. See draft-ietf-dnsop-avoid-fragmentation-15 -/// for discussion. -const DEF_UDP_PAYLOAD_SIZE: u16 = 1232; - -//------------ Config --------------------------------------------------------- - -/// Configuration for a UDP transport connection. -#[derive(Clone, Debug)] -pub struct Config { - /// Maximum number of parallel requests for a transport connection. - pub max_parallel: usize, - - /// Read timeout. - pub read_timeout: Duration, - - /// Maimum number of retries. - pub max_retries: u8, - - /// EDNS(0) UDP payload size. Set this value to None to be able to create - /// a DNS request without ENDS(0) option. - pub udp_payload_size: Option, -} - -impl Default for Config { - fn default() -> Self { - Self { - max_parallel: DEF_MAX_PARALLEL, - read_timeout: DEF_READ_TIMEOUT, - max_retries: DEF_MAX_RETRIES, - udp_payload_size: Some(DEF_UDP_PAYLOAD_SIZE), - } - } -} - -//------------ Connection ----------------------------------------------------- - -/// A UDP transport connection. -#[derive(Clone, Debug)] -pub struct Connection { - /// Reference to the actual connection object. - inner: Arc, -} - -impl Connection { - /// Create a new UDP transport connection. - pub fn new( - config: Option, - remote_addr: SocketAddr, - ) -> Result { - let config = match config { - Some(config) => { - check_config(&config)?; - config - } - None => Default::default(), - }; - let connection = InnerConnection::new(config, remote_addr)?; - Ok(Self { - inner: Arc::new(connection), - }) - } - - /// Start a new DNS request. - async fn request_impl< - CR: ComposeRequest + Clone + Send + Sync + 'static, - >( - &self, - request_msg: &CR, - ) -> Result, Error> { - let gr = self.inner.request(request_msg, self.clone()).await?; - Ok(Box::new(gr)) - } - - /// Get a permit from the semaphore to start using a socket. - async fn get_permit(&self) -> OwnedSemaphorePermit { - self.inner.get_permit().await - } -} - -impl SendRequest - for Connection -{ - fn send_request<'a>( - &'a self, - request_msg: &'a CR, - ) -> Pin< - Box< - dyn Future, Error>> - + Send - + '_, - >, - > { - return Box::pin(self.request_impl(request_msg)); - } -} - -//------------ Query ---------------------------------------------------------- - -/* - -/// State of the DNS query. -#[derive(Debug)] -enum QueryState { - /// Get a semaphore permit. - GetPermit(Connection), - - /// Get a UDP socket. - GetSocket, - - /// Connect the socket. - Connect, - - /// Send the request. - Send, - - /// Receive the reply. - Receive(Instant), -} -*/ - -/* - -/// The state of a DNS query. -#[derive(Debug)] -pub struct Query { - /// Address of remote server to connect to. - remote_addr: SocketAddr, - - /// DNS request message. - query_msg: Message, - - /// Semaphore permit that allow use of socket. - _permit: Option, - - /// UDP socket for communication. - sock: Option, - - /// Current number of retries. - retries: u8, - - /// State of query. - state: QueryState, -} - -impl Query { - /// Create new Query object. - fn new( - query_msg: Message, - remote_addr: SocketAddr, - conn: Connection, - ) -> Query { - Query { - query_msg, - remote_addr, - _permit: None, - sock: None, - retries: 0, - state: QueryState::GetPermit(conn), - } - } - - /// Get the result of a DNS Query. - /// - /// This function is cancel safe. - async fn get_result_impl(&mut self) -> Result, Error> { - let recv_size = 2000; // Should be configurable. - - loop { - match &self.state { - QueryState::GetPermit(conn) => { - // We need to get past the semaphore that limits the - // number of concurrent sockets we can use. - let permit = conn.get_permit().await; - self._permit = Some(permit); - self.state = QueryState::GetSocket; - continue; - } - QueryState::GetSocket => { - self.sock = Some( - Self::udp_bind(self.remote_addr.is_ipv4()).await?, - ); - self.state = QueryState::Connect; - continue; - } - QueryState::Connect => { - self.sock - .as_ref() - .expect("socket should be present") - .connect(self.remote_addr) - .await - .map_err(|e| Error::UdpConnect(Arc::new(e)))?; - self.state = QueryState::Send; - continue; - } - QueryState::Send => { - // Set random ID in header - let header = self.query_msg.header_mut(); - header.set_random_id(); - let dgram = self.query_msg.as_slice(); - - let sent = self - .sock - .as_ref() - .expect("socket should be present") - .send(dgram) - .await - .map_err(|e| Error::UdpSend(Arc::new(e)))?; - if sent != self.query_msg.as_slice().len() { - return Err(Error::UdpShortSend); - } - self.state = QueryState::Receive(Instant::now()); - continue; - } - QueryState::Receive(start) => { - let elapsed = start.elapsed(); - if elapsed > READ_TIMEOUT { - todo!(); - } - let remain = READ_TIMEOUT - elapsed; - - let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here. - let timeout_res = timeout( - remain, - self.sock - .as_ref() - .expect("socket should be present") - .recv(&mut buf), - ) - .await; - if timeout_res.is_err() { - self.retries += 1; - if self.retries < MAX_RETRIES { - self.sock = None; - self.state = QueryState::GetSocket; - continue; - } - return Err(Error::UdpTimeoutNoResponse); - } - let len = timeout_res - .expect("errror case is checked above") - .map_err(|e| Error::UdpReceive(Arc::new(e)))?; - buf.truncate(len); - - // We ignore garbage since there is a timer on this whole thing. - let answer = match Message::from_octets(buf.into()) { - Ok(answer) => answer, - Err(_) => continue, - }; - - // Unfortunately we cannot pass query_msg to is_answer - // because is_answer requires Octets, which is not - // implemented by BytesMut. Make a copy. - let query_msg = Message::from_octets( - self.query_msg.as_slice(), - ) - .expect( - "Message failed to parse contents of another Message", - ); - if !is_answer(answer, &query_msg) { - continue; - } - self.sock = None; - self._permit = None; - return Ok(answer); - } - } - } - } - - /// Bind to a local UDP port. - /// - /// This should explicitly pick a random number in a suitable range of - /// ports. - async fn udp_bind(v4: bool) -> Result { - let mut i = 0; - loop { - let local: SocketAddr = if v4 { - ([0u8; 4], 0).into() - } else { - ([0u16; 8], 0).into() - }; - match UdpSocket::bind(&local).await { - Ok(sock) => return Ok(sock), - Err(err) => { - if i == RETRY_RANDOM_PORT { - return Err(Error::UdpBind(Arc::new(err))); - } else { - i += 1 - } - } - } - } - } -} - -impl GetResult for Query { - fn get_result( - &mut self, - ) -> Pin< - Box, Error>> + Send + '_>, - > { - Box::pin(self.get_result_impl()) - } -} - -*/ - -//------------ ReqResp -------------------------------------------------------- - -/// The state of a DNS request. -pub struct ReqResp { - /// Future that does the actual work of GetResponse. - get_response_fut: - Pin, Error>> + Send>>, -} - -impl ReqResp { - /// Create new ReqResp object. - fn new( - config: Config, - request_msg: &CR, - remote_addr: SocketAddr, - conn: Connection, - udp_payload_size: Option, - ) -> Self { - Self { - get_response_fut: Box::pin(Self::get_response_impl2( - config, - request_msg.clone(), - remote_addr, - conn, - udp_payload_size, - )), - } - } - - /// Async function that waits for the future stored in Query to complete. - async fn get_response_impl(&mut self) -> Result, Error> { - (&mut self.get_response_fut).await - } - - /// Get the response of a DNS request. - /// - /// This function is not cancel safe. - async fn get_response_impl2( - config: Config, - mut request_bmb: CR, - remote_addr: SocketAddr, - conn: Connection, - udp_payload_size: Option, - ) -> Result, Error> { - let recv_size = 2000; // Should be configurable. - - let mut retries: u8 = 0; - - // We need to get past the semaphore that limits the - // number of concurrent sockets we can use. - let _permit = conn.get_permit().await; - - loop { - let sock = Some(Self::udp_bind(remote_addr.is_ipv4()).await?); - - sock.as_ref() - .expect("socket should be present") - .connect(remote_addr) - .await - .map_err(|e| Error::UdpConnect(Arc::new(e)))?; - - // Set random ID in header - let header = request_bmb.header_mut(); - header.set_random_id(); - // Set UDP payload size - if let Some(size) = udp_payload_size { - request_bmb.set_udp_payload_size(size) - } - let request_msg = request_bmb.to_message(); - let dgram = request_msg.as_slice(); - - let sent = sock - .as_ref() - .expect("socket should be present") - .send(dgram) - .await - .map_err(|e| Error::UdpSend(Arc::new(e)))?; - if sent != dgram.len() { - return Err(Error::UdpShortSend); - } - - let start = Instant::now(); - - loop { - let elapsed = start.elapsed(); - if elapsed > config.read_timeout { - // Break out of the receive loop and continue in the - // transmit loop. - break; - } - let remain = config.read_timeout - elapsed; - - let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here. - let timeout_res = timeout( - remain, - sock.as_ref() - .expect("socket should be present") - .recv(&mut buf), - ) - .await; - if timeout_res.is_err() { - retries += 1; - if retries < config.max_retries { - // Break out of the receive loop and continue in the - // transmit loop. - break; - } - return Err(Error::UdpTimeoutNoResponse); - } - let len = timeout_res - .expect("errror case is checked above") - .map_err(|e| Error::UdpReceive(Arc::new(e)))?; - buf.truncate(len); - - // We ignore garbage since there is a timer on this whole - // thing. - let answer = match Message::from_octets(buf.into()) { - // Just go back to receiving. - Ok(answer) => answer, - Err(_) => continue, - }; - - if !is_answer(&answer, &request_msg) { - // Wrong answer, go back to receiving - continue; - } - return Ok(answer); - } - retries += 1; - if retries < config.max_retries { - continue; - } - break; - } - Err(Error::UdpTimeoutNoResponse) - } - - /// Bind to a local UDP port. - /// - /// This should explicitly pick a random number in a suitable range of - /// ports. - async fn udp_bind(v4: bool) -> Result { - let mut i = 0; - loop { - let local: SocketAddr = if v4 { - ([0u8; 4], 0).into() - } else { - ([0u16; 8], 0).into() - }; - match UdpSocket::bind(&local).await { - Ok(sock) => return Ok(sock), - Err(err) => { - if i == RETRY_RANDOM_PORT { - return Err(Error::UdpBind(Arc::new(err))); - } else { - i += 1 - } - } - } - } - } -} - -impl Debug for ReqResp { - fn fmt(&self, _: &mut Formatter<'_>) -> Result<(), core::fmt::Error> { - todo!() - } -} - -impl GetResponse for ReqResp { - fn get_response( - &mut self, - ) -> Pin< - Box, Error>> + Send + '_>, - > { - Box::pin(self.get_response_impl()) - } -} - -//------------ InnerConnection ------------------------------------------------ - -/// Actual implementation of the UDP transport connection. -#[derive(Debug)] -struct InnerConnection { - /// User configuration variables. - config: Config, - - /// Address of the remote server. - remote_addr: SocketAddr, - - /// Semaphore to limit access to UDP sockets. - semaphore: Arc, -} - -impl InnerConnection { - /// Create new InnerConnection object. - fn new( - config: Config, - remote_addr: SocketAddr, - ) -> Result { - let max_parallel = config.max_parallel; - Ok(Self { - config, - remote_addr, - semaphore: Arc::new(Semaphore::new(max_parallel)), - }) - } - - /// Return a Query object that contains the query state. - async fn request( - &self, - request_msg: &CR, - conn: Connection, - ) -> Result { - Ok(ReqResp::new( - self.config.clone(), - request_msg, - self.remote_addr, - conn, - self.config.udp_payload_size, - )) - } - - /// Return a permit for a our semaphore. - async fn get_permit(&self) -> OwnedSemaphorePermit { - self.semaphore - .clone() - .acquire_owned() - .await - .expect("the semaphore has not been closed") - } -} - -//------------ Utility -------------------------------------------------------- - -/// Check if config is valid. -fn check_config(config: &Config) -> Result<(), Error> { - if config.max_parallel < MIN_MAX_PARALLEL - || config.max_parallel > MAX_MAX_PARALLEL - { - return Err(Error::UdpConfigError(Arc::new(std::io::Error::new( - ErrorKind::Other, - "max_parallel", - )))); - } - - if config.read_timeout < MIN_READ_TIMEOUT - || config.read_timeout > MAX_READ_TIMEOUT - { - return Err(Error::UdpConfigError(Arc::new(std::io::Error::new( - ErrorKind::Other, - "read_timeout", - )))); - } - - if config.max_retries < MIN_MAX_RETRIES - || config.max_retries > MAX_MAX_RETRIES - { - return Err(Error::UdpConfigError(Arc::new(std::io::Error::new( - ErrorKind::Other, - "max_retries", - )))); - } - Ok(()) -} - -/// Check if a message is a valid reply for a query. Allow the question section -/// to be empty if there is an error or if the reply is truncated. -fn is_answer< - QueryOcts: AsRef<[u8]> + Octets, - ReplyOcts: AsRef<[u8]> + Octets, ->( - reply: &Message, - query: &Message, -) -> bool { - let reply_header = reply.header(); - let reply_hcounts = reply.header_counts(); - - // First check qr and id - if !reply_header.qr() || reply_header.id() != query.header().id() { - return false; - } - - // If either tc is set or the result is an error, then the question - // section can be empty. In that case we require all other sections - // to be empty as well. - if (reply_header.tc() || reply_header.rcode() != Rcode::NoError) - && reply_hcounts.qdcount() == 0 - && reply_hcounts.ancount() == 0 - && reply_hcounts.nscount() == 0 - && reply_hcounts.arcount() == 0 - { - // We can accept this as a valid reply. - return true; - } - - // Remaining checks. The question section in the reply has to be the - // same as in the query. - if reply_hcounts.qdcount() != query.header_counts().qdcount() { - false - } else { - reply.question() == query.question() - } -} diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs deleted file mode 100644 index f11e3cc37..000000000 --- a/src/net/client/udp_tcp.rs +++ /dev/null @@ -1,255 +0,0 @@ -//! A UDP transport that falls back to TCP if the reply is truncated - -#![warn(missing_docs)] -#![warn(clippy::missing_docs_in_private_items)] - -// To do: -// - handle shutdown - -use bytes::Bytes; -use std::boxed::Box; -use std::fmt::Debug; -use std::future::Future; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::Arc; - -use crate::base::Message; -use crate::net::client::multi_stream; -use crate::net::client::protocol::TcpConnect; -use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, SendRequest, -}; -use crate::net::client::udp; - -//------------ Config --------------------------------------------------------- - -/// Configuration for an octet_stream transport connection. -#[derive(Clone, Debug, Default)] -pub struct Config { - /// Configuration for the UDP transport. - pub udp: Option, - - /// Configuration for the multi_stream (TCP) transport. - pub multi_stream: Option, -} - -//------------ Connection ----------------------------------------------------- - -/// DNS transport connection that first issues a query over a UDP transport and -/// falls back to TCP if the reply is truncated. -#[derive(Clone)] -pub struct Connection { - /// Reference to the real object that provides the connection. - inner: Arc>, -} - -impl Connection { - /// Create a new connection. - pub fn new( - config: Option, - remote_addr: SocketAddr, - ) -> Result { - let config = match config { - Some(config) => { - check_config(&config)?; - config - } - None => Default::default(), - }; - let connection = InnerConnection::new(config, remote_addr)?; - Ok(Self { - inner: Arc::new(connection), - }) - } - - /// Worker function for a connection object. - pub fn run( - &self, - ) -> Pin> + Send>> { - self.inner.run() - } - - /// Start a request for the Request trait. - async fn request_impl( - &self, - request_msg: &CR, - ) -> Result, Error> { - let gr = self.inner.request(request_msg).await?; - Ok(Box::new(gr)) - } -} - -impl SendRequest - for Connection -{ - fn send_request<'a>( - &'a self, - request_msg: &'a CR, - ) -> Pin< - Box< - dyn Future, Error>> - + Send - + '_, - >, - > { - return Box::pin(self.request_impl(request_msg)); - } -} - -//------------ ReqResp -------------------------------------------------------- - -/// Object that contains the current state of a query. -#[derive(Debug)] -pub struct ReqResp { - /// Reqeust message. - request_msg: BMB, - - /// UDP transport to be used. - udp_conn: udp::Connection, - - /// TCP transport to be used. - tcp_conn: multi_stream::Connection, - - /// Current state of the request. - state: QueryState, -} - -/// Status of the query. -#[derive(Debug)] -enum QueryState { - /// Start a request over the UDP transport. - StartUdpRequest, - - /// Get the response from the UDP transport. - GetUdpResponse(Box), - - /// Start a request over the TCP transport. - StartTcpRequest, - - /// Get the response from the TCP transport. - GetTcpResponse(Box), -} - -impl ReqResp { - /// Create a new ReqResp object. - /// - /// The initial state is to start with a UDP transport. - fn new( - request_msg: &CR, - udp_conn: udp::Connection, - tcp_conn: multi_stream::Connection, - ) -> ReqResp { - Self { - request_msg: request_msg.clone(), - udp_conn, - tcp_conn, - state: QueryState::StartUdpRequest, - } - } - - /// Get the response of a DNS request. - /// - /// This function is cancel safe. - async fn get_response_impl(&mut self) -> Result, Error> { - loop { - match &mut self.state { - QueryState::StartUdpRequest => { - let msg = self.request_msg.clone(); - let request = self.udp_conn.send_request(&msg).await?; - self.state = QueryState::GetUdpResponse(request); - continue; - } - QueryState::GetUdpResponse(ref mut request) => { - let response = request.get_response().await?; - if response.header().tc() { - self.state = QueryState::StartTcpRequest; - continue; - } - return Ok(response); - } - QueryState::StartTcpRequest => { - let msg = self.request_msg.clone(); - let request = self.tcp_conn.send_request(&msg).await?; - self.state = QueryState::GetTcpResponse(request); - continue; - } - QueryState::GetTcpResponse(ref mut query) => { - let response = query.get_response().await?; - return Ok(response); - } - } - } - } -} - -impl GetResponse - for ReqResp -{ - fn get_response( - &mut self, - ) -> Pin< - Box, Error>> + Send + '_>, - > { - Box::pin(self.get_response_impl()) - } -} - -//------------ InnerConnection ------------------------------------------------ - -/// The actual connection object. -struct InnerConnection { - /// The remote address to connect to. - remote_addr: SocketAddr, - - /// The UDP transport connection. - udp_conn: udp::Connection, - - /// The TCP transport connection. - tcp_conn: multi_stream::Connection, -} - -impl InnerConnection { - /// Create a new InnerConnection object. - /// - /// Create the UDP and TCP connections. Store the remote address because - /// run needs it later. - fn new(config: Config, remote_addr: SocketAddr) -> Result { - let udp_conn = udp::Connection::new(config.udp, remote_addr)?; - let tcp_conn = multi_stream::Connection::new(config.multi_stream)?; - - Ok(Self { - remote_addr, - udp_conn, - tcp_conn, - }) - } - - /// Implementation of the worker function. - /// - /// Create a TCP connect object and pass that to run function - /// of the multi_stream object. - fn run(&self) -> Pin> + Send>> { - let tcp_connect = TcpConnect::new(self.remote_addr); - - let fut = self.tcp_conn.run(tcp_connect); - Box::pin(fut) - } - - /// Implementation of the request function. - /// - /// Just create a ReqResp object with the state it needs. - async fn request(&self, request_msg: &CR) -> Result, Error> { - Ok(ReqResp::new( - request_msg, - self.udp_conn.clone(), - self.tcp_conn.clone(), - )) - } -} - -/// Check if config is valid. -fn check_config(_config: &Config) -> Result<(), Error> { - // Nothing to check at the moment. - Ok(()) -} From 5217e16c359e7bb14ca4a9042bcdc0233cc2cb1f Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 28 Dec 2023 15:32:05 +0100 Subject: [PATCH 100/124] Fmt --- src/resolv/stub/mod.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index a4e2ad43d..8fa57884f 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -19,13 +19,13 @@ use crate::base::message_builder::{ }; use crate::base::name::{ToDname, ToRelativeDname}; use crate::base::question::Question; +use crate::net::client::dgram_stream; use crate::net::client::multi_stream; use crate::net::client::protocol::{TcpConnect, UdpConnect}; use crate::net::client::redundant; use crate::net::client::request::{ ComposeRequest, RequestMessage, SendRequest, }; -use crate::net::client::dgram_stream; use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs}; use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts}; use crate::resolv::lookup::srv::{lookup_srv, FoundSrvs, SrvError}; @@ -164,10 +164,11 @@ impl StubResolver { } else { for s in &self.servers { if let Transport::Udp = s.transport { - let udp_connect = UdpConnect::new(s.addr); - let tcp_connect = TcpConnect::new(s.addr); + let udp_connect = UdpConnect::new(s.addr); + let tcp_connect = TcpConnect::new(s.addr); let udptcp_conn = - dgram_stream::Connection::new(None, udp_connect).unwrap(); + dgram_stream::Connection::new(None, udp_connect) + .unwrap(); // Start the run function on a separate task. let run_fut = udptcp_conn.run(tcp_connect); fut_list_udp_tcp.push(async move { From eb13944e48580cf011bfe0448ba8512c8357efc1 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Thu, 28 Dec 2023 16:17:14 +0100 Subject: [PATCH 101/124] Rename octet_stream to just stream. --- examples/client-transports.rs | 6 +-- src/net/client/mod.rs | 2 +- src/net/client/multi_stream.rs | 38 +++++++++---------- src/net/client/{octet_stream.rs => stream.rs} | 0 tests/net-client.rs | 6 +-- 5 files changed, 26 insertions(+), 26 deletions(-) rename src/net/client/{octet_stream.rs => stream.rs} (100%) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 1cb22b9c2..b27b75606 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -5,7 +5,7 @@ use domain::base::Rtype::Aaaa; use domain::net::client::dgram; use domain::net::client::dgram_stream; use domain::net::client::multi_stream; -use domain::net::client::octet_stream; +use domain::net::client::stream; use domain::net::client::protocol::{TcpConnect, TlsConnect, UdpConnect}; use domain::net::client::redundant; use domain::net::client::request::{RequestMessage, SendRequest}; @@ -37,7 +37,7 @@ async fn main() { let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); let multi_stream_config = multi_stream::Config { - octet_stream: Some(octet_stream::Config { + stream: Some(stream::Config { response_timeout: Duration::from_millis(100), }), }; @@ -216,7 +216,7 @@ async fn main() { } }; - let tcp = octet_stream::Connection::new(None).unwrap(); + let tcp = stream::Connection::new(None).unwrap(); let run_fut = tcp.run(tcp_conn); tokio::spawn(async move { run_fut.await; diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 3720af80b..2b835085b 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -14,7 +14,7 @@ pub mod dgram; pub mod dgram_stream; pub mod multi_stream; -pub mod octet_stream; +pub mod stream; pub mod protocol; pub mod redundant; pub mod request; diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index f523e4bf6..b17a07e51 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -30,7 +30,7 @@ use tokio::time::{sleep_until, Instant}; use crate::base::iana::Rcode; use crate::base::Message; -use crate::net::client::octet_stream; +use crate::net::client::stream; use crate::net::client::protocol::AsyncConnect; use crate::net::client::request::{ ComposeRequest, Error, GetResponse, SendRequest, @@ -45,11 +45,11 @@ const ERR_CONN_CLOSED: &str = "connection closed"; //------------ Config --------------------------------------------------------- -/// Configuration for an octet_stream transport connection. +/// Configuration for an stream transport connection. #[derive(Clone, Debug, Default)] pub struct Config { /// Response timeout. - pub octet_stream: Option, + pub stream: Option, } //------------ Connection ----------------------------------------------------- @@ -155,7 +155,7 @@ pub struct ReqResp { state: QueryState, /// A multi_octet connection object is needed to request new underlying - /// octet_stream transport connections. + /// stream transport connections. conn: Connection, /// id of most recent connection. @@ -170,14 +170,14 @@ pub struct ReqResp { /// Status of a query. Used in [Query]. #[derive(Debug)] enum QueryState { - /// Get a octet_stream transport. + /// Get a stream transport. GetConn(oneshot::Receiver>), /// Start a query using the transport. - StartQuery(octet_stream::Connection), + StartQuery(stream::Connection), /// Get the result of the query. - GetResult(octet_stream::QueryNoCheck), + GetResult(stream::QueryNoCheck), /// Wait until trying again. /// @@ -198,8 +198,8 @@ struct ChanRespOk { /// id of this connection. id: u64, - /// New octet_stream transport. - conn: octet_stream::Connection, + /// New stream transport. + conn: stream::Connection, } impl ReqResp { @@ -355,7 +355,7 @@ struct InnerConnection { } #[derive(Debug)] -/// A request to [Connection::run] either for a new octet_stream or to +/// A request to [Connection::run] either for a new stream or to /// shutdown. struct ChanReq { /// A requests consists of a command. @@ -382,7 +382,7 @@ type ReplySender = oneshot::Sender>; /// the status of the connection. // The types Status and ConnState are only used in InnerConnection struct State3<'a, S, IO, CR> { - /// Underlying octet_stream connection. + /// Underlying stream connection. conn_state: SingleConnState3, /// Current connection id. @@ -392,7 +392,7 @@ struct State3<'a, S, IO, CR> { stream: S, /// Collection of futures for the async run function of the underlying - /// octet_stream. + /// stream. runners: FuturesUnordered< Pin> + Send + 'a>>, >, @@ -401,20 +401,20 @@ struct State3<'a, S, IO, CR> { phantom: PhantomData<&'a IO>, } -/// State of the current underlying octet_stream transport. +/// State of the current underlying stream transport. enum SingleConnState3 { - /// No current octet_stream transport. + /// No current stream transport. None, - /// Current octet_stream transport. - Some(octet_stream::Connection), + /// Current stream transport. + Some(stream::Connection), /// State that deals with an error getting a new octet stream from /// a connection stream. Err(ErrorState), } -/// State associated with a failed attempt to create a new octet_stream +/// State associated with a failed attempt to create a new stream /// transport. #[derive(Clone)] struct ErrorState { @@ -598,7 +598,7 @@ impl InnerConnection { let stream = res_conn .expect("error case is checked before"); - let conn = octet_stream::Connection::new(config.octet_stream.clone())?; + let conn = stream::Connection::new(config.stream.clone())?; let conn_run = conn.clone(); let clo = || async move { @@ -648,7 +648,7 @@ impl InnerConnection { // Avoid new queries drop(receiver); - // Wait for existing octet_stream runners to terminate + // Wait for existing stream runners to terminate while !state.runners.is_empty() { state.runners.next().await; } diff --git a/src/net/client/octet_stream.rs b/src/net/client/stream.rs similarity index 100% rename from src/net/client/octet_stream.rs rename to src/net/client/stream.rs diff --git a/tests/net-client.rs b/tests/net-client.rs index 494d0f431..719fbd96c 100644 --- a/tests/net-client.rs +++ b/tests/net-client.rs @@ -10,7 +10,7 @@ use crate::net::deckard::parse_deckard::parse_file; use domain::net::client::dgram; use domain::net::client::dgram_stream; use domain::net::client::multi_stream; -use domain::net::client::octet_stream; +use domain::net::client::stream; use domain::net::client::redundant; use std::fs::File; use std::net::IpAddr; @@ -44,7 +44,7 @@ fn single() { let step_value = Arc::new(CurrStepValue::new()); let conn = Connection::new(deckard.clone(), step_value.clone()); - let octstr = octet_stream::Connection::new(None).unwrap(); + let octstr = stream::Connection::new(None).unwrap(); let run_fut = octstr.run(conn); tokio::spawn(async move { run_fut.await; @@ -142,7 +142,7 @@ fn tcp() { } }; - let tcp = octet_stream::Connection::new(None).unwrap(); + let tcp = stream::Connection::new(None).unwrap(); let run_fut = tcp.run(tcp_conn); tokio::spawn(async move { run_fut.await; From 059a779cc32b3e7f0826af857de5b529fe5131db Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Thu, 28 Dec 2023 16:32:37 +0100 Subject: [PATCH 102/124] Reflow. --- examples/client-transports.rs | 2 +- src/net/client/mod.rs | 3 +-- src/net/client/multi_stream.rs | 2 +- src/net/mod.rs | 5 ++++- tests/net-client.rs | 2 +- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index b27b75606..f1632e5d1 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -5,10 +5,10 @@ use domain::base::Rtype::Aaaa; use domain::net::client::dgram; use domain::net::client::dgram_stream; use domain::net::client::multi_stream; -use domain::net::client::stream; use domain::net::client::protocol::{TcpConnect, TlsConnect, UdpConnect}; use domain::net::client::redundant; use domain::net::client::request::{RequestMessage, SendRequest}; +use domain::net::client::stream; use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; use std::time::Duration; diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 2b835085b..d1d210fb6 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -7,14 +7,13 @@ #![cfg(feature = "unstable-client-transport")] #![cfg_attr(docsrs, doc(cfg(feature = "unstable-client-transport")))] - #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] pub mod dgram; pub mod dgram_stream; pub mod multi_stream; -pub mod stream; pub mod protocol; pub mod redundant; pub mod request; +pub mod stream; diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index b17a07e51..f553bda77 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -30,11 +30,11 @@ use tokio::time::{sleep_until, Instant}; use crate::base::iana::Rcode; use crate::base::Message; -use crate::net::client::stream; use crate::net::client::protocol::AsyncConnect; use crate::net::client::request::{ ComposeRequest, Error, GetResponse, SendRequest, }; +use crate::net::client::stream; /// Capacity of the channel that transports [ChanReq]. const DEF_CHAN_CAP: usize = 8; diff --git a/src/net/mod.rs b/src/net/mod.rs index 049aed754..5b7e9435b 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -7,7 +7,10 @@ #![cfg_attr(feature = "unstable-client-transport", doc = " [`client`]")] #![cfg_attr(not(feature = "unstable-client-transport"), doc = " `client`")] //! sub-module intended for sending requests and receiving responses to them. -#![cfg_attr(not(feature = "unstable-client-transport"), doc = " The `unstable-client-transport` feature is necessary to enable this module.")] +#![cfg_attr( + not(feature = "unstable-client-transport"), + doc = " The `unstable-client-transport` feature is necessary to enable this module." +)] //! #![cfg(feature = "net")] #![cfg_attr(docsrs, doc(cfg(feature = "net")))] diff --git a/tests/net-client.rs b/tests/net-client.rs index 719fbd96c..92ab1b1c3 100644 --- a/tests/net-client.rs +++ b/tests/net-client.rs @@ -10,8 +10,8 @@ use crate::net::deckard::parse_deckard::parse_file; use domain::net::client::dgram; use domain::net::client::dgram_stream; use domain::net::client::multi_stream; -use domain::net::client::stream; use domain::net::client::redundant; +use domain::net::client::stream; use std::fs::File; use std::net::IpAddr; use std::net::SocketAddr; From bb4a0ed31910bf013b93b7a1f6bee62b72585be3 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 28 Dec 2023 22:17:20 +0100 Subject: [PATCH 103/124] Poll-based AsyncDgramRecv and AsyncDgramSend and helper traits for asynchronous functions in AsyncDgramRecvEx and AsyncDgramSendEx. --- Cargo.toml | 1 + src/net/client/dgram.rs | 22 ++--- src/net/client/dgram_stream.rs | 10 +-- src/net/client/protocol.rs | 157 ++++++++++++++++++++++++--------- tests/net/deckard/dgram.rs | 68 ++++++++------ 5 files changed, 172 insertions(+), 86 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index db1f602b9..eaf1350b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ chrono = { version = "0.4.6", optional = true, default-features = false futures-util = { version = "0.3", optional = true } heapless = { version = "0.7", optional = true } #openssl = { version = "0.10", optional = true } +pin-project-lite = "0.2" ring = { version = "0.17", optional = true } serde = { version = "1.0.130", optional = true, features = ["derive"] } siphasher = { version = "1", optional = true } diff --git a/src/net/client/dgram.rs b/src/net/client/dgram.rs index dcce66304..7f7b793ef 100644 --- a/src/net/client/dgram.rs +++ b/src/net/client/dgram.rs @@ -20,7 +20,8 @@ use tokio::time::{timeout, Duration, Instant}; use crate::base::iana::Rcode; use crate::base::Message; use crate::net::client::protocol::{ - AsyncConnect, AsyncDgramRecv, AsyncDgramSend, + AsyncConnect, AsyncDgramRecv, AsyncDgramRecvEx, AsyncDgramSend, + AsyncDgramSendEx, }; use crate::net::client::request::{ ComposeRequest, Error, GetResponse, SendRequest, @@ -100,7 +101,7 @@ pub struct Connection { impl< S: AsyncConnect + Clone + Send + Sync + 'static, - C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + 'static, + C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, > Connection { /// Create a new datagram transport connection. @@ -140,7 +141,7 @@ impl< impl< S: AsyncConnect + Clone + Send + Sync + 'static, - C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + 'static, + C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, CR: ComposeRequest + Clone + Send + Sync + 'static, > SendRequest for Connection { @@ -171,7 +172,7 @@ impl ReqResp { /// Create new ReqResp object. fn new< S: AsyncConnect + Clone + Send + Sync + 'static, - C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + 'static, + C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, CR: ComposeRequest + Clone + Send + Sync + 'static, >( config: Config, @@ -201,7 +202,7 @@ impl ReqResp { /// This function is not cancel safe. async fn get_response_impl2< S: AsyncConnect + Clone + Send + Sync + 'static, - C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + 'static, + C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, CR: ComposeRequest, >( config: Config, @@ -219,7 +220,7 @@ impl ReqResp { let _permit = conn.get_permit().await; loop { - let sock = connect + let mut sock = connect .connect() .await .map_err(|e| Error::UdpConnect(Arc::new(e)))?; @@ -253,8 +254,8 @@ impl ReqResp { } let remain = config.read_timeout - elapsed; - let buf = vec![0; recv_size]; // XXX use uninit'ed mem here. - let timeout_res = timeout(remain, sock.recv(buf)).await; + let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here. + let timeout_res = timeout(remain, sock.recv(&mut buf)).await; if timeout_res.is_err() { retries += 1; if retries < config.max_retries { @@ -264,9 +265,10 @@ impl ReqResp { } return Err(Error::UdpTimeoutNoResponse); } - let buf = timeout_res + let len = timeout_res .expect("errror case is checked above") .map_err(|e| Error::UdpReceive(Arc::new(e)))?; + buf.truncate(len); // We ignore garbage since there is a timer on this whole // thing. @@ -325,7 +327,7 @@ struct InnerConnection { impl< S: AsyncConnect + Clone + Send + Sync + 'static, - C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + 'static, + C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, > InnerConnection { /// Create new InnerConnection object. diff --git a/src/net/client/dgram_stream.rs b/src/net/client/dgram_stream.rs index 89b96249d..f07afab33 100644 --- a/src/net/client/dgram_stream.rs +++ b/src/net/client/dgram_stream.rs @@ -52,7 +52,7 @@ impl< CR: ComposeRequest + Clone + 'static, > Connection where - S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync, + S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, { /// Create a new connection. pub fn new( @@ -98,7 +98,7 @@ impl< CR: ComposeRequest + Clone + 'static, > SendRequest for Connection where - S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync, + S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, { fn send_request<'a>( &'a self, @@ -174,7 +174,7 @@ impl< /// This function is cancel safe. async fn get_response_impl(&mut self) -> Result, Error> where - S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync, + S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, { loop { match &mut self.state { @@ -212,7 +212,7 @@ impl< CR: ComposeRequest + Clone + Debug + 'static, > GetResponse for ReqResp where - S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync, + S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, { fn get_response( &mut self, @@ -245,7 +245,7 @@ impl< /// run needs it later. fn new(config: Config, dgram_connect: S) -> Result where - S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync, + S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, { let udp_conn = dgram::Connection::new(config.dgram, dgram_connect)?; let tcp_conn = multi_stream::Connection::new(config.multi_stream)?; diff --git a/src/net/client/protocol.rs b/src/net/client/protocol.rs index 566ea6f81..b909c0ded 100644 --- a/src/net/client/protocol.rs +++ b/src/net/client/protocol.rs @@ -2,11 +2,13 @@ use core::future::Future; use core::pin::Pin; +use pin_project_lite::pin_project; use std::boxed::Box; use std::io; use std::net::SocketAddr; use std::sync::Arc; -use std::vec::Vec; +use std::task::{Context, Poll}; +use tokio::io::ReadBuf; use tokio::net::{TcpStream, ToSocketAddrs, UdpSocket}; use tokio_rustls::client::TlsStream; use tokio_rustls::rustls::{ClientConfig, ServerName}; @@ -122,15 +124,66 @@ where //------------ AsyncDgramRecv ------------------------------------------------- -/// Receive a datagram packet asynchronously. +/// Receive a datagram packets asynchronously. /// /// pub trait AsyncDgramRecv { - /// The future performing the receive operation. - type Fut: Future, io::Error>> + Send; + /// Polled receive. + fn poll_recv( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll>; +} + +/// Convenvience trait to turn poll_recv into an asynchronous function. +pub trait AsyncDgramRecvEx: AsyncDgramRecv { + /// Asynchronous receive function. + fn recv<'a>(&'a mut self, buf: &'a mut [u8]) -> DgramRecv<'a, Self> + where + Self: Unpin, + { + recv(self, buf) + } +} + +impl AsyncDgramRecvEx for R {} + +pin_project! { + /// Return value of recv. This captures the future for recv. + pub struct DgramRecv<'a, R: ?Sized> { + receiver: &'a mut R, + buf: &'a mut [u8], + } +} + +impl Future for DgramRecv<'_, R> { + type Output = io::Result; + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let me = self.project(); + let mut buf = ReadBuf::new(me.buf); + match Pin::new(me.receiver).poll_recv(cx, &mut buf) { + Poll::Pending => return Poll::Pending, + Poll::Ready(res) => { + if let Err(err) = res { + return Poll::Ready(Err(err)); + } + } + } + Poll::Ready(Ok(buf.filled().len())) + } +} - /// Returns a future that performs the receive operation. - fn recv(&self, buf: Vec) -> Self::Fut; +/// Helper function for the recv method. +fn recv<'a, R: ?Sized>( + receiver: &'a mut R, + buf: &'a mut [u8], +) -> DgramRecv<'a, R> { + DgramRecv { receiver, buf } } //------------ AsyncDgramSend ------------------------------------------------- @@ -139,11 +192,50 @@ pub trait AsyncDgramRecv { /// /// pub trait AsyncDgramSend { - /// The future performing the send operation. - type Fut: Future> + Send; + /// Polled send function. + fn poll_send( + self: Pin<&Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll>; +} - /// Returns a future that performs the send operation. - fn send(&self, buf: &[u8]) -> Self::Fut; +/// Convenience trait that turns poll_send into an asynchronous function. +pub trait AsyncDgramSendEx: AsyncDgramSend { + /// Asynchronous function to send a packet. + fn send<'a>(&'a self, buf: &'a [u8]) -> DgramSend<'a, Self> + where + Self: Unpin, + { + send(self, buf) + } +} + +impl AsyncDgramSendEx for S {} + +/// This is the return value of send. It captures the future for send. +pub struct DgramSend<'a, S: ?Sized> { + /// The datagram send object. + sender: &'a S, + + /// The buffer that needs to be sent. + buf: &'a [u8], +} + +impl Future for DgramSend<'_, S> { + type Output = io::Result; + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(self.sender).poll_send(cx, self.buf) + } +} + +/// Send helper function to implement the send method of AsyncDgramSendEx. +fn send<'a, S: ?Sized>(sender: &'a S, buf: &'a [u8]) -> DgramSend<'a, S> { + DgramSend { sender, buf } } //------------ UdpConnect -------------------------------------------------- @@ -220,42 +312,21 @@ impl UdpDgram { } impl AsyncDgramRecv for UdpDgram { - type Fut = - Pin, io::Error>> + Send>>; - fn recv(&self, mut buf: Vec) -> Self::Fut { - let sock = self.sock.clone(); - Box::pin(async move { - let len = sock.recv(&mut buf).await?; - buf.truncate(len); - Ok(buf) - }) + fn poll_recv( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.sock.poll_recv(cx, buf) } } impl AsyncDgramSend for UdpDgram { - type Fut = Pin> + Send>>; - fn send(&self, buf: &[u8]) -> Self::Fut { - let sock = self.sock.clone(); - let buf = buf.to_vec(); - Box::pin(async move { sock.send(&buf).await }) - } -} - -/* -struct Sender { - sock: Arc, - buf: Vec -} - -impl Sender { - fn new() -> Self { Self } -} - -impl Future for Sender { - type Output = Result; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> - Poll { - self.sock.poll_send(cx, &self.buf) + fn poll_send( + self: Pin<&Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.sock.poll_send(cx, buf) } } -*/ diff --git a/tests/net/deckard/dgram.rs b/tests/net/deckard/dgram.rs index c902e5473..0156f6335 100644 --- a/tests/net/deckard/dgram.rs +++ b/tests/net/deckard/dgram.rs @@ -10,7 +10,9 @@ use domain::net::client::protocol::{ use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; +use std::sync::Mutex as SyncMutex; +use std::task::{Context, Poll, Waker}; +use tokio::io::ReadBuf; #[derive(Clone, Debug)] pub struct Dgram { @@ -46,51 +48,61 @@ pub struct DgramConnection { deckard: Deckard, step_value: Arc, - sender: mpsc::Sender>>, - receiver: Arc>>>>, + reply: SyncMutex>>>, + waker: SyncMutex>, } impl DgramConnection { fn new(deckard: Deckard, step_value: Arc) -> Self { - let (sender, receiver) = mpsc::channel(2); Self { deckard, step_value, - sender, - receiver: Arc::new(Mutex::new(receiver)), + reply: SyncMutex::new(None), + waker: SyncMutex::new(None), } } } impl AsyncDgramRecv for DgramConnection { - type Fut = - Pin, std::io::Error>> + Send>>; - fn recv(&self, buf: Vec) -> Self::Fut { - let arc_m_rec = self.receiver.clone(); - Box::pin(async move { - let mut rec = arc_m_rec.lock().await; - let msg = (*rec).recv().await.unwrap(); - let msg_octets = msg.into_octets(); - if msg_octets.len() > buf.len() { - panic!("test returned reply that is bigger than buffer"); - } - Ok(msg_octets) - }) + fn poll_recv( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let mut reply = self.reply.lock().unwrap(); + if (*reply).is_some() { + let slice = (*reply).as_ref().unwrap().as_slice(); + buf.put_slice(slice); + *reply = None; + return Poll::Ready(Ok(())); + } + *reply = None; + let mut waker = self.waker.lock().unwrap(); + *waker = Some(cx.waker().clone()); + Poll::Pending } } impl AsyncDgramSend for DgramConnection { - type Fut = - Pin> + Send>>; - fn send(&self, buf: &[u8]) -> Self::Fut { + fn poll_send( + self: Pin<&Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { let msg = Message::from_octets(buf).unwrap(); let opt_reply = do_server(&msg, &self.deckard, &self.step_value); - let sender = self.sender.clone(); let len = buf.len(); - Box::pin(async move { - if opt_reply.is_some() { - sender.send(opt_reply.unwrap()).await.unwrap(); + if opt_reply.is_some() { + // Do we need to support more than one reply? + let mut reply = self.reply.lock().unwrap(); + *reply = opt_reply; + drop(reply); + let mut waker = self.waker.lock().unwrap(); + let opt_waker = (*waker).take(); + drop(waker); + if let Some(waker) = opt_waker { + waker.wake(); } - Ok(len) - }) + } + Poll::Ready(Ok(len)) } } From bacdca9064d698ae848ec2504c4431509c620320 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 29 Dec 2023 14:40:43 +0100 Subject: [PATCH 104/124] More documentation for net::client. --- src/net/client/mod.rs | 98 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index d1d210fb6..980193045 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -1,4 +1,102 @@ //! Sending requests and receiving responses. +//! +//! This module provides DNS transport protocols that allow sending a DNS +//! request and receiving the corresponding reply. +//! +//! Sending a request and receiving the reply consists of four steps: +//! 1) Creating a request message, +//! 2) Creating a DNS transport, +//! 3) Sending the request, and +//! 4) Receiving the reply. +//! +//! The first and second step are independent and happen in any order. +//! The third step uses the resuts of the first and second step. +//! Finally, the fourth step uses the result of the third step. + +//! # Creating a request message +//! +//! The DNS transport protocols expect a request message that implements +//! [ComposeRequest][request::ComposeRequest] trait. +//! This trait allows transports to add ENDS(0) options, set flags, etc. +//! The [RequestMessage][request::RequestMessage] type implements this trait. +//! The [new][request::RequestMessage::new] method of RequestMessage create +//! a new RequestMessage object based an existing messsage (that implements +//! ```Into>```). +//! +//! For example: +//! ```rust +//! let mut msg = MessageBuilder::new_vec(); +//! msg.header_mut().set_rd(true); +//! let mut msg = msg.question(); +//! msg.push((Dname::vec_from_str("example.com").unwrap(), Aaaa)).unwrap(); +//! let req = RequestMessage::new(msg); +//! ``` + +//! # Creating a DNS transport +//! +//! Creating a DNS transport typically involves creating a configuration +//! object, creating the underlying network connection, creating the +//! DNS transport and running a ```run``` method as a separate task. This +//! is illustrated in the following example: +//! ```rust +//! let multi_stream_config = multi_stream::Config { +//! stream: Some(stream::Config { +//! response_timeout: Duration::from_millis(100), +//! }), +//! }; +//! let tcp_connect = TcpConnect::new(server_addr); +//! let tcp_conn = +//! multi_stream::Connection::new(Some(multi_stream_config.clone())) +//! .unwrap(); +//! let run_fut = tcp_conn.run(tcp_connect); +//! tokio::spawn(async move { +//! let res = run_fut.await; +//! }); +//! ``` +//! Note that the run function ends when the last reference to the DNS +//! transport is dropped. For this reason it is important to avoid having a +//! reference to the transport end up in the task. Only pass the future +//! returned by the run function to the task. +//! +//! The currently implemented DNS transport have the following layering. At +//! the lower layer are [dgram] and [stream]. The dgram transport is used for +//! DNS over UDP, the stream transport is used for DNS over a single TCP or +//! TLS connection. The transport works as long as the connection continuous +//! to exist. +//! The [multi_stream] transport is layered on top of stream, and creates new +//! TCP or TLS connections when old ones terminates. +//! Next, [dgram_stream] combines the dgram transport with the multi_stream +//! transport. This is typically needed because a request over UDP can receive +//! a truncated response, which should be retried over TCP. +//! Finally, the [redundant] transport can select the best transport out of +//! a collection of underlying transports. + +//! # Sending the request +//! +//! A DNS transport implements the [SendRequest][request::SendRequest] trait. +//! This trait provides a single method, +//! [send_request][request::SendRequest::send_request] and returns an object +//! that provides the response. +//! +//! For example: +//! ```rust +//! let mut request = tls_conn.send_request(&req).await.unwrap(); +//! ``` +//! where ```tls_conn``` is a transport connection for DNS over TLS. + +//! # Receiving the request +//! +//! The [send_request][request::SendRequest::send_request] method returns an +//! object that implements the [GetResponse][request::GetResponse] trait. +//! This trait provides a single method, +//! [get_response][request::GetResponse::get_response], which returns the +//! DNS response message or an error. This method is intended to be +//! cancelation safe. +//! +//! For example: +//! ```rust +//! let reply = request.get_response().await; +//! ``` //! # Example with various transport connections //! ``` From 598a886882aedb85963c3bde23b047a7876c97ee Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Tue, 2 Jan 2024 15:54:11 +0100 Subject: [PATCH 105/124] Refactor stream transports to provide a use-once run future. --- Cargo.toml | 6 +- examples/client-transports.rs | 72 ++-- src/lib.rs | 2 +- src/net/client/dgram.rs | 221 ++++++------ src/net/client/dgram_stream.rs | 241 +++++++------ src/net/client/multi_stream.rs | 594 +++++++++++++++------------------ src/net/client/stream.rs | 403 +++++++++++----------- src/resolv/stub/mod.rs | 22 +- tests/net-client.rs | 27 +- 9 files changed, 750 insertions(+), 838 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index eaf1350b6..da280d85e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,8 +17,9 @@ name = "domain" path = "src/lib.rs" [dependencies] -octseq = "0.3" -time = "0.3.1" +octseq = "0.3" +pin-project-lite = "0.2" +time = "0.3.1" rand = { version = "0.8", optional = true } bytes = { version = "1.0", optional = true } @@ -26,7 +27,6 @@ chrono = { version = "0.4.6", optional = true, default-features = false futures-util = { version = "0.3", optional = true } heapless = { version = "0.7", optional = true } #openssl = { version = "0.10", optional = true } -pin-project-lite = "0.2" ring = { version = "0.17", optional = true } serde = { version = "1.0.130", optional = true, features = ["derive"] } siphasher = { version = "1", optional = true } diff --git a/examples/client-transports.rs b/examples/client-transports.rs index f1632e5d1..793d9c25a 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -36,38 +36,37 @@ async fn main() { // Destination for UDP and TCP let server_addr = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 53); - let multi_stream_config = multi_stream::Config { - stream: Some(stream::Config { - response_timeout: Duration::from_millis(100), - }), - }; + let mut stream_config = stream::Config::new(); + stream_config.set_response_timeout(Duration::from_millis(100)); + let multi_stream_config = + multi_stream::Config::from(stream_config.clone()); // Create a new UDP+TCP transport connection. Pass the destination address // and port as parameter. - let dgram_config = dgram::Config { - max_parallel: 1, - read_timeout: Duration::from_millis(1000), - max_retries: 1, - udp_payload_size: Some(1400), - }; - let dgram_stream_config = dgram_stream::Config { - dgram: Some(dgram_config.clone()), - multi_stream: Some(multi_stream_config.clone()), - }; + let mut dgram_config = dgram::Config::new(); + dgram_config.set_max_parallel(1); + dgram_config.set_read_timeout(Duration::from_millis(1000)); + dgram_config.set_max_retries(1); + dgram_config.set_udp_payload_size(Some(1400)); + let dgram_stream_config = dgram_stream::Config::from_parts( + dgram_config.clone(), + multi_stream_config.clone(), + ); let udp_connect = UdpConnect::new(server_addr); let tcp_connect = TcpConnect::new(server_addr); - let udptcp_conn = - dgram_stream::Connection::new(Some(dgram_stream_config), udp_connect) - .unwrap(); + let (udptcp_conn, transport) = dgram_stream::Connection::with_config( + udp_connect, + tcp_connect, + dgram_stream_config, + ); // Start the run function in a separate task. The run function will // terminate when all references to the connection have been dropped. // Make sure that the task does not accidentally get a reference to the // connection. - let run_fut = udptcp_conn.run(tcp_connect); tokio::spawn(async move { - let res = run_fut.await; - println!("UDP+TCP run exited with {:?}", res); + transport.run().await; + println!("UDP+TCP run exited"); }); // Send a query message. @@ -88,16 +87,16 @@ async fn main() { // A muli_stream transport connection sets up new TCP connections when // needed. - let tcp_conn = - multi_stream::Connection::new(Some(multi_stream_config.clone())) - .unwrap(); + let (tcp_conn, transport) = multi_stream::Connection::with_config( + tcp_connect, + multi_stream_config.clone(), + ); // Get a future for the run function. The run function receives // the connection stream as a parameter. - let run_fut = tcp_conn.run(tcp_connect); tokio::spawn(async move { - let res = run_fut.await; - println!("multi TCP run exited with {:?}", res); + transport.run().await; + println!("multi TCP run exited"); }); // Send a query message. @@ -144,14 +143,15 @@ async fn main() { ); // Again create a multi_stream transport connection. - let tls_conn = - multi_stream::Connection::new(Some(multi_stream_config)).unwrap(); + let (tls_conn, transport) = multi_stream::Connection::with_config( + tls_connect, + multi_stream_config, + ); // Start the run function. - let run_fut = tls_conn.run(tls_connect); tokio::spawn(async move { - let res = run_fut.await; - println!("TLS run exited with {:?}", res); + transport.run().await; + println!("TLS run exited"); }); let mut request = tls_conn.send_request(&req).await.unwrap(); @@ -193,8 +193,7 @@ async fn main() { // reply is truncated. This transport does not have a separate run // function. let udp_connect = UdpConnect::new(server_addr); - let dgram_conn = - dgram::Connection::new(Some(dgram_config), udp_connect).unwrap(); + let dgram_conn = dgram::Connection::new(Some(dgram_config), udp_connect); // Send a query message. let mut request = dgram_conn.send_request(&req).await.unwrap(); @@ -216,10 +215,9 @@ async fn main() { } }; - let tcp = stream::Connection::new(None).unwrap(); - let run_fut = tcp.run(tcp_conn); + let (tcp, transport) = stream::Connection::new(tcp_conn); tokio::spawn(async move { - run_fut.await; + transport.run().await; println!("single TCP run terminated"); }); diff --git a/src/lib.rs b/src/lib.rs index e628d5e7f..af99c537e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -118,7 +118,7 @@ //! features, it is best to specify a concrete version as a dependency in //! `Cargo.toml` using the `=` operator, e.g.: //! -//! ``` +//! ```text //! [dependencies] //! domain = "=0.9.3" //! ``` diff --git a/src/net/client/dgram.rs b/src/net/client/dgram.rs index 7f7b793ef..6f52e7164 100644 --- a/src/net/client/dgram.rs +++ b/src/net/client/dgram.rs @@ -6,17 +6,6 @@ // To do: // - cookies -use bytes::Bytes; -use octseq::Octets; -use std::boxed::Box; -use std::fmt::{Debug, Formatter}; -use std::future::Future; -use std::io::ErrorKind; -use std::pin::Pin; -use std::sync::Arc; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; -use tokio::time::{timeout, Duration, Instant}; - use crate::base::iana::Rcode; use crate::base::Message; use crate::net::client::protocol::{ @@ -26,38 +15,33 @@ use crate::net::client::protocol::{ use crate::net::client::request::{ ComposeRequest, Error, GetResponse, SendRequest, }; +use bytes::Bytes; +use core::cmp; +use octseq::Octets; +use std::boxed::Box; +use std::fmt::{Debug, Formatter}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio::time::{timeout, Duration, Instant}; -/// Default configuration value for the maximum number of parallel DNS query -/// over a single datagram transport connection. -const DEF_MAX_PARALLEL: usize = 100; - -/// Minimum configuration value for max_parallel. -const MIN_MAX_PARALLEL: usize = 1; - -/// Maximum configuration value for max_parallel. -const MAX_MAX_PARALLEL: usize = 1000; - -/// Default configuration value for the maximum amount of time to wait for a -/// reply. -const DEF_READ_TIMEOUT: Duration = Duration::from_secs(5); - -/// Minimum configuration value for read_timeout. -const MIN_READ_TIMEOUT: Duration = Duration::from_millis(1); - -/// Maximum configuration value for read_timeout. -const MAX_READ_TIMEOUT: Duration = Duration::from_secs(60); +//------------ Configuration Constants ---------------------------------------- -/// Default configuration value for maximum number of retries after timeouts. -const DEF_MAX_RETRIES: u8 = 5; +/// Configuration limits for the maximum number of parallel requests. +const MAX_PARALLEL: DefMinMax = DefMinMax::new(100, 1, 1000); -/// Minimum allowed configuration value for max_retries. -const MIN_MAX_RETRIES: u8 = 1; +/// Configuration limits for the read timeout. +const READ_TIMEOUT: DefMinMax = DefMinMax::new( + Duration::from_secs(5), + Duration::from_millis(1), + Duration::from_secs(60), +); -/// Maximum allowed configuration value for max_retries. -const MAX_MAX_RETRIES: u8 = 100; +/// Configuration limits for the maximum number of retries. +const MAX_RETRIES: DefMinMax = DefMinMax::new(5, 1, 100); -/// Default UDP payload size. See draft-ietf-dnsop-avoid-fragmentation-15 -/// for discussion. +/// Default UDP payload size. const DEF_UDP_PAYLOAD_SIZE: u16 = 1232; //------------ Config --------------------------------------------------------- @@ -66,25 +50,86 @@ const DEF_UDP_PAYLOAD_SIZE: u16 = 1232; #[derive(Clone, Debug)] pub struct Config { /// Maximum number of parallel requests for a transport connection. - pub max_parallel: usize, + max_parallel: usize, /// Read timeout. - pub read_timeout: Duration, + read_timeout: Duration, /// Maimum number of retries. - pub max_retries: u8, + max_retries: u8, /// EDNS(0) UDP payload size. Set this value to None to be able to create /// a DNS request without ENDS(0) option. - pub udp_payload_size: Option, + udp_payload_size: Option, +} + +impl Config { + /// Creates a new config with default values. + pub fn new() -> Self { + Default::default() + } + + /// Returns the maximum number of parallel requests. + /// + /// Once this many number of requests are currently outstanding, + /// additional requests will wait. + pub fn max_parallel(&self) -> usize { + self.max_parallel + } + + /// Sets the maximum number of parallel requests. + /// + /// If this value is too small or too large, it will be caped. + pub fn set_max_parallel(&mut self, value: usize) { + self.max_parallel = MAX_PARALLEL.limit(value) + } + + /// Returns the read timeout. + /// + /// The read timeout is the maximum amount of time to wait for any + /// response after a request was sent. + pub fn read_timeout(&self) -> Duration { + self.read_timeout + } + + /// Sets the read timeout. + /// + /// If this value is too small or too large, it will be caped. + pub fn set_read_timeout(&mut self, value: Duration) { + self.read_timeout = READ_TIMEOUT.limit(value) + } + + /// Returns the maximum number a request is retried before giving up. + pub fn max_retries(&self) -> u8 { + self.max_retries + } + + /// Sets the maximum number of request retries. + /// + /// If this value is too small or too large, it will be caped. + pub fn set_max_retries(&mut self, value: u8) { + self.max_retries = MAX_RETRIES.limit(value) + } + + /// Returns the UDP payload size. + /// + /// See draft-ietf-dnsop-avoid-fragmentation-15 for a discussion. + pub fn udp_payload_size(&self) -> Option { + self.udp_payload_size + } + + /// Sets the UDP payload size. + pub fn set_udp_payload_size(&mut self, value: Option) { + self.udp_payload_size = value; + } } impl Default for Config { fn default() -> Self { Self { - max_parallel: DEF_MAX_PARALLEL, - read_timeout: DEF_READ_TIMEOUT, - max_retries: DEF_MAX_RETRIES, + max_parallel: MAX_PARALLEL.default(), + read_timeout: READ_TIMEOUT.default(), + max_retries: MAX_RETRIES.default(), udp_payload_size: Some(DEF_UDP_PAYLOAD_SIZE), } } @@ -94,7 +139,7 @@ impl Default for Config { /// A datagram transport connection. #[derive(Clone, Debug)] -pub struct Connection { +pub struct Connection { /// Reference to the actual connection object. inner: Arc>, } @@ -105,21 +150,12 @@ impl< > Connection { /// Create a new datagram transport connection. - pub fn new( - config: Option, - connect: S, - ) -> Result, Error> { - let config = match config { - Some(config) => { - check_config(&config)?; - config - } - None => Default::default(), - }; - let connection = InnerConnection::new(config, connect)?; - Ok(Self { + pub fn new(config: Option, connect: S) -> Connection { + let connection = + InnerConnection::new(config.unwrap_or_default(), connect); + Self { inner: Arc::new(connection), - }) + } } /// Start a new DNS request. @@ -314,7 +350,7 @@ impl GetResponse for ReqResp { /// Actual implementation of the datagram transport connection. #[derive(Debug)] -struct InnerConnection { +struct InnerConnection { /// User configuration variables. config: Config, @@ -331,13 +367,13 @@ impl< > InnerConnection { /// Create new InnerConnection object. - fn new(config: Config, connect: S) -> Result, Error> { + fn new(config: Config, connect: S) -> InnerConnection { let max_parallel = config.max_parallel; - Ok(Self { + Self { config, connect, semaphore: Arc::new(Semaphore::new(max_parallel)), - }) + } } /// Return a Query object that contains the query state. @@ -367,37 +403,6 @@ impl< //------------ Utility -------------------------------------------------------- -/// Check if config is valid. -fn check_config(config: &Config) -> Result<(), Error> { - if config.max_parallel < MIN_MAX_PARALLEL - || config.max_parallel > MAX_MAX_PARALLEL - { - return Err(Error::UdpConfigError(Arc::new(std::io::Error::new( - ErrorKind::Other, - "max_parallel", - )))); - } - - if config.read_timeout < MIN_READ_TIMEOUT - || config.read_timeout > MAX_READ_TIMEOUT - { - return Err(Error::UdpConfigError(Arc::new(std::io::Error::new( - ErrorKind::Other, - "read_timeout", - )))); - } - - if config.max_retries < MIN_MAX_RETRIES - || config.max_retries > MAX_MAX_RETRIES - { - return Err(Error::UdpConfigError(Arc::new(std::io::Error::new( - ErrorKind::Other, - "max_retries", - )))); - } - Ok(()) -} - /// Check if a message is a valid reply for a query. Allow the question section /// to be empty if there is an error or if the reply is truncated. fn is_answer< @@ -436,3 +441,29 @@ fn is_answer< reply.question() == query.question() } } + +//------------ DefMinMax ----------------------------------------------------- + +#[derive(Clone, Copy)] +struct DefMinMax { + def: T, + min: T, + max: T, +} + +impl DefMinMax { + const fn new(def: T, min: T, max: T) -> Self { + Self { def, min, max } + } + + fn default(self) -> T { + self.def + } + + fn limit(self, value: T) -> T + where + T: Ord, + { + cmp::max(self.min, cmp::min(self.max, value)) + } +} diff --git a/src/net/client/dgram_stream.rs b/src/net/client/dgram_stream.rs index f07afab33..df59111a6 100644 --- a/src/net/client/dgram_stream.rs +++ b/src/net/client/dgram_stream.rs @@ -6,13 +6,6 @@ // To do: // - handle shutdown -use bytes::Bytes; -use std::boxed::Box; -use std::fmt::Debug; -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; - use crate::base::Message; use crate::net::client::dgram; use crate::net::client::multi_stream; @@ -22,8 +15,11 @@ use crate::net::client::protocol::{ use crate::net::client::request::{ ComposeRequest, Error, GetResponse, SendRequest, }; - -use tokio::io::{AsyncRead, AsyncWrite}; +use bytes::Bytes; +use std::boxed::Box; +use std::fmt::Debug; +use std::future::Future; +use std::pin::Pin; //------------ Config --------------------------------------------------------- @@ -31,10 +27,58 @@ use tokio::io::{AsyncRead, AsyncWrite}; #[derive(Clone, Debug, Default)] pub struct Config { /// Configuration for the UDP transport. - pub dgram: Option, + dgram: dgram::Config, /// Configuration for the multi_stream (TCP) transport. - pub multi_stream: Option, + multi_stream: multi_stream::Config, +} + +impl Config { + /// Creates a new config with default values. + pub fn new() -> Self { + Default::default() + } + + /// Creates a new config from the two portions. + pub fn from_parts( + dgram: dgram::Config, + multi_stream: multi_stream::Config, + ) -> Self { + Self { + dgram, + multi_stream, + } + } + + /// Returns the datagram config. + pub fn dgram(&self) -> &dgram::Config { + &self.dgram + } + + /// Returns a mutable reference to the datagram config. + pub fn dgram_mut(&mut self) -> &mut dgram::Config { + &mut self.dgram + } + + /// Sets the datagram config. + pub fn set_dgram(&mut self, dgram: dgram::Config) { + self.dgram = dgram + } + + /// Returns the stream config. + pub fn stream(&self) -> &multi_stream::Config { + &self.multi_stream + } + + /// Returns a mutable reference to the stream config. + pub fn stream_mut(&mut self) -> &mut multi_stream::Config { + &mut self.multi_stream + } + + /// Sets the stream config. + pub fn set_stream(&mut self, stream: multi_stream::Config) { + self.multi_stream = stream + } } //------------ Connection ----------------------------------------------------- @@ -42,67 +86,72 @@ pub struct Config { /// DNS transport connection that first issues a query over a UDP transport and /// falls back to TCP if the reply is truncated. #[derive(Clone)] -pub struct Connection { - /// Reference to the real object that provides the connection. - inner: Arc>, +pub struct Connection { + /// The UDP transport connection. + udp_conn: dgram::Connection, + + /// The TCP transport connection. + tcp_conn: multi_stream::Connection, } -impl< - S: AsyncConnect + Clone + Debug + Send + Sync + 'static, - CR: ComposeRequest + Clone + 'static, - > Connection +impl Connection where - S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, + DgramS: AsyncConnect + Clone + Send + Sync + 'static, + DgramS::Connection: + AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, { - /// Create a new connection. - pub fn new( - config: Option, - dgram_connect: S, - ) -> Result { - let config = match config { - Some(config) => { - check_config(&config)?; - config - } - None => Default::default(), - }; - let connection = InnerConnection::new(config, dgram_connect)?; - Ok(Self { - inner: Arc::new(connection), - }) + /// Creates a new multi-stream transport with default configuration. + pub fn new( + dgram_remote: DgramS, + stream_remote: StreamS, + ) -> (Self, multi_stream::Transport) { + Self::with_config(dgram_remote, stream_remote, Default::default()) } - /// Worker function for a connection object. - pub fn run( - &self, - stream_connect: SC, - ) -> Pin> + Send>> - where - SC::Connection: AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, - { - self.inner.run(stream_connect) + /// Creates a new multi-stream transport. + pub fn with_config( + dgram_remote: DgramS, + stream_remote: StreamS, + config: Config, + ) -> (Self, multi_stream::Transport) { + let udp_conn = + dgram::Connection::new(Some(config.dgram), dgram_remote); + let (tcp_conn, transport) = multi_stream::Connection::with_config( + stream_remote, + config.multi_stream, + ); + (Self { udp_conn, tcp_conn }, transport) } +} +impl Connection +where + DgramS: AsyncConnect + Clone + Debug + Send + Sync + 'static, + DgramS::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, + Req: ComposeRequest + Clone + 'static, +{ /// Start a request for the Request trait. async fn request_impl( &self, - request_msg: &CR, + request_msg: &Req, ) -> Result, Error> { - let gr = self.inner.request(request_msg).await?; - Ok(Box::new(gr)) + Ok(Box::new(ReqResp::new( + request_msg, + self.udp_conn.clone(), + self.tcp_conn.clone(), + ))) } } -impl< - S: AsyncConnect + Clone + Debug + Send + Sync + 'static, - CR: ComposeRequest + Clone + 'static, - > SendRequest for Connection +impl SendRequest for Connection where - S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, + DgramS: AsyncConnect + Clone + Debug + Send + Sync + 'static, + DgramS::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, + Req: ComposeRequest + Clone + 'static, { fn send_request<'a>( &'a self, - request_msg: &'a CR, + request_msg: &'a Req, ) -> Pin< Box< dyn Future, Error>> @@ -118,15 +167,15 @@ where /// Object that contains the current state of a query. #[derive(Debug)] -pub struct ReqResp { +pub struct ReqResp { /// Reqeust message. - request_msg: BMB, + request_msg: Req, /// UDP transport to be used. udp_conn: dgram::Connection, /// TCP transport to be used. - tcp_conn: multi_stream::Connection, + tcp_conn: multi_stream::Connection, /// Current state of the request. state: QueryState, @@ -150,17 +199,17 @@ enum QueryState { impl< S: AsyncConnect + Clone + Send + Sync + 'static, - CR: ComposeRequest + Clone + 'static, - > ReqResp + Reg: ComposeRequest + Clone + 'static, + > ReqResp { /// Create a new ReqResp object. /// /// The initial state is to start with a UDP transport. fn new( - request_msg: &CR, + request_msg: &Reg, udp_conn: dgram::Connection, - tcp_conn: multi_stream::Connection, - ) -> ReqResp { + tcp_conn: multi_stream::Connection, + ) -> ReqResp { Self { request_msg: request_msg.clone(), udp_conn, @@ -209,8 +258,8 @@ impl< impl< S: AsyncConnect + Clone + Debug + Send + Sync + 'static, - CR: ComposeRequest + Clone + Debug + 'static, - > GetResponse for ReqResp + Reg: ComposeRequest + Clone + Debug + 'static, + > GetResponse for ReqResp where S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, { @@ -222,69 +271,3 @@ where Box::pin(self.get_response_impl()) } } - -//------------ InnerConnection ------------------------------------------------ - -/// The actual connection object. -struct InnerConnection { - /// The UDP transport connection. - udp_conn: dgram::Connection, - - /// The TCP transport connection. - tcp_conn: multi_stream::Connection, -} - -impl< - S: AsyncConnect + Clone + Send + Sync + 'static, - CR: ComposeRequest + Clone + 'static, - > InnerConnection -{ - /// Create a new InnerConnection object. - /// - /// Create the UDP and TCP connections. Store the remote address because - /// run needs it later. - fn new(config: Config, dgram_connect: S) -> Result - where - S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, - { - let udp_conn = dgram::Connection::new(config.dgram, dgram_connect)?; - let tcp_conn = multi_stream::Connection::new(config.multi_stream)?; - - Ok(Self { udp_conn, tcp_conn }) - } - - /// Implementation of the worker function. - /// - /// Create a TCP connect object and pass that to run function - /// of the multi_stream object. - fn run( - &self, - stream_connect: SC, - ) -> Pin> + Send>> - where - SC::Connection: AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, - { - let fut = self.tcp_conn.run(stream_connect); - Box::pin(fut) - } - - /// Implementation of the request function. - /// - /// Just create a ReqResp object with the state it needs. - async fn request( - &self, - request_msg: &CR, - ) -> Result, Error> { - Ok(ReqResp::new( - request_msg, - self.udp_conn.clone(), - self.tcp_conn.clone(), - )) - } -} - -/// Check if config is valid. -fn check_config(_config: &Config) -> Result<(), Error> { - // Nothing to check at the moment. - Ok(()) -} diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index f553bda77..a9d57747d 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -6,128 +6,154 @@ // To do: // - too many connection errors +use crate::base::iana::Rcode; +use crate::base::Message; +use crate::net::client::protocol::AsyncConnect; +use crate::net::client::request::{ + ComposeRequest, Error, GetResponse, SendRequest, +}; +use crate::net::client::stream; use bytes::Bytes; - use futures_util::stream::FuturesUnordered; use futures_util::StreamExt; - use octseq::Octets; - use rand::random; - use std::boxed::Box; use std::fmt::Debug; use std::future::Future; -use std::marker::PhantomData; use std::pin::Pin; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::Duration; - use tokio::io; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::{mpsc, oneshot}; use tokio::time::{sleep_until, Instant}; -use crate::base::iana::Rcode; -use crate::base::Message; -use crate::net::client::protocol::AsyncConnect; -use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, SendRequest, -}; -use crate::net::client::stream; +//------------ Constants ----------------------------------------------------- -/// Capacity of the channel that transports [ChanReq]. +/// Capacity of the channel that transports `ChanReq`. const DEF_CHAN_CAP: usize = 8; -/// Error reported when the connection is closed and -/// [InnerConnection::run] terminated. +/// Error messafe when the connection is closed. const ERR_CONN_CLOSED: &str = "connection closed"; //------------ Config --------------------------------------------------------- -/// Configuration for an stream transport connection. +/// Configuration for an multi-stream transport. #[derive(Clone, Debug, Default)] pub struct Config { - /// Response timeout. - pub stream: Option, + /// Configuration of the underlying stream transport. + stream: stream::Config, +} + +impl Config { + /// Returns the underlying stream config. + pub fn stream(&self) -> &stream::Config { + &self.stream + } + + /// Returns a mutable reference to the underlying stream config. + pub fn stream_mut(&mut self) -> &mut stream::Config { + &mut self.stream + } +} + +impl From for Config { + fn from(stream: stream::Config) -> Self { + Self { stream } + } } //------------ Connection ----------------------------------------------------- +/// A connection to a multi-stream transport. #[derive(Clone, Debug)] -/// A DNS over octect streams transport. -pub struct Connection { - /// Reference counted [InnerConnection]. - inner: Arc>, +pub struct Connection { + /// The sender half of the connection request channel. + sender: Arc>>, } -impl Connection { - /// Constructor for [Connection]. - /// - /// Returns a [Connection] wrapped in a [Result](io::Result). - pub fn new(config: Option) -> Result { - let config = match config { - Some(config) => { - check_config(&config)?; - config - } - None => Default::default(), - }; - let connection = InnerConnection::new(config)?; - Ok(Self { - inner: Arc::new(connection), - }) +impl Connection { + /// Creates a new multi-stream transport with default configuration. + pub fn new(remote: Remote) -> (Self, Transport) { + Self::with_config(remote, Default::default()) } - /// Main execution function for [Connection]. - /// - /// This function has to run in the background or together with - /// any calls to [query](Self::query) or [ReqResp::get_response]. - pub fn run< - S: AsyncConnect + Send + 'static, - C: 'static + AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, - >( - &self, - stream: S, - ) -> Pin> + Send>> { - self.inner.run(stream) + /// Creates a new multi-stream transport. + pub fn with_config( + remote: Remote, + config: Config, + ) -> (Self, Transport) { + let (sender, transport) = Transport::new(remote, config); + ( + Self { + sender: sender.into(), + }, + transport, + ) } +} - /// Start a DNS request. +impl Connection { + /// Starts a request. /// - /// This function takes a precomposed message as a parameter and - /// returns a [ReqResp] object wrapped in a [Result]. - async fn query_impl( + /// This is the future that is returned by the `SendRequest` impl. + async fn _send_request( &self, - query_msg: &CR, + request: &Req, ) -> Result, Error> { let (tx, rx) = oneshot::channel(); - self.inner.new_conn(None, tx).await?; - let gr = ReqResp::::new(self.clone(), query_msg, rx); + self.new_conn(None, tx).await?; + let gr = Query::new(self.clone(), request.clone(), rx); Ok(Box::new(gr)) } - /// Shutdown this transport. - pub async fn shutdown(&self) -> Result<(), &'static str> { - self.inner.shutdown().await - } - /// Request a new connection. async fn new_conn( &self, - id: u64, - tx: oneshot::Sender>, + opt_id: Option, + sender: oneshot::Sender>, ) -> Result<(), Error> { - self.inner.new_conn(Some(id), tx).await + let req = ChanReq { + cmd: ReqCmd::NewConn(opt_id, sender), + }; + match self.sender.send(req).await { + Err(_) => + // Send error. The receiver is gone, this means that the + // connection is closed. + { + Err(Error::ConnectionClosed) + } + Ok(_) => Ok(()), + } + } + + /// Request a shutdown. + pub async fn shutdown(&self) -> Result<(), &'static str> { + let req = ChanReq { + cmd: ReqCmd::Shutdown, + }; + match self.sender.send(req).await { + Err(_) => + // Send error. The receiver is gone, this means that the + // connection is closed. + { + Err(ERR_CONN_CLOSED) + } + Ok(_) => Ok(()), + } } } -impl SendRequest - for Connection +//--- SendRequest + +impl SendRequest for Connection +where + Req: ComposeRequest + Clone + 'static, { fn send_request<'a>( &'a self, - request_msg: &'a CR, + request: &'a Req, ) -> Pin< Box< dyn Future, Error>> @@ -135,90 +161,86 @@ impl SendRequest + '_, >, > { - return Box::pin(self.query_impl(request_msg)); + return Box::pin(self._send_request(request)); } } -//------------ ReqResp -------------------------------------------------------- +//------------ Query -------------------------------------------------------- -/// This struct represent an active DNS request. +/// The connection side of an active request. #[derive(Debug)] -pub struct ReqResp { - /// Request message. +struct Query { + /// The request message. /// - /// The reply message is compared with the request message to see if - /// it matches the query. - // query_msg: Message>, - request_msg: CR, + /// It is kept so we can compare a response with it. + request_msg: Req, /// Current state of the query. - state: QueryState, + state: QueryState, - /// A multi_octet connection object is needed to request new underlying - /// stream transport connections. - conn: Connection, + /// The underlying transport. + conn: Connection, - /// id of most recent connection. + /// The id of the most recent connection. conn_id: u64, - // /// Number of retries without delay. - // imm_retry_count: u16, /// Number of retries with delay. delayed_retry_count: u64, } -/// Status of a query. Used in [Query]. +/// The states of the query state machine. #[derive(Debug)] -enum QueryState { - /// Get a stream transport. - GetConn(oneshot::Receiver>), +enum QueryState { + /// Receive a new connection from the receiver. + GetConn(oneshot::Receiver>), - /// Start a query using the transport. - StartQuery(stream::Connection), + /// Start a query using the given stream transport. + StartQuery(Arc>), /// Get the result of the query. GetResult(stream::QueryNoCheck), /// Wait until trying again. /// - /// The instant represents when the error occured, the duration how + /// The instant represents when the error occurred, the duration how /// long to wait. Delay(Instant, Duration), - /// The response has been received and the query is done. + /// A response has been received and the query is done. Done, } -/// The reply to a NewConn request. -type ChanResp = Result, Arc>; +/// The response to a connection request. +type ChanResp = Result, Arc>; -/// Response to the DNS request sent by [InnerConnection::run] to [Query]. +/// The successful response to a connection request. #[derive(Debug)] -struct ChanRespOk { - /// id of this connection. +struct ChanRespOk { + /// The id of this connection. id: u64, - /// New stream transport. - conn: stream::Connection, + /// The new stream transport to use for sending a request. + conn: Arc>, } -impl ReqResp { - /// Constructor for [ReqResp], takes a DNS request and a receiver for the - /// reply. +impl Query { + /// Creates a new query. fn new( - conn: Connection, - request_msg: &CR, - receiver: oneshot::Receiver>, - ) -> ReqResp { + conn: Connection, + request_msg: Req, + receiver: oneshot::Receiver>, + ) -> Self { Self { conn, - request_msg: request_msg.clone(), + request_msg: request_msg, state: QueryState::GetConn(receiver), conn_id: 0, delayed_retry_count: 0, } } +} +impl Query { /// Get the result of a DNS request. /// /// This function returns the reply to a DNS request wrapped in a @@ -229,13 +251,14 @@ impl ReqResp { loop { match self.state { QueryState::GetConn(ref mut receiver) => { - let res = receiver.await; - if res.is_err() { - // Assume receive error - self.state = QueryState::Done; - return Err(Error::StreamReceiveError); - } - let res = res.expect("error is checked before"); + let res = match receiver.await { + Ok(res) => res, + Err(_) => { + // Assume receive error + self.state = QueryState::Done; + return Err(Error::StreamReceiveError); + } + }; // Another Result. This time from executing the request match res { @@ -264,10 +287,8 @@ impl ReqResp { Err(err) => { if let Error::ConnectionClosed = err { let (tx, rx) = oneshot::channel(); - let res = self - .conn - .new_conn(self.conn_id, tx) - .await; + let res = + self.new_conn(self.conn_id, tx).await; if let Err(err) = res { self.state = QueryState::Done; return Err(err); @@ -305,7 +326,7 @@ impl ReqResp { QueryState::Delay(instant, duration) => { sleep_until(instant + duration).await; let (tx, rx) = oneshot::channel(); - let res = self.conn.new_conn(self.conn_id, tx).await; + let res = self.new_conn(self.conn_id, tx).await; if let Err(err) = res { self.state = QueryState::Done; return Err(err); @@ -319,9 +340,18 @@ impl ReqResp { } } } + + /// Requests a new connection. + async fn new_conn( + &self, + id: u64, + tx: oneshot::Sender>, + ) -> Result<(), Error> { + self.conn.new_conn(Some(id), tx).await + } } -impl GetResponse for ReqResp { +impl GetResponse for Query { fn get_response( &mut self, ) -> Pin< @@ -331,83 +361,59 @@ impl GetResponse for ReqResp { } } -//------------ InnerConnection ------------------------------------------------ +//------------ Transport ------------------------------------------------ /// The actual implementation of [Connection]. #[derive(Debug)] -struct InnerConnection { +pub struct Transport { /// User configuration values. config: Config, - /// [InnerConnection::sender] and [InnerConnection::receiver] are - /// part of a single channel. - /// - /// Used by [ReqResp] to send requests to [InnerConnection::run]. - sender: mpsc::Sender>, + /// The remote destination. + stream: Remote, - /// receiver part of the channel. - /// - /// Protected by a mutex to allow read/write access by - /// [InnerConnection::run]. - /// The Option is to allow [InnerConnection::run] to signal that the - /// connection is closed. - receiver: Mutex>>>, + /// Underlying stream connection. + conn_state: SingleConnState3, + + /// Current connection id. + conn_id: u64, + + /// Receiver part of the channel. + receiver: mpsc::Receiver>, } #[derive(Debug)] /// A request to [Connection::run] either for a new stream or to /// shutdown. -struct ChanReq { +struct ChanReq { /// A requests consists of a command. - cmd: ReqCmd, + cmd: ReqCmd, } #[derive(Debug)] /// Commands that can be requested. -enum ReqCmd { +enum ReqCmd { /// Request for a (new) connection. /// /// The id of the previous connection (if any) is passed as well as a /// channel to send the reply. - NewConn(Option, ReplySender), + NewConn(Option, ReplySender), /// Shutdown command. Shutdown, } /// This is the type of sender in [ReqCmd]. -type ReplySender = oneshot::Sender>; - -/// Internal datastructure of [InnerConnection::run] to keep track of -/// the status of the connection. -// The types Status and ConnState are only used in InnerConnection -struct State3<'a, S, IO, CR> { - /// Underlying stream connection. - conn_state: SingleConnState3, - - /// Current connection id. - conn_id: u64, - - /// Connection stream for new octet streams. - stream: S, - - /// Collection of futures for the async run function of the underlying - /// stream. - runners: FuturesUnordered< - Pin> + Send + 'a>>, - >, - - /// Phantom data for type IO - phantom: PhantomData<&'a IO>, -} +type ReplySender = oneshot::Sender>; /// State of the current underlying stream transport. -enum SingleConnState3 { +#[derive(Debug)] +enum SingleConnState3 { /// No current stream transport. None, /// Current stream transport. - Some(stream::Connection), + Some(Arc>), /// State that deals with an error getting a new octet stream from /// a connection stream. @@ -416,7 +422,7 @@ enum SingleConnState3 { /// State associated with a failed attempt to create a new stream /// transport. -#[derive(Clone)] +#[derive(Clone, Debug)] struct ErrorState { /// The error we got from the most recent attempt. error: Arc, @@ -431,68 +437,43 @@ struct ErrorState { timeout: Duration, } -impl InnerConnection { - /// Constructor for [InnerConnection]. - /// - /// This is the implementation of [Connection::new]. - pub fn new(config: Config) -> Result { - let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); - Ok(Self { - config, - sender: tx, - receiver: Mutex::new(Some(rx)), - }) - } - - /// Main execution function for [InnerConnection]. - /// - /// This function Gets called by [Connection::run]. - /// This function is not async cancellation safe. - /// Make sure the resulting future does not contain a reference to self. - pub fn run< - S: AsyncConnect + Send + 'static, - C: 'static + AsyncRead + AsyncWrite + Debug + Send + Sync + Unpin, - >( - &self, - stream: S, - ) -> Pin> + Send>> { - let mut receiver = self.receiver.lock().unwrap(); - let opt_receiver = receiver.take(); - drop(receiver); - - Box::pin(Self::run_impl(self.config.clone(), stream, opt_receiver)) +impl Transport { + /// Creates a new transport. + fn new( + stream: Remote, + config: Config, + ) -> (mpsc::Sender>, Self) { + let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP); + ( + sender, + Self { + config, + stream, + conn_state: SingleConnState3::None, + conn_id: 0, + receiver, + }, + ) } +} - /// Implementation of the run method. This function does not have - /// a reference to self. - #[rustfmt::skip] - async fn run_impl< - 'a, - S: AsyncConnect + Send, - C: 'static + AsyncRead + AsyncWrite + Debug + Send + Unpin, - >( - config: Config, - stream: S, - opt_receiver: Option>> - ) -> Result<(), Error> { - let mut receiver = { - opt_receiver.expect("no receiver present?") - }; - let mut curr_cmd: Option> = None; - - let mut state = State3::<'a, S, C, CR> { - conn_state: SingleConnState3::None, - conn_id: 0, - stream, - runners: FuturesUnordered::< - Pin> + Send>>, - >::new(), - phantom: PhantomData, - }; - +impl Transport +where + Remote: AsyncConnect, + Remote::Connection: AsyncRead + AsyncWrite, + Req: ComposeRequest, +{ + /// Run the transport machinery. + pub async fn run(mut self) { + let mut curr_cmd: Option> = None; let mut do_stream = false; + let mut runners = FuturesUnordered::new(); let mut stream_fut: Pin< - Box> + Send>, + Box< + dyn Future< + Output = Result, + > + Send, + >, > = Box::pin(stream_nop()); let mut opt_chan = None; @@ -503,7 +484,7 @@ impl InnerConnection { match req { ReqCmd::NewConn(opt_id, chan) => { if let SingleConnState3::Err(error_state) = - &state.conn_state + &self.conn_state { if error_state.timer.elapsed() < error_state.timeout @@ -523,20 +504,20 @@ impl InnerConnection { // Check if the command has an id greather than the // current id. if let Some(id) = opt_id { - if id >= state.conn_id { + if id >= self.conn_id { // We need a new connection. Remove the // current one. This is the best place to // increment conn_id. - state.conn_id += 1; - state.conn_state = SingleConnState3::None; + self.conn_id += 1; + self.conn_state = SingleConnState3::None; } } // If we still have a connection then we can reply // immediately. - if let SingleConnState3::Some(conn) = &state.conn_state + if let SingleConnState3::Some(conn) = &self.conn_state { let resp = ChanResp::Ok(ChanRespOk { - id: state.conn_id, + id: self.conn_id, conn: conn.clone(), }); // Ignore errors. We don't care if the receiver @@ -544,7 +525,7 @@ impl InnerConnection { _ = chan.send(resp); } else { opt_chan = Some(chan); - stream_fut = Box::pin(state.stream.connect()); + stream_fut = Box::pin(self.stream.connect()); do_stream = true; } } @@ -553,7 +534,7 @@ impl InnerConnection { } if do_stream { - let runners_empty = state.runners.is_empty(); + let runners_empty = runners.is_empty(); loop { tokio::select! { @@ -561,57 +542,55 @@ impl InnerConnection { do_stream = false; stream_fut = Box::pin(stream_nop()); - if let Err(error) = res_conn { - let error = Arc::new(error); - match state.conn_state { - SingleConnState3::None => - state.conn_state = - SingleConnState3::Err(ErrorState { - error: error.clone(), - retries: 0, - timer: Instant::now(), - timeout: retry_time(0), - }), - SingleConnState3::Some(_) => - panic!("Illegal Some state"), - SingleConnState3::Err(error_state) => { - state.conn_state = - SingleConnState3::Err(ErrorState { - error: error_state.error.clone(), - retries: error_state.retries+1, - timer: Instant::now(), - timeout: retry_time( - error_state.retries+1), - }); + let stream = match res_conn { + Ok(stream) => stream, + Err(error) => { + let error = Arc::new(error); + match self.conn_state { + SingleConnState3::None => + self.conn_state = + SingleConnState3::Err(ErrorState { + error: error.clone(), + retries: 0, + timer: Instant::now(), + timeout: retry_time(0), + }), + SingleConnState3::Some(_) => + panic!("Illegal Some state"), + SingleConnState3::Err(error_state) => { + self.conn_state = + SingleConnState3::Err(ErrorState { + error: + error_state.error.clone(), + retries: error_state.retries+1, + timer: Instant::now(), + timeout: retry_time( + error_state.retries+1), + }); + } } - } - - let resp = ChanResp::Err(error); - let loc_opt_chan = opt_chan.take(); - - // Ignore errors. We don't care if the receiver - // is gone - _ = loc_opt_chan.expect("weird, no channel?") - .send(resp); - break; - } - let stream = res_conn - .expect("error case is checked before"); - let conn = stream::Connection::new(config.stream.clone())?; - let conn_run = conn.clone(); + let resp = ChanResp::Err(error); + let loc_opt_chan = opt_chan.take(); - let clo = || async move { - conn_run.run(stream).await + // Ignore errors. We don't care if the receiver + // is gone + _ = loc_opt_chan.expect("weird, no channel?") + .send(resp); + break; + } }; - let fut = clo(); - state.runners.push(Box::pin(fut)); + let (conn, tran) = stream::Connection::with_config( + stream, self.config.stream.clone() + ); + let conn = Arc::new(conn); + runners.push(Box::pin(tran.run())); let resp = ChanResp::Ok(ChanRespOk { - id: state.conn_id, + id: self.conn_id, conn: conn.clone(), }); - state.conn_state = SingleConnState3::Some(conn); + self.conn_state = SingleConnState3::Some(conn); let loc_opt_chan = opt_chan.take(); @@ -621,7 +600,7 @@ impl InnerConnection { .send(resp); break; } - _ = state.runners.next(), if !runners_empty => { + _ = runners.next(), if !runners_empty => { } } } @@ -629,67 +608,28 @@ impl InnerConnection { } assert!(curr_cmd.is_none()); - let recv_fut = receiver.recv(); - let runners_empty = state.runners.is_empty(); + let recv_fut = self.receiver.recv(); + let runners_empty = runners.is_empty(); tokio::select! { msg = recv_fut => { if msg.is_none() { - // All references to the connection object have been - // dropped. Shutdown. + // All references to the connection object have been + // dropped. Shutdown. break; } curr_cmd = Some(msg.expect("None is checked before").cmd); } - _ = state.runners.next(), if !runners_empty => { + _ = runners.next(), if !runners_empty => { } } } // Avoid new queries - drop(receiver); + drop(self.receiver); // Wait for existing stream runners to terminate - while !state.runners.is_empty() { - state.runners.next().await; - } - - // Done - Ok(()) - } - - /// Request a new connection. - async fn new_conn( - &self, - opt_id: Option, - sender: oneshot::Sender>, - ) -> Result<(), Error> { - let req = ChanReq { - cmd: ReqCmd::NewConn(opt_id, sender), - }; - match self.sender.send(req).await { - Err(_) => - // Send error. The receiver is gone, this means that the - // connection is closed. - { - Err(Error::ConnectionClosed) - } - Ok(_) => Ok(()), - } - } - - /// Request a shutdown. - async fn shutdown(&self) -> Result<(), &'static str> { - let req = ChanReq { - cmd: ReqCmd::Shutdown, - }; - match self.sender.send(req).await { - Err(_) => - // Send error. The receiver is gone, this means that the - // connection is closed. - { - Err(ERR_CONN_CLOSED) - } - Ok(_) => Ok(()), + while !runners.is_empty() { + runners.next().await; } } } @@ -754,9 +694,3 @@ fn is_answer_ignore_id< async fn stream_nop() -> Result { Err(io::Error::new(io::ErrorKind::Other, "nop")) } - -/// Check if config is valid. -fn check_config(_config: &Config) -> Result<(), Error> { - // Nothing to check at the moment. - Ok(()) -} diff --git a/src/net/client/stream.rs b/src/net/client/stream.rs index 2c64983ed..340ce6796 100644 --- a/src/net/client/stream.rs +++ b/src/net/client/stream.rs @@ -1,4 +1,4 @@ -//! A DNS over octet stream transport +//! A client transport using a stream socket. #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] @@ -17,62 +17,80 @@ // - request timeout // - create new connection after end/failure of previous one +use crate::base::opt::{AllOptData, OptRecord, TcpKeepalive}; +use crate::base::Message; +use crate::net::client::request::{ + ComposeRequest, Error, GetResponse, SendRequest, +}; use bytes; use bytes::{Bytes, BytesMut}; +use core::cmp; use core::convert::From; +use octseq::Octets; use std::boxed::Box; use std::fmt::Debug; use std::future::Future; -use std::io::ErrorKind; use std::pin::Pin; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::{Duration, Instant}; use std::vec::Vec; - -use crate::base::{ - opt::{AllOptData, OptRecord, TcpKeepalive}, - Message, -}; -use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, SendRequest, -}; -use octseq::Octets; - -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot}; use tokio::time::sleep; -/// Default configuration value for the amount of time to wait on a non-idle -/// connection for the other side to send a response on any outstanding query. -// Implement a simple response timer to see if the connection and the server -// are alive. Set the timer when the connection goes from idle to busy. -// Reset the timer each time a reply arrives. Cancel the timer when the -// connection goes back to idle. When the time expires, mark all outstanding -// queries as timed out and shutdown the connection. -// -// Note: nsd has 120 seconds, unbound has 3 seconds. +//------------ Configuration Constants ---------------------------------------- + +/// Default response timeout. +/// +/// Note: nsd has 120 seconds, unbound has 3 seconds. const DEF_RESPONSE_TIMEOUT: Duration = Duration::from_secs(19); -/// Minimum configuration value for response_timeout. +/// Minimum configuration value for the response timeout. const MIN_RESPONSE_TIMEOUT: Duration = Duration::from_millis(1); -/// Maximum configuration value for response_timeout. +/// Maximum configuration value for the response timeout. const MAX_RESPONSE_TIMEOUT: Duration = Duration::from_secs(600); -/// Capacity of the channel that transports [ChanReq]. +/// Capacity of the channel that transports `ChanReq`s. const DEF_CHAN_CAP: usize = 8; -/// Capacity of a private channel between [InnerConnection::reader] and -/// [InnerConnection::run]. +/// Capacity of a private channel dispatching responses. const READ_REPLY_CHAN_CAP: usize = 8; //------------ Config --------------------------------------------------------- -/// Configuration for an octet_stream transport connection. +/// Configuration for a stream transport connection. #[derive(Clone, Debug)] pub struct Config { /// Response timeout. - pub response_timeout: Duration, + response_timeout: Duration, +} + +impl Config { + /// Creates a new, default config. + pub fn new() -> Self { + Default::default() + } + + /// Returns the response timeout. + /// + /// This is the amount of time to wait on a non-idle connection for a + /// response to an outstanding request. + pub fn response_timeout(&self) -> Duration { + self.response_timeout + } + + /// Sets the response timeout. + /// + /// Excessive values are quietly trimmed. + // + // XXX Maybe that’s wrong and we should rather return an error? + pub fn set_response_timeout(&mut self, timeout: Duration) { + self.response_timeout = cmp::max( + cmp::min(timeout, MAX_RESPONSE_TIMEOUT), + MIN_RESPONSE_TIMEOUT, + ) + } } impl Default for Config { @@ -85,53 +103,53 @@ impl Default for Config { //------------ Connection ----------------------------------------------------- -#[derive(Clone, Debug)] -/// A single DNS over octect stream connection. -pub struct Connection { - /// Reference counted [InnerConnection]. - inner: Arc>, +/// A connection to a single stream transport. +#[derive(Debug)] +pub struct Connection { + /// The sender half of the request channel. + sender: mpsc::Sender>, } -impl Connection { - /// Constructor for [Connection]. +impl Connection { + /// Creates a new stream transport with default configuration. /// - /// Returns a [Connection] wrapped in a [Result](io::Result). - pub fn new(config: Option) -> Result { - let config = match config { - Some(config) => { - check_config(&config)?; - config - } - None => Default::default(), - }; - let connection = InnerConnection::new(config)?; - Ok(Self { - inner: Arc::new(connection), - }) + /// Returns a connection and a future that drives the transport using + /// the provided stream. This future needs to be run while any queries + /// are active. This is most easly achieved by spawning it into a runtime. + /// It terminates when the last connection is dropped. + pub fn new(stream: Stream) -> (Self, Transport) { + Self::with_config(stream, Default::default()) } - /// Main execution function for [Connection]. + /// Creates a new stream transport with the given configuration. /// - /// This function has to run in the background or together with - /// any calls to [query](Self::query) or [Query::get_result]. - /// Worker function for a connection object. - pub fn run( - &self, - io: IO, - ) -> Pin> + Send>> { - self.inner.run(io) + /// Returns a connection and a future that drives the transport using + /// the provided stream. This future needs to be run while any queries + /// are active. This is most easly achieved by spawning it into a runtime. + /// It terminates when the last connection is dropped. + pub fn with_config( + stream: Stream, + config: Config, + ) -> (Self, Transport) { + let (sender, transport) = Transport::new(stream, config); + let this = Self { + sender: sender.into(), + }; + (this, transport) } +} +impl Connection { /// Start a DNS request. /// /// This function takes a precomposed message as a parameter and /// returns a [ReqRepl] object wrapped in a [Result]. async fn request_impl( &self, - request_msg: &CR, + request_msg: &Req, ) -> Result, Error> { let (tx, rx) = oneshot::channel(); - self.inner.request(tx, request_msg).await?; + self.request(tx, request_msg.clone()).await?; Ok(Box::new(ReqResp::new(request_msg, rx))) } @@ -141,20 +159,41 @@ impl Connection { /// match the request avoids having to keep the request around. pub async fn query_no_check( &self, - query_msg: &CR, + query_msg: &Req, ) -> Result { let (tx, rx) = oneshot::channel(); - self.inner.request(tx, query_msg).await?; + self.request(tx, query_msg.clone()).await?; Ok(QueryNoCheck::new(rx)) } + + /// Sends a request. + async fn request( + &self, + sender: oneshot::Sender, + request_msg: Req, + ) -> Result<(), Error> { + let req = ChanReq { + sender, + msg: request_msg, + }; + match self.sender.send(req).await { + Err(_) => + // Send error. The receiver is gone, this means that the + // connection is closed. + { + Err(Error::ConnectionClosed) + } + Ok(_) => Ok(()), + } + } } -impl SendRequest - for Connection +impl SendRequest + for Connection { fn send_request<'a>( &'a self, - request_msg: &'a CR, + request_msg: &'a Req, ) -> Pin< Box< dyn Future, Error>> @@ -196,8 +235,8 @@ enum QueryState { impl ReqResp { /// Constructor for [Query], takes a DNS query and a receiver for the /// reply. - fn new( - request_msg: &CR, + fn new( + request_msg: &Req, receiver: oneshot::Receiver, ) -> ReqResp { let vec = request_msg.to_vec(); @@ -233,8 +272,7 @@ impl ReqResp { return Err(err); } - let resp = res.expect("error case is checked already"); - let msg = resp.reply; + let msg = res.expect("error case is checked already"); if !is_answer_ignore_id(&msg, &self.request_msg) { return Err(Error::WrongReplyForQuery); @@ -300,8 +338,7 @@ impl QueryNoCheck { return Err(err); } - let resp = res.expect("error case is checked already"); - let msg = resp.reply; + let msg = res.expect("error case is checked already"); Ok(msg) } @@ -312,34 +349,26 @@ impl QueryNoCheck { } } -//------------ InnerConnection ------------------------------------------------ +//------------ Transport ------------------------------------------------ -/// The actual implementation of [Connection]. +/// The underlying machinery of a stream transport. #[derive(Debug)] -struct InnerConnection { - /// User configuration variables. - config: Config, +pub struct Transport { + /// The stream socket towards the remove end. + stream: Stream, - /// [InnerConnection::sender] and [InnerConnection::receiver] are - /// part of a single channel. - /// - /// Used by [Query] to send requests to [InnerConnection::run]. - sender: mpsc::Sender>, + /// Transport configuration. + config: Config, - /// receiver part of the channel. - /// - /// Protected by a mutex to allow read/write access by - /// [InnerConnection::run]. - /// The Option is to allow [InnerConnection::run] to signal that the - /// connection is closed. - receiver: Mutex>>>, + /// The receiver half of request channel. + receiver: mpsc::Receiver>, } +/// A message from a `Query` to start a new request. #[derive(Debug)] -/// A request from [Query] to [Connection::run] to start a DNS request. -struct ChanReq { +struct ChanReq { /// DNS request message - msg: CR, + msg: Req, /// Sender to send result back to [Query] sender: ReplySender, @@ -348,17 +377,10 @@ struct ChanReq { /// This is the type of sender in [ChanReq]. type ReplySender = oneshot::Sender; -/// Response to the DNS request sent by [InnerConnection::run] to [Query]. -type ChanResp = Result; - -#[derive(Debug)] -/// a response to a [ChanReq]. -struct Response { - /// The DNS reply message. - reply: Message, -} +/// A message back to `Query` returning a response. +type ChanResp = Result, Error>; -/// Internal datastructure of [InnerConnection::run] to keep track of +/// Internal datastructure of [Transport::run] to keep track of /// outstanding DNS requests. struct Queries { /// The number of elements in [Queries::vec] that are not None. @@ -371,19 +393,18 @@ struct Queries { vec: Vec>, } -/// Internal datastructure of [InnerConnection::run] to keep track of +/// Internal datastructure of [Transport::run] to keep track of /// the status of the connection. -// The types Status and ConnState are only used in InnerConnection +// The types Status and ConnState are only used in Transport struct Status { /// State of the connection. state: ConnState, - /// Boolean if we need to include an edns-tcp-keepalive option in an - /// outogoing request. + /// Do we need to include edns-tcp-keepalive in an outogoing request. /// - /// Typically send_keepalive is true at the start of the connection. - /// it gets cleared when we successfully managed to include the option - /// in a request. + /// Typically this is true at the start of the connection and gets + /// cleared when we successfully managed to include the option in a + /// request. send_keepalive: bool, /// Time we are allow to keep the connection open when idle. @@ -415,64 +436,46 @@ enum ConnState { /// A read error occurred. ReadError(Error), - /// It took too long to receive a (or another) response. + /// It took too long to receive a response. ReadTimeout, /// A write error occurred. WriteError(Error), } -/// A DNS message received to [InnerConnection::reader] and sent to -/// [InnerConnection::run]. -// This type could be local to InnerConnection, but I don't know how -type ReaderChanReply = Message; - -impl InnerConnection { - /// Constructor for [InnerConnection]. - /// - /// This is the implementation of [Connection::new]. - pub fn new(config: Config) -> Result { - let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); - Ok(Self { - config, - sender: tx, - receiver: Mutex::new(Some(rx)), - }) - } - - /// Run method. - /// - /// Make sure the future does not contain a reference to self. - pub fn run( - &self, - io: IO, - ) -> Pin> + Send>> { - let mut receiver = self.receiver.lock().unwrap(); - let opt_receiver = receiver.take(); - drop(receiver); - - Box::pin(Self::run_impl(self.config.clone(), io, opt_receiver)) +impl Transport { + /// Creates a new transport. + fn new( + stream: Stream, + config: Config, + ) -> (mpsc::Sender>, Self) { + let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP); + ( + sender, + Self { + config, + stream, + receiver, + }, + ) } +} - /// Main execution function for [InnerConnection]. - /// - /// This function Gets called by [Connection::run]. - /// This function is not async cancellation safe - async fn run_impl( - config: Config, - io: IO, - opt_receiver: Option>>, - ) -> Option<()> { +impl Transport +where + Stream: AsyncRead + AsyncWrite, + Req: ComposeRequest, +{ + /// Run the transport machinery. + pub async fn run(mut self) { let (reply_sender, mut reply_receiver) = - mpsc::channel::(READ_REPLY_CHAN_CAP); + mpsc::channel::>(READ_REPLY_CHAN_CAP); - let (mut read_stream, mut write_stream) = tokio::io::split(io); + let (read_stream, mut write_stream) = tokio::io::split(self.stream); - let reader_fut = Self::reader(&mut read_stream, reply_sender); + let reader_fut = Self::reader(read_stream, reply_sender); tokio::pin!(reader_fut); - let mut receiver = { opt_receiver.expect("no receiver present?") }; - let mut status = Status { state: ConnState::Active(None), idle_timeout: None, @@ -491,7 +494,7 @@ impl InnerConnection { ConnState::Active(opt_instant) => { if let Some(instant) = opt_instant { let elapsed = instant.elapsed(); - if elapsed > config.response_timeout { + if elapsed > self.config.response_timeout { Self::error( Error::StreamReadTimeout, &mut query_vec, @@ -499,7 +502,7 @@ impl InnerConnection { status.state = ConnState::ReadTimeout; break; } - Some(config.response_timeout - elapsed) + Some(self.config.response_timeout - elapsed) } else { None } @@ -532,12 +535,12 @@ impl InnerConnection { None => // Just use the response timeout { - config.response_timeout + self.config.response_timeout } }; let sleep_fut = sleep(timeout); - let recv_fut = receiver.recv(); + let recv_fut = self.receiver.recv(); let (do_write, msg) = match &reqmsg { None => { @@ -598,14 +601,16 @@ impl InnerConnection { } res = recv_fut, if !do_write => { match res { - Some(req) => - Self::insert_req(req, &mut status, - &mut reqmsg, &mut query_vec), + Some(req) => { + Self::insert_req( + req, &mut status, &mut reqmsg, &mut query_vec + ) + } None => { - // All references to the connection object have - // been dropped. Shutdown. - break; - } + // All references to the connection object have + // been dropped. Shutdown. + break; + } } } _ = sleep_fut => { @@ -631,43 +636,19 @@ impl InnerConnection { // Send FIN _ = write_stream.shutdown().await; - - None - } - - /// This function sends a DNS request to [InnerConnection::run]. - pub async fn request( - &self, - sender: oneshot::Sender, - request_msg: &CR, - ) -> Result<(), Error> { - let req = ChanReq { - sender, - msg: request_msg.clone(), - }; - match self.sender.send(req).await { - Err(_) => - // Send error. The receiver is gone, this means that the - // connection is closed. - { - Err(Error::ConnectionClosed) - } - Ok(_) => Ok(()), - } } /// This function reads a DNS message from the connection and sends - /// it to [InnerConnection::run]. + /// it to [Transport::run]. /// /// Reading has to be done in two steps: first read a two octet value /// the specifies the length of the message, and then read in a loop the /// body of the message. /// /// This function is not async cancellation safe. - async fn reader( - //sock: &mut ReadStream, - mut sock: ReadStream, - sender: mpsc::Sender, + async fn reader( + mut sock: tokio::io::ReadHalf, + sender: mpsc::Sender>, ) -> Result<(), Error> { loop { let read_res = sock.read_u16().await; @@ -724,7 +705,7 @@ impl InnerConnection { } } - /// An error occured, report the error to all outstanding [Query] objects. + /// Reports an error to all outstanding queries. fn error(error: Error, query_vec: &mut Queries) { // Update all requests that are in progress. Don't wait for // any reply that may be on its way. @@ -737,12 +718,16 @@ impl InnerConnection { } } - /// Handle received EDNS options, in particular the edns-tcp-keepalive - /// option. - fn handle_opts>( - opts: &OptRecord, + /// Handles received EDNS options. + /// + /// In particular, it processes the edns-tcp-keepalive option. + fn handle_opts>( + opts: &OptRecord, status: &mut Status, ) { + // XXX This handles _all_ keepalive options. I think just using the + // first option as returned by Opt::tcp_keepalive should be good + // enough? -- M. for option in opts.opt().iter().flatten() { if let AllOptData::TcpKeepalive(tcpkeepalive) = option { Self::handle_keepalive(tcpkeepalive, status); @@ -750,7 +735,7 @@ impl InnerConnection { } } - /// Demultiplex a DNS reply and send it to the right [Query] object. + /// Demultiplexes a response and sends it to the right query. /// /// In addition, the status is updated to IdleTimeout or Idle if there /// are no remaining pending requests. @@ -782,8 +767,7 @@ impl InnerConnection { Some(_) => { let sender = Self::take_query(query_vec, index) .expect("sender should be there"); - let reply = Response { reply: answer }; - _ = sender.send(Ok(reply)); + _ = sender.send(Ok(answer)); } } if query_vec.count == 0 { @@ -810,7 +794,7 @@ impl InnerConnection { /// idle. Addend a edns-tcp-keepalive option if needed. // Note: maybe reqmsg should be a return value. fn insert_req( - mut req: ChanReq, + mut req: ChanReq, status: &mut Status, reqmsg: &mut Option>, query_vec: &mut Queries, @@ -906,7 +890,7 @@ impl InnerConnection { /// Convert the query message to a vector. // This function should return the vector instead of storing it // through a reference. - fn convert_query(msg: &CR, reqmsg: &mut Option>) { + fn convert_query(msg: &Req, reqmsg: &mut Option>) { // Ideally there should be a write_all_vectored. Until there is one, // copy to a new Vec and prepend the length octets. @@ -1010,16 +994,3 @@ fn is_answer_ignore_id< reply.question() == query.question() } } - -/// Check if config is valid. -fn check_config(config: &Config) -> Result<(), Error> { - if config.response_timeout < MIN_RESPONSE_TIMEOUT - || config.response_timeout > MAX_RESPONSE_TIMEOUT - { - return Err(Error::OctetStreamConfigError(Arc::new( - std::io::Error::new(ErrorKind::Other, "response_timeout"), - ))); - } - - Ok(()) -} diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index 8fa57884f..b3c0e1f27 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -150,15 +150,15 @@ impl StubResolver { if self.options.use_vc { for s in &self.servers { if let Transport::Tcp = s.transport { - let tcp_connect = TcpConnect::new(s.addr); - let tcp_conn = - multi_stream::Connection::new(None).unwrap(); + let (conn, tran) = multi_stream::Connection::new( + TcpConnect::new(s.addr), + ); // Start the run function on a separate task. - let run_fut = tcp_conn.run(tcp_connect); + let run_fut = tran.run(); fut_list_tcp.push(async move { let _res = run_fut.await; }); - redun.add(Box::new(tcp_conn)).await.unwrap(); + redun.add(Box::new(conn)).await.unwrap(); } } } else { @@ -166,15 +166,15 @@ impl StubResolver { if let Transport::Udp = s.transport { let udp_connect = UdpConnect::new(s.addr); let tcp_connect = TcpConnect::new(s.addr); - let udptcp_conn = - dgram_stream::Connection::new(None, udp_connect) - .unwrap(); + let (conn, tran) = dgram_stream::Connection::new( + udp_connect, + tcp_connect, + ); // Start the run function on a separate task. - let run_fut = udptcp_conn.run(tcp_connect); fut_list_udp_tcp.push(async move { - let _res = run_fut.await; + tran.run().await; }); - redun.add(Box::new(udptcp_conn)).await.unwrap(); + redun.add(Box::new(conn)).await.unwrap(); } } } diff --git a/tests/net-client.rs b/tests/net-client.rs index 92ab1b1c3..1712fabeb 100644 --- a/tests/net-client.rs +++ b/tests/net-client.rs @@ -30,7 +30,7 @@ fn dgram() { let step_value = Arc::new(CurrStepValue::new()); let conn = Dgram::new(deckard.clone(), step_value.clone()); - let octstr = dgram::Connection::new(None, conn).unwrap(); + let octstr = dgram::Connection::new(None, conn); do_client(&deckard, octstr, &step_value).await; }); @@ -44,10 +44,9 @@ fn single() { let step_value = Arc::new(CurrStepValue::new()); let conn = Connection::new(deckard.clone(), step_value.clone()); - let octstr = stream::Connection::new(None).unwrap(); - let run_fut = octstr.run(conn); + let (octstr, transport) = stream::Connection::new(conn); tokio::spawn(async move { - run_fut.await; + transport.run().await; }); do_client(&deckard, octstr, &step_value).await; @@ -62,10 +61,9 @@ fn multi() { let step_value = Arc::new(CurrStepValue::new()); let multi_conn = Connect::new(deckard.clone(), step_value.clone()); - let ms = multi_stream::Connection::new(None).unwrap(); - let run_fut = ms.run(multi_conn); + let (ms, ms_tran) = multi_stream::Connection::new(multi_conn); tokio::spawn(async move { - run_fut.await.unwrap(); + ms_tran.run().await; println!("multi conn run terminated"); }); @@ -82,10 +80,9 @@ fn dgram_stream() { let step_value = Arc::new(CurrStepValue::new()); let conn = Dgram::new(deckard.clone(), step_value.clone()); let multi_conn = Connect::new(deckard.clone(), step_value.clone()); - let ds = dgram_stream::Connection::new(None, conn).unwrap(); - let run_fut = ds.run(multi_conn); + let (ds, tran) = dgram_stream::Connection::new(conn, multi_conn); tokio::spawn(async move { - run_fut.await.unwrap(); + tran.run().await; println!("dgram_stream conn run terminated"); }); @@ -101,10 +98,9 @@ fn redundant() { let step_value = Arc::new(CurrStepValue::new()); let multi_conn = Connect::new(deckard.clone(), step_value.clone()); - let ms = multi_stream::Connection::new(None).unwrap(); - let run_fut = ms.run(multi_conn); + let (ms, ms_tran) = multi_stream::Connection::new(multi_conn); tokio::spawn(async move { - run_fut.await.unwrap(); + ms_tran.run().await; println!("multi conn run terminated"); }); @@ -142,10 +138,9 @@ fn tcp() { } }; - let tcp = stream::Connection::new(None).unwrap(); - let run_fut = tcp.run(tcp_conn); + let (tcp, transport) = stream::Connection::new(tcp_conn); tokio::spawn(async move { - run_fut.await; + transport.run().await; println!("single TCP run terminated"); }); From b0225f0cf7427b651d3492163ece23bb37486c8d Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Tue, 2 Jan 2024 16:12:00 +0100 Subject: [PATCH 106/124] Clippy-suggested fixes. --- src/net/client/dgram.rs | 9 +++++++++ src/net/client/multi_stream.rs | 2 +- src/net/client/stream.rs | 5 +---- tests/net-client.rs | 1 - 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/net/client/dgram.rs b/src/net/client/dgram.rs index 6f52e7164..399c52924 100644 --- a/src/net/client/dgram.rs +++ b/src/net/client/dgram.rs @@ -444,22 +444,31 @@ fn is_answer< //------------ DefMinMax ----------------------------------------------------- +/// The default, minimum, and maximum values for a config variable. #[derive(Clone, Copy)] struct DefMinMax { + /// The default value, def: T, + + /// The minimum value, min: T, + + /// The maximum value, max: T, } impl DefMinMax { + /// Creates a new value. const fn new(def: T, min: T, max: T) -> Self { Self { def, min, max } } + /// Returns the default value. fn default(self) -> T { self.def } + /// Trims the given value to fit into the minimum/maximum range. fn limit(self, value: T) -> T where T: Ord, diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index a9d57747d..4fb55efcf 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -232,7 +232,7 @@ impl Query { ) -> Self { Self { conn, - request_msg: request_msg, + request_msg, state: QueryState::GetConn(receiver), conn_id: 0, delayed_retry_count: 0, diff --git a/src/net/client/stream.rs b/src/net/client/stream.rs index 340ce6796..672d9c0c3 100644 --- a/src/net/client/stream.rs +++ b/src/net/client/stream.rs @@ -132,10 +132,7 @@ impl Connection { config: Config, ) -> (Self, Transport) { let (sender, transport) = Transport::new(stream, config); - let this = Self { - sender: sender.into(), - }; - (this, transport) + (Self { sender }, transport) } } diff --git a/tests/net-client.rs b/tests/net-client.rs index 1712fabeb..8b65a50fa 100644 --- a/tests/net-client.rs +++ b/tests/net-client.rs @@ -18,7 +18,6 @@ use std::net::SocketAddr; use std::str::FromStr; use std::sync::Arc; use tokio::net::TcpStream; -use tokio_test; const TEST_FILE: &str = "test-data/basic.rpl"; From 7703680d6290e0d2c65c94b270bb6ee0e5cb48db Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Tue, 2 Jan 2024 16:45:46 +0100 Subject: [PATCH 107/124] Add unstable-client-transport to resolv. --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index da280d85e..5508ea6c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ default = ["std", "rand"] bytes = ["dep:bytes", "octseq/bytes"] heapless = ["dep:heapless", "octseq/heapless"] interop = ["bytes", "ring"] -resolv = ["net", "smallvec", "std", "rand"] +resolv = ["net", "smallvec", "std", "rand", "unstable-client-transport"] resolv-sync = ["resolv", "tokio/rt"] serde = ["dep:serde", "octseq/serde"] sign = ["std"] @@ -62,7 +62,7 @@ unstable-client-transport = [] # This feature should include all features that the CI should include for a # test run. Which is everything except interop. -ci-test = ["resolv", "resolv-sync", "sign", "std", "serde", "tsig", "validate", "zonefile"] +ci-test = ["net", "resolv", "resolv-sync", "sign", "std", "serde", "tsig", "validate", "zonefile"] [dev-dependencies] rustls = { version = "0.21.9" } From 0e6ec78087d761d33f6340cb4f97a298b2c1656b Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Tue, 2 Jan 2024 17:15:40 +0100 Subject: [PATCH 108/124] Fix doctests. --- src/net/client/mod.rs | 60 +++++++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 980193045..b169f88ac 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -25,10 +25,14 @@ //! //! For example: //! ```rust +//! # use domain::base::{Dname, MessageBuilder, Rtype}; +//! # use domain::net::client::request::RequestMessage; //! let mut msg = MessageBuilder::new_vec(); //! msg.header_mut().set_rd(true); //! let mut msg = msg.question(); -//! msg.push((Dname::vec_from_str("example.com").unwrap(), Aaaa)).unwrap(); +//! msg.push( +//! (Dname::vec_from_str("example.com").unwrap(), Rtype::Aaaa) +//! ).unwrap(); //! let req = RequestMessage::new(msg); //! ``` @@ -39,19 +43,26 @@ //! DNS transport and running a ```run``` method as a separate task. This //! is illustrated in the following example: //! ```rust -//! let multi_stream_config = multi_stream::Config { -//! stream: Some(stream::Config { -//! response_timeout: Duration::from_millis(100), -//! }), -//! }; +//! # use domain::net::client::multi_stream; +//! # use domain::net::client::protocol::TcpConnect; +//! # use domain::net::client::request::SendRequest; +//! # use std::time::Duration; +//! # async fn _test() { +//! # let server_addr = String::from("127.0.0.1:53"); +//! let mut multi_stream_config = multi_stream::Config::default(); +//! multi_stream_config.stream_mut().set_response_timeout( +//! Duration::from_millis(100), +//! ); //! let tcp_connect = TcpConnect::new(server_addr); -//! let tcp_conn = -//! multi_stream::Connection::new(Some(multi_stream_config.clone())) -//! .unwrap(); -//! let run_fut = tcp_conn.run(tcp_connect); -//! tokio::spawn(async move { -//! let res = run_fut.await; -//! }); +//! let (tcp_conn, transport) = multi_stream::Connection::with_config( +//! tcp_connect, multi_stream_config +//! ); +//! tokio::spawn(transport.run()); +//! # let req = domain::net::client::request::RequestMessage::new( +//! # domain::base::MessageBuilder::new_vec() +//! # ); +//! # let mut request = tcp_conn.send_request(&req).await.unwrap(); +//! # } //! ``` //! Note that the run function ends when the last reference to the DNS //! transport is dropped. For this reason it is important to avoid having a @@ -80,7 +91,18 @@ //! //! For example: //! ```rust +//! # use domain::net::client::request::SendRequest; +//! # async fn _test() { +//! # let (tls_conn, _) = domain::net::client::stream::Connection::new( +//! # domain::net::client::protocol::TcpConnect::new( +//! # String::from("127.0.0.1:53") +//! # ) +//! # ); +//! # let req = domain::net::client::request::RequestMessage::new( +//! # domain::base::MessageBuilder::new_vec() +//! # ); //! let mut request = tls_conn.send_request(&req).await.unwrap(); +//! # } //! ``` //! where ```tls_conn``` is a transport connection for DNS over TLS. @@ -95,7 +117,19 @@ //! //! For example: //! ```rust +//! # use crate::domain::net::client::request::SendRequest; +//! # async fn _test() { +//! # let (tls_conn, _) = domain::net::client::stream::Connection::new( +//! # domain::net::client::protocol::TcpConnect::new( +//! # String::from("127.0.0.1:53") +//! # ) +//! # ); +//! # let req = domain::net::client::request::RequestMessage::new( +//! # domain::base::MessageBuilder::new_vec() +//! # ); +//! # let mut request = tls_conn.send_request(&req).await.unwrap(); //! let reply = request.get_response().await; +//! # } //! ``` //! # Example with various transport connections From acbb06c55cad631195d30d93daeba8bc7f2b1212 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Wed, 3 Jan 2024 14:48:45 +0100 Subject: [PATCH 109/124] Simplify stream transport. --- Cargo.toml | 3 +- src/base/message.rs | 6 + src/net/client/multi_stream.rs | 6 +- src/net/client/request.rs | 160 ++++++++----- src/net/client/stream.rs | 417 +++++++++------------------------ tests/net-client.rs | 36 ++- tests/net/deckard/client.rs | 37 ++- 7 files changed, 299 insertions(+), 366 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5508ea6c2..c908f55fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ heapless = { version = "0.7", optional = true } ring = { version = "0.17", optional = true } serde = { version = "1.0.130", optional = true, features = ["derive"] } siphasher = { version = "1", optional = true } +slab = { version = "0.4.0", optional = true } smallvec = { version = "1", optional = true } tokio = { version = "1.33", optional = true, features = ["io-util", "macros", "net", "time", "sync", "rt-multi-thread" ] } tokio-rustls = { version = "0.24", optional = true, features = [] } @@ -52,7 +53,7 @@ serde = ["dep:serde", "octseq/serde"] sign = ["std"] smallvec = ["dep:smallvec", "octseq/smallvec"] std = [] -net = ["bytes", "futures-util", "std", "tokio", "tokio-rustls"] +net = ["bytes", "futures-util", "slab", "std", "tokio", "tokio-rustls"] tsig = ["bytes", "ring", "smallvec"] validate = ["std", "ring"] zonefile = ["bytes", "std"] diff --git a/src/base/message.rs b/src/base/message.rs index 83cd2b80b..b23b7156e 100644 --- a/src/base/message.rs +++ b/src/base/message.rs @@ -1194,6 +1194,12 @@ impl From for CopyRecordsError { } } +impl From for CopyRecordsError { + fn from(err: PushError) -> Self { + CopyRecordsError::Push(err) + } +} + //--- Display and Error impl fmt::Display for CopyRecordsError { diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 4fb55efcf..94881d670 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -198,7 +198,7 @@ enum QueryState { StartQuery(Arc>), /// Get the result of the query. - GetResult(stream::QueryNoCheck), + GetResult(stream::Query), /// Wait until trying again. /// @@ -282,7 +282,7 @@ impl Query { } QueryState::StartQuery(ref mut conn) => { let msg = self.request_msg.clone(); - let query_res = conn.query_no_check(&msg).await; + let query_res = conn.start_request(msg.clone()).await; match query_res { Err(err) => { if let Error::ConnectionClosed = err { @@ -305,7 +305,7 @@ impl Query { } } QueryState::GetResult(ref mut query) => { - let reply = query.get_result().await; + let reply = query.get_response().await; if reply.is_err() { self.delayed_retry_count += 1; diff --git a/src/net/client/request.rs b/src/net/client/request.rs index 81b11f509..b042880c3 100644 --- a/src/net/client/request.rs +++ b/src/net/client/request.rs @@ -3,10 +3,13 @@ #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] -use crate::base::opt::{ComposeOptData, LongOptData, OptRecord}; -use crate::base::{ - Header, Message, MessageBuilder, ParsedDname, Rtype, StaticCompressor, +use crate::base::message::CopyRecordsError; +use crate::base::message_builder::{ + AdditionalBuilder, MessageBuilder, PushError, StaticCompressor, }; +use crate::base::opt::{ComposeOptData, LongOptData, OptRecord}; +use crate::base::wire::Composer; +use crate::base::{Header, Message, ParsedDname, Rtype}; use crate::rdata::AllRecordData; use bytes::Bytes; use octseq::Octets; @@ -22,6 +25,12 @@ use std::{error, fmt}; /// A trait that allows composing a request as a series. pub trait ComposeRequest: Debug + Send + Sync { + /// Appends the final message to a provided composer. + fn append_message( + &self, + target: &mut Target, + ) -> Result<(), CopyRecordsError>; + /// Create a message that captures the recorded changes. fn to_message(&self) -> Message>; @@ -40,9 +49,31 @@ pub trait ComposeRequest: Debug + Send + Sync { &mut self, opt: &impl ComposeOptData, ) -> Result<(), LongOptData>; + + /// Returns whether a message is an answer to the request. + fn is_answer(&self, answer: &Message<[u8]>) -> bool; +} + +//------------ HandleRequest ------------------------------------------------- + +/// Trait for handling a DNS request. +pub trait HandleRequest { + /// The response returned upon success. + type Response: AsRef>; + + /// The error returned upon failure. + type Error; + + /// The future producing the response. + type Fut<'s>: Future> + 's + where + Self: 's; + + /// Returns a future processing the request. + fn handle_request(&self, request_msg: Req) -> Self::Fut<'_>; } -//------------ Request ------------------------------------------------------- +//------------ SendRequest --------------------------------------------------- /// Trait for starting a DNS request based on a request composer. /// @@ -113,82 +144,67 @@ impl + Debug + Octets> RequestMessage { self.opt.get_or_insert_with(Default::default) } - /// Create new message based on the changes to the base message. - fn to_message_impl(&self) -> Result>, Error> { + /// Appends the message to a composer. + fn append_message_impl( + &self, + mut target: MessageBuilder, + ) -> Result, CopyRecordsError> { let source = &self.msg; - let mut target = - MessageBuilder::from_target(StaticCompressor::new(Vec::new())) - .expect("Vec is expected to have enough space"); - let target_hdr = target.header_mut(); - target_hdr.set_flags(self.header.flags()); - target_hdr.set_opcode(self.header.opcode()); - target_hdr.set_rcode(self.header.rcode()); - target_hdr.set_id(self.header.id()); + *target.header_mut() = self.header; let source = source.question(); let mut target = target.question(); for rr in source { - let rr = rr.map_err(|_e| Error::MessageParseError)?; - target - .push(rr) - .map_err(|_e| Error::MessageBuilderPushError)?; + target.push(rr?)?; } - let mut source = - source.answer().map_err(|_e| Error::MessageParseError)?; + let mut source = source.answer()?; let mut target = target.answer(); for rr in &mut source { - let rr = rr.map_err(|_e| Error::MessageParseError)?; - let rr = rr - .into_record::>>() - .map_err(|_e| Error::MessageParseError)? + let rr = rr? + .into_record::>>()? .expect("record expected"); - target - .push(rr) - .map_err(|_e| Error::MessageBuilderPushError)?; + target.push(rr)?; } - let mut source = source - .next_section() - .map_err(|_e| Error::MessageParseError)? - .expect("section should be present"); + let mut source = + source.next_section()?.expect("section should be present"); let mut target = target.authority(); for rr in &mut source { - let rr = rr.map_err(|_e| Error::MessageParseError)?; - let rr = rr - .into_record::>>() - .map_err(|_e| Error::MessageParseError)? + let rr = rr? + .into_record::>>()? .expect("record expected"); - target - .push(rr) - .map_err(|_e| Error::MessageBuilderPushError)?; + target.push(rr)?; } - let source = source - .next_section() - .map_err(|_e| Error::MessageParseError)? - .expect("section should be present"); + let source = + source.next_section()?.expect("section should be present"); let mut target = target.additional(); for rr in source { - let rr = rr.map_err(|_e| Error::MessageParseError)?; - if rr.rtype() == Rtype::Opt { - } else { + let rr = rr?; + if rr.rtype() != Rtype::Opt { let rr = rr - .into_record::>>() - .map_err(|_e| Error::MessageParseError)? + .into_record::>>()? .expect("record expected"); - target - .push(rr) - .map_err(|_e| Error::MessageBuilderPushError)?; + target.push(rr)?; } } if let Some(opt) = self.opt.as_ref() { - target - .push(opt.as_record()) - .map_err(|_| Error::MessageBuilderPushError)?; + target.push(opt.as_record())?; } + Ok(target) + } + + /// Create new message based on the changes to the base message. + fn to_message_impl(&self) -> Result>, Error> { + let target = + MessageBuilder::from_target(StaticCompressor::new(Vec::new())) + .expect("Vec is expected to have enough space"); + + let target = self.append_message_impl(target)?; + // It would be nice to use .builder() here. But that one deletes all // section. We have to resort to .as_builder() which gives a // reference and then .clone() @@ -203,6 +219,16 @@ impl + Debug + Octets> RequestMessage { impl + Clone + Debug + Octets + Send + Sync + 'static> ComposeRequest for RequestMessage { + fn append_message( + &self, + target: &mut Target, + ) -> Result<(), CopyRecordsError> { + let target = MessageBuilder::from_target(target) + .map_err(|_| CopyRecordsError::Push(PushError::ShortBuf))?; + self.append_message_impl(target)?; + Ok(()) + } + fn to_vec(&self) -> Vec { let msg = self.to_message(); msg.as_octets().clone() @@ -226,6 +252,18 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> ) -> Result<(), LongOptData> { self.opt_mut().push(opt).map_err(|e| e.unlimited_buf()) } + + fn is_answer(&self, answer: &Message<[u8]>) -> bool { + if !answer.header().qr() + || answer.header_counts().qdcount() + != self.msg.header_counts().qdcount() + || answer.header().id() != self.header.id() + { + false + } else { + answer.question() == self.msg.for_slice().question() + } + } } //------------ Error --------------------------------------------------------- @@ -254,6 +292,9 @@ pub enum Error { /// Octet sequence too short to be a valid DNS message. ShortMessage, + /// Message too long for stream transport. + StreamLongMessage, + /// Stream transport closed because it was idle (for too long). StreamIdleTimeout, @@ -326,6 +367,9 @@ impl fmt::Display for Error { Error::ShortMessage => { write!(f, "octet sequence to short to be a valid message") } + Error::StreamLongMessage => { + write!(f, "message too long for stream transport") + } Error::StreamIdleTimeout => { write!(f, "stream was idle for too long") } @@ -366,6 +410,15 @@ impl fmt::Display for Error { } } +impl From for Error { + fn from(err: CopyRecordsError) -> Self { + match err { + CopyRecordsError::Parse(_) => Self::MessageParseError, + CopyRecordsError::Push(_) => Self::MessageBuilderPushError, + } + } +} + impl error::Error for Error { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { @@ -376,6 +429,7 @@ impl error::Error for Error { Error::OctetStreamConfigError(e) => Some(e), Error::RedundantTransportNotFound => None, Error::ShortMessage => None, + Error::StreamLongMessage => None, Error::StreamIdleTimeout => None, Error::StreamReceiveError => None, Error::StreamReadError(e) => Some(e), diff --git a/src/net/client/stream.rs b/src/net/client/stream.rs index 672d9c0c3..d2186afef 100644 --- a/src/net/client/stream.rs +++ b/src/net/client/stream.rs @@ -17,16 +17,19 @@ // - request timeout // - create new connection after end/failure of previous one +use crate::base::message::Message; +use crate::base::message_builder::StreamTarget; use crate::base::opt::{AllOptData, OptRecord, TcpKeepalive}; -use crate::base::Message; use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, SendRequest, + ComposeRequest, Error, GetResponse, HandleRequest, SendRequest, }; use bytes; use bytes::{Bytes, BytesMut}; use core::cmp; use core::convert::From; +use futures_util::FutureExt; use octseq::Octets; +use slab::Slab; use std::boxed::Box; use std::fmt::Debug; use std::future::Future; @@ -136,52 +139,58 @@ impl Connection { } } -impl Connection { +impl Connection { + /// Sends a request and receives a response. + pub async fn request( + &self, + request: Req, + ) -> Result, Error> { + let receiver = self.send_request(request).await?; + receiver.await.map_err(|_| Error::StreamReceiveError)? + } +} + +impl Connection { /// Start a DNS request. /// /// This function takes a precomposed message as a parameter and /// returns a [ReqRepl] object wrapped in a [Result]. async fn request_impl( &self, - request_msg: &Req, + request: Req, ) -> Result, Error> { - let (tx, rx) = oneshot::channel(); - self.request(tx, request_msg.clone()).await?; - Ok(Box::new(ReqResp::new(request_msg, rx))) + let rx = self.send_request(request).await?; + Ok(Box::new(Query::new(rx))) } /// Start a DNS request but do not check if the reply matches the request. /// /// This function is similar to [Self::query]. Not checking if the reply /// match the request avoids having to keep the request around. - pub async fn query_no_check( + pub async fn start_request( &self, - query_msg: &Req, - ) -> Result { - let (tx, rx) = oneshot::channel(); - self.request(tx, query_msg.clone()).await?; - Ok(QueryNoCheck::new(rx)) + query_msg: Req, + ) -> Result { + let rx = self.send_request(query_msg).await?; + Ok(Query::new(rx)) } /// Sends a request. - async fn request( + async fn send_request( &self, - sender: oneshot::Sender, request_msg: Req, - ) -> Result<(), Error> { + ) -> Result, Error> { + let (sender, receiver) = oneshot::channel(); let req = ChanReq { sender, msg: request_msg, }; - match self.sender.send(req).await { - Err(_) => + self.sender.send(req).await.map_err(|_| { // Send error. The receiver is gone, this means that the // connection is closed. - { - Err(Error::ConnectionClosed) - } - Ok(_) => Ok(()), - } + Error::ConnectionClosed + })?; + Ok(receiver) } } @@ -198,51 +207,36 @@ impl SendRequest + '_, >, > { - return Box::pin(self.request_impl(request_msg)); + return Box::pin(self.request_impl(request_msg.clone())); } } -//------------ ReqResp -------------------------------------------------------- - -/// This struct represent an active DNS request. -#[derive(Debug)] -pub struct ReqResp { - /// Request message. - /// - /// The reply message is compared with the request message to see if - /// it matches the query. - request_msg: Message>, +impl HandleRequest for Connection { + type Response = Message; + type Error = Error; + type Fut<'s> = Pin> + Send + 's + >> where Self: 's; - /// Current state of the query. - state: QueryState, + fn handle_request(&self, request: Req) -> Self::Fut<'_> { + self.request(request).boxed() + } } -/// Status of a query. Used in [Query]. -#[derive(Debug)] -enum QueryState { - /// A request is in progress. - /// - /// The receiver for receiving the response is part of this state. - Busy(oneshot::Receiver), +//------------ Query ---------------------------------------------------------- - /// The response has been received and the query is done. - Done, +/// This struct represent an active DNS request. +#[derive(Debug)] +pub struct Query { + /// The receiver for the response. + receiver: oneshot::Receiver, } -impl ReqResp { +impl Query { /// Constructor for [Query], takes a DNS query and a receiver for the /// reply. - fn new( - request_msg: &Req, - receiver: oneshot::Receiver, - ) -> ReqResp { - let vec = request_msg.to_vec(); - let msg = Message::from_octets(vec) - .expect("Message failed to parse contents of another Message"); - Self { - request_msg: msg, - state: QueryState::Busy(receiver), - } + fn new(receiver: oneshot::Receiver) -> Query { + Self { receiver } } /// Get the result of a DNS request. @@ -252,38 +246,13 @@ impl ReqResp { pub async fn get_response_impl( &mut self, ) -> Result, Error> { - match self.state { - QueryState::Busy(ref mut receiver) => { - let res = receiver.await; - self.state = QueryState::Done; - if res.is_err() { - // Assume receive error - return Err(Error::StreamReceiveError); - } - let res = res.expect("already check error case"); - - // clippy seems to be wrong here. Replacing - // the following with 'res?;' doesn't work - #[allow(clippy::question_mark)] - if let Err(err) = res { - return Err(err); - } - - let msg = res.expect("error case is checked already"); - - if !is_answer_ignore_id(&msg, &self.request_msg) { - return Err(Error::WrongReplyForQuery); - } - Ok(msg) - } - QueryState::Done => { - panic!("Already done"); - } - } + (&mut self.receiver) + .await + .map_err(|_| Error::StreamReceiveError)? } } -impl GetResponse for ReqResp { +impl GetResponse for Query { fn get_response( &mut self, ) -> Pin< @@ -293,60 +262,7 @@ impl GetResponse for ReqResp { } } -//------------ QueryNoCheck --------------------------------------------------- - -/// This represents that state of an active DNS query if there is no need -/// to check that the reply matches the request. The assumption is that the -/// caller will do this check. -#[derive(Debug)] -pub struct QueryNoCheck { - /// Current state of the query. - state: QueryState, -} - -impl QueryNoCheck { - /// Constructor for [Query], takes a DNS query and a receiver for the - /// reply. - fn new(receiver: oneshot::Receiver) -> QueryNoCheck { - Self { - state: QueryState::Busy(receiver), - } - } - - /// Get the result of a DNS query. - /// - /// This function returns the reply to a DNS query wrapped in a - /// [Result]. - pub async fn get_result(&mut self) -> Result, Error> { - match self.state { - QueryState::Busy(ref mut receiver) => { - let res = receiver.await; - self.state = QueryState::Done; - if res.is_err() { - // Assume receive error - return Err(Error::StreamReceiveError); - } - let res = res.expect("error case is checked already"); - - // clippy seems to be wrong here. Replacing - // the following with 'res?;' doesn't work - #[allow(clippy::question_mark)] - if let Err(err) = res { - return Err(err); - } - - let msg = res.expect("error case is checked already"); - - Ok(msg) - } - QueryState::Done => { - panic!("Already done"); - } - } - } -} - -//------------ Transport ------------------------------------------------ +//------------ Transport ----------------------------------------------------- /// The underlying machinery of a stream transport. #[derive(Debug)] @@ -377,19 +293,6 @@ type ReplySender = oneshot::Sender; /// A message back to `Query` returning a response. type ChanResp = Result, Error>; -/// Internal datastructure of [Transport::run] to keep track of -/// outstanding DNS requests. -struct Queries { - /// The number of elements in [Queries::vec] that are not None. - count: usize, - - /// Index in the [Queries::vec] where to look for a space for a new query. - curr: usize, - - /// Vector of senders to forward a DNS reply message (or error) to. - vec: Vec>, -} - /// Internal datastructure of [Transport::run] to keep track of /// the status of the connection. // The types Status and ConnState are only used in Transport @@ -478,11 +381,7 @@ where idle_timeout: None, send_keepalive: true, }; - let mut query_vec = Queries { - count: 0, - curr: 0, - vec: Vec::new(), - }; + let mut query_vec = Slab::new(); let mut reqmsg: Option> = None; @@ -580,13 +479,12 @@ where &mut status); }; drop(opt_record); - Self::demux_reply(answer, - &mut status, &mut query_vec); + Self::demux_reply(answer, &mut status, &mut query_vec); } res = write_stream.write_all(msg), - if do_write => { + if do_write => { if let Err(error) = res { - let error = Error::StreamWriteError(Arc::new(error)); + let error = Error::StreamWriteError(Arc::new(error)); Self::error(error.clone(), &mut query_vec); status.state = ConnState::WriteError(error); @@ -703,15 +601,11 @@ where } /// Reports an error to all outstanding queries. - fn error(error: Error, query_vec: &mut Queries) { + fn error(error: Error, query_vec: &mut Slab>) { // Update all requests that are in progress. Don't wait for // any reply that may be on its way. - for index in 0..query_vec.vec.len() { - if query_vec.vec[index].is_some() { - let sender = Self::take_query(query_vec, index) - .expect("we tested is_none before"); - _ = sender.send(Err(error.clone())); - } + for item in query_vec.drain() { + _ = item.sender.send(Err(error.clone())); } } @@ -739,35 +633,28 @@ where fn demux_reply( answer: Message, status: &mut Status, - query_vec: &mut Queries, + query_vec: &mut Slab>, ) { // We got an answer, reset the timer status.state = ConnState::Active(Some(Instant::now())); - let ind16 = answer.header().id(); - let index: usize = ind16.into(); - - let vec_len = query_vec.vec.len(); - if index >= vec_len { - // Index is out of bouds. We should mark - // the connection as broken - return; - } - - // Do we have a query with this ID? - match &mut query_vec.vec[index] { + // Get the correct query and send it the reply. + let req = match query_vec.try_remove(answer.header().id().into()) { + Some(req) => req, None => { // No query with this ID. We should // mark the connection as broken return; } - Some(_) => { - let sender = Self::take_query(query_vec, index) - .expect("sender should be there"); - _ = sender.send(Ok(answer)); - } - } - if query_vec.count == 0 { + }; + let answer = if req.msg.is_answer(answer.for_slice()) { + Ok(answer) + } else { + Err(Error::WrongReplyForQuery) + }; + _ = req.sender.send(answer); + + if query_vec.is_empty() { // Clear the activity timer. There is no need to do // this because state will be set to either IdleTimeout // or Idle just below. However, it is nicer to keep @@ -791,10 +678,10 @@ where /// idle. Addend a edns-tcp-keepalive option if needed. // Note: maybe reqmsg should be a return value. fn insert_req( - mut req: ChanReq, + req: ChanReq, status: &mut Status, reqmsg: &mut Option>, - query_vec: &mut Queries, + query_vec: &mut Slab>, ) { match &status.state { ConnState::Active(timer) => { @@ -829,15 +716,12 @@ where // Note that insert may fail if there are too many // outstanding queires. First call insert before checking // send_keepalive. - let index = { - let res = Self::insert(req.sender, query_vec); - match res { - Err(_) => { - // insert sends an error reply, so we can just - // return here - return; - } - Ok(index) => index, + let (index, req) = match Self::insert(req, query_vec) { + Ok(index) => index, + Err(_) => { + // insert sends an error reply, so we can just + // return here + return; } }; @@ -856,24 +740,23 @@ where let hdr = req.msg.header_mut(); hdr.set_id(ind16); - if status.send_keepalive { - let res = add_tcp_keepalive(&mut req.msg); + if status.send_keepalive + && req.msg.add_opt(&TcpKeepalive::new(None)).is_ok() + { + status.send_keepalive = false; + } - if let Ok(()) = res { - status.send_keepalive = false; + match Self::convert_query(&req.msg) { + Ok(msg) => { + *reqmsg = Some(msg); + } + Err(err) => { + // Take the sender out again and return the error. + if let Some(req) = query_vec.try_remove(index) { + _ = req.sender.send(Err(err)); + } } } - Self::convert_query(&req.msg, reqmsg); - } - - /// Take an element out of query_vec. - fn take_query( - query_vec: &mut Queries, - index: usize, - ) -> Option { - let query = query_vec.vec[index].take(); - query_vec.count -= 1; - query } /// Handle a received edns-tcp-keepalive option. @@ -885,109 +768,31 @@ where } /// Convert the query message to a vector. - // This function should return the vector instead of storing it - // through a reference. - fn convert_query(msg: &Req, reqmsg: &mut Option>) { - // Ideally there should be a write_all_vectored. Until there is one, - // copy to a new Vec and prepend the length octets. - - let slice = msg.to_vec(); - let len = slice.len(); - - let mut vec = Vec::with_capacity(2 + len); - let len16 = len as u16; - vec.extend_from_slice(&len16.to_be_bytes()); - vec.extend_from_slice(&slice); - - *reqmsg = Some(vec); + fn convert_query(msg: &Req) -> Result, Error> { + let mut target = StreamTarget::new_vec(); + msg.append_message(&mut target) + .map_err(|_| Error::StreamLongMessage)?; + Ok(target.into_target()) } /// Insert a sender (for the reply) in the query_vec and return the index. fn insert( - sender: oneshot::Sender, - query_vec: &mut Queries, - ) -> Result { + req: ChanReq, + query_vec: &mut Slab>, + ) -> Result<(usize, &mut ChanReq), Error> { // Fail if there are to many entries already in this vector // We cannot have more than u16::MAX entries because the // index needs to fit in an u16. For efficiency we want to // keep the vector half empty. So we return a failure if // 2*count > u16::MAX - if 2 * query_vec.count > u16::MAX.into() { - // We own sender. So we need to send the error reply here + if 2 * query_vec.len() > u16::MAX.into() { + // We own the sender. So we need to send the error reply here let error = Error::StreamTooManyOutstandingQueries; - _ = sender.send(Err(error.clone())); + _ = req.sender.send(Err(error.clone())); return Err(error); } - let q = Some(sender); - - let vec_len = query_vec.vec.len(); - - // Append if the amount of empty space in the vector is less - // than half. But limit vec_len to u16::MAX - if vec_len < 2 * (query_vec.count + 1) && vec_len < u16::MAX.into() { - // Just append - query_vec.vec.push(q); - query_vec.count += 1; - let index = query_vec.vec.len() - 1; - return Ok(index); - } - let loc_curr = query_vec.curr; - - for index in loc_curr..vec_len { - if query_vec.vec[index].is_none() { - Self::insert_at(query_vec, index, q); - return Ok(index); - } - } - - // Nothing until the end of the vector. Try for the entire - // vector - for index in 0..vec_len { - if query_vec.vec[index].is_none() { - Self::insert_at(query_vec, index, q); - return Ok(index); - } - } - - // Still nothing, that is not good - panic!("insert failed"); - } - - /// Insert a sender at a specific position in query_vec and update - /// the statistics. - fn insert_at( - query_vec: &mut Queries, - index: usize, - q: Option, - ) { - query_vec.vec[index] = q; - query_vec.count += 1; - query_vec.curr = index + 1; - } -} - -//------------ Utility -------------------------------------------------------- - -/// Add an edns-tcp-keepalive option to a BaseMessageBuilder. -fn add_tcp_keepalive(msg: &mut CR) -> Result<(), Error> { - msg.add_opt(&TcpKeepalive::new(None))?; - Ok(()) -} - -/// Check if a DNS reply match the query. Ignore whether id fields match. -fn is_answer_ignore_id< - Octs1: Octets + AsRef<[u8]>, - Octs2: Octets + AsRef<[u8]>, ->( - reply: &Message, - query: &Message, -) -> bool { - if !reply.header().qr() - || reply.header_counts().qdcount() != query.header_counts().qdcount() - { - false - } else { - reply.question() == query.question() + let entry = query_vec.vacant_entry(); + Ok((entry.key(), entry.insert(req))) } } diff --git a/tests/net-client.rs b/tests/net-client.rs index 8b65a50fa..6c21b54ea 100644 --- a/tests/net-client.rs +++ b/tests/net-client.rs @@ -1,8 +1,8 @@ #![cfg(feature = "net")] mod net; -use crate::net::deckard::client::do_client; use crate::net::deckard::client::CurrStepValue; +use crate::net::deckard::client::{closure_do_client, do_client}; use crate::net::deckard::connect::Connect; use crate::net::deckard::connection::Connection; use crate::net::deckard::dgram::Dgram; @@ -146,3 +146,37 @@ fn tcp() { do_client(&deckard, tcp, &CurrStepValue::new()).await; }); } + +#[test] +#[ignore] +// Connect directly to the internet. Disabled by default. +fn tcp_async_fn() { + tokio_test::block_on(async { + let file = File::open(TEST_FILE).unwrap(); + let deckard = parse_file(file); + + let server_addr = + SocketAddr::new(IpAddr::from_str("9.9.9.9").unwrap(), 53); + + let tcp_conn = match TcpStream::connect(server_addr).await { + Ok(conn) => conn, + Err(err) => { + println!( + "TCP Connection to {server_addr} failed: {err}, exiting" + ); + return; + } + }; + + let (tcp, transport) = stream::Connection::new(tcp_conn); + tokio::spawn(async move { + transport.run().await; + println!("single TCP run terminated"); + }); + + closure_do_client(&deckard, &CurrStepValue::new(), |req| { + tcp.request(req) + }) + .await; + }) +} diff --git a/tests/net/deckard/client.rs b/tests/net/deckard/client.rs index 80a75d91a..17db8ef55 100644 --- a/tests/net/deckard/client.rs +++ b/tests/net/deckard/client.rs @@ -4,10 +4,43 @@ use crate::net::deckard::parse_query; use bytes::Bytes; use domain::base::{Message, MessageBuilder}; -use domain::net::client::request::RequestMessage; -use domain::net::client::request::SendRequest; +use domain::net::client::request::{Error, RequestMessage, SendRequest}; +use std::future::Future; use std::sync::Mutex; +pub async fn closure_do_client( + deckard: &Deckard, + step_value: &CurrStepValue, + request: F, +) where + F: Fn(RequestMessage>) -> Fut, + Fut: Future, Error>>, +{ + let mut resp: Option> = None; + + // Assume steps are in order. Maybe we need to define that. + for step in &deckard.scenario.steps { + step_value.set(step.step_value); + match step.step_type { + StepType::Query => { + let reqmsg = entry2reqmsg(step.entry.as_ref().unwrap()); + resp = Some(request(reqmsg).await.unwrap()); + } + StepType::CheckAnswer => { + let answer = resp.take().unwrap(); + if !match_msg(step.entry.as_ref().unwrap(), &answer, true) { + panic!("reply failed"); + } + } + StepType::TimePasses + | StepType::Traffic + | StepType::CheckTempfile + | StepType::Assign => todo!(), + } + } + println!("Done"); +} + pub async fn do_client>>>( deckard: &Deckard, request: R, From db618389f0e09e364e76c26bb1c6a9c13b63e9a3 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Wed, 3 Jan 2024 14:56:06 +0100 Subject: [PATCH 110/124] We need Slab 0.4.9 not 0.4.0. --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index c908f55fa..0a43e6175 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ heapless = { version = "0.7", optional = true } ring = { version = "0.17", optional = true } serde = { version = "1.0.130", optional = true, features = ["derive"] } siphasher = { version = "1", optional = true } -slab = { version = "0.4.0", optional = true } +slab = { version = "0.4.9", optional = true } smallvec = { version = "1", optional = true } tokio = { version = "1.33", optional = true, features = ["io-util", "macros", "net", "time", "sync", "rt-multi-thread" ] } tokio-rustls = { version = "0.24", optional = true, features = [] } From 684564fbbc9915cb8490f1efa4db887f89564fb6 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Wed, 3 Jan 2024 18:41:35 +0100 Subject: [PATCH 111/124] Bring back our own slab. --- Cargo.toml | 3 +- src/net/client/stream.rs | 205 +++++++++++++++++++++++++++++++++------ 2 files changed, 175 insertions(+), 33 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0a43e6175..5508ea6c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,6 @@ heapless = { version = "0.7", optional = true } ring = { version = "0.17", optional = true } serde = { version = "1.0.130", optional = true, features = ["derive"] } siphasher = { version = "1", optional = true } -slab = { version = "0.4.9", optional = true } smallvec = { version = "1", optional = true } tokio = { version = "1.33", optional = true, features = ["io-util", "macros", "net", "time", "sync", "rt-multi-thread" ] } tokio-rustls = { version = "0.24", optional = true, features = [] } @@ -53,7 +52,7 @@ serde = ["dep:serde", "octseq/serde"] sign = ["std"] smallvec = ["dep:smallvec", "octseq/smallvec"] std = [] -net = ["bytes", "futures-util", "slab", "std", "tokio", "tokio-rustls"] +net = ["bytes", "futures-util", "std", "tokio", "tokio-rustls"] tsig = ["bytes", "ring", "smallvec"] validate = ["std", "ring"] zonefile = ["bytes", "std"] diff --git a/src/net/client/stream.rs b/src/net/client/stream.rs index d2186afef..add2e402c 100644 --- a/src/net/client/stream.rs +++ b/src/net/client/stream.rs @@ -29,7 +29,6 @@ use core::cmp; use core::convert::From; use futures_util::FutureExt; use octseq::Octets; -use slab::Slab; use std::boxed::Box; use std::fmt::Debug; use std::future::Future; @@ -381,7 +380,7 @@ where idle_timeout: None, send_keepalive: true, }; - let mut query_vec = Slab::new(); + let mut query_vec = Queries::new(); let mut reqmsg: Option> = None; @@ -459,10 +458,8 @@ where // error. panic!("reader terminated"), Err(error) => { - Self::error(error.clone(), - &mut query_vec); - status.state = - ConnState::ReadError(error); + Self::error(error.clone(), &mut query_vec); + status.state = ConnState::ReadError(error); // Reader failed. Break // out of loop and // shut down @@ -601,7 +598,7 @@ where } /// Reports an error to all outstanding queries. - fn error(error: Error, query_vec: &mut Slab>) { + fn error(error: Error, query_vec: &mut Queries>) { // Update all requests that are in progress. Don't wait for // any reply that may be on its way. for item in query_vec.drain() { @@ -633,13 +630,13 @@ where fn demux_reply( answer: Message, status: &mut Status, - query_vec: &mut Slab>, + query_vec: &mut Queries>, ) { // We got an answer, reset the timer status.state = ConnState::Active(Some(Instant::now())); // Get the correct query and send it the reply. - let req = match query_vec.try_remove(answer.header().id().into()) { + let req = match query_vec.try_remove(answer.header().id()) { Some(req) => req, None => { // No query with this ID. We should @@ -681,7 +678,7 @@ where req: ChanReq, status: &mut Status, reqmsg: &mut Option>, - query_vec: &mut Slab>, + query_vec: &mut Queries>, ) { match &status.state { ConnState::Active(timer) => { @@ -716,19 +713,17 @@ where // Note that insert may fail if there are too many // outstanding queires. First call insert before checking // send_keepalive. - let (index, req) = match Self::insert(req, query_vec) { - Ok(index) => index, - Err(_) => { - // insert sends an error reply, so we can just - // return here + let (index, req) = match query_vec.insert(req) { + Ok(res) => res, + Err(req) => { + // Send an appropriate error and return. + _ = req.sender.send( + Err(Error::StreamTooManyOutstandingQueries) + ); return; } }; - let ind16: u16 = index - .try_into() - .expect("insert should return a value that fits in u16"); - // We set the ID to the array index. Defense in depth // suggests that a random ID is better because it works // even if TCP sequence numbers could be predicted. However, @@ -738,7 +733,7 @@ where // resilient against forgery by third parties." let hdr = req.msg.header_mut(); - hdr.set_id(ind16); + hdr.set_id(index); if status.send_keepalive && req.msg.add_opt(&TcpKeepalive::new(None)).is_ok() @@ -774,25 +769,173 @@ where .map_err(|_| Error::StreamLongMessage)?; Ok(target.into_target()) } +} + +//------------ Queries ------------------------------------------------------- + +/// Mapping outstanding queries to their ID. +/// +/// This is generic over anything rather than our concrete request time for +/// easier testing. +#[derive(Clone, Debug)] +struct Queries { + /// The number of elements in `vec` that are not None. + count: usize, + + /// Index in `vec? where to look for a space for a new query. + curr: usize, + + /// Vector of senders to forward a DNS reply message (or error) to. + vec: Vec>, +} + +impl Queries { + /// Creates a new empty value. + fn new() -> Self { + Self { + count: 0, + curr: 0, + vec: Vec::new(), + } + } - /// Insert a sender (for the reply) in the query_vec and return the index. + /// Returns whether there are no more outstanding queries. + fn is_empty(&self) -> bool { + self.count == 0 + } + + /// Inserts the given query. + /// + /// Upon success, returns the index and a mutable reference to the stored + /// query. + /// + /// Upon error, which means the set is full, returns the query. fn insert( - req: ChanReq, - query_vec: &mut Slab>, - ) -> Result<(usize, &mut ChanReq), Error> { + &mut self, req: T + ) -> Result<(u16, &mut T), T> { // Fail if there are to many entries already in this vector // We cannot have more than u16::MAX entries because the // index needs to fit in an u16. For efficiency we want to // keep the vector half empty. So we return a failure if // 2*count > u16::MAX - if 2 * query_vec.len() > u16::MAX.into() { - // We own the sender. So we need to send the error reply here - let error = Error::StreamTooManyOutstandingQueries; - _ = req.sender.send(Err(error.clone())); - return Err(error); + if 2 * self.count > u16::MAX.into() { + return Err(req) } - let entry = query_vec.vacant_entry(); - Ok((entry.key(), entry.insert(req))) + // If more than half the vec is empty, we try and find the index of + // an empty slot. + let idx = if self.vec.len() >= 2 * self.count { + let mut found = None; + for idx in self.curr .. self.vec.len() { + if self.vec[idx].is_none() { + found = Some(idx); + break; + } + } + found + } + else { + None + }; + + // If we have an index, we can insert there, otherwise we need to + // append. + let idx = match idx { + Some(idx) => { + self.vec[idx] = Some(req); + idx + } + None => { + let idx = self.vec.len(); + self.vec.push(Some(req)); + idx + } + }; + + self.count += 1; + if idx == self.curr { + self.curr += 1; + } + let req = self.vec[idx].as_mut().expect("no inserted item?"); + let idx = u16::try_from(idx).expect("query vec too large"); + Ok((idx, req)) + } + + /// Tries to remove and return the query at the given index. + /// + /// Returns `None` if there was no query there. + fn try_remove(&mut self, index: u16) -> Option { + let res = self.vec.get_mut(usize::from(index))?.take()?; + self.count = self.count.saturating_sub(1); + self.curr = cmp::min(self.curr, index.into()); + Some(res) + } + + /// Removes all queries and returns an iterator over them. + fn drain(&mut self) -> impl Iterator + '_ { + let res = self.vec.drain(..).flatten(); // Skips all the `None`s. + self.count = 0; + self.curr = 0; + res + } +} + +//============ Tests ========================================================= + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[allow(clippy::needless_range_loop)] + fn queries_insert_remove() { + // Insert items, remove a few, insert a few more. Check that + // everything looks right. + let mut idxs = [None; 20]; + let mut queries = Queries::new(); + + for i in 0..12 { + let (idx, item) = queries.insert(i).unwrap(); + idxs[i] = Some(idx); + assert_eq!(i, *item); + } + assert_eq!(queries.count, 12); + assert_eq!(queries.vec.iter().flatten().count(), 12); + + for i in [1, 2, 3, 4, 7, 9] { + let item = queries.try_remove(idxs[i].unwrap()).unwrap(); + assert_eq!(i, item); + idxs[i] = None; + } + assert_eq!(queries.count, 6); + assert_eq!(queries.vec.iter().flatten().count(), 6); + + for i in 12..20 { + let (idx, item) = queries.insert(i).unwrap(); + idxs[i] = Some(idx); + assert_eq!(i, *item); + } + assert_eq!(queries.count, 14); + assert_eq!(queries.vec.iter().flatten().count(), 14); + + for i in 0..20 { + if let Some(idx) = idxs[i] { + let item = queries.try_remove(idx).unwrap(); + assert_eq!(i, item); + } + } + assert_eq!(queries.count, 0); + assert_eq!(queries.vec.iter().flatten().count(), 0); + } + + #[test] + fn queries_overrun() { + // This is just a quick check that inserting to much stuff doesn’t + // break. + let mut queries = Queries::new(); + for i in 0..usize::from(u16::MAX) * 2 { + let _ = queries.insert(i); + } } } + From 9ba57d3b7dc11ebbad87dc017c8f44559da369a7 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Wed, 3 Jan 2024 18:41:55 +0100 Subject: [PATCH 112/124] Format. --- src/net/client/stream.rs | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/net/client/stream.rs b/src/net/client/stream.rs index add2e402c..a8ce3041c 100644 --- a/src/net/client/stream.rs +++ b/src/net/client/stream.rs @@ -717,9 +717,9 @@ where Ok(res) => res, Err(req) => { // Send an appropriate error and return. - _ = req.sender.send( - Err(Error::StreamTooManyOutstandingQueries) - ); + _ = req + .sender + .send(Err(Error::StreamTooManyOutstandingQueries)); return; } }; @@ -810,31 +810,28 @@ impl Queries { /// query. /// /// Upon error, which means the set is full, returns the query. - fn insert( - &mut self, req: T - ) -> Result<(u16, &mut T), T> { + fn insert(&mut self, req: T) -> Result<(u16, &mut T), T> { // Fail if there are to many entries already in this vector // We cannot have more than u16::MAX entries because the // index needs to fit in an u16. For efficiency we want to // keep the vector half empty. So we return a failure if // 2*count > u16::MAX if 2 * self.count > u16::MAX.into() { - return Err(req) + return Err(req); } // If more than half the vec is empty, we try and find the index of // an empty slot. let idx = if self.vec.len() >= 2 * self.count { let mut found = None; - for idx in self.curr .. self.vec.len() { + for idx in self.curr..self.vec.len() { if self.vec[idx].is_none() { found = Some(idx); break; } } found - } - else { + } else { None }; @@ -856,7 +853,7 @@ impl Queries { if idx == self.curr { self.curr += 1; } - let req = self.vec[idx].as_mut().expect("no inserted item?"); + let req = self.vec[idx].as_mut().expect("no inserted item?"); let idx = u16::try_from(idx).expect("query vec too large"); Ok((idx, req)) } @@ -938,4 +935,3 @@ mod test { } } } - From 2a3c61573c960499af28da483894924f9449ebdf Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Thu, 4 Jan 2024 15:06:16 +0100 Subject: [PATCH 113/124] Simplify all the other basic transports as well. --- examples/client-transports.rs | 7 +- src/net/client/dgram.rs | 326 ++++++++++++--------------------- src/net/client/dgram_stream.rs | 79 ++++++-- src/net/client/multi_stream.rs | 227 ++++++++++------------- src/net/client/request.rs | 27 ++- src/net/client/stream.rs | 2 +- tests/net-client.rs | 2 +- 7 files changed, 303 insertions(+), 367 deletions(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 793d9c25a..a7d97b6fe 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -195,11 +195,8 @@ async fn main() { let udp_connect = UdpConnect::new(server_addr); let dgram_conn = dgram::Connection::new(Some(dgram_config), udp_connect); - // Send a query message. - let mut request = dgram_conn.send_request(&req).await.unwrap(); - - // Get the reply - let reply = request.get_response().await; + // Send a query message and get the reply. + let reply = dgram_conn.request(req.clone()).await.unwrap(); println!("Dgram reply: {:?}", reply); // Create a single TCP transport connection. This is usefull for a diff --git a/src/net/client/dgram.rs b/src/net/client/dgram.rs index 399c52924..27a725b41 100644 --- a/src/net/client/dgram.rs +++ b/src/net/client/dgram.rs @@ -6,24 +6,22 @@ // To do: // - cookies -use crate::base::iana::Rcode; use crate::base::Message; use crate::net::client::protocol::{ AsyncConnect, AsyncDgramRecv, AsyncDgramRecvEx, AsyncDgramSend, AsyncDgramSendEx, }; use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, SendRequest, + ComposeRequest, Error, GetResponse, HandleRequest, SendRequest, }; use bytes::Bytes; -use core::cmp; -use octseq::Octets; +use core::{cmp, fmt}; +use futures_util::FutureExt; use std::boxed::Box; -use std::fmt::{Debug, Formatter}; use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio::sync::{Semaphore, SemaphorePermit}; use tokio::time::{timeout, Duration, Instant}; //------------ Configuration Constants ---------------------------------------- @@ -138,114 +136,39 @@ impl Default for Config { //------------ Connection ----------------------------------------------------- /// A datagram transport connection. -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Connection { - /// Reference to the actual connection object. - inner: Arc>, + /// User configuration variables. + config: Config, + + /// Connections to datagram sockets. + connect: S, + + /// Semaphore to limit access to UDP sockets. + semaphore: Semaphore, } -impl< - S: AsyncConnect + Clone + Send + Sync + 'static, - C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, - > Connection -{ +impl Connection { /// Create a new datagram transport connection. pub fn new(config: Option, connect: S) -> Connection { - let connection = - InnerConnection::new(config.unwrap_or_default(), connect); + let config = config.unwrap_or_default(); Self { - inner: Arc::new(connection), + semaphore: Semaphore::new(config.max_parallel), + config, + connect, } } - - /// Start a new DNS request. - async fn request_impl< - CR: ComposeRequest + Clone + Send + Sync + 'static, - >( - &self, - request_msg: &CR, - ) -> Result, Error> { - let gr = self.inner.request(request_msg, self.clone()).await?; - Ok(Box::new(gr)) - } - - /// Get a permit from the semaphore to start using a socket. - async fn get_permit(&self) -> OwnedSemaphorePermit { - self.inner.get_permit().await - } } -impl< - S: AsyncConnect + Clone + Send + Sync + 'static, - C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, - CR: ComposeRequest + Clone + Send + Sync + 'static, - > SendRequest for Connection +impl Connection +where + S: AsyncConnect, + S::Connection: AsyncDgramRecv + AsyncDgramSend + Unpin, { - fn send_request<'a>( - &'a self, - request_msg: &'a CR, - ) -> Pin< - Box< - dyn Future, Error>> - + Send - + '_, - >, - > { - return Box::pin(self.request_impl(request_msg)); - } -} - -//------------ ReqResp -------------------------------------------------------- - -/// The state of a DNS request. -pub struct ReqResp { - /// Future that does the actual work of GetResponse. - get_response_fut: - Pin, Error>> + Send>>, -} - -impl ReqResp { - /// Create new ReqResp object. - fn new< - S: AsyncConnect + Clone + Send + Sync + 'static, - C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, - CR: ComposeRequest + Clone + Send + Sync + 'static, - >( - config: Config, - request_msg: &CR, - conn: Connection, - udp_payload_size: Option, - connect: S, - ) -> Self { - Self { - get_response_fut: Box::pin(Self::get_response_impl2( - config, - request_msg.clone(), - conn, - udp_payload_size, - connect, - )), - } - } - - /// Async function that waits for the future stored in Query to complete. - async fn get_response_impl(&mut self) -> Result, Error> { - (&mut self.get_response_fut).await - } - - /// Get the response of a DNS request. - /// - /// This function is not cancel safe. - async fn get_response_impl2< - S: AsyncConnect + Clone + Send + Sync + 'static, - C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, - CR: ComposeRequest, - >( - config: Config, - mut request_bmb: CR, - conn: Connection, - udp_payload_size: Option, - connect: S, + /// Sends a request and receives a response. + pub async fn request( + &self, + mut request: Req, ) -> Result, Error> { let recv_size = 2000; // Should be configurable. @@ -253,22 +176,23 @@ impl ReqResp { // We need to get past the semaphore that limits the // number of concurrent sockets we can use. - let _permit = conn.get_permit().await; + let _permit = self.get_permit().await; loop { - let mut sock = connect + let mut sock = self + .connect .connect() .await .map_err(|e| Error::UdpConnect(Arc::new(e)))?; // Set random ID in header - let header = request_bmb.header_mut(); - header.set_random_id(); + request.header_mut().set_random_id(); + // Set UDP payload size - if let Some(size) = udp_payload_size { - request_bmb.set_udp_payload_size(size) + if let Some(size) = self.config.udp_payload_size { + request.set_udp_payload_size(size) } - let request_msg = request_bmb.to_message(); + let request_msg = request.to_message(); let dgram = request_msg.as_slice(); let sent = sock @@ -283,18 +207,18 @@ impl ReqResp { loop { let elapsed = start.elapsed(); - if elapsed > config.read_timeout { + if elapsed > self.config.read_timeout { // Break out of the receive loop and continue in the // transmit loop. break; } - let remain = config.read_timeout - elapsed; + let remain = self.config.read_timeout - elapsed; let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here. let timeout_res = timeout(remain, sock.recv(&mut buf)).await; if timeout_res.is_err() { retries += 1; - if retries < config.max_retries { + if retries < self.config.max_retries { // Break out of the receive loop and continue in the // transmit loop. break; @@ -314,131 +238,121 @@ impl ReqResp { Err(_) => continue, }; - if !is_answer(&answer, &request_msg) { + if !request.is_answer(answer.for_slice()) { // Wrong answer, go back to receiving continue; } return Ok(answer); } retries += 1; - if retries < config.max_retries { + if retries < self.config.max_retries { continue; } break; } Err(Error::UdpTimeoutNoResponse) } -} -impl Debug for ReqResp { - fn fmt(&self, _: &mut Formatter<'_>) -> Result<(), core::fmt::Error> { - todo!() + /// Return a permit for a our semaphore. + async fn get_permit(&self) -> SemaphorePermit { + self.semaphore + .acquire() + .await + .expect("the semaphore has not been closed") } } -impl GetResponse for ReqResp { - fn get_response( - &mut self, - ) -> Pin< - Box, Error>> + Send + '_>, - > { - Box::pin(self.get_response_impl()) +impl Connection +where + S: AsyncConnect + Clone + Send + Sync + 'static, + S::Connection: + AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, +{ + /// Start a new DNS request. + async fn request_impl< + CR: ComposeRequest + Clone + Send + Sync + 'static, + >( + self: Arc, + request_msg: CR, + ) -> Result, Error> { + Ok(Box::new(Query { + get_response_fut: async move { self.request(request_msg).await } + .boxed(), + })) } } -//------------ InnerConnection ------------------------------------------------ +//--- SendRequest and HandleRequest -/// Actual implementation of the datagram transport connection. -#[derive(Debug)] -struct InnerConnection { - /// User configuration variables. - config: Config, - - /// Connections to datagram sockets. - connect: S, - - /// Semaphore to limit access to UDP sockets. - semaphore: Arc, +impl SendRequest for Arc> +where + S: AsyncConnect + Clone + Send + Sync + 'static, + S::Connection: + AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, + Req: ComposeRequest + Clone + Send + Sync + 'static, +{ + fn send_request<'a>( + &'a self, + request_msg: &'a Req, + ) -> Pin< + Box< + dyn Future, Error>> + + Send + + '_, + >, + > { + let this = self.clone(); + Box::pin(this.request_impl(request_msg.clone())) + } } -impl< - S: AsyncConnect + Clone + Send + Sync + 'static, - C: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, - > InnerConnection +impl HandleRequest for Connection +where + S: AsyncConnect + Clone + Send + Sync + 'static, + S::Connection: + AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, + Req: ComposeRequest + Clone + Send + Sync + 'static, { - /// Create new InnerConnection object. - fn new(config: Config, connect: S) -> InnerConnection { - let max_parallel = config.max_parallel; - Self { - config, - connect, - semaphore: Arc::new(Semaphore::new(max_parallel)), - } + type Response = Message; + type Error = Error; + type Fut<'s> = Pin> + Send + 's + >> where Self: 's; + + fn handle_request(&self, request: Req) -> Self::Fut<'_> { + self.request(request).boxed() } +} - /// Return a Query object that contains the query state. - async fn request( - &self, - request_msg: &CR, - conn: Connection, - ) -> Result { - Ok(ReqResp::new( - self.config.clone(), - request_msg, - conn, - self.config.udp_payload_size, - self.connect.clone(), - )) - } +//------------ Query -------------------------------------------------------- - /// Return a permit for a our semaphore. - async fn get_permit(&self) -> OwnedSemaphorePermit { - self.semaphore - .clone() - .acquire_owned() - .await - .expect("the semaphore has not been closed") - } +/// The state of a DNS request. +pub struct Query { + /// Future that does the actual work of GetResponse. + get_response_fut: + Pin, Error>> + Send>>, } -//------------ Utility -------------------------------------------------------- - -/// Check if a message is a valid reply for a query. Allow the question section -/// to be empty if there is an error or if the reply is truncated. -fn is_answer< - QueryOcts: AsRef<[u8]> + Octets, - ReplyOcts: AsRef<[u8]> + Octets, ->( - reply: &Message, - query: &Message, -) -> bool { - let reply_header = reply.header(); - let reply_hcounts = reply.header_counts(); - - // First check qr and id - if !reply_header.qr() || reply_header.id() != query.header().id() { - return false; +impl Query { + /// Async function that waits for the future stored in Query to complete. + async fn get_response_impl(&mut self) -> Result, Error> { + (&mut self.get_response_fut).await } +} - // If either tc is set or the result is an error, then the question - // section can be empty. In that case we require all other sections - // to be empty as well. - if (reply_header.tc() || reply_header.rcode() != Rcode::NoError) - && reply_hcounts.qdcount() == 0 - && reply_hcounts.ancount() == 0 - && reply_hcounts.nscount() == 0 - && reply_hcounts.arcount() == 0 - { - // We can accept this as a valid reply. - return true; +impl fmt::Debug for Query { + fn fmt(&self, _: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + todo!() } +} - // Remaining checks. The question section in the reply has to be the - // same as in the query. - if reply_hcounts.qdcount() != query.header_counts().qdcount() { - false - } else { - reply.question() == query.question() +impl GetResponse for Query { + fn get_response( + &mut self, + ) -> Pin< + Box, Error>> + Send + '_>, + > { + Box::pin(self.get_response_impl()) } } diff --git a/src/net/client/dgram_stream.rs b/src/net/client/dgram_stream.rs index df59111a6..b917504db 100644 --- a/src/net/client/dgram_stream.rs +++ b/src/net/client/dgram_stream.rs @@ -13,13 +13,15 @@ use crate::net::client::protocol::{ AsyncConnect, AsyncDgramRecv, AsyncDgramSend, }; use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, SendRequest, + ComposeRequest, Error, GetResponse, HandleRequest, SendRequest, }; use bytes::Bytes; +use futures_util::FutureExt; use std::boxed::Box; use std::fmt::Debug; use std::future::Future; use std::pin::Pin; +use std::sync::Arc; //------------ Config --------------------------------------------------------- @@ -88,7 +90,7 @@ impl Config { #[derive(Clone)] pub struct Connection { /// The UDP transport connection. - udp_conn: dgram::Connection, + udp_conn: Arc>, /// The TCP transport connection. tcp_conn: multi_stream::Connection, @@ -115,7 +117,7 @@ where config: Config, ) -> (Self, multi_stream::Transport) { let udp_conn = - dgram::Connection::new(Some(config.dgram), dgram_remote); + dgram::Connection::new(Some(config.dgram), dgram_remote).into(); let (tcp_conn, transport) = multi_stream::Connection::with_config( stream_remote, config.multi_stream, @@ -124,6 +126,25 @@ where } } +impl Connection +where + DgramS: AsyncConnect, + DgramS::Connection: AsyncDgramRecv + AsyncDgramSend + Unpin, + Req: ComposeRequest + Clone, +{ + /// Sends a request and receives a response. + pub async fn request( + &self, + request: Req, + ) -> Result, Error> { + let response = self.udp_conn.request(request.clone()).await?; + if !response.header().tc() { + return Ok(response); + } + self.tcp_conn.request(request).await + } +} + impl Connection where DgramS: AsyncConnect + Clone + Debug + Send + Sync + 'static, @@ -135,7 +156,7 @@ where &self, request_msg: &Req, ) -> Result, Error> { - Ok(Box::new(ReqResp::new( + Ok(Box::new(Query::new( request_msg, self.udp_conn.clone(), self.tcp_conn.clone(), @@ -143,6 +164,8 @@ where } } +//--- SendRequest and HandleRequest + impl SendRequest for Connection where DgramS: AsyncConnect + Clone + Debug + Send + Sync + 'static, @@ -163,16 +186,33 @@ where } } -//------------ ReqResp -------------------------------------------------------- +impl HandleRequest for Connection +where + DgramS: AsyncConnect + Clone + Debug + Send + Sync + 'static, + DgramS::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, + Req: ComposeRequest + Clone + 'static, +{ + type Response = Message; + type Error = Error; + type Fut<'s> = Pin> + Send + 's + >> where Self: 's; + + fn handle_request(&self, request: Req) -> Self::Fut<'_> { + self.request(request).boxed() + } +} + +//------------ Query -------------------------------------------------------- /// Object that contains the current state of a query. #[derive(Debug)] -pub struct ReqResp { +pub struct Query { /// Reqeust message. request_msg: Req, /// UDP transport to be used. - udp_conn: dgram::Connection, + udp_conn: Arc>, /// TCP transport to be used. tcp_conn: multi_stream::Connection, @@ -197,19 +237,19 @@ enum QueryState { GetTcpResponse(Box), } -impl< - S: AsyncConnect + Clone + Send + Sync + 'static, - Reg: ComposeRequest + Clone + 'static, - > ReqResp +impl Query +where + S: AsyncConnect + Clone + Send + Sync + 'static, + Req: ComposeRequest + Clone + 'static, { - /// Create a new ReqResp object. + /// Create a new Query object. /// /// The initial state is to start with a UDP transport. fn new( - request_msg: &Reg, - udp_conn: dgram::Connection, - tcp_conn: multi_stream::Connection, - ) -> ReqResp { + request_msg: &Req, + udp_conn: Arc>, + tcp_conn: multi_stream::Connection, + ) -> Query { Self { request_msg: request_msg.clone(), udp_conn, @@ -256,12 +296,11 @@ impl< } } -impl< - S: AsyncConnect + Clone + Debug + Send + Sync + 'static, - Reg: ComposeRequest + Clone + Debug + 'static, - > GetResponse for ReqResp +impl GetResponse for Query where + S: AsyncConnect + Clone + Debug + Send + Sync + 'static, S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, + Req: ComposeRequest + Clone + 'static, { fn get_response( &mut self, diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index 94881d670..e55ec301a 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -6,17 +6,15 @@ // To do: // - too many connection errors -use crate::base::iana::Rcode; use crate::base::Message; use crate::net::client::protocol::AsyncConnect; use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, SendRequest, + ComposeRequest, Error, GetResponse, HandleRequest, SendRequest, }; use crate::net::client::stream; use bytes::Bytes; use futures_util::stream::FuturesUnordered; -use futures_util::StreamExt; -use octseq::Octets; +use futures_util::{FutureExt, StreamExt}; use rand::random; use std::boxed::Box; use std::fmt::Debug; @@ -67,10 +65,10 @@ impl From for Config { //------------ Connection ----------------------------------------------------- /// A connection to a multi-stream transport. -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Connection { /// The sender half of the connection request channel. - sender: Arc>>, + sender: mpsc::Sender>, } impl Connection { @@ -85,26 +83,30 @@ impl Connection { config: Config, ) -> (Self, Transport) { let (sender, transport) = Transport::new(remote, config); - ( - Self { - sender: sender.into(), - }, - transport, - ) + (Self { sender }, transport) } } -impl Connection { +impl Connection { + /// Sends a request and receives a response. + pub async fn request( + &self, + request: Req, + ) -> Result, Error> { + Query::new(self.clone(), request).get_response().await + } + /// Starts a request. /// /// This is the future that is returned by the `SendRequest` impl. async fn _send_request( &self, request: &Req, - ) -> Result, Error> { - let (tx, rx) = oneshot::channel(); - self.new_conn(None, tx).await?; - let gr = Query::new(self.clone(), request.clone(), rx); + ) -> Result, Error> + where + Req: 'static, + { + let gr = Query::new(self.clone(), request.clone()); Ok(Box::new(gr)) } @@ -112,20 +114,16 @@ impl Connection { async fn new_conn( &self, opt_id: Option, - sender: oneshot::Sender>, - ) -> Result<(), Error> { + ) -> Result>, Error> { + let (sender, receiver) = oneshot::channel(); let req = ChanReq { cmd: ReqCmd::NewConn(opt_id, sender), }; - match self.sender.send(req).await { - Err(_) => - // Send error. The receiver is gone, this means that the - // connection is closed. - { - Err(Error::ConnectionClosed) - } - Ok(_) => Ok(()), - } + self.sender + .send(req) + .await + .map_err(|_| Error::ConnectionClosed)?; + Ok(receiver) } /// Request a shutdown. @@ -145,7 +143,17 @@ impl Connection { } } -//--- SendRequest +//--- Clone + +impl Clone for Connection { + fn clone(&self) -> Self { + Self { + sender: self.sender.clone(), + } + } +} + +//--- SendRequest and HandleRequest impl SendRequest for Connection where @@ -165,6 +173,21 @@ where } } +impl HandleRequest for Connection +where + Req: ComposeRequest + Clone + Send, +{ + type Response = Message; + type Error = Error; + type Fut<'s> = Pin> + Send + 's + >> where Self: 's; + + fn handle_request(&self, request: Req) -> Self::Fut<'_> { + self.request(request).boxed() + } +} + //------------ Query -------------------------------------------------------- /// The connection side of an active request. @@ -181,8 +204,8 @@ struct Query { /// The underlying transport. conn: Connection, - /// The id of the most recent connection. - conn_id: u64, + /// The id of the most recent connection, if any. + conn_id: Option, /// Number of retries with delay. delayed_retry_count: u64, @@ -191,8 +214,11 @@ struct Query { /// The states of the query state machine. #[derive(Debug)] enum QueryState { + /// Request a new connection. + RequestConn, + /// Receive a new connection from the receiver. - GetConn(oneshot::Receiver>), + ReceiveConn(oneshot::Receiver>), /// Start a query using the given stream transport. StartQuery(Arc>), @@ -225,32 +251,36 @@ struct ChanRespOk { impl Query { /// Creates a new query. - fn new( - conn: Connection, - request_msg: Req, - receiver: oneshot::Receiver>, - ) -> Self { + fn new(conn: Connection, request_msg: Req) -> Self { Self { conn, request_msg, - state: QueryState::GetConn(receiver), - conn_id: 0, + state: QueryState::RequestConn, + conn_id: None, delayed_retry_count: 0, } } } -impl Query { +impl Query { /// Get the result of a DNS request. /// - /// This function returns the reply to a DNS request wrapped in a - /// [Result]. - pub async fn get_response_impl( - &mut self, - ) -> Result, Error> { + /// This function is cancellation safe. If its future is dropped before + /// it is resolved, you can call it again to get a new future. + pub async fn get_response(&mut self) -> Result, Error> { loop { match self.state { - QueryState::GetConn(ref mut receiver) => { + QueryState::RequestConn => { + let rx = match self.conn.new_conn(self.conn_id).await { + Ok(rx) => rx, + Err(err) => { + self.state = QueryState::Done; + return Err(err); + } + }; + self.state = QueryState::ReceiveConn(rx); + } + QueryState::ReceiveConn(ref mut receiver) => { let res = match receiver.await { Ok(res) => res, Err(_) => { @@ -274,7 +304,7 @@ impl Query { let id = ok_res.id; let conn = ok_res.conn; - self.conn_id = id; + self.conn_id = Some(id); self.state = QueryState::StartQuery(conn); continue; } @@ -286,14 +316,7 @@ impl Query { match query_res { Err(err) => { if let Error::ConnectionClosed = err { - let (tx, rx) = oneshot::channel(); - let res = - self.new_conn(self.conn_id, tx).await; - if let Err(err) = res { - self.state = QueryState::Done; - return Err(err); - } - self.state = QueryState::GetConn(rx); + self.state = QueryState::RequestConn; continue; } return Err(err); @@ -305,34 +328,28 @@ impl Query { } } QueryState::GetResult(ref mut query) => { - let reply = query.get_response().await; - - if reply.is_err() { - self.delayed_retry_count += 1; - let retry_time = retry_time(self.delayed_retry_count); - self.state = - QueryState::Delay(Instant::now(), retry_time); - continue; - } - - let msg = reply.expect("error is checked before"); - let request_msg = self.request_msg.to_message(); - - if !is_answer_ignore_id(&msg, &request_msg) { - return Err(Error::WrongReplyForQuery); + match query.get_response().await { + Ok(reply) => return Ok(reply), + // XXX This replicates the previous behavior. But + // maybe we should have a whole category of + // fatal errors where retrying doesn’t make any + // sense? + Err(Error::WrongReplyForQuery) => { + return Err(Error::WrongReplyForQuery) + } + Err(_) => { + self.delayed_retry_count += 1; + let retry_time = + retry_time(self.delayed_retry_count); + self.state = + QueryState::Delay(Instant::now(), retry_time); + continue; + } } - return Ok(msg); } QueryState::Delay(instant, duration) => { sleep_until(instant + duration).await; - let (tx, rx) = oneshot::channel(); - let res = self.new_conn(self.conn_id, tx).await; - if let Err(err) = res { - self.state = QueryState::Done; - return Err(err); - } - self.state = QueryState::GetConn(rx); - continue; + self.state = QueryState::RequestConn; } QueryState::Done => { panic!("Already done"); @@ -340,15 +357,6 @@ impl Query { } } } - - /// Requests a new connection. - async fn new_conn( - &self, - id: u64, - tx: oneshot::Sender>, - ) -> Result<(), Error> { - self.conn.new_conn(Some(id), tx).await - } } impl GetResponse for Query { @@ -357,7 +365,7 @@ impl GetResponse for Query { ) -> Pin< Box, Error>> + Send + '_>, > { - Box::pin(self.get_response_impl()) + Box::pin(Self::get_response(self)) } } @@ -648,47 +656,6 @@ fn retry_time(retries: u64) -> Duration { Duration::from_micros(to_usecs as u64) } -/// Check if a message is the reply to a query. -/// -/// Avoid checking the id field because the id has been changed in the -/// query that was actually issued. -fn is_answer_ignore_id< - Octs1: Octets + AsRef<[u8]>, - Octs2: Octets + AsRef<[u8]>, ->( - reply: &Message, - query: &Message, -) -> bool { - let reply_header = reply.header(); - let reply_hcounts = reply.header_counts(); - - // First check qr is set - if !reply_header.qr() { - return false; - } - - // If the result is an error, then the question - // section can be empty. In that case we require all other sections - // to be empty as well. - if reply_header.rcode() != Rcode::NoError - && reply_hcounts.qdcount() == 0 - && reply_hcounts.ancount() == 0 - && reply_hcounts.nscount() == 0 - && reply_hcounts.arcount() == 0 - { - // We can accept this as a valid reply. - return true; - } - - // Remaining checks. The question section in the reply has to be the - // same as in the query. - if reply_hcounts.qdcount() != query.header_counts().qdcount() { - false - } else { - reply.question() == query.question() - } -} - /// Helper function to create an empty future that is compatible with the /// future returned by a connection stream. async fn stream_nop() -> Result { diff --git a/src/net/client/request.rs b/src/net/client/request.rs index b042880c3..8eb081e37 100644 --- a/src/net/client/request.rs +++ b/src/net/client/request.rs @@ -3,6 +3,7 @@ #![warn(missing_docs)] #![warn(clippy::missing_docs_in_private_items)] +use crate::base::iana::Rcode; use crate::base::message::CopyRecordsError; use crate::base::message_builder::{ AdditionalBuilder, MessageBuilder, PushError, StaticCompressor, @@ -254,11 +255,29 @@ impl + Clone + Debug + Octets + Send + Sync + 'static> } fn is_answer(&self, answer: &Message<[u8]>) -> bool { - if !answer.header().qr() - || answer.header_counts().qdcount() - != self.msg.header_counts().qdcount() - || answer.header().id() != self.header.id() + let answer_header = answer.header(); + let answer_hcounts = answer.header_counts(); + + // First check qr is set and IDs match. + if !answer_header.qr() || answer_header.id() != self.header.id() { + return false; + } + + // If the result is an error, then the question section can be empty. + // In that case we require all other sections to be empty as well. + if answer_header.rcode() != Rcode::NoError + && answer_hcounts.qdcount() == 0 + && answer_hcounts.ancount() == 0 + && answer_hcounts.nscount() == 0 + && answer_hcounts.arcount() == 0 { + // We can accept this as a valid reply. + return true; + } + + // Now the question section in the reply has to be the same as in the + // query. + if answer_hcounts.qdcount() != self.msg.header_counts().qdcount() { false } else { answer.question() == self.msg.for_slice().question() diff --git a/src/net/client/stream.rs b/src/net/client/stream.rs index a8ce3041c..c13800750 100644 --- a/src/net/client/stream.rs +++ b/src/net/client/stream.rs @@ -775,7 +775,7 @@ where /// Mapping outstanding queries to their ID. /// -/// This is generic over anything rather than our concrete request time for +/// This is generic over anything rather than our concrete request type for /// easier testing. #[derive(Clone, Debug)] struct Queries { diff --git a/tests/net-client.rs b/tests/net-client.rs index 6c21b54ea..9f3b5af9d 100644 --- a/tests/net-client.rs +++ b/tests/net-client.rs @@ -29,7 +29,7 @@ fn dgram() { let step_value = Arc::new(CurrStepValue::new()); let conn = Dgram::new(deckard.clone(), step_value.clone()); - let octstr = dgram::Connection::new(None, conn); + let octstr = Arc::new(dgram::Connection::new(None, conn)); do_client(&deckard, octstr, &step_value).await; }); From e4f704988dc82ed9e979e159fd86b2cbd6e334dc Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Thu, 4 Jan 2024 15:30:56 +0100 Subject: [PATCH 114/124] =?UTF-8?q?Don=E2=80=99t=20run=20the=20client=20do?= =?UTF-8?q?c=20examples.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/net/client/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index b169f88ac..69fc52dab 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -90,7 +90,7 @@ //! that provides the response. //! //! For example: -//! ```rust +//! ```no_run //! # use domain::net::client::request::SendRequest; //! # async fn _test() { //! # let (tls_conn, _) = domain::net::client::stream::Connection::new( @@ -116,7 +116,7 @@ //! cancelation safe. //! //! For example: -//! ```rust +//! ```no_run //! # use crate::domain::net::client::request::SendRequest; //! # async fn _test() { //! # let (tls_conn, _) = domain::net::client::stream::Connection::new( From b9b733a7b0980240abb9a1f1e32ae4497e465f71 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Fri, 5 Jan 2024 18:12:44 +0100 Subject: [PATCH 115/124] Further improvements to the datagram client. --- examples/client-transports.rs | 5 +- src/base/message.rs | 12 + src/net/client/dgram.rs | 414 +++++++++++++++++++++++++-------- src/net/client/dgram_stream.rs | 4 +- src/net/client/mod.rs | 1 - src/net/client/protocol.rs | 203 ++++++++-------- src/net/client/redundant.rs | 213 +++++++++-------- src/net/client/request.rs | 44 ++-- src/net/client/stream.rs | 5 +- tests/net-client.rs | 2 +- tests/net/deckard/dgram.rs | 2 +- 11 files changed, 549 insertions(+), 356 deletions(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index a7d97b6fe..d64ce0331 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -193,10 +193,11 @@ async fn main() { // reply is truncated. This transport does not have a separate run // function. let udp_connect = UdpConnect::new(server_addr); - let dgram_conn = dgram::Connection::new(Some(dgram_config), udp_connect); + let dgram_conn = + dgram::Connection::with_config(udp_connect, dgram_config); // Send a query message and get the reply. - let reply = dgram_conn.request(req.clone()).await.unwrap(); + let reply = dgram_conn.query(req.clone()).await.unwrap(); println!("Dgram reply: {:?}", reply); // Create a single TCP transport connection. This is usefull for a diff --git a/src/base/message.rs b/src/base/message.rs index b23b7156e..79db7997c 100644 --- a/src/base/message.rs +++ b/src/base/message.rs @@ -171,6 +171,18 @@ impl Message { Ok(unsafe { Self::from_octets_unchecked(octets) }) } + /// Creates a message from octets, returning the octets if it fails. + pub fn try_from_octets(octets: Octs) -> Result + where + Octs: AsRef<[u8]>, + { + if Message::check_slice(octets.as_ref()).is_err() { + Err(octets) + } else { + Ok(unsafe { Self::from_octets_unchecked(octets) }) + } + } + /// Creates a message from a bytes value without checking. /// /// # Safety diff --git a/src/net/client/dgram.rs b/src/net/client/dgram.rs index 27a725b41..1838ae2ae 100644 --- a/src/net/client/dgram.rs +++ b/src/net/client/dgram.rs @@ -1,7 +1,10 @@ -//! A DNS over datagram transport +//! A client over datagram protocols. +//! +//! This module implements a DNS client for use with datagram protocols, i.e., +//! message-oriented, connection-less, unreliable network protocols. In +//! practice, this is pretty much exclusively UDP. #![warn(missing_docs)] -#![warn(clippy::missing_docs_in_private_items)] // To do: // - cookies @@ -17,12 +20,14 @@ use crate::net::client::request::{ use bytes::Bytes; use core::{cmp, fmt}; use futures_util::FutureExt; +use octseq::OctetsInto; use std::boxed::Box; use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use tokio::sync::{Semaphore, SemaphorePermit}; -use tokio::time::{timeout, Duration, Instant}; +use std::{error, io}; +use tokio::sync::{Semaphore, TryAcquireError}; +use tokio::time::{timeout_at, Duration, Instant}; //------------ Configuration Constants ---------------------------------------- @@ -42,9 +47,12 @@ const MAX_RETRIES: DefMinMax = DefMinMax::new(5, 1, 100); /// Default UDP payload size. const DEF_UDP_PAYLOAD_SIZE: u16 = 1232; +/// The default receive buffer size. +const DEF_RECV_SIZE: usize = 2000; + //------------ Config --------------------------------------------------------- -/// Configuration for a datagram transport connection. +/// Configuration of a datagram transport. #[derive(Clone, Debug)] pub struct Config { /// Maximum number of parallel requests for a transport connection. @@ -53,12 +61,16 @@ pub struct Config { /// Read timeout. read_timeout: Duration, - /// Maimum number of retries. + /// Maximum number of retries. max_retries: u8, - /// EDNS(0) UDP payload size. Set this value to None to be able to create - /// a DNS request without ENDS(0) option. + /// EDNS UDP payload size. + /// + /// If this is `None`, no OPT record will be included at all. udp_payload_size: Option, + + /// Receive buffer size. + recv_size: usize, } impl Config { @@ -67,58 +79,83 @@ impl Config { Default::default() } - /// Returns the maximum number of parallel requests. + /// Sets the maximum number of parallel requests. /// /// Once this many number of requests are currently outstanding, /// additional requests will wait. - pub fn max_parallel(&self) -> usize { - self.max_parallel - } - - /// Sets the maximum number of parallel requests. /// /// If this value is too small or too large, it will be caped. pub fn set_max_parallel(&mut self, value: usize) { self.max_parallel = MAX_PARALLEL.limit(value) } - /// Returns the read timeout. + /// Returns the maximum number of parallel requests. + pub fn max_parallel(&self) -> usize { + self.max_parallel + } + + /// Sets the read timeout. /// /// The read timeout is the maximum amount of time to wait for any /// response after a request was sent. + /// + /// If this value is too small or too large, it will be caped. + pub fn set_read_timeout(&mut self, value: Duration) { + self.read_timeout = READ_TIMEOUT.limit(value) + } + + /// Returns the read timeout. pub fn read_timeout(&self) -> Duration { self.read_timeout } - /// Sets the read timeout. + /// Sets the maximum number a request is retried before giving up. /// /// If this value is too small or too large, it will be caped. - pub fn set_read_timeout(&mut self, value: Duration) { - self.read_timeout = READ_TIMEOUT.limit(value) + pub fn set_max_retries(&mut self, value: u8) { + self.max_retries = MAX_RETRIES.limit(value) } - /// Returns the maximum number a request is retried before giving up. + /// Returns the maximum number of request retries. pub fn max_retries(&self) -> u8 { self.max_retries } - /// Sets the maximum number of request retries. + /// Sets the requested UDP payload size. /// - /// If this value is too small or too large, it will be caped. - pub fn set_max_retries(&mut self, value: u8) { - self.max_retries = MAX_RETRIES.limit(value) + /// This value indicates to the server the maximum size of a UDP packet. + /// For UDP on public networks, this value should be left at the default + /// of 1232 to avoid issues rising from packet fragmentation. See + /// [draft-ietf-dnsop-avoid-fragmentation] for a discussion on these + /// issues and recommendations. + /// + /// On private networks or protocols other than UDP, other values can be + /// used. + /// + /// Setting the UDP payload size to `None` currently results in messages + /// that will not include an OPT record. + /// + /// [draft-ietf-dnsop-avoid-fragmentation]: https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/ + pub fn set_udp_payload_size(&mut self, value: Option) { + self.udp_payload_size = value; } /// Returns the UDP payload size. - /// - /// See draft-ietf-dnsop-avoid-fragmentation-15 for a discussion. pub fn udp_payload_size(&self) -> Option { self.udp_payload_size } - /// Sets the UDP payload size. - pub fn set_udp_payload_size(&mut self, value: Option) { - self.udp_payload_size = value; + /// Sets the receive buffer size. + /// + /// This is the amount of memory that is allocated for receiving a + /// response. + pub fn set_recv_size(&mut self, size: usize) { + self.recv_size = size + } + + /// Returns the receive buffer size. + pub fn recv_size(&self) -> usize { + self.recv_size } } @@ -129,13 +166,17 @@ impl Default for Config { read_timeout: READ_TIMEOUT.default(), max_retries: MAX_RETRIES.default(), udp_payload_size: Some(DEF_UDP_PAYLOAD_SIZE), + recv_size: DEF_RECV_SIZE, } } } //------------ Connection ----------------------------------------------------- -/// A datagram transport connection. +/// A datagram protocol connection. +/// +/// Because it owns the connection’s resources, this type is not `Clone`. +/// However, it is entirely safe to share it by sticking it into e.g. an arc. #[derive(Debug)] pub struct Connection { /// User configuration variables. @@ -149,9 +190,13 @@ pub struct Connection { } impl Connection { - /// Create a new datagram transport connection. - pub fn new(config: Option, connect: S) -> Connection { - let config = config.unwrap_or_default(); + /// Create a new datagram transport with default configuration. + pub fn new(connect: S) -> Connection { + Self::with_config(connect, Default::default()) + } + + /// Create a new datagram transport with a given configuration. + pub fn with_config(connect: S, config: Config) -> Connection { Self { semaphore: Semaphore::new(config.max_parallel), config, @@ -165,100 +210,117 @@ where S: AsyncConnect, S::Connection: AsyncDgramRecv + AsyncDgramSend + Unpin, { - /// Sends a request and receives a response. - pub async fn request( + /// Performs a query. + /// + /// Sends the provided and returns either a response or an error. If there + /// are currently too many active queries, the future will wait until the + /// number has dropped below the limit. + pub async fn query( &self, - mut request: Req, - ) -> Result, Error> { - let recv_size = 2000; // Should be configurable. + request: Req, + ) -> Result, QueryError> { + let _ = self.semaphore.acquire().await.expect("semaphore closed"); + self.query_unchecked(request).await + } + /// Tries to perform a query. + /// + /// This is essentially the same as [`request`][Self::query] but returns + /// an error immediately if there are currently too many active queries. + pub async fn try_query( + &self, + request: Req, + ) -> Result, TryQueryError> { + let _ = match self.semaphore.try_acquire() { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => { + return Err(TryQueryError::TooManyQueries(request)) + } + Err(TryAcquireError::Closed) => panic!("semaphore closed"), + }; + self.query_unchecked(request) + .await + .map_err(TryQueryError::Query) + } + + /// Performs a query without acquiring the semaphore. + async fn query_unchecked( + &self, + mut request: Req, + ) -> Result, QueryError> { + // How often we’ve retried the request. let mut retries: u8 = 0; - // We need to get past the semaphore that limits the - // number of concurrent sockets we can use. - let _permit = self.get_permit().await; + // A place to store the receive buffer for reuse. + let mut reuse_buf = None; - loop { - let mut sock = self - .connect - .connect() - .await - .map_err(|e| Error::UdpConnect(Arc::new(e)))?; + // Transmit loop. + while retries < self.config.max_retries { + let mut sock = + self.connect.connect().await.map_err(QueryError::connect)?; // Set random ID in header request.header_mut().set_random_id(); - // Set UDP payload size + // Set UDP payload size if necessary. if let Some(size) = self.config.udp_payload_size { request.set_udp_payload_size(size) } + + // Create the message and send it out. let request_msg = request.to_message(); let dgram = request_msg.as_slice(); - - let sent = sock - .send(dgram) - .await - .map_err(|e| Error::UdpSend(Arc::new(e)))?; + let sent = sock.send(dgram).await.map_err(QueryError::send)?; if sent != dgram.len() { - return Err(Error::UdpShortSend); + return Err(QueryError::short_send()); } - let start = Instant::now(); - - loop { - let elapsed = start.elapsed(); - if elapsed > self.config.read_timeout { - // Break out of the receive loop and continue in the - // transmit loop. - break; - } - let remain = self.config.read_timeout - elapsed; - - let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here. - let timeout_res = timeout(remain, sock.recv(&mut buf)).await; - if timeout_res.is_err() { - retries += 1; - if retries < self.config.max_retries { - // Break out of the receive loop and continue in the - // transmit loop. - break; - } - return Err(Error::UdpTimeoutNoResponse); - } - let len = timeout_res - .expect("errror case is checked above") - .map_err(|e| Error::UdpReceive(Arc::new(e)))?; + // Receive loop. It may at most take read_timeout time. + let deadline = Instant::now() + self.config.read_timeout; + while deadline > Instant::now() { + let mut buf = reuse_buf.take().unwrap_or_else(|| { + // XXX use uninit'ed mem here. + vec![0; self.config.recv_size] + }); + let len = + match timeout_at(deadline, sock.recv(&mut buf)).await { + Ok(Ok(len)) => len, + Ok(Err(err)) => { + // Receiving failed. + return Err(QueryError::receive(err)); + } + Err(_) => { + // Timeout. + // XXX Is this extra increase of the retry counter + // here on purpose? If not, we can turn the + // outer loop into a for loop. + retries += 1; + break; + } + }; buf.truncate(len); // We ignore garbage since there is a timer on this whole // thing. - let answer = match Message::from_octets(buf.into()) { - // Just go back to receiving. + let answer = match Message::try_from_octets(buf) { Ok(answer) => answer, - Err(_) => continue, + Err(buf) => { + // Just go back to receiving. + reuse_buf = Some(buf); + continue; + } }; if !request.is_answer(answer.for_slice()) { // Wrong answer, go back to receiving + reuse_buf = Some(answer.into_octets()); continue; } - return Ok(answer); + return Ok(answer.octets_into()); } retries += 1; - if retries < self.config.max_retries { - continue; - } - break; } - Err(Error::UdpTimeoutNoResponse) - } - - /// Return a permit for a our semaphore. - async fn get_permit(&self) -> SemaphorePermit { - self.semaphore - .acquire() - .await - .expect("the semaphore has not been closed") + Err(QueryError::timeout()) } } @@ -276,8 +338,10 @@ where request_msg: CR, ) -> Result, Error> { Ok(Box::new(Query { - get_response_fut: async move { self.request(request_msg).await } - .boxed(), + get_response_fut: async move { + self.query(request_msg).await.map_err(Into::into) + } + .boxed(), })) } } @@ -320,7 +384,7 @@ where >> where Self: 's; fn handle_request(&self, request: Req) -> Self::Fut<'_> { - self.request(request).boxed() + async { self.query(request).await.map_err(Into::into) }.boxed() } } @@ -390,3 +454,157 @@ impl DefMinMax { cmp::max(self.min, cmp::min(self.max, value)) } } + +//============ Errors ======================================================== + +//------------ QueryError ---------------------------------------------------- + +/// A query failed. +#[derive(Debug)] +pub struct QueryError { + /// Which step failed? + kind: QueryErrorKind, + + /// The underlying IO error. + io: std::io::Error, +} + +impl QueryError { + fn new(kind: QueryErrorKind, io: io::Error) -> Self { + Self { kind, io } + } + + fn connect(io: io::Error) -> Self { + Self::new(QueryErrorKind::Connect, io) + } + + fn send(io: io::Error) -> Self { + Self::new(QueryErrorKind::Send, io) + } + + fn short_send() -> Self { + Self::new( + QueryErrorKind::Send, + io::Error::new(io::ErrorKind::Other, "short request sent"), + ) + } + + fn timeout() -> Self { + Self::new( + QueryErrorKind::Timeout, + io::Error::new(io::ErrorKind::TimedOut, "timeout expired"), + ) + } + + fn receive(io: io::Error) -> Self { + Self::new(QueryErrorKind::Receive, io) + } +} + +impl QueryError { + /// Returns information about when the query has failed. + pub fn kind(&self) -> QueryErrorKind { + self.kind + } + + /// Converts the query error into the underlying IO error. + pub fn io_error(self) -> std::io::Error { + self.io + } +} + +impl From for std::io::Error { + fn from(err: QueryError) -> std::io::Error { + err.io + } +} + +impl fmt::Display for QueryError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}: {}", self.kind.error_str(), self.io) + } +} + +impl error::Error for QueryError {} + +//------------ QueryErrorKind ------------------------------------------------ + +/// Which part of processing the query failed? +#[derive(Copy, Clone, Debug)] +pub enum QueryErrorKind { + /// Failed to connect to the remote. + Connect, + + /// Failed to send the request. + Send, + + /// The request has timed out. + Timeout, + + /// Failed to read the response. + Receive, +} + +impl QueryErrorKind { + /// Returns the string to be used when displaying a query error. + fn error_str(self) -> &'static str { + match self { + Self::Connect => "connecting failed", + Self::Send => "sending request failed", + Self::Timeout | Self::Receive => "reading response failed", + } + } +} + +impl fmt::Display for QueryErrorKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match self { + Self::Connect => "connecting failed", + Self::Send => "sending request failed", + Self::Timeout => "request timeout", + Self::Receive => "reading response failed", + }) + } +} + +//------------ TryQueryError ------------------------------------------------- + +/// An attempted query failed +/// +/// This error is returned by [`Connection::try_query`]. +pub enum TryQueryError { + /// The query has failed with the given error. + Query(QueryError), + + /// There were too many active queries. + /// + /// This variant contains the original request unchanged. + TooManyQueries(Req), +} + +impl fmt::Debug for TryQueryError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Query(err) => { + f.debug_tuple("TryQueryError::Query").field(err).finish() + } + Self::TooManyQueries(_) => f + .debug_tuple("TryQueryError::Req") + .field(&format_args!("_")) + .finish(), + } + } +} + +impl fmt::Display for TryQueryError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Query(error) => error.fmt(f), + Self::TooManyQueries(_) => { + f.write_str("too many active requests") + } + } + } +} + +impl error::Error for TryQueryError {} diff --git a/src/net/client/dgram_stream.rs b/src/net/client/dgram_stream.rs index b917504db..89593a412 100644 --- a/src/net/client/dgram_stream.rs +++ b/src/net/client/dgram_stream.rs @@ -117,7 +117,7 @@ where config: Config, ) -> (Self, multi_stream::Transport) { let udp_conn = - dgram::Connection::new(Some(config.dgram), dgram_remote).into(); + dgram::Connection::with_config(dgram_remote, config.dgram).into(); let (tcp_conn, transport) = multi_stream::Connection::with_config( stream_remote, config.multi_stream, @@ -137,7 +137,7 @@ where &self, request: Req, ) -> Result, Error> { - let response = self.udp_conn.request(request.clone()).await?; + let response = self.udp_conn.query(request.clone()).await?; if !response.header().tc() { return Ok(response); } diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 69fc52dab..a933652ae 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -140,7 +140,6 @@ #![cfg(feature = "unstable-client-transport")] #![cfg_attr(docsrs, doc(cfg(feature = "unstable-client-transport")))] #![warn(missing_docs)] -#![warn(clippy::missing_docs_in_private_items)] pub mod dgram; pub mod dgram_stream; diff --git a/src/net/client/protocol.rs b/src/net/client/protocol.rs index b909c0ded..3a5779dec 100644 --- a/src/net/client/protocol.rs +++ b/src/net/client/protocol.rs @@ -122,11 +122,65 @@ where } } +//------------ UdpConnect -------------------------------------------------- + +/// Create new TCP connections. +#[derive(Clone, Copy, Debug)] +pub struct UdpConnect { + /// Remote address to connect to. + addr: SocketAddr, +} + +impl UdpConnect { + /// Create new UDP connections. + /// + /// addr is the destination address to connect to. + pub fn new(addr: SocketAddr) -> Self { + Self { addr } + } + + /// Bind to a random local UDP port. + async fn bind_and_connect(self) -> Result { + let mut i = 0; + let sock = loop { + let local: SocketAddr = if self.addr.is_ipv4() { + ([0u8; 4], 0).into() + } else { + ([0u16; 8], 0).into() + }; + match UdpSocket::bind(&local).await { + Ok(sock) => break sock, + Err(err) => { + if i == RETRY_RANDOM_PORT { + return Err(err); + } else { + i += 1 + } + } + } + }; + sock.connect(self.addr).await?; + Ok(sock) + } +} + +impl AsyncConnect for UdpConnect { + type Connection = UdpSocket; + type Fut = Pin< + Box< + dyn Future> + + Send, + >, + >; + + fn connect(&self) -> Self::Fut { + Box::pin(self.bind_and_connect()) + } +} + //------------ AsyncDgramRecv ------------------------------------------------- /// Receive a datagram packets asynchronously. -/// -/// pub trait AsyncDgramRecv { /// Polled receive. fn poll_recv( @@ -136,6 +190,18 @@ pub trait AsyncDgramRecv { ) -> Poll>; } +impl AsyncDgramRecv for UdpSocket { + fn poll_recv( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + UdpSocket::poll_recv(self, cx, buf) + } +} + +//------------ AsyncDgramRecvEx ----------------------------------------------- + /// Convenvience trait to turn poll_recv into an asynchronous function. pub trait AsyncDgramRecvEx: AsyncDgramRecv { /// Asynchronous receive function. @@ -143,16 +209,21 @@ pub trait AsyncDgramRecvEx: AsyncDgramRecv { where Self: Unpin, { - recv(self, buf) + DgramRecv { + receiver: self, + buf, + } } } impl AsyncDgramRecvEx for R {} +//------------ DgramRecv ----------------------------------------------------- + pin_project! { /// Return value of recv. This captures the future for recv. pub struct DgramRecv<'a, R: ?Sized> { - receiver: &'a mut R, + receiver: &'a R, buf: &'a mut [u8], } } @@ -178,14 +249,6 @@ impl Future for DgramRecv<'_, R> { } } -/// Helper function for the recv method. -fn recv<'a, R: ?Sized>( - receiver: &'a mut R, - buf: &'a mut [u8], -) -> DgramRecv<'a, R> { - DgramRecv { receiver, buf } -} - //------------ AsyncDgramSend ------------------------------------------------- /// Send a datagram packet asynchronously. @@ -194,12 +257,24 @@ fn recv<'a, R: ?Sized>( pub trait AsyncDgramSend { /// Polled send function. fn poll_send( - self: Pin<&Self>, + &self, cx: &mut Context<'_>, buf: &[u8], ) -> Poll>; } +impl AsyncDgramSend for UdpSocket { + fn poll_send( + &self, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + UdpSocket::poll_send(self, cx, buf) + } +} + +//------------ AsyncDgramSendEx ---------------------------------------------- + /// Convenience trait that turns poll_send into an asynchronous function. pub trait AsyncDgramSendEx: AsyncDgramSend { /// Asynchronous function to send a packet. @@ -207,12 +282,14 @@ pub trait AsyncDgramSendEx: AsyncDgramSend { where Self: Unpin, { - send(self, buf) + DgramSend { sender: self, buf } } } impl AsyncDgramSendEx for S {} +//------------ DgramSend ----------------------------------------------------- + /// This is the return value of send. It captures the future for send. pub struct DgramSend<'a, S: ?Sized> { /// The datagram send object. @@ -232,101 +309,3 @@ impl Future for DgramSend<'_, S> { Pin::new(self.sender).poll_send(cx, self.buf) } } - -/// Send helper function to implement the send method of AsyncDgramSendEx. -fn send<'a, S: ?Sized>(sender: &'a S, buf: &'a [u8]) -> DgramSend<'a, S> { - DgramSend { sender, buf } -} - -//------------ UdpConnect -------------------------------------------------- - -/// Create new TCP connections. -#[derive(Clone, Copy, Debug)] -pub struct UdpConnect { - /// Remote address to connect to. - addr: SocketAddr, -} - -impl UdpConnect { - /// Create new UDP connections. - /// - /// addr is the destination address to connect to. - pub fn new(addr: SocketAddr) -> Self { - Self { addr } - } -} - -impl AsyncConnect for UdpConnect { - type Connection = UdpDgram; - type Fut = Pin< - Box< - dyn Future> - + Send, - >, - >; - - fn connect(&self) -> Self::Fut { - Box::pin(UdpDgram::new(self.addr)) - } -} - -/// A single UDP 'connection' -pub struct UdpDgram { - /// Underlying UDP socket - sock: Arc, -} - -impl UdpDgram { - /// Create a new UdpDgram object. - async fn new(addr: SocketAddr) -> Result { - let sock = Self::udp_bind(addr.is_ipv4()).await?; - sock.connect(addr).await?; - Ok(Self { - sock: Arc::new(sock), - }) - } - /// Bind to a local UDP port. - /// - /// This should explicitly pick a random number in a suitable range of - /// ports. - async fn udp_bind(v4: bool) -> Result { - let mut i = 0; - loop { - let local: SocketAddr = if v4 { - ([0u8; 4], 0).into() - } else { - ([0u16; 8], 0).into() - }; - match UdpSocket::bind(&local).await { - Ok(sock) => return Ok(sock), - Err(err) => { - if i == RETRY_RANDOM_PORT { - return Err(err); - } else { - i += 1 - } - } - } - } - } -} - -impl AsyncDgramRecv for UdpDgram { - fn poll_recv( - &self, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - self.sock.poll_recv(cx, buf) - } -} - -impl AsyncDgramSend for UdpDgram { - fn poll_send( - self: Pin<&Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.sock.poll_send(cx, buf) - } -} diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index c172f5d63..3e8dfdd2c 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -83,12 +83,12 @@ pub struct Config { /// This type represents a transport connection. #[derive(Clone, Debug)] -pub struct Connection { +pub struct Connection { /// Reference to the actual implementation of the connection. - inner: Arc>, + inner: Arc>, } -impl<'a, BMB: Clone + Debug + Send + Sync + 'static> Connection { +impl<'a, Req: Clone + Debug + Send + Sync + 'static> Connection { /// Create a new connection. pub fn new(config: Option) -> Result { let config = match config { @@ -98,7 +98,7 @@ impl<'a, BMB: Clone + Debug + Send + Sync + 'static> Connection { } None => Default::default(), }; - let connection = InnerConnection::new(config)?; + let connection = Transport::new(config)?; //test_send(connection); Ok(Self { inner: Arc::new(connection), @@ -113,7 +113,7 @@ impl<'a, BMB: Clone + Debug + Send + Sync + 'static> Connection { /// Add a transport connection. pub async fn add( &self, - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, ) -> Result<(), Error> { self.inner.add(conn).await } @@ -121,19 +121,19 @@ impl<'a, BMB: Clone + Debug + Send + Sync + 'static> Connection { /// Implementation of the request function. async fn request_impl( &self, - request_msg: &BMB, + request_msg: &Req, ) -> Result, Error> { let request = self.inner.request(request_msg.clone()).await?; Ok(Box::new(request)) } } -impl SendRequest - for Connection +impl SendRequest + for Connection { fn send_request<'a>( &'a self, - request_msg: &'a BMB, + request_msg: &'a Req, ) -> Pin< Box< dyn Future, Error>> @@ -145,11 +145,11 @@ impl SendRequest } } -//------------ ReqResp -------------------------------------------------------- +//------------ Query -------------------------------------------------------- /// This type represents an active query request. #[derive(Debug)] -pub struct ReqResp { +pub struct Query { /// User configuration. config: Config, @@ -157,13 +157,13 @@ pub struct ReqResp { state: QueryState, /// The reuqest message - request_msg: BMB, + request_msg: Req, /// List of connections identifiers and estimated response times. conn_rt: Vec, /// Channel to send requests to the run function. - sender: mpsc::Sender>, + sender: mpsc::Sender>, /// List of futures for outstanding requests. fut_list: @@ -201,15 +201,15 @@ enum QueryState { } /// The commands that can be sent to the run function. -enum ChanReq { +enum ChanReq { /// Add a connection - Add(AddReq), + Add(AddReq), /// Get the list of estimated response times for all connections GetRT(RTReq), /// Start a query - Query(RequestReq), + Query(RequestReq), /// Report how long it took to get a response Report(TimeReport), @@ -218,16 +218,16 @@ enum ChanReq { Failure(TimeReport), } -impl Debug for ChanReq { +impl Debug for ChanReq { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { f.debug_struct("ChanReq").finish() } } /// Request to add a new connection -struct AddReq { +struct AddReq { /// New connection to add - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, /// Channel to send the reply to tx: oneshot::Sender, @@ -246,18 +246,18 @@ struct RTReq /**/ { type RTReply = Result, Error>; /// Request to start a request -struct RequestReq { +struct RequestReq { /// Identifier of connection id: u64, /// Request message - request_msg: BMB, + request_msg: Req, /// Channel to send the reply to tx: oneshot::Sender, } -impl Debug for RequestReq { +impl Debug for RequestReq { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { f.debug_struct("RequestReq") .field("id", &self.id) @@ -304,13 +304,13 @@ struct ConnRT { /// Result of the futures in fut_list. type FutListOutput = (usize, Result, Error>); -impl ReqResp { +impl Query { /// Create a new query object. fn new( config: Config, - request_msg: BMB, + request_msg: Req, mut conn_rt: Vec, - sender: mpsc::Sender>, + sender: mpsc::Sender>, ) -> Self { let conn_rt_len = conn_rt.len(); conn_rt.sort_unstable_by(conn_rt_cmp); @@ -362,79 +362,78 @@ impl ReqResp { loop { tokio::select! { res = self.fut_list.next() => { - let res = res.expect("res should not be empty"); - match res.1 { - Err(ref err) => { - if self.config.defer_transport_error { - if self.deferred_transport_error.is_none() { - self.deferred_transport_error = Some(err.clone()); - } - if res.0 == ind { - // The current upstream finished, - // try the next one, if any. - self.state = - if ind+1 < self.conn_rt.len() { - QueryState::Probe(ind+1) + let res = res.expect("res should not be empty"); + match res.1 { + Err(ref err) => { + if self.config.defer_transport_error { + if self.deferred_transport_error.is_none() { + self.deferred_transport_error = Some(err.clone()); + } + if res.0 == ind { + // The current upstream finished, + // try the next one, if any. + self.state = + if ind+1 < self.conn_rt.len() { + QueryState::Probe(ind+1) + } + else + { + QueryState::Wait + }; + // Break out of receive loop + break; + } + // Just continue receiving + continue; } - else - { - QueryState::Wait - }; - // Break out of receive loop - break; - } - // Just continue receiving - continue; - } - // Return error to the user. - } - Ok(ref msg) => { - if skip(msg, &self.config) { - if self.deferred_reply.is_none() { - self.deferred_reply = Some(msg.clone()); - } - if res.0 == ind { - // The current upstream finished, - // try the next one, if any. - self.state = - if ind+1 < self.conn_rt.len() { - QueryState::Probe(ind+1) + // Return error to the user. } - else - { - QueryState::Wait - }; - // Break out of receive loop - break; + Ok(ref msg) => { + if skip(msg, &self.config) { + if self.deferred_reply.is_none() { + self.deferred_reply = Some(msg.clone()); + } + if res.0 == ind { + // The current upstream finished, + // try the next one, if any. + self.state = + if ind+1 < self.conn_rt.len() { + QueryState::Probe(ind+1) + } + else + { + QueryState::Wait + }; + // Break out of receive loop + break; + } + // Just continue receiving + continue; + } + // Now we have a reply that can be + // returned to the user. } - // Just continue receiving - continue; - } - // Now we have a reply that can be - // returned to the user. } - } - self.result = Some(res.1); - self.res_index= res.0; + self.result = Some(res.1); + self.res_index= res.0; - self.state = QueryState::Report(0); - // Break out of receive loop - break; + self.state = QueryState::Report(0); + // Break out of receive loop + break; } _ = sleep_until(timeout) => { - // Move to the next Probe state if there - // are more upstreams to try, otherwise - // move to the Wait state. - self.state = - if ind+1 < self.conn_rt.len() { - QueryState::Probe(ind+1) - } - else - { - QueryState::Wait - }; - // Break out of receive loop - break; + // Move to the next Probe state if there + // are more upstreams to try, otherwise + // move to the Wait state. + self.state = + if ind+1 < self.conn_rt.len() { + QueryState::Probe(ind+1) + } + else { + QueryState::Wait + }; + // Break out of receive loop + break; } } } @@ -538,9 +537,7 @@ impl ReqResp { } } -impl GetResponse - for ReqResp -{ +impl GetResponse for Query { fn get_response( &mut self, ) -> Pin< @@ -550,22 +547,22 @@ impl GetResponse } } -//------------ InnerConnection ------------------------------------------------ +//------------ Transport ----------------------------------------------------- /// Type that actually implements the connection. #[derive(Debug)] -struct InnerConnection { +struct Transport { /// User configuation. config: Config, /// Receive side of the channel used by the runner. - receiver: Mutex>>>, + receiver: Mutex>>>, /// To send a request to the runner. - sender: mpsc::Sender>, + sender: mpsc::Sender>, } -impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { +impl<'a, Req: Clone + Send + Sync + 'static> Transport { /// Implementation of the new method. fn new(config: Config) -> Result { let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); @@ -588,11 +585,11 @@ impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { } /// Implementation of the run method. - async fn run_impl(opt_receiver: Option>>) { + async fn run_impl(opt_receiver: Option>>) { let mut next_id: u64 = 10; let mut conn_stats: Vec = Vec::new(); let mut conn_rt: Vec = Vec::new(); - let mut conns: Vec + Send + Sync>> = + let mut conns: Vec + Send + Sync>> = Vec::new(); let mut receiver = @@ -693,7 +690,7 @@ impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { /// Implementation of the add method. async fn add( &self, - conn: Box + Send + Sync>, + conn: Box + Send + Sync>, ) -> Result<(), Error> { let (tx, rx) = oneshot::channel(); self.sender @@ -706,15 +703,15 @@ impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { /// Implementation of the query method. async fn request( &'a self, - request_msg: BMB, - ) -> Result, Error> { + request_msg: Req, + ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.sender .send(ChanReq::GetRT(RTReq { tx })) .await .expect("send should not fail"); let conn_rt = rx.await.expect("receive should not fail")?; - Ok(ReqResp::new( + Ok(Query::new( self.config.clone(), request_msg, conn_rt, @@ -728,11 +725,11 @@ impl<'a, BMB: Clone + Send + Sync + 'static> InnerConnection { /// Async function to send a request and wait for the reply. /// /// This gives a single future that we can put in a list. -async fn start_request( +async fn start_request( index: usize, id: u64, - sender: mpsc::Sender>, - request_msg: BMB, + sender: mpsc::Sender>, + request_msg: Req, ) -> (usize, Result, Error>) { let (tx, rx) = oneshot::channel(); sender diff --git a/src/net/client/request.rs b/src/net/client/request.rs index 8eb081e37..dfcda592a 100644 --- a/src/net/client/request.rs +++ b/src/net/client/request.rs @@ -302,9 +302,6 @@ pub enum Error { /// ParseError from Message. MessageParseError, - /// octet_stream configuration error. - OctetStreamConfigError(Arc), - /// Underlying transport not found in redundant connection RedundantTransportNotFound, @@ -318,6 +315,7 @@ pub enum Error { StreamIdleTimeout, /// Error receiving a reply. + // StreamReceiveError, /// Reading from stream gave an error. @@ -335,32 +333,14 @@ pub enum Error { /// Reading for a stream ended unexpectedly. StreamUnexpectedEndOfData, - /// Binding a UDP socket gave an error. - UdpBind(Arc), - - /// UDP configuration error. - UdpConfigError(Arc), - - /// Connecting a UDP socket gave an error. - UdpConnect(Arc), - - /// Receiving from a UDP socket gave an error. - UdpReceive(Arc), - - /// Sending over a UDP socket gaven an error. - UdpSend(Arc), - - /// Sending over a UDP socket gave a partial result. - UdpShortSend, - - /// Timeout receiving a response over a UDP socket. - UdpTimeoutNoResponse, - /// Reply does not match the query. WrongReplyForQuery, /// No transport available to transmit request. NoTransportAvailable, + + /// An error happened in the datagram transport. + Dgram(Arc), } impl From for Error { @@ -369,8 +349,16 @@ impl From for Error { } } +impl From for Error { + fn from(err: super::dgram::QueryError) -> Self { + Self::Dgram(err.into()) + } +} + impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { + unimplemented!() + /* match self { Error::ConnectionClosed => write!(f, "connection closed"), Error::OptTooLong => write!(f, "OPT record is too long"), @@ -378,7 +366,6 @@ impl fmt::Display for Error { write!(f, "PushError from MessageBuilder") } Error::MessageParseError => write!(f, "ParseError from Message"), - Error::OctetStreamConfigError(_) => write!(f, "bad config value"), Error::RedundantTransportNotFound => write!( f, "Underlying transport not found in redundant connection" @@ -426,6 +413,7 @@ impl fmt::Display for Error { write!(f, "no transport available") } } + */ } } @@ -440,12 +428,13 @@ impl From for Error { impl error::Error for Error { fn source(&self) -> Option<&(dyn error::Error + 'static)> { + unimplemented!() + /* match self { Error::ConnectionClosed => None, Error::OptTooLong => None, Error::MessageBuilderPushError => None, Error::MessageParseError => None, - Error::OctetStreamConfigError(e) => Some(e), Error::RedundantTransportNotFound => None, Error::ShortMessage => None, Error::StreamLongMessage => None, @@ -466,5 +455,6 @@ impl error::Error for Error { Error::WrongReplyForQuery => None, Error::NoTransportAvailable => None, } + */ } } diff --git a/src/net/client/stream.rs b/src/net/client/stream.rs index c13800750..6360ddc8d 100644 --- a/src/net/client/stream.rs +++ b/src/net/client/stream.rs @@ -162,10 +162,7 @@ impl Connection { Ok(Box::new(Query::new(rx))) } - /// Start a DNS request but do not check if the reply matches the request. - /// - /// This function is similar to [Self::query]. Not checking if the reply - /// match the request avoids having to keep the request around. + /// Start a DNS request. pub async fn start_request( &self, query_msg: Req, diff --git a/tests/net-client.rs b/tests/net-client.rs index 9f3b5af9d..ebd9341ef 100644 --- a/tests/net-client.rs +++ b/tests/net-client.rs @@ -29,7 +29,7 @@ fn dgram() { let step_value = Arc::new(CurrStepValue::new()); let conn = Dgram::new(deckard.clone(), step_value.clone()); - let octstr = Arc::new(dgram::Connection::new(None, conn)); + let octstr = Arc::new(dgram::Connection::new(conn)); do_client(&deckard, octstr, &step_value).await; }); diff --git a/tests/net/deckard/dgram.rs b/tests/net/deckard/dgram.rs index 0156f6335..6fd0eb33f 100644 --- a/tests/net/deckard/dgram.rs +++ b/tests/net/deckard/dgram.rs @@ -84,7 +84,7 @@ impl AsyncDgramRecv for DgramConnection { impl AsyncDgramSend for DgramConnection { fn poll_send( - self: Pin<&Self>, + &self, _: &mut Context<'_>, buf: &[u8], ) -> Poll> { From 31a1117df6162dd9a0e602d56284abe7972b5513 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Fri, 5 Jan 2024 18:29:05 +0100 Subject: [PATCH 116/124] Mark the big example as no_run. --- src/net/client/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index a933652ae..2d5b8c365 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -133,7 +133,7 @@ //! ``` //! # Example with various transport connections -//! ``` +//! ```no_run #![doc = include_str!("../../../examples/client-transports.rs")] //! ``` From 8d2da50e47c66c8c7e99002372b5a4795be39792 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Mon, 8 Jan 2024 11:36:26 +0100 Subject: [PATCH 117/124] Change the datagram transmit loop into a for loop. --- src/net/client/dgram.rs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/net/client/dgram.rs b/src/net/client/dgram.rs index 1838ae2ae..590d8d353 100644 --- a/src/net/client/dgram.rs +++ b/src/net/client/dgram.rs @@ -248,14 +248,11 @@ where &self, mut request: Req, ) -> Result, QueryError> { - // How often we’ve retried the request. - let mut retries: u8 = 0; - // A place to store the receive buffer for reuse. let mut reuse_buf = None; // Transmit loop. - while retries < self.config.max_retries { + for _ in 0..self.config.max_retries { let mut sock = self.connect.connect().await.map_err(QueryError::connect)?; @@ -291,10 +288,6 @@ where } Err(_) => { // Timeout. - // XXX Is this extra increase of the retry counter - // here on purpose? If not, we can turn the - // outer loop into a for loop. - retries += 1; break; } }; @@ -318,7 +311,6 @@ where } return Ok(answer.octets_into()); } - retries += 1; } Err(QueryError::timeout()) } From a3257843a862c78b4a201bf8173f837cbff290bc Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Tue, 9 Jan 2024 14:41:07 +0100 Subject: [PATCH 118/124] Impl Display and Error for request::Error. --- src/net/client/request.rs | 28 +++------------------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/src/net/client/request.rs b/src/net/client/request.rs index dfcda592a..985d98569 100644 --- a/src/net/client/request.rs +++ b/src/net/client/request.rs @@ -356,9 +356,7 @@ impl From for Error { } impl fmt::Display for Error { - fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { - unimplemented!() - /* + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Error::ConnectionClosed => write!(f, "connection closed"), Error::OptTooLong => write!(f, "OPT record is too long"), @@ -395,25 +393,14 @@ impl fmt::Display for Error { Error::StreamUnexpectedEndOfData => { write!(f, "unexpected end of data") } - Error::UdpBind(_) => write!(f, "error binding UDP socket"), - Error::UdpConfigError(_) => write!(f, "bad config value"), - Error::UdpConnect(_) => write!(f, "error connecting UDP socket"), - Error::UdpReceive(_) => { - write!(f, "error receiving from UDP socket") - } - Error::UdpSend(_) => write!(f, "error sending to UDP socket"), - Error::UdpShortSend => write!(f, "partial sent to UDP socket"), - Error::UdpTimeoutNoResponse => { - write!(f, "timeout waiting for response") - } Error::WrongReplyForQuery => { write!(f, "reply does not match query") } Error::NoTransportAvailable => { write!(f, "no transport available") } + Error::Dgram(err) => fmt::Display::fmt(err, f) } - */ } } @@ -428,8 +415,6 @@ impl From for Error { impl error::Error for Error { fn source(&self) -> Option<&(dyn error::Error + 'static)> { - unimplemented!() - /* match self { Error::ConnectionClosed => None, Error::OptTooLong => None, @@ -445,16 +430,9 @@ impl error::Error for Error { Error::StreamTooManyOutstandingQueries => None, Error::StreamWriteError(e) => Some(e), Error::StreamUnexpectedEndOfData => None, - Error::UdpBind(e) => Some(e), - Error::UdpConfigError(e) => Some(e), - Error::UdpConnect(e) => Some(e), - Error::UdpReceive(e) => Some(e), - Error::UdpSend(e) => Some(e), - Error::UdpShortSend => None, - Error::UdpTimeoutNoResponse => None, Error::WrongReplyForQuery => None, Error::NoTransportAvailable => None, + Error::Dgram(err) => Some(err), } - */ } } From 548d2396ab550ad3a519f287c2e1e09c7074393c Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Tue, 9 Jan 2024 16:49:42 +0100 Subject: [PATCH 119/124] Revert back to the previous design. --- examples/client-transports.rs | 17 ++-- src/net/client/dgram.rs | 176 ++++++++++++--------------------- src/net/client/dgram_stream.rs | 93 +++-------------- src/net/client/mod.rs | 6 +- src/net/client/multi_stream.rs | 71 ++++--------- src/net/client/redundant.rs | 80 +++++++++------ src/net/client/request.rs | 32 +----- src/net/client/stream.rs | 137 +++++++++---------------- src/resolv/stub/mod.rs | 2 +- tests/net-client.rs | 38 +------ tests/net/deckard/client.rs | 38 +------ 11 files changed, 214 insertions(+), 476 deletions(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index d64ce0331..1ef379c92 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -70,7 +70,7 @@ async fn main() { }); // Send a query message. - let mut request = udptcp_conn.send_request(&req).await.unwrap(); + let mut request = udptcp_conn.send_request(req.clone()); // Get the reply println!("Wating for UDP+TCP reply"); @@ -100,7 +100,7 @@ async fn main() { }); // Send a query message. - let mut request = tcp_conn.send_request(&req).await.unwrap(); + let mut request = tcp_conn.send_request(req.clone()); // Get the reply. A multi_stream connection does not have any timeout. // Wrap get_result in a timeout. @@ -154,7 +154,7 @@ async fn main() { println!("TLS run exited"); }); - let mut request = tls_conn.send_request(&req).await.unwrap(); + let mut request = tls_conn.send_request(req.clone()); println!("Wating for TLS reply"); let reply = timeout(Duration::from_millis(500), request.get_response()).await; @@ -179,7 +179,7 @@ async fn main() { // Start a few queries. for i in 1..10 { - let mut request = redun.send_request(&req).await.unwrap(); + let mut request = redun.send_request(req.clone()); let reply = request.get_response().await; if i == 2 { println!("redundant connection reply: {:?}", reply); @@ -196,8 +196,11 @@ async fn main() { let dgram_conn = dgram::Connection::with_config(udp_connect, dgram_config); - // Send a query message and get the reply. - let reply = dgram_conn.query(req.clone()).await.unwrap(); + // Send a message. + let mut request = dgram_conn.send_request(req.clone()); + // + // Get the reply + let reply = request.get_response().await; println!("Dgram reply: {:?}", reply); // Create a single TCP transport connection. This is usefull for a @@ -220,7 +223,7 @@ async fn main() { }); // Send a request message. - let mut request = tcp.send_request(&req).await.unwrap(); + let mut request = tcp.send_request(req); // Get the reply let reply = request.get_response().await; diff --git a/src/net/client/dgram.rs b/src/net/client/dgram.rs index 590d8d353..b57d9c9b3 100644 --- a/src/net/client/dgram.rs +++ b/src/net/client/dgram.rs @@ -15,18 +15,17 @@ use crate::net::client::protocol::{ AsyncDgramSendEx, }; use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, HandleRequest, SendRequest, + ComposeRequest, Error, GetResponse, SendRequest, }; use bytes::Bytes; use core::{cmp, fmt}; -use futures_util::FutureExt; use octseq::OctetsInto; use std::boxed::Box; use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::{error, io}; -use tokio::sync::{Semaphore, TryAcquireError}; +use tokio::sync::Semaphore; use tokio::time::{timeout_at, Duration, Instant}; //------------ Configuration Constants ---------------------------------------- @@ -179,6 +178,11 @@ impl Default for Config { /// However, it is entirely safe to share it by sticking it into e.g. an arc. #[derive(Debug)] pub struct Connection { + state: Arc>, +} + +#[derive(Debug)] +struct ConnectionState { /// User configuration variables. config: Config, @@ -191,16 +195,18 @@ pub struct Connection { impl Connection { /// Create a new datagram transport with default configuration. - pub fn new(connect: S) -> Connection { + pub fn new(connect: S) -> Self { Self::with_config(connect, Default::default()) } /// Create a new datagram transport with a given configuration. - pub fn with_config(connect: S, config: Config) -> Connection { + pub fn with_config(connect: S, config: Config) -> Self { Self { - semaphore: Semaphore::new(config.max_parallel), - config, - connect, + state: Arc::new(ConnectionState { + semaphore: Semaphore::new(config.max_parallel), + config, + connect, + }), } } } @@ -210,57 +216,40 @@ where S: AsyncConnect, S::Connection: AsyncDgramRecv + AsyncDgramSend + Unpin, { - /// Performs a query. + /// Performs a request. /// /// Sends the provided and returns either a response or an error. If there /// are currently too many active queries, the future will wait until the /// number has dropped below the limit. - pub async fn query( - &self, - request: Req, - ) -> Result, QueryError> { - let _ = self.semaphore.acquire().await.expect("semaphore closed"); - self.query_unchecked(request).await - } - - /// Tries to perform a query. - /// - /// This is essentially the same as [`request`][Self::query] but returns - /// an error immediately if there are currently too many active queries. - pub async fn try_query( - &self, - request: Req, - ) -> Result, TryQueryError> { - let _ = match self.semaphore.try_acquire() { - Ok(permit) => permit, - Err(TryAcquireError::NoPermits) => { - return Err(TryQueryError::TooManyQueries(request)) - } - Err(TryAcquireError::Closed) => panic!("semaphore closed"), - }; - self.query_unchecked(request) + pub async fn handle_request_impl( + self, + mut request: Req, + ) -> Result, Error> { + // Acquire the semaphore or wait for it. + let _ = self + .state + .semaphore + .acquire() .await - .map_err(TryQueryError::Query) - } + .expect("semaphore closed"); - /// Performs a query without acquiring the semaphore. - async fn query_unchecked( - &self, - mut request: Req, - ) -> Result, QueryError> { // A place to store the receive buffer for reuse. let mut reuse_buf = None; // Transmit loop. - for _ in 0..self.config.max_retries { - let mut sock = - self.connect.connect().await.map_err(QueryError::connect)?; + for _ in 0..self.state.config.max_retries { + let mut sock = self + .state + .connect + .connect() + .await + .map_err(QueryError::connect)?; // Set random ID in header request.header_mut().set_random_id(); // Set UDP payload size if necessary. - if let Some(size) = self.config.udp_payload_size { + if let Some(size) = self.state.config.udp_payload_size { request.set_udp_payload_size(size) } @@ -269,22 +258,22 @@ where let dgram = request_msg.as_slice(); let sent = sock.send(dgram).await.map_err(QueryError::send)?; if sent != dgram.len() { - return Err(QueryError::short_send()); + return Err(QueryError::short_send().into()); } // Receive loop. It may at most take read_timeout time. - let deadline = Instant::now() + self.config.read_timeout; + let deadline = Instant::now() + self.state.config.read_timeout; while deadline > Instant::now() { let mut buf = reuse_buf.take().unwrap_or_else(|| { // XXX use uninit'ed mem here. - vec![0; self.config.recv_size] + vec![0; self.state.config.recv_size] }); let len = match timeout_at(deadline, sock.recv(&mut buf)).await { Ok(Ok(len)) => len, Ok(Err(err)) => { // Receiving failed. - return Err(QueryError::receive(err)); + return Err(QueryError::receive(err).into()); } Err(_) => { // Timeout. @@ -312,97 +301,58 @@ where return Ok(answer.octets_into()); } } - Err(QueryError::timeout()) + Err(QueryError::timeout().into()) } } -impl Connection -where - S: AsyncConnect + Clone + Send + Sync + 'static, - S::Connection: - AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, -{ - /// Start a new DNS request. - async fn request_impl< - CR: ComposeRequest + Clone + Send + Sync + 'static, - >( - self: Arc, - request_msg: CR, - ) -> Result, Error> { - Ok(Box::new(Query { - get_response_fut: async move { - self.query(request_msg).await.map_err(Into::into) - } - .boxed(), - })) - } -} +//--- Clone -//--- SendRequest and HandleRequest - -impl SendRequest for Arc> -where - S: AsyncConnect + Clone + Send + Sync + 'static, - S::Connection: - AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, - Req: ComposeRequest + Clone + Send + Sync + 'static, -{ - fn send_request<'a>( - &'a self, - request_msg: &'a Req, - ) -> Pin< - Box< - dyn Future, Error>> - + Send - + '_, - >, - > { - let this = self.clone(); - Box::pin(this.request_impl(request_msg.clone())) +impl Clone for Connection { + fn clone(&self) -> Self { + Self { + state: self.state.clone(), + } } } -impl HandleRequest for Connection +//--- SendRequest + +impl SendRequest for Connection where S: AsyncConnect + Clone + Send + Sync + 'static, S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static, Req: ComposeRequest + Clone + Send + Sync + 'static, { - type Response = Message; - type Error = Error; - type Fut<'s> = Pin> + Send + 's - >> where Self: 's; - - fn handle_request(&self, request: Req) -> Self::Fut<'_> { - async { self.query(request).await.map_err(Into::into) }.boxed() + fn send_request(&self, request_msg: Req) -> Box { + Box::new(Request { + fut: Box::pin(self.clone().handle_request_impl(request_msg)), + }) } } -//------------ Query -------------------------------------------------------- +//------------ Request ------------------------------------------------------ /// The state of a DNS request. -pub struct Query { +pub struct Request { /// Future that does the actual work of GetResponse. - get_response_fut: - Pin, Error>> + Send>>, + fut: Pin, Error>> + Send>>, } -impl Query { - /// Async function that waits for the future stored in Query to complete. +impl Request { + /// Async function that waits for the future stored in Request to complete. async fn get_response_impl(&mut self) -> Result, Error> { - (&mut self.get_response_fut).await + (&mut self.fut).await } } -impl fmt::Debug for Query { +impl fmt::Debug for Request { fn fmt(&self, _: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { todo!() } } -impl GetResponse for Query { +impl GetResponse for Request { fn get_response( &mut self, ) -> Pin< @@ -566,7 +516,7 @@ impl fmt::Display for QueryErrorKind { /// This error is returned by [`Connection::try_query`]. pub enum TryQueryError { /// The query has failed with the given error. - Query(QueryError), + Request(QueryError), /// There were too many active queries. /// @@ -577,8 +527,8 @@ pub enum TryQueryError { impl fmt::Debug for TryQueryError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Self::Query(err) => { - f.debug_tuple("TryQueryError::Query").field(err).finish() + Self::Request(err) => { + f.debug_tuple("TryQueryError::Request").field(err).finish() } Self::TooManyQueries(_) => f .debug_tuple("TryQueryError::Req") @@ -591,7 +541,7 @@ impl fmt::Debug for TryQueryError { impl fmt::Display for TryQueryError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Self::Query(error) => error.fmt(f), + Self::Request(error) => error.fmt(f), Self::TooManyQueries(_) => { f.write_str("too many active requests") } diff --git a/src/net/client/dgram_stream.rs b/src/net/client/dgram_stream.rs index 89593a412..42b007ade 100644 --- a/src/net/client/dgram_stream.rs +++ b/src/net/client/dgram_stream.rs @@ -13,10 +13,9 @@ use crate::net::client::protocol::{ AsyncConnect, AsyncDgramRecv, AsyncDgramSend, }; use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, HandleRequest, SendRequest, + ComposeRequest, Error, GetResponse, SendRequest, }; use bytes::Bytes; -use futures_util::FutureExt; use std::boxed::Box; use std::fmt::Debug; use std::future::Future; @@ -126,88 +125,28 @@ where } } -impl Connection -where - DgramS: AsyncConnect, - DgramS::Connection: AsyncDgramRecv + AsyncDgramSend + Unpin, - Req: ComposeRequest + Clone, -{ - /// Sends a request and receives a response. - pub async fn request( - &self, - request: Req, - ) -> Result, Error> { - let response = self.udp_conn.query(request.clone()).await?; - if !response.header().tc() { - return Ok(response); - } - self.tcp_conn.request(request).await - } -} +//--- SendRequest -impl Connection +impl SendRequest for Connection where DgramS: AsyncConnect + Clone + Debug + Send + Sync + 'static, DgramS::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, Req: ComposeRequest + Clone + 'static, { - /// Start a request for the Request trait. - async fn request_impl( - &self, - request_msg: &Req, - ) -> Result, Error> { - Ok(Box::new(Query::new( + fn send_request(&self, request_msg: Req) -> Box { + Box::new(Request::new( request_msg, self.udp_conn.clone(), self.tcp_conn.clone(), - ))) + )) } } -//--- SendRequest and HandleRequest - -impl SendRequest for Connection -where - DgramS: AsyncConnect + Clone + Debug + Send + Sync + 'static, - DgramS::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, - Req: ComposeRequest + Clone + 'static, -{ - fn send_request<'a>( - &'a self, - request_msg: &'a Req, - ) -> Pin< - Box< - dyn Future, Error>> - + Send - + '_, - >, - > { - return Box::pin(self.request_impl(request_msg)); - } -} - -impl HandleRequest for Connection -where - DgramS: AsyncConnect + Clone + Debug + Send + Sync + 'static, - DgramS::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, - Req: ComposeRequest + Clone + 'static, -{ - type Response = Message; - type Error = Error; - type Fut<'s> = Pin> + Send + 's - >> where Self: 's; - - fn handle_request(&self, request: Req) -> Self::Fut<'_> { - self.request(request).boxed() - } -} - -//------------ Query -------------------------------------------------------- +//------------ Request -------------------------------------------------------- /// Object that contains the current state of a query. #[derive(Debug)] -pub struct Query { +pub struct Request { /// Reqeust message. request_msg: Req, @@ -237,21 +176,21 @@ enum QueryState { GetTcpResponse(Box), } -impl Query +impl Request where S: AsyncConnect + Clone + Send + Sync + 'static, Req: ComposeRequest + Clone + 'static, { - /// Create a new Query object. + /// Create a new Request object. /// /// The initial state is to start with a UDP transport. fn new( - request_msg: &Req, + request_msg: Req, udp_conn: Arc>, tcp_conn: multi_stream::Connection, - ) -> Query { + ) -> Request { Self { - request_msg: request_msg.clone(), + request_msg, udp_conn, tcp_conn, state: QueryState::StartUdpRequest, @@ -269,7 +208,7 @@ where match &mut self.state { QueryState::StartUdpRequest => { let msg = self.request_msg.clone(); - let request = self.udp_conn.send_request(&msg).await?; + let request = self.udp_conn.send_request(msg); self.state = QueryState::GetUdpResponse(request); continue; } @@ -283,7 +222,7 @@ where } QueryState::StartTcpRequest => { let msg = self.request_msg.clone(); - let request = self.tcp_conn.send_request(&msg).await?; + let request = self.tcp_conn.send_request(msg); self.state = QueryState::GetTcpResponse(request); continue; } @@ -296,7 +235,7 @@ where } } -impl GetResponse for Query +impl GetResponse for Request where S: AsyncConnect + Clone + Debug + Send + Sync + 'static, S::Connection: AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin, diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 2d5b8c365..b4d595690 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -61,7 +61,7 @@ //! # let req = domain::net::client::request::RequestMessage::new( //! # domain::base::MessageBuilder::new_vec() //! # ); -//! # let mut request = tcp_conn.send_request(&req).await.unwrap(); +//! # let mut request = tcp_conn.send_request(req); //! # } //! ``` //! Note that the run function ends when the last reference to the DNS @@ -101,7 +101,7 @@ //! # let req = domain::net::client::request::RequestMessage::new( //! # domain::base::MessageBuilder::new_vec() //! # ); -//! let mut request = tls_conn.send_request(&req).await.unwrap(); +//! let mut request = tls_conn.send_request(req); //! # } //! ``` //! where ```tls_conn``` is a transport connection for DNS over TLS. @@ -127,7 +127,7 @@ //! # let req = domain::net::client::request::RequestMessage::new( //! # domain::base::MessageBuilder::new_vec() //! # ); -//! # let mut request = tls_conn.send_request(&req).await.unwrap(); +//! # let mut request = tls_conn.send_request(req); //! let reply = request.get_response().await; //! # } //! ``` diff --git a/src/net/client/multi_stream.rs b/src/net/client/multi_stream.rs index e55ec301a..a5690a464 100644 --- a/src/net/client/multi_stream.rs +++ b/src/net/client/multi_stream.rs @@ -9,12 +9,12 @@ use crate::base::Message; use crate::net::client::protocol::AsyncConnect; use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, HandleRequest, SendRequest, + ComposeRequest, Error, GetResponse, SendRequest, }; use crate::net::client::stream; use bytes::Bytes; use futures_util::stream::FuturesUnordered; -use futures_util::{FutureExt, StreamExt}; +use futures_util::StreamExt; use rand::random; use std::boxed::Box; use std::fmt::Debug; @@ -87,13 +87,13 @@ impl Connection { } } -impl Connection { +impl Connection { /// Sends a request and receives a response. pub async fn request( &self, request: Req, ) -> Result, Error> { - Query::new(self.clone(), request).get_response().await + Request::new(self.clone(), request).get_response().await } /// Starts a request. @@ -106,7 +106,7 @@ impl Connection { where Req: 'static, { - let gr = Query::new(self.clone(), request.clone()); + let gr = Request::new(self.clone(), request.clone()); Ok(Box::new(gr)) } @@ -153,46 +153,22 @@ impl Clone for Connection { } } -//--- SendRequest and HandleRequest +//--- SendRequest impl SendRequest for Connection where Req: ComposeRequest + Clone + 'static, { - fn send_request<'a>( - &'a self, - request: &'a Req, - ) -> Pin< - Box< - dyn Future, Error>> - + Send - + '_, - >, - > { - return Box::pin(self._send_request(request)); - } -} - -impl HandleRequest for Connection -where - Req: ComposeRequest + Clone + Send, -{ - type Response = Message; - type Error = Error; - type Fut<'s> = Pin> + Send + 's - >> where Self: 's; - - fn handle_request(&self, request: Req) -> Self::Fut<'_> { - self.request(request).boxed() + fn send_request(&self, request: Req) -> Box { + Box::new(Request::new(self.clone(), request)) } } -//------------ Query -------------------------------------------------------- +//------------ Request -------------------------------------------------------- /// The connection side of an active request. #[derive(Debug)] -struct Query { +struct Request { /// The request message. /// /// It is kept so we can compare a response with it. @@ -224,7 +200,7 @@ enum QueryState { StartQuery(Arc>), /// Get the result of the query. - GetResult(stream::Query), + GetResult(stream::Request), /// Wait until trying again. /// @@ -249,7 +225,7 @@ struct ChanRespOk { conn: Arc>, } -impl Query { +impl Request { /// Creates a new query. fn new(conn: Connection, request_msg: Req) -> Self { Self { @@ -262,7 +238,7 @@ impl Query { } } -impl Query { +impl Request { /// Get the result of a DNS request. /// /// This function is cancellation safe. If its future is dropped before @@ -311,21 +287,10 @@ impl Query { } } QueryState::StartQuery(ref mut conn) => { - let msg = self.request_msg.clone(); - let query_res = conn.start_request(msg.clone()).await; - match query_res { - Err(err) => { - if let Error::ConnectionClosed = err { - self.state = QueryState::RequestConn; - continue; - } - return Err(err); - } - Ok(query) => { - self.state = QueryState::GetResult(query); - continue; - } - } + self.state = QueryState::GetResult( + conn.get_request(self.request_msg.clone()), + ); + continue; } QueryState::GetResult(ref mut query) => { match query.get_response().await { @@ -359,7 +324,7 @@ impl Query { } } -impl GetResponse for Query { +impl GetResponse for Request { fn get_response( &mut self, ) -> Pin< diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index 3e8dfdd2c..336396da4 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -82,7 +82,7 @@ pub struct Config { //------------ Connection ----------------------------------------------------- /// This type represents a transport connection. -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Connection { /// Reference to the actual implementation of the connection. inner: Arc>, @@ -120,28 +120,61 @@ impl<'a, Req: Clone + Debug + Send + Sync + 'static> Connection { /// Implementation of the request function. async fn request_impl( - &self, - request_msg: &Req, - ) -> Result, Error> { - let request = self.inner.request(request_msg.clone()).await?; - Ok(Box::new(request)) + self, + request_msg: Req, + ) -> Result, Error> { + self.inner.request(request_msg).await?.get_response().await + } +} + +impl Clone for Connection { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } } } impl SendRequest for Connection { - fn send_request<'a>( - &'a self, - request_msg: &'a Req, + fn send_request(&self, request_msg: Req) -> Box { + Box::new(Request { + fut: Box::pin(self.clone().request_impl(request_msg)), + }) + } +} + +//------------ Request ------------------------------------------------------- + +/// An active request. +pub struct Request { + /// The underlying future. + fut: Pin, Error>> + Send>>, +} + +impl Request { + /// Async function that waits for the future stored in Query to complete. + async fn get_response_impl(&mut self) -> Result, Error> { + (&mut self.fut).await + } +} + +impl GetResponse for Request { + fn get_response( + &mut self, ) -> Pin< - Box< - dyn Future, Error>> - + Send - + '_, - >, + Box, Error>> + Send + '_>, > { - return Box::pin(self.request_impl(request_msg)); + Box::pin(self.get_response_impl()) + } +} + +impl Debug for Request { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + f.debug_struct("Request") + .field("fut", &format_args!("_")) + .finish() } } @@ -339,7 +372,7 @@ impl Query { } /// Implementation of get_response. - async fn get_response_impl(&mut self) -> Result, Error> { + async fn get_response(&mut self) -> Result, Error> { loop { match self.state { QueryState::Init => { @@ -537,16 +570,6 @@ impl Query { } } -impl GetResponse for Query { - fn get_response( - &mut self, - ) -> Pin< - Box, Error>> + Send + '_>, - > { - Box::pin(self.get_response_impl()) - } -} - //------------ Transport ----------------------------------------------------- /// Type that actually implements the connection. @@ -628,10 +651,9 @@ impl<'a, Req: Clone + Send + Sync + 'static> Transport { match opt_ind { Some(ind) => { let query = conns[ind] - .send_request(&request_req.request_msg) - .await; + .send_request(request_req.request_msg); // Don't care if send fails - let _ = request_req.tx.send(query); + let _ = request_req.tx.send(Ok(query)); } None => { // Don't care if send fails diff --git a/src/net/client/request.rs b/src/net/client/request.rs index 985d98569..b89bf0ee5 100644 --- a/src/net/client/request.rs +++ b/src/net/client/request.rs @@ -55,25 +55,6 @@ pub trait ComposeRequest: Debug + Send + Sync { fn is_answer(&self, answer: &Message<[u8]>) -> bool; } -//------------ HandleRequest ------------------------------------------------- - -/// Trait for handling a DNS request. -pub trait HandleRequest { - /// The response returned upon success. - type Response: AsRef>; - - /// The error returned upon failure. - type Error; - - /// The future producing the response. - type Fut<'s>: Future> + 's - where - Self: 's; - - /// Returns a future processing the request. - fn handle_request(&self, request_msg: Req) -> Self::Fut<'_>; -} - //------------ SendRequest --------------------------------------------------- /// Trait for starting a DNS request based on a request composer. @@ -82,18 +63,9 @@ pub trait HandleRequest { /// However, the use of 'dyn Request' in redundant currently prevents that. pub trait SendRequest { /// Request function that takes a ComposeRequest type. - /// - /// This function is intended to be cancel safe. - fn send_request<'a>( - &'a self, - request_msg: &'a CR, - ) -> Pin + Send + '_>>; + fn send_request(&self, request_msg: CR) -> Box; } -/// This type is the actual result type of the future returned by the -/// request function in the Request trait. -type RequestResultOutput = Result, Error>; - //------------ GetResponse --------------------------------------------------- /// Trait for getting the result of a DNS query. @@ -399,7 +371,7 @@ impl fmt::Display for Error { Error::NoTransportAvailable => { write!(f, "no transport available") } - Error::Dgram(err) => fmt::Display::fmt(err, f) + Error::Dgram(err) => fmt::Display::fmt(err, f), } } } diff --git a/src/net/client/stream.rs b/src/net/client/stream.rs index 6360ddc8d..c679e6c58 100644 --- a/src/net/client/stream.rs +++ b/src/net/client/stream.rs @@ -21,13 +21,12 @@ use crate::base::message::Message; use crate::base::message_builder::StreamTarget; use crate::base::opt::{AllOptData, OptRecord, TcpKeepalive}; use crate::net::client::request::{ - ComposeRequest, Error, GetResponse, HandleRequest, SendRequest, + ComposeRequest, Error, GetResponse, SendRequest, }; use bytes; use bytes::{Bytes, BytesMut}; use core::cmp; use core::convert::From; -use futures_util::FutureExt; use octseq::Octets; use std::boxed::Box; use std::fmt::Debug; @@ -138,117 +137,65 @@ impl Connection { } } -impl Connection { - /// Sends a request and receives a response. - pub async fn request( - &self, - request: Req, - ) -> Result, Error> { - let receiver = self.send_request(request).await?; - receiver.await.map_err(|_| Error::StreamReceiveError)? - } -} - -impl Connection { +impl Connection { /// Start a DNS request. /// /// This function takes a precomposed message as a parameter and /// returns a [ReqRepl] object wrapped in a [Result]. - async fn request_impl( - &self, - request: Req, - ) -> Result, Error> { - let rx = self.send_request(request).await?; - Ok(Box::new(Query::new(rx))) - } - - /// Start a DNS request. - pub async fn start_request( - &self, - query_msg: Req, - ) -> Result { - let rx = self.send_request(query_msg).await?; - Ok(Query::new(rx)) - } - - /// Sends a request. - async fn send_request( - &self, - request_msg: Req, - ) -> Result, Error> { + async fn handle_request_impl( + self, + msg: Req, + ) -> Result, Error> { let (sender, receiver) = oneshot::channel(); - let req = ChanReq { - sender, - msg: request_msg, - }; + let req = ChanReq { sender, msg }; self.sender.send(req).await.map_err(|_| { // Send error. The receiver is gone, this means that the // connection is closed. Error::ConnectionClosed })?; - Ok(receiver) + receiver.await.map_err(|_| Error::StreamReceiveError)? } -} -impl SendRequest - for Connection -{ - fn send_request<'a>( - &'a self, - request_msg: &'a Req, - ) -> Pin< - Box< - dyn Future, Error>> - + Send - + '_, - >, - > { - return Box::pin(self.request_impl(request_msg.clone())); + /// Returns a request handler for this connection. + pub fn get_request(&self, request_msg: Req) -> Request { + Request { + fut: Box::pin(self.clone().handle_request_impl(request_msg)), + } } } -impl HandleRequest for Connection { - type Response = Message; - type Error = Error; - type Fut<'s> = Pin> + Send + 's - >> where Self: 's; +impl Clone for Connection { + fn clone(&self) -> Self { + Self { + sender: self.sender.clone(), + } + } +} - fn handle_request(&self, request: Req) -> Self::Fut<'_> { - self.request(request).boxed() +impl SendRequest + for Connection +{ + fn send_request(&self, request_msg: Req) -> Box { + Box::new(self.get_request(request_msg)) } } -//------------ Query ---------------------------------------------------------- +//------------ Request ------------------------------------------------------- -/// This struct represent an active DNS request. -#[derive(Debug)] -pub struct Query { - /// The receiver for the response. - receiver: oneshot::Receiver, +/// An active request. +pub struct Request { + /// The underlying future. + fut: Pin, Error>> + Send>>, } -impl Query { - /// Constructor for [Query], takes a DNS query and a receiver for the - /// reply. - fn new(receiver: oneshot::Receiver) -> Query { - Self { receiver } - } - - /// Get the result of a DNS request. - /// - /// This function returns the reply to a DNS request wrapped in a - /// [Result]. - pub async fn get_response_impl( - &mut self, - ) -> Result, Error> { - (&mut self.receiver) - .await - .map_err(|_| Error::StreamReceiveError)? +impl Request { + /// Async function that waits for the future stored in Request to complete. + async fn get_response_impl(&mut self) -> Result, Error> { + (&mut self.fut).await } } -impl GetResponse for Query { +impl GetResponse for Request { fn get_response( &mut self, ) -> Pin< @@ -258,6 +205,14 @@ impl GetResponse for Query { } } +impl Debug for Request { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Request") + .field("fut", &format_args!("_")) + .finish() + } +} + //------------ Transport ----------------------------------------------------- /// The underlying machinery of a stream transport. @@ -273,20 +228,20 @@ pub struct Transport { receiver: mpsc::Receiver>, } -/// A message from a `Query` to start a new request. +/// A message from a `Request` to start a new request. #[derive(Debug)] struct ChanReq { /// DNS request message msg: Req, - /// Sender to send result back to [Query] + /// Sender to send result back to [Request] sender: ReplySender, } /// This is the type of sender in [ChanReq]. type ReplySender = oneshot::Sender; -/// A message back to `Query` returning a response. +/// A message back to `Request` returning a response. type ChanResp = Result, Error>; /// Internal datastructure of [Transport::run] to keep track of diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index b3c0e1f27..29e9ea078 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -413,7 +413,7 @@ impl<'a> Query<'a> { let request_msg = RequestMessage::new(msg); let transport = self.resolver.get_transport().await; - let mut gr_fut = transport.send_request(&request_msg).await.unwrap(); + let mut gr_fut = transport.send_request(request_msg); let reply = timeout(self.resolver.options.timeout, gr_fut.get_response()) .await diff --git a/tests/net-client.rs b/tests/net-client.rs index ebd9341ef..095450ac5 100644 --- a/tests/net-client.rs +++ b/tests/net-client.rs @@ -1,8 +1,8 @@ #![cfg(feature = "net")] mod net; +use crate::net::deckard::client::do_client; use crate::net::deckard::client::CurrStepValue; -use crate::net::deckard::client::{closure_do_client, do_client}; use crate::net::deckard::connect::Connect; use crate::net::deckard::connection::Connection; use crate::net::deckard::dgram::Dgram; @@ -29,7 +29,7 @@ fn dgram() { let step_value = Arc::new(CurrStepValue::new()); let conn = Dgram::new(deckard.clone(), step_value.clone()); - let octstr = Arc::new(dgram::Connection::new(conn)); + let octstr = dgram::Connection::new(conn); do_client(&deckard, octstr, &step_value).await; }); @@ -146,37 +146,3 @@ fn tcp() { do_client(&deckard, tcp, &CurrStepValue::new()).await; }); } - -#[test] -#[ignore] -// Connect directly to the internet. Disabled by default. -fn tcp_async_fn() { - tokio_test::block_on(async { - let file = File::open(TEST_FILE).unwrap(); - let deckard = parse_file(file); - - let server_addr = - SocketAddr::new(IpAddr::from_str("9.9.9.9").unwrap(), 53); - - let tcp_conn = match TcpStream::connect(server_addr).await { - Ok(conn) => conn, - Err(err) => { - println!( - "TCP Connection to {server_addr} failed: {err}, exiting" - ); - return; - } - }; - - let (tcp, transport) = stream::Connection::new(tcp_conn); - tokio::spawn(async move { - transport.run().await; - println!("single TCP run terminated"); - }); - - closure_do_client(&deckard, &CurrStepValue::new(), |req| { - tcp.request(req) - }) - .await; - }) -} diff --git a/tests/net/deckard/client.rs b/tests/net/deckard/client.rs index 17db8ef55..5929c9d7e 100644 --- a/tests/net/deckard/client.rs +++ b/tests/net/deckard/client.rs @@ -4,43 +4,9 @@ use crate::net::deckard::parse_query; use bytes::Bytes; use domain::base::{Message, MessageBuilder}; -use domain::net::client::request::{Error, RequestMessage, SendRequest}; -use std::future::Future; +use domain::net::client::request::{RequestMessage, SendRequest}; use std::sync::Mutex; -pub async fn closure_do_client( - deckard: &Deckard, - step_value: &CurrStepValue, - request: F, -) where - F: Fn(RequestMessage>) -> Fut, - Fut: Future, Error>>, -{ - let mut resp: Option> = None; - - // Assume steps are in order. Maybe we need to define that. - for step in &deckard.scenario.steps { - step_value.set(step.step_value); - match step.step_type { - StepType::Query => { - let reqmsg = entry2reqmsg(step.entry.as_ref().unwrap()); - resp = Some(request(reqmsg).await.unwrap()); - } - StepType::CheckAnswer => { - let answer = resp.take().unwrap(); - if !match_msg(step.entry.as_ref().unwrap(), &answer, true) { - panic!("reply failed"); - } - } - StepType::TimePasses - | StepType::Traffic - | StepType::CheckTempfile - | StepType::Assign => todo!(), - } - } - println!("Done"); -} - pub async fn do_client>>>( deckard: &Deckard, request: R, @@ -54,7 +20,7 @@ pub async fn do_client>>>( match step.step_type { StepType::Query => { let reqmsg = entry2reqmsg(step.entry.as_ref().unwrap()); - let mut req = request.send_request(&reqmsg).await.unwrap(); + let mut req = request.send_request(reqmsg); resp = Some(req.get_response().await.unwrap()); } StepType::CheckAnswer => { From 5a9f862c98289660f85d42e46ee2a0d0202edfa2 Mon Sep 17 00:00:00 2001 From: Martin Hoffmann Date: Tue, 9 Jan 2024 17:10:56 +0100 Subject: [PATCH 120/124] Change creation of redundant transport to new model. --- examples/client-transports.rs | 4 +- src/net/client/redundant.rs | 131 ++++++++++------------------------ src/resolv/stub/mod.rs | 4 +- tests/net-client.rs | 4 +- 4 files changed, 43 insertions(+), 100 deletions(-) diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 1ef379c92..acda31f91 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -163,10 +163,10 @@ async fn main() { drop(request); // Create a transport connection for redundant connections. - let redun = redundant::Connection::new(None).unwrap(); + let (redun, transp) = redundant::Connection::new(); // Start the run function on a separate task. - let run_fut = redun.run(); + let run_fut = transp.run(); tokio::spawn(async move { run_fut.await; println!("redundant run terminated"); diff --git a/src/net/client/redundant.rs b/src/net/client/redundant.rs index 336396da4..d0e9a4d95 100644 --- a/src/net/client/redundant.rs +++ b/src/net/client/redundant.rs @@ -17,8 +17,6 @@ use std::cmp::Ordering; use std::fmt::{Debug, Formatter}; use std::future::Future; use std::pin::Pin; -use std::sync::Arc; -use std::sync::Mutex; use std::vec::Vec; use tokio::sync::{mpsc, oneshot}; @@ -67,7 +65,7 @@ const PROBE_RT: Duration = Duration::from_millis(1); //------------ Config --------------------------------------------------------- /// User configuration variables. -#[derive(Clone, Debug, Default)] +#[derive(Clone, Copy, Debug, Default)] pub struct Config { /// Defer transport errors. pub defer_transport_error: bool, @@ -84,30 +82,23 @@ pub struct Config { /// This type represents a transport connection. #[derive(Debug)] pub struct Connection { - /// Reference to the actual implementation of the connection. - inner: Arc>, + /// User configuation. + config: Config, + + /// To send a request to the runner. + sender: mpsc::Sender>, } -impl<'a, Req: Clone + Debug + Send + Sync + 'static> Connection { +impl Connection { /// Create a new connection. - pub fn new(config: Option) -> Result { - let config = match config { - Some(config) => { - check_config(&config)?; - config - } - None => Default::default(), - }; - let connection = Transport::new(config)?; - //test_send(connection); - Ok(Self { - inner: Arc::new(connection), - }) + pub fn new() -> (Self, Transport) { + Self::with_config(Default::default()) } - /// Runner function for a connection. - pub fn run(&self) -> Pin + Send>> { - self.inner.run() + /// Create a new connection with a given config. + pub fn with_config(config: Config) -> (Self, Transport) { + let (sender, receiver) = mpsc::channel(DEF_CHAN_CAP); + (Self { config, sender }, Transport::new(receiver)) } /// Add a transport connection. @@ -115,22 +106,36 @@ impl<'a, Req: Clone + Debug + Send + Sync + 'static> Connection { &self, conn: Box + Send + Sync>, ) -> Result<(), Error> { - self.inner.add(conn).await + let (tx, rx) = oneshot::channel(); + self.sender + .send(ChanReq::Add(AddReq { conn, tx })) + .await + .expect("send should not fail"); + rx.await.expect("receive should not fail") } - /// Implementation of the request function. + /// Implementation of the query method. async fn request_impl( self, request_msg: Req, ) -> Result, Error> { - self.inner.request(request_msg).await?.get_response().await + let (tx, rx) = oneshot::channel(); + self.sender + .send(ChanReq::GetRT(RTReq { tx })) + .await + .expect("send should not fail"); + let conn_rt = rx.await.expect("receive should not fail")?; + Query::new(self.config, request_msg, conn_rt, self.sender.clone()) + .get_response() + .await } } impl Clone for Connection { fn clone(&self) -> Self { Self { - inner: self.inner.clone(), + config: self.config, + sender: self.sender.clone(), } } } @@ -574,51 +579,27 @@ impl Query { /// Type that actually implements the connection. #[derive(Debug)] -struct Transport { - /// User configuation. - config: Config, - +pub struct Transport { /// Receive side of the channel used by the runner. - receiver: Mutex>>>, - - /// To send a request to the runner. - sender: mpsc::Sender>, + receiver: mpsc::Receiver>, } impl<'a, Req: Clone + Send + Sync + 'static> Transport { /// Implementation of the new method. - fn new(config: Config) -> Result { - let (tx, rx) = mpsc::channel(DEF_CHAN_CAP); - Ok(Self { - config, - receiver: Mutex::new(Some(rx)), - sender: tx, - }) + fn new(receiver: mpsc::Receiver>) -> Self { + Self { receiver } } /// Run method. - /// - /// Make sure the future does not contain a reference to self. - fn run(&self) -> Pin + Send>> { - let mut receiver = self.receiver.lock().unwrap(); - let opt_receiver = receiver.take(); - drop(receiver); - - Box::pin(Self::run_impl(opt_receiver)) - } - - /// Implementation of the run method. - async fn run_impl(opt_receiver: Option>>) { + pub async fn run(mut self) { let mut next_id: u64 = 10; let mut conn_stats: Vec = Vec::new(); let mut conn_rt: Vec = Vec::new(); let mut conns: Vec + Send + Sync>> = Vec::new(); - let mut receiver = - opt_receiver.expect("receiver should not be empty"); loop { - let req = match receiver.recv().await { + let req = match self.receiver.recv().await { Some(req) => req, None => break, // All references to connection objects are // dropped. Shutdown. @@ -708,38 +689,6 @@ impl<'a, Req: Clone + Send + Sync + 'static> Transport { } } } - - /// Implementation of the add method. - async fn add( - &self, - conn: Box + Send + Sync>, - ) -> Result<(), Error> { - let (tx, rx) = oneshot::channel(); - self.sender - .send(ChanReq::Add(AddReq { conn, tx })) - .await - .expect("send should not fail"); - rx.await.expect("receive should not fail") - } - - /// Implementation of the query method. - async fn request( - &'a self, - request_msg: Req, - ) -> Result, Error> { - let (tx, rx) = oneshot::channel(); - self.sender - .send(ChanReq::GetRT(RTReq { tx })) - .await - .expect("send should not fail"); - let conn_rt = rx.await.expect("receive should not fail")?; - Ok(Query::new( - self.config.clone(), - request_msg, - conn_rt, - self.sender.clone(), - )) - } } //------------ Utility -------------------------------------------------------- @@ -811,9 +760,3 @@ fn get_opt_rcode(msg: &Message) -> OptRcode { } } } - -/// Check if config is valid. -fn check_config(_config: &Config) -> Result<(), Error> { - // Nothing to check at the moment. - Ok(()) -} diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index 29e9ea078..c405558fa 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -126,10 +126,10 @@ impl StubResolver { &self, ) -> redundant::Connection { // Create a redundant transport and fill it with the right transports - let redun = redundant::Connection::new(None).unwrap(); + let (redun, transp) = redundant::Connection::new(); // Start the run function on a separate task. - let redun_run_fut = redun.run(); + let redun_run_fut = transp.run(); // It would be nice to have just one task. However redun.run() has to // execute before we can call redun.add(). However, we need to know diff --git a/tests/net-client.rs b/tests/net-client.rs index 095450ac5..10a7ab467 100644 --- a/tests/net-client.rs +++ b/tests/net-client.rs @@ -104,8 +104,8 @@ fn redundant() { }); // Redundant add previous connection. - let redun = redundant::Connection::new(None).unwrap(); - let run_fut = redun.run(); + let (redun, transp) = redundant::Connection::new(); + let run_fut = transp.run(); tokio::spawn(async move { run_fut.await; println!("redundant conn run terminated"); From 1ed6117be1f8a0fc94331ac17a010bb58f2c6773 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 11 Jan 2024 15:22:27 +0100 Subject: [PATCH 121/124] write_all is not cancel safe, use write instead. Change deckard to accept one byte at a time. --- src/net/client/stream.rs | 30 +++++++++++++++++++----------- tests/net/deckard/connection.rs | 18 +++++++++++++----- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/net/client/stream.rs b/src/net/client/stream.rs index c679e6c58..3e3b05817 100644 --- a/src/net/client/stream.rs +++ b/src/net/client/stream.rs @@ -335,6 +335,7 @@ where let mut query_vec = Queries::new(); let mut reqmsg: Option> = None; + let mut reqmsg_offset = 0; loop { let opt_timeout = match status.state { @@ -430,18 +431,25 @@ where drop(opt_record); Self::demux_reply(answer, &mut status, &mut query_vec); } - res = write_stream.write_all(msg), + res = write_stream.write(&msg[reqmsg_offset..]), if do_write => { - if let Err(error) = res { - let error = Error::StreamWriteError(Arc::new(error)); - Self::error(error.clone(), &mut query_vec); - status.state = - ConnState::WriteError(error); - break; - } - else { - reqmsg = None; - } + match res { + Err(error) => { + let error = + Error::StreamWriteError(Arc::new(error)); + Self::error(error.clone(), &mut query_vec); + status.state = + ConnState::WriteError(error); + break; + } + Ok(len) => { + reqmsg_offset += len; + if reqmsg_offset >= msg.len() { + reqmsg = None; + reqmsg_offset = 0; + } + } + } } res = recv_fut, if !do_write => { match res { diff --git a/tests/net/deckard/connection.rs b/tests/net/deckard/connection.rs index ff459141b..d106a5395 100644 --- a/tests/net/deckard/connection.rs +++ b/tests/net/deckard/connection.rs @@ -18,6 +18,8 @@ pub struct Connection { waker: Option, reply: Option>>, send_body: bool, + + tmpbuf: Vec, } impl Connection { @@ -31,6 +33,7 @@ impl Connection { waker: None, reply: None, send_body: false, + tmpbuf: Vec::new(), } } } @@ -67,14 +70,19 @@ impl AsyncWrite for Connection { _: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let buflen = buf.len(); + self.tmpbuf.push(buf[0]); + let buflen = self.tmpbuf.len(); + if buflen < 2 { + return Poll::Ready(Ok(1)); + } let mut len_str: [u8; 2] = [0; 2]; - len_str.copy_from_slice(&buf[0..2]); + len_str.copy_from_slice(&self.tmpbuf[0..2]); let len = u16::from_be_bytes(len_str) as usize; if buflen != 2 + len { - panic!("expecting one complete message per write"); + return Poll::Ready(Ok(1)); } - let msg = Message::from_octets(buf[2..].to_vec()).unwrap(); + let msg = Message::from_octets(self.tmpbuf[2..].to_vec()).unwrap(); + self.tmpbuf = Vec::new(); let opt_reply = do_server(&msg, &self.deckard, &self.step_value); if opt_reply.is_some() { // Do we need to support more than one reply? @@ -84,7 +92,7 @@ impl AsyncWrite for Connection { waker.wake(); } } - Poll::Ready(Ok(buflen)) + Poll::Ready(Ok(1)) } fn poll_flush( self: Pin<&mut Self>, From bfb5399376b19388d67e3e4b849a369e50382e8b Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Thu, 11 Jan 2024 17:35:45 +0100 Subject: [PATCH 122/124] Get rid of unwrap. --- src/resolv/stub/mod.rs | 73 +++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index c405558fa..2f6fad373 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -14,9 +14,7 @@ use self::conf::{ }; use crate::base::iana::Rcode; use crate::base::message::Message; -use crate::base::message_builder::{ - AdditionalBuilder, MessageBuilder, StreamTarget, -}; +use crate::base::message_builder::{AdditionalBuilder, MessageBuilder}; use crate::base::name::{ToDname, ToRelativeDname}; use crate::base::question::Question; use crate::net::client::dgram_stream; @@ -24,7 +22,7 @@ use crate::net::client::multi_stream; use crate::net::client::protocol::{TcpConnect, UdpConnect}; use crate::net::client::redundant; use crate::net::client::request::{ - ComposeRequest, RequestMessage, SendRequest, + ComposeRequest, Error, RequestMessage, SendRequest, }; use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs}; use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts}; @@ -38,6 +36,7 @@ use std::fmt::Debug; use std::future::Future; use std::net::IpAddr; use std::pin::Pin; +use std::string::ToString; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::vec::Vec; @@ -124,7 +123,7 @@ impl StubResolver { CR: Clone + Debug + ComposeRequest + Send + Sync + 'static, >( &self, - ) -> redundant::Connection { + ) -> Result, Error> { // Create a redundant transport and fill it with the right transports let (redun, transp) = redundant::Connection::new(); @@ -156,9 +155,9 @@ impl StubResolver { // Start the run function on a separate task. let run_fut = tran.run(); fut_list_tcp.push(async move { - let _res = run_fut.await; + run_fut.await; }); - redun.add(Box::new(conn)).await.unwrap(); + redun.add(Box::new(conn)).await?; } } } else { @@ -174,7 +173,7 @@ impl StubResolver { fut_list_udp_tcp.push(async move { tran.run().await; }); - redun.add(Box::new(conn)).await.unwrap(); + redun.add(Box::new(conn)).await?; } } } @@ -183,20 +182,22 @@ impl StubResolver { run(fut_list_tcp, fut_list_udp_tcp).await; }); - redun + Ok(redun) } async fn get_transport( &self, - ) -> redundant::Connection>> { + ) -> Result>>, Error> { let mut opt_transport = self.transport.lock().await; - if opt_transport.is_none() { - let transport = self.setup_transport().await; - *opt_transport = Some(transport); + match &*opt_transport { + Some(transport) => Ok(transport.clone()), + None => { + let transport = self.setup_transport().await?; + *opt_transport = Some(transport.clone()); + Ok(transport) + } } - - (*opt_transport).as_ref().unwrap().clone() } } @@ -269,10 +270,10 @@ impl StubResolver { /// The only argument is a closure taking a reference to a `StubResolver` /// and returning a future. Whatever that future resolves to will be /// returned. - pub fn run(op: F) -> R::Output + pub fn run(op: F) -> R::Output where - R: Future + Send + 'static, - R::Output: Send + 'static, + R: Future> + Send + 'static, + E: From, F: FnOnce(StubResolver) -> R + Send + 'static, { Self::run_with_conf(ResolvConf::default(), op) @@ -284,17 +285,16 @@ impl StubResolver { /// tailor-making your own resolver. /// /// [`run()`]: #method.run - pub fn run_with_conf(conf: ResolvConf, op: F) -> R::Output + pub fn run_with_conf(conf: ResolvConf, op: F) -> R::Output where - R: Future + Send + 'static, - R::Output: Send + 'static, + R: Future> + Send + 'static, + E: From, F: FnOnce(StubResolver) -> R + Send + 'static, { let resolver = Self::from_conf(conf); let runtime = runtime::Builder::new_current_thread() .enable_all() - .build() - .unwrap(); + .build()?; runtime.block_on(op(resolver)) } } @@ -391,13 +391,11 @@ impl<'a> Query<'a> { } fn create_message(question: Question) -> QueryMessage { - let mut message = MessageBuilder::from_target( - StreamTarget::new(Default::default()).unwrap(), - ) - .unwrap(); + let mut message = MessageBuilder::from_target(Default::default()) + .expect("MessageBuilder should not fail"); message.header_mut().set_rd(true); let mut message = message.question(); - message.push(question).unwrap(); + message.push(question).expect("push should not fail"); message.additional() } @@ -405,20 +403,21 @@ impl<'a> Query<'a> { &mut self, message: &mut QueryMessage, ) -> Result { - let msg = Message::from_octets( - message.as_target().as_dgram_slice().to_vec(), - ) - .unwrap(); + let msg = Message::from_octets(message.as_target().to_vec()) + .expect("Message::from_octets should not fail"); let request_msg = RequestMessage::new(msg); - let transport = self.resolver.get_transport().await; + let transport = self.resolver.get_transport().await.map_err(|e| { + io::Error::new(io::ErrorKind::Other, e.to_string()) + })?; let mut gr_fut = transport.send_request(request_msg); let reply = timeout(self.resolver.options.timeout, gr_fut.get_response()) - .await - .unwrap() - .unwrap(); + .await? + .map_err(|e| { + io::Error::new(io::ErrorKind::Other, e.to_string()) + })?; Ok(Answer { message: reply }) } @@ -447,7 +446,7 @@ impl<'a> Query<'a> { //------------ QueryMessage -------------------------------------------------- // XXX This needs to be re-evaluated if we start adding OPTtions to the query. -pub(super) type QueryMessage = AdditionalBuilder>>; +pub(super) type QueryMessage = AdditionalBuilder>; //------------ Answer -------------------------------------------------------- From 54021f1bc81c339df17bce52d72099a95dd5195e Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 12 Jan 2024 14:29:43 +0100 Subject: [PATCH 123/124] Improve documentation --- src/net/client/mod.rs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index b4d595690..85e162094 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -9,13 +9,13 @@ //! 3) Sending the request, and //! 4) Receiving the reply. //! -//! The first and second step are independent and happen in any order. +//! The first and second step are independent and can happen in any order. //! The third step uses the resuts of the first and second step. //! Finally, the fourth step uses the result of the third step. //! # Creating a request message //! -//! The DNS transport protocols expect a request message that implements +//! The DNS transport protocols expect a request message that implements the //! [ComposeRequest][request::ComposeRequest] trait. //! This trait allows transports to add ENDS(0) options, set flags, etc. //! The [RequestMessage][request::RequestMessage] type implements this trait. @@ -64,13 +64,8 @@ //! # let mut request = tcp_conn.send_request(req); //! # } //! ``` -//! Note that the run function ends when the last reference to the DNS -//! transport is dropped. For this reason it is important to avoid having a -//! reference to the transport end up in the task. Only pass the future -//! returned by the run function to the task. -//! -//! The currently implemented DNS transport have the following layering. At -//! the lower layer are [dgram] and [stream]. The dgram transport is used for +//! The currently implemented DNS transports have the following layering. At +//! the lowest layer are [dgram] and [stream]. The dgram transport is used for //! DNS over UDP, the stream transport is used for DNS over a single TCP or //! TLS connection. The transport works as long as the connection continuous //! to exist. From a08366e246a4bc777e36bc191502f18a9c451c1e Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Fri, 12 Jan 2024 14:31:12 +0100 Subject: [PATCH 124/124] Remove TryQueryError and remove pub from handle_request_impl. --- src/net/client/dgram.rs | 44 +---------------------------------------- 1 file changed, 1 insertion(+), 43 deletions(-) diff --git a/src/net/client/dgram.rs b/src/net/client/dgram.rs index b57d9c9b3..a6197e639 100644 --- a/src/net/client/dgram.rs +++ b/src/net/client/dgram.rs @@ -221,7 +221,7 @@ where /// Sends the provided and returns either a response or an error. If there /// are currently too many active queries, the future will wait until the /// number has dropped below the limit. - pub async fn handle_request_impl( + async fn handle_request_impl( self, mut request: Req, ) -> Result, Error> { @@ -508,45 +508,3 @@ impl fmt::Display for QueryErrorKind { }) } } - -//------------ TryQueryError ------------------------------------------------- - -/// An attempted query failed -/// -/// This error is returned by [`Connection::try_query`]. -pub enum TryQueryError { - /// The query has failed with the given error. - Request(QueryError), - - /// There were too many active queries. - /// - /// This variant contains the original request unchanged. - TooManyQueries(Req), -} - -impl fmt::Debug for TryQueryError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Request(err) => { - f.debug_tuple("TryQueryError::Request").field(err).finish() - } - Self::TooManyQueries(_) => f - .debug_tuple("TryQueryError::Req") - .field(&format_args!("_")) - .finish(), - } - } -} - -impl fmt::Display for TryQueryError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Request(error) => error.fmt(f), - Self::TooManyQueries(_) => { - f.write_str("too many active requests") - } - } - } -} - -impl error::Error for TryQueryError {}