Skip to content

Commit

Permalink
Support for UDP+TCP
Browse files Browse the repository at this point in the history
  • Loading branch information
Philip-NLnetLabs committed Aug 7, 2023
1 parent f645c9d commit 4e0012b
Showing 1 changed file with 127 additions and 26 deletions.
153 changes: 127 additions & 26 deletions examples/simple-proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@ use clap::Parser;
use domain::base::iana::Rtype;
use domain::base::message_builder::PushError;
use domain::base::opt::{Opt, OptRecord};
use domain::base::{Message, MessageBuilder, ParsedDname, StaticCompressor, StreamTarget};
use domain::net::client::multi_stream::Connection;
use domain::base::{
Message, MessageBuilder, ParsedDname, StaticCompressor, StreamTarget,
};
use domain::net::client::multi_stream;
use domain::net::client::tcp_factory::TcpConnFactory;
use domain::net::client::tls_factory::TlsConnFactory;
use domain::net::client::udp_tcp;
use domain::rdata::AllRecordData;
use domain::serve::buf::BufSource;
use domain::serve::dgram::DgramServer;
use domain::serve::service::{CallResult, Service, ServiceError, Transaction};
use domain::serve::service::{
CallResult, Service, ServiceError, Transaction,
};
use futures::Stream;
use octseq::octets::OctetsFrom;
use octseq::Octets;
Expand All @@ -39,7 +44,7 @@ struct Args {
#[arg(short = 'p', long = "port", value_parser = clap::value_parser!(u16))]
port: Option<u16>,

/// Option for the destination TCP port.
/// Option for the local port.
#[arg(long = "locport", value_parser = clap::value_parser!(u16))]
locport: Option<u16>,

Expand All @@ -50,14 +55,21 @@ struct Args {
/// Server name for TLS.
#[arg(long = "servername", group = "tls-params-group")]
servername: Option<String>,

/// Flag to use UDP+TCP for upstream connections.
#[arg(long = "udp")]
do_udp: bool,
}

/// Convert a Message into a MessageBuilder.
fn to_builder<Octs1: Octets>(
source: &Message<Octs1>,
) -> Result<MessageBuilder<StaticCompressor<StreamTarget<Vec<u8>>>>, PushError> {
let mut target =
MessageBuilder::from_target(StaticCompressor::new(StreamTarget::new_vec())).unwrap();
) -> Result<MessageBuilder<StaticCompressor<StreamTarget<Vec<u8>>>>, PushError>
{
let mut target = MessageBuilder::from_target(StaticCompressor::new(
StreamTarget::new_vec(),
))
.unwrap();

let header = source.header();
*target.header_mut() = header;
Expand Down Expand Up @@ -99,7 +111,8 @@ fn to_builder<Octs1: Octets>(
let opt_record = OptRecord::from_record(rr);
target
.opt(|newopt| {
newopt.set_udp_payload_size(opt_record.udp_payload_size());
newopt
.set_udp_payload_size(opt_record.udp_payload_size());
newopt.set_version(opt_record.version());
newopt.set_dnssec_ok(opt_record.dnssec_ok());

Expand Down Expand Up @@ -142,19 +155,71 @@ fn to_stream_target<Octs1: Octets, OctsOut>(
Ok(builder.as_target().as_target().clone())
}

// We need a query trait to merge these into one service function.

/// Function that returns a Service trait.
///
/// This is a trick to capture the Future by an async block into a type.
fn stream_service<
RequestOctets: AsRef<[u8]> + Octets + Send + Sync + 'static,
>(
conn: multi_stream::Connection<Vec<u8>>,
) -> impl Service<RequestOctets>
where
for<'a> &'a RequestOctets: AsRef<[u8]>,
{
/// Basic query function for Service.
fn query<RequestOctets: AsRef<[u8]> + Octets, ReplyOcts>(
message: Message<RequestOctets>,
conn: multi_stream::Connection<Vec<u8>>,
) -> Transaction<
impl Future<Output = Result<CallResult<Vec<u8>>, ServiceError<()>>>,
impl Stream<Item = Result<CallResult<Vec<u8>>, ServiceError<()>>>,
>
where
for<'a> &'a RequestOctets: AsRef<[u8]>,
{
Transaction::<_, NoStream<Vec<u8>>>::Single(async move {
// Extract the ID. We need to set it in the reply.
let id = message.header().id();
// We get a Message, but the client transport needs a
// MessageBuilder. Convert.
println!("request {:?}", message);
let mut msg_builder = to_builder(&message).unwrap();
println!("request {:?}", msg_builder);
let mut query = conn.query(&mut msg_builder).await.unwrap();
let reply = query.get_result().await.unwrap();
println!("got reply {:?}", reply);

// Set the ID
let mut reply: Message<Vec<u8>> = OctetsFrom::octets_from(reply);
reply.header_mut().set_id(id);

// We get the reply as Message from the client transport but
// we need to return a StreamTarget. Convert.
let stream = to_stream_target::<_, Vec<u8>>(&reply).unwrap();
Ok(CallResult::new(stream))
})
}

move |message| Ok(query::<RequestOctets, Vec<u8>>(message, conn.clone()))
}

/// Function that returns a Service trait.
///
/// This is a trick to capture the Future by an async block into a type.
fn service<RequestOctets: AsRef<[u8]> + Octets + Send + Sync + 'static>(
conn: Connection<Vec<u8>>,
fn udptcp_service<
RequestOctets: AsRef<[u8]> + Octets + Send + Sync + 'static,
>(
conn: udp_tcp::Connection<Vec<u8>>,
) -> impl Service<RequestOctets>
where
for<'a> &'a RequestOctets: AsRef<[u8]>,
{
/// Basic query function for Service.
fn query<RequestOctets: AsRef<[u8]> + Octets, ReplyOcts>(
message: Message<RequestOctets>,
conn: Connection<Vec<u8>>,
conn: udp_tcp::Connection<Vec<u8>>,
) -> Transaction<
impl Future<Output = Result<CallResult<Vec<u8>>, ServiceError<()>>>,
impl Stream<Item = Result<CallResult<Vec<u8>>, ServiceError<()>>>,
Expand Down Expand Up @@ -197,7 +262,10 @@ struct NoStream<Octs> {
impl<Octs> Stream for NoStream<Octs> {
type Item = Result<CallResult<Octs>, ServiceError<()>>;

fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
todo!()
}
}
Expand All @@ -223,7 +291,10 @@ struct VecSingle(Option<CallResult<Vec<u8>>>);
impl Future for VecSingle {
type Output = Result<CallResult<Vec<u8>>, ServiceError<()>>;

fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
fn poll(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Self::Output> {
Poll::Ready(Ok(self.0.take().unwrap()))
}
}
Expand All @@ -236,19 +307,49 @@ async fn main() {

let locport = args.locport.unwrap_or_else(|| "8053".parse().unwrap());

let udpsocket = UdpSocket::bind(SocketAddr::new("127.0.0.1".parse().unwrap(), locport)).await.unwrap();
let udpsocket = UdpSocket::bind(SocketAddr::new(
"127.0.0.1".parse().unwrap(),
locport,
))
.await
.unwrap();

if args.do_udp {
let port = args.port.unwrap_or_else(|| "53".parse().unwrap());

if args.do_tls {
let conn =
udp_tcp::Connection::new(SocketAddr::new(server, port)).unwrap();
let conn_run = conn.clone();

tokio::spawn(async move {
conn_run.run().await;
println!("run terminated");
});

let svc = udptcp_service(conn);

let buf_source = Arc::new(VecBufSource);
let srv = Arc::new(DgramServer::new(
udpsocket,
buf_source.clone(),
Arc::new(svc),
));
let udp_join_handle = tokio::spawn(srv.run());

udp_join_handle.await.unwrap().unwrap();
} else if args.do_tls {
let port = args.port.unwrap_or_else(|| "853".parse().unwrap());

let mut root_store = rustls::RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
root_store.add_server_trust_anchors(
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}),
);
let client_config = Arc::new(
ClientConfig::builder()
.with_safe_defaults()
Expand All @@ -262,15 +363,15 @@ async fn main() {
SocketAddr::new(server, port),
);

let conn = Connection::new().unwrap();
let conn = multi_stream::Connection::new().unwrap();
let conn_run = conn.clone();

tokio::spawn(async move {
conn_run.run(tls_factory).await;
println!("run terminated");
});

let svc = service(conn);
let svc = stream_service(conn);

let buf_source = Arc::new(VecBufSource);
let srv = Arc::new(DgramServer::new(
Expand All @@ -285,15 +386,15 @@ async fn main() {
let port = args.port.unwrap_or_else(|| "53".parse().unwrap());
let tcp_factory = TcpConnFactory::new(SocketAddr::new(server, port));

let conn = Connection::new().unwrap();
let conn = multi_stream::Connection::new().unwrap();
let conn_run = conn.clone();

tokio::spawn(async move {
conn_run.run(tcp_factory).await;
println!("run terminated");
});

let svc = service(conn);
let svc = stream_service(conn);

let buf_source = Arc::new(VecBufSource);
let srv = Arc::new(DgramServer::new(
Expand Down

0 comments on commit 4e0012b

Please sign in to comment.