From fc3100dca9733dfb1e473ccc855528b00ac6da89 Mon Sep 17 00:00:00 2001 From: Philip Homburg Date: Mon, 7 Aug 2023 10:29:58 +0200 Subject: [PATCH] UDP and UDP+TCP client transports --- src/net/client/mod.rs | 2 + src/net/client/udp.rs | 290 ++++++++++++++++++++++++++++++++++++++ src/net/client/udp_tcp.rs | 211 +++++++++++++++++++++++++++ 3 files changed, 503 insertions(+) create mode 100644 src/net/client/udp.rs create mode 100644 src/net/client/udp_tcp.rs diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 8a429ed80..61cf20ebb 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -9,3 +9,5 @@ pub mod tcp_channel; pub mod tcp_factory; pub mod tcp_mutex; pub mod tls_factory; +pub mod udp; +pub mod udp_tcp; diff --git a/src/net/client/udp.rs b/src/net/client/udp.rs new file mode 100644 index 000000000..5cb94e81f --- /dev/null +++ b/src/net/client/udp.rs @@ -0,0 +1,290 @@ +//! A DNS over UDP transport + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +// To do: +// - cookies +// - random port + +use bytes::Bytes; +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::UdpSocket; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio::time::{timeout, Duration, Instant}; + +use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; + +/// How many times do we try a new random port if we get ‘address in use.’ +const RETRY_RANDOM_PORT: usize = 10; + +/// Maximum number of parallel DNS query over a single UDP transport +/// connection. +const MAX_PARALLEL: usize = 100; + +/// Maximum amount of time to wait for a reply. +const READ_TIMEOUT: Duration = Duration::from_secs(5); + +/// Maximum number of retries after timeouts. +const MAX_RETRIES: u8 = 5; + +/// A UDP transport connection. +#[derive(Clone)] +pub struct Connection { + /// Reference to the actual connection object. + inner: Arc, +} + +impl Connection { + /// Create a new UDP transport connection. + pub fn new(remote_addr: SocketAddr) -> io::Result { + let connection = InnerConnection::new(remote_addr)?; + Ok(Self { + inner: Arc::new(connection), + }) + } + + /// Start a new DNS query. + pub async fn query + Clone>( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result, &'static str> { + self.inner.query(query_msg, self.clone()).await + } + + /// Get a permit from the semaphore to start using a socket. + async fn get_permit(&self) -> OwnedSemaphorePermit { + self.inner.get_permit().await + } +} + +/// State of the DNS query. +enum QueryState { + /// Get a semaphore permit. + GetPermit(Connection), + + /// Get a UDP socket. + GetSocket, + + /// Connect the socket. + Connect, + + /// Send the request. + Send, + + /// Receive the reply. + Receive(Instant), +} + +/// The state of a DNS query. +pub struct Query { + /// Address of remote server to connect to. + remote_addr: SocketAddr, + + /// DNS request message. + query_msg: MessageBuilder>>, + + /// Semaphore permit that allow use of socket. + _permit: Option, + + /// UDP socket for communication. + sock: Option, + + /// Current number of retries. + retries: u8, + + /// State of query. + state: QueryState, +} + +impl + Clone> Query { + /// Create new Query object. + fn new( + query_msg: &mut MessageBuilder>>, + remote_addr: SocketAddr, + conn: Connection, + ) -> Query { + Query { + query_msg: query_msg.clone(), + remote_addr, + _permit: None, + sock: None, + retries: 0, + state: QueryState::GetPermit(conn), + } + } + + /// Get the result of a DNS Query. + /// + /// This function is cancel safe. + pub async fn get_result( + &mut self, + ) -> Result, Arc> { + let recv_size = 2000; // Should be configurable. + + loop { + match &self.state { + QueryState::GetPermit(conn) => { + // We need to get past the semaphore that limits the + // number of concurrent sockets we can use. + let permit = conn.get_permit().await; + self._permit = Some(permit); + self.state = QueryState::GetSocket; + continue; + } + QueryState::GetSocket => { + self.sock = Some( + Self::udp_bind(self.remote_addr.is_ipv4()).await?, + ); + self.state = QueryState::Connect; + continue; + } + QueryState::Connect => { + self.sock + .as_ref() + .expect("socket should be present") + .connect(self.remote_addr) + .await?; + self.state = QueryState::Send; + continue; + } + QueryState::Send => { + let sent = self + .sock + .as_ref() + .expect("socket should be present") + .send( + self.query_msg + .as_target() + .as_target() + .as_dgram_slice(), + ) + .await?; + if sent + != self + .query_msg + .as_target() + .as_target() + .as_dgram_slice() + .len() + { + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "short UDP send", + ))); + } + self.state = QueryState::Receive(Instant::now()); + continue; + } + QueryState::Receive(start) => { + let elapsed = start.elapsed(); + if elapsed > READ_TIMEOUT { + todo!(); + } + let remain = READ_TIMEOUT - elapsed; + + let mut buf = vec![0; recv_size]; // XXX use uninit'ed mem here. + let timeout_res = timeout( + remain, + self.sock + .as_ref() + .expect("socket should be present") + .recv(&mut buf), + ) + .await; + if timeout_res.is_err() { + self.retries += 1; + if self.retries < MAX_RETRIES { + self.sock = None; + self.state = QueryState::GetSocket; + continue; + } + return Err(Arc::new(io::Error::new( + io::ErrorKind::Other, + "no response", + ))); + } + let len = + timeout_res.expect("errror case is checked above")?; + buf.truncate(len); + + // We ignore garbage since there is a timer on this whole thing. + let answer = match Message::from_octets(buf.into()) { + Ok(answer) => answer, + Err(_) => continue, + }; + if !answer.is_answer(&self.query_msg.as_message()) { + continue; + } + self.sock = None; + self._permit = None; + return Ok(answer); + } + } + } + } + + /// Bind to a local UDP port. + /// + /// This should explicitly pick a random number in a suitable range of + /// ports. + async fn udp_bind(v4: bool) -> Result { + let mut i = 0; + loop { + let local: SocketAddr = if v4 { + ([0u8; 4], 0).into() + } else { + ([0u16; 8], 0).into() + }; + match UdpSocket::bind(&local).await { + Ok(sock) => return Ok(sock), + Err(err) => { + if i == RETRY_RANDOM_PORT { + return Err(err); + } else { + i += 1 + } + } + } + } + } +} + +/// Actual implementation of the UDP transport connection. +struct InnerConnection { + /// Address of the remote server. + remote_addr: SocketAddr, + + /// Semaphore to limit access to UDP sockets. + semaphore: Arc, +} + +impl InnerConnection { + /// Create new InnerConnection object. + fn new(remote_addr: SocketAddr) -> io::Result { + Ok(Self { + remote_addr, + semaphore: Arc::new(Semaphore::new(MAX_PARALLEL)), + }) + } + + /// Return a Query object that contains the query state. + async fn query + Clone>( + &self, + query_msg: &mut MessageBuilder>>, + conn: Connection, + ) -> Result, &'static str> { + Ok(Query::new(query_msg, self.remote_addr, conn)) + } + + /// Return a permit for a our semaphore. + async fn get_permit(&self) -> OwnedSemaphorePermit { + self.semaphore + .clone() + .acquire_owned() + .await + .expect("the semaphore has not been closed") + } +} diff --git a/src/net/client/udp_tcp.rs b/src/net/client/udp_tcp.rs new file mode 100644 index 000000000..69fa63c87 --- /dev/null +++ b/src/net/client/udp_tcp.rs @@ -0,0 +1,211 @@ +//! A UDP transport that falls back to TCP if the reply is truncated + +#![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] + +// To do: +// - handle shutdown + +use bytes::Bytes; +use octseq::OctetsBuilder; +use std::fmt::Debug; +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; + +use crate::base::wire::Composer; +use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget}; +use crate::net::client::multi_stream; +use crate::net::client::tcp_factory::TcpConnFactory; +use crate::net::client::udp; + +/// DNS transport connection that first issue a query over a UDP transport and +/// falls back to TCP if the reply is truncated. +#[derive(Clone)] +pub struct Connection { + /// Reference to the real object that provides the connection. + inner: Arc>, +} + +impl + Connection +{ + /// Create a new connection. + pub fn new(remote_addr: SocketAddr) -> io::Result> { + let connection = InnerConnection::new(remote_addr)?; + Ok(Self { + inner: Arc::new(connection), + }) + } + + /// Worker function for a connection object. + pub async fn run(&self) -> Option<()> { + self.inner.run().await + } + + /// Start a query. + pub async fn query( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result, &'static str> { + self.inner.query(query_msg).await + } +} + +/// Object that contains the current state of a query. +pub struct Query { + /// Reqeust message. + query_msg: MessageBuilder>>, + + /// UDP transport to be used. + udp_conn: udp::Connection, + + /// TCP transport to be used. + tcp_conn: multi_stream::Connection, + + /// Current state of the query. + state: QueryState, +} + +/// Status of the query. +enum QueryState { + /// Start a query over the UDP transport. + StartUdpQuery, + + /// Get the result from the UDP transport. + GetUdpResult(udp::Query), + + /// Start a query over the TCP transport. + StartTcpQuery, + + /// Get the result from the TCP transport. + GetTcpResult(multi_stream::Query), +} + +impl< + Octs: AsMut<[u8]> + + AsRef<[u8]> + + Clone + + Composer + + Debug + + OctetsBuilder + + Send + + 'static, + > Query +{ + /// Create a new Query object. + /// + /// The initial state is to start with a UDP transport. + fn new( + query_msg: &mut MessageBuilder>>, + udp_conn: udp::Connection, + tcp_conn: multi_stream::Connection, + ) -> Query { + Query { + query_msg: query_msg.clone(), + udp_conn, + tcp_conn, + state: QueryState::StartUdpQuery, + } + } + + /// Get the result of a DNS query. + /// + /// This function is cancel safe. + pub async fn get_result( + &mut self, + ) -> Result, Arc> { + loop { + match &mut self.state { + QueryState::StartUdpQuery => { + let query = self + .udp_conn + .query(&mut self.query_msg.clone()) + .await + .map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + })?; + self.state = QueryState::GetUdpResult(query); + continue; + } + QueryState::GetUdpResult(ref mut query) => { + let reply = query.get_result().await?; + if reply.header().tc() { + self.state = QueryState::StartTcpQuery; + continue; + } + return Ok(reply); + } + QueryState::StartTcpQuery => { + let query = self + .tcp_conn + .query(&mut self.query_msg.clone()) + .await + .map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + })?; + self.state = QueryState::GetTcpResult(query); + continue; + } + QueryState::GetTcpResult(ref mut query) => { + let reply = query.get_result().await?; + return Ok(reply); + } + } + } + } +} + +/// The actual connection object. +struct InnerConnection { + /// The remote address to connect to. + remote_addr: SocketAddr, + + /// The UDP transport connection. + udp_conn: udp::Connection, + + /// The TCP transport connection. + tcp_conn: multi_stream::Connection, +} + +impl + InnerConnection +{ + /// Create a new InnerConnection object. + /// + /// Create the UDP and TCP connections. Store the remote address because + /// run needs it later. + fn new(remote_addr: SocketAddr) -> io::Result> { + let udp_conn = udp::Connection::new(remote_addr)?; + let tcp_conn = multi_stream::Connection::new()?; + + Ok(Self { + remote_addr, + udp_conn, + tcp_conn, + }) + } + + /// Implementation of the worker function. + /// + /// Create a TCP connection factory and pass that to worker function + /// of the multi_stream object. + pub async fn run(&self) -> Option<()> { + let tcp_factory = TcpConnFactory::new(self.remote_addr); + self.tcp_conn.run(tcp_factory).await + } + + /// Implementation of the query function. + /// + /// Just create a Query object with the state it needs. + async fn query( + &self, + query_msg: &mut MessageBuilder>>, + ) -> Result, &'static str> { + Ok(Query::new( + query_msg, + self.udp_conn.clone(), + self.tcp_conn.clone(), + )) + } +}