Skip to content

Commit

Permalink
Traits for queries and an error type.
Browse files Browse the repository at this point in the history
  • Loading branch information
Philip-NLnetLabs committed Aug 9, 2023
1 parent fc3100d commit 38b11a7
Show file tree
Hide file tree
Showing 7 changed files with 369 additions and 148 deletions.
126 changes: 126 additions & 0 deletions src/net/client/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
//! Error type for client transports.

#![warn(missing_docs)]
#![warn(clippy::missing_docs_in_private_items)]

use std::error;
use std::fmt::{Display, Formatter};
use std::sync::Arc;

/// Error type for client transports.
#[derive(Clone, Debug)]
pub enum Error {
/// Connection was already closed.
ConnectionClosed,

/// Octet sequence too short to be a valid DNS message.
ShortMessage,

/// Stream transport closed because it was idle (for too long).
StreamIdleTimeout,

/// Error receiving a reply.
StreamReceiveError,

/// Reading from stream gave an error.
StreamReadError(Arc<std::io::Error>),

/// Reading from stream took too long.
StreamReadTimeout,

/// Too many outstand queries on a single stream transport.
StreamTooManyOutstandingQueries,

/// Writing to a stream gave an error.
StreamWriteError(Arc<std::io::Error>),

/// Reading for a stream ended unexpectedly.
StreamUnexpectedEndOfData,

/// Binding a UDP socket gave an error.
UdpBind(Arc<std::io::Error>),

/// Connecting a UDP socket gave an error.
UdpConnect(Arc<std::io::Error>),

/// Receiving from a UDP socket gave an error.
UdpReceive(Arc<std::io::Error>),

/// Sending over a UDP socket gaven an error.
UdpSend(Arc<std::io::Error>),

/// Sending over a UDP socket gave a partial result.
UdpShortSend,

/// Timeout receiving a response over a UDP socket.
UdpTimeoutNoResponse,

/// Reply does not match the query.
WrongReplyForQuery,
}

impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
Error::ConnectionClosed => write!(f, "connection closed"),
Error::ShortMessage => {
write!(f, "octet sequence to short to be a valid message")
}
Error::StreamIdleTimeout => {
write!(f, "stream was idle for too long")
}
Error::StreamReceiveError => write!(f, "error receiving a reply"),
Error::StreamReadError(_) => {
write!(f, "error reading from stream")
}
Error::StreamReadTimeout => {
write!(f, "timeout reading from stream")
}
Error::StreamTooManyOutstandingQueries => {
write!(f, "too many outstanding queries on stream")
}
Error::StreamWriteError(_) => {
write!(f, "error writing to stream")
}
Error::StreamUnexpectedEndOfData => {
write!(f, "unexpected end of data")
}
Error::UdpBind(_) => write!(f, "error binding UDP socket"),
Error::UdpConnect(_) => write!(f, "error connecting UDP socket"),
Error::UdpReceive(_) => {
write!(f, "error receiving from UDP socket")
}
Error::UdpSend(_) => write!(f, "error sending to UDP socket"),
Error::UdpShortSend => write!(f, "partial sent to UDP socket"),
Error::UdpTimeoutNoResponse => {
write!(f, "timeout waiting for response")
}
Error::WrongReplyForQuery => {
write!(f, "reply does not match query")
}
}
}
}

impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
Error::ConnectionClosed => None,
Error::ShortMessage => None,
Error::StreamIdleTimeout => None,
Error::StreamReceiveError => None,
Error::StreamReadError(e) => Some(e),
Error::StreamReadTimeout => None,
Error::StreamTooManyOutstandingQueries => None,
Error::StreamWriteError(e) => Some(e),
Error::StreamUnexpectedEndOfData => None,
Error::UdpBind(e) => Some(e),
Error::UdpConnect(e) => Some(e),
Error::UdpReceive(e) => Some(e),
Error::UdpSend(e) => Some(e),
Error::UdpShortSend => None,
Error::UdpTimeoutNoResponse => None,
Error::WrongReplyForQuery => None,
}
}
}
2 changes: 2 additions & 0 deletions src/net/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
#![cfg(feature = "net")]
#![cfg_attr(docsrs, doc(cfg(feature = "net")))]

pub mod error;
pub mod factory;
pub mod multi_stream;
pub mod octet_stream;
pub mod query;
pub mod tcp_channel;
pub mod tcp_factory;
pub mod tcp_mutex;
Expand Down
78 changes: 49 additions & 29 deletions src/net/client/multi_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#![warn(missing_docs)]
#![warn(clippy::missing_docs_in_private_items)]

// To do:
// - too many connection errors

use bytes::Bytes;

use futures::lock::Mutex as Futures_mutex;
Expand All @@ -28,9 +31,11 @@ use tokio::time::{sleep_until, Instant};

use crate::base::wire::Composer;
use crate::base::{Message, MessageBuilder, StaticCompressor, StreamTarget};
use crate::net::client::error::Error;
use crate::net::client::factory::ConnFactory;
use crate::net::client::octet_stream::Connection as SingleConnection;
use crate::net::client::octet_stream::QueryNoCheck as SingleQuery;
use crate::net::client::query::{GetResult, QueryMessage};

/// Capacity of the channel that transports [ChanReq].
const DEF_CHAN_CAP: usize = 8;
Expand Down Expand Up @@ -423,7 +428,7 @@ impl<
&self,
opt_id: Option<u64>,
sender: oneshot::Sender<ChanResp<Octs>>,
) -> Result<(), &'static str> {
) -> Result<(), Error> {
let req = ChanReq {
cmd: ReqCmd::NewConn(opt_id, sender),
};
Expand All @@ -432,7 +437,7 @@ impl<
// Send error. The receiver is gone, this means that the
// connection is closed.
{
Err(ERR_CONN_CLOSED)
Err(Error::ConnectionClosed)
}
Ok(_) => Ok(()),
}
Expand Down Expand Up @@ -494,10 +499,10 @@ impl<
///
/// This function takes a precomposed message as a parameter and
/// returns a [Query] object wrapped in a [Result].
pub async fn query(
pub async fn query_impl(
&self,
query_msg: &mut MessageBuilder<StaticCompressor<StreamTarget<Octs>>>,
) -> Result<Query<Octs>, &'static str> {
) -> Result<Query<Octs>, Error> {
let (tx, rx) = oneshot::channel();
self.inner.new_conn(None, tx).await?;
Ok(Query::new(self.clone(), query_msg, rx))
Expand All @@ -513,11 +518,24 @@ impl<
&self,
id: u64,
tx: oneshot::Sender<ChanResp<Octs>>,
) -> Result<(), &'static str> {
) -> Result<(), Error> {
self.inner.new_conn(Some(id), tx).await
}
}

impl<Octs: Clone + Composer + Debug + OctetsBuilder + Send + 'static>
QueryMessage<Query<Octs>, Octs> for Connection<Octs>
{
fn query<'a>(
&'a self,
query_msg: &'a mut MessageBuilder<
StaticCompressor<StreamTarget<Octs>>,
>,
) -> Pin<Box<dyn Future<Output = Result<Query<Octs>, Error>> + '_>> {
return Box::pin(self.query_impl(query_msg));
}
}

impl<
Octs: AsRef<[u8]>
+ AsMut<[u8]>
Expand Down Expand Up @@ -550,20 +568,15 @@ impl<
///
/// This function returns the reply to a DNS query wrapped in a
/// [Result].
pub async fn get_result(
&mut self,
) -> Result<Message<Bytes>, Arc<std::io::Error>> {
pub async fn get_result_impl(&mut self) -> Result<Message<Bytes>, Error> {
loop {
match self.state {
QueryState::GetConn(ref mut receiver) => {
let res = receiver.await;
if res.is_err() {
// Assume receive error
self.state = QueryState::Done;
return Err(Arc::new(io::Error::new(
io::ErrorKind::Other,
"receive error",
)));
return Err(Error::StreamReceiveError);
}
let res = res.expect("error is checked before");

Expand Down Expand Up @@ -592,26 +605,20 @@ impl<
let query_res = conn.query_no_check(&mut msg).await;
match query_res {
Err(err) => {
if err == ERR_CONN_CLOSED {
if let Error::ConnectionClosed = err {
let (tx, rx) = oneshot::channel();
let res = self
.conn
.new_conn(self.conn_id, tx)
.await;
if let Err(err) = res {
self.state = QueryState::Done;
return Err(Arc::new(io::Error::new(
io::ErrorKind::Other,
err,
)));
return Err(err);
}
self.state = QueryState::GetConn(rx);
continue;
}
return Err(Arc::new(io::Error::new(
io::ErrorKind::Other,
err,
)));
return Err(err);
}
Ok(query) => {
self.state = QueryState::GetResult(query);
Expand All @@ -637,10 +644,7 @@ impl<
.expect("how to go from MessageBuild to Message?");

if !is_answer_ignore_id(&msg, &query_msg) {
return Err(Arc::new(io::Error::new(
io::ErrorKind::Other,
"wrong answer",
)));
return Err(Error::WrongReplyForQuery);
}
return Ok(msg);
}
Expand All @@ -650,10 +654,7 @@ impl<
let res = self.conn.new_conn(self.conn_id, tx).await;
if let Err(err) = res {
self.state = QueryState::Done;
return Err(Arc::new(io::Error::new(
io::ErrorKind::Other,
err,
)));
return Err(err);
}
self.state = QueryState::GetConn(rx);
continue;
Expand All @@ -666,6 +667,25 @@ impl<
}
}

impl<
Octs: AsMut<[u8]>
+ AsRef<[u8]>
+ Clone
+ Composer
+ Debug
+ OctetsBuilder
+ Send
+ 'static,
> GetResult for Query<Octs>
{
fn get_result(
&mut self,
) -> Pin<Box<dyn Future<Output = Result<Message<Bytes>, Error>> + '_>>
{
Box::pin(self.get_result_impl())
}
}

/// Compute the retry timeout based on the number of retries so far.
///
/// The computation is a random value (in microseconds) between zero and
Expand Down
Loading

0 comments on commit 38b11a7

Please sign in to comment.