diff --git a/examples/simple-proxy.rs b/examples/simple-proxy.rs index 603ac300f..c75eed01f 100644 --- a/examples/simple-proxy.rs +++ b/examples/simple-proxy.rs @@ -7,14 +7,19 @@ use clap::Parser; use domain::base::iana::Rtype; use domain::base::message_builder::PushError; use domain::base::opt::{Opt, OptRecord}; -use domain::base::{Message, MessageBuilder, ParsedDname, StaticCompressor, StreamTarget}; -use domain::net::client::multi_stream::Connection; +use domain::base::{ + Message, MessageBuilder, ParsedDname, StaticCompressor, StreamTarget, +}; +use domain::net::client::multi_stream; use domain::net::client::tcp_factory::TcpConnFactory; use domain::net::client::tls_factory::TlsConnFactory; +use domain::net::client::udp_tcp; use domain::rdata::AllRecordData; use domain::serve::buf::BufSource; use domain::serve::dgram::DgramServer; -use domain::serve::service::{CallResult, Service, ServiceError, Transaction}; +use domain::serve::service::{ + CallResult, Service, ServiceError, Transaction, +}; use futures::Stream; use octseq::octets::OctetsFrom; use octseq::Octets; @@ -39,7 +44,7 @@ struct Args { #[arg(short = 'p', long = "port", value_parser = clap::value_parser!(u16))] port: Option, - /// Option for the destination TCP port. + /// Option for the local port. #[arg(long = "locport", value_parser = clap::value_parser!(u16))] locport: Option, @@ -50,14 +55,21 @@ struct Args { /// Server name for TLS. #[arg(long = "servername", group = "tls-params-group")] servername: Option, + + /// Flag to use UDP+TCP for upstream connections. + #[arg(long = "udp")] + do_udp: bool, } /// Convert a Message into a MessageBuilder. fn to_builder( source: &Message, -) -> Result>>>, PushError> { - let mut target = - MessageBuilder::from_target(StaticCompressor::new(StreamTarget::new_vec())).unwrap(); +) -> Result>>>, PushError> +{ + let mut target = MessageBuilder::from_target(StaticCompressor::new( + StreamTarget::new_vec(), + )) + .unwrap(); let header = source.header(); *target.header_mut() = header; @@ -99,7 +111,8 @@ fn to_builder( let opt_record = OptRecord::from_record(rr); target .opt(|newopt| { - newopt.set_udp_payload_size(opt_record.udp_payload_size()); + newopt + .set_udp_payload_size(opt_record.udp_payload_size()); newopt.set_version(opt_record.version()); newopt.set_dnssec_ok(opt_record.dnssec_ok()); @@ -142,11 +155,63 @@ fn to_stream_target( Ok(builder.as_target().as_target().clone()) } +// We need a query trait to merge these into one service function. + +/// Function that returns a Service trait. +/// +/// This is a trick to capture the Future by an async block into a type. +fn stream_service< + RequestOctets: AsRef<[u8]> + Octets + Send + Sync + 'static, +>( + conn: multi_stream::Connection>, +) -> impl Service +where + for<'a> &'a RequestOctets: AsRef<[u8]>, +{ + /// Basic query function for Service. + fn query + Octets, ReplyOcts>( + message: Message, + conn: multi_stream::Connection>, + ) -> Transaction< + impl Future>, ServiceError<()>>>, + impl Stream>, ServiceError<()>>>, + > + where + for<'a> &'a RequestOctets: AsRef<[u8]>, + { + Transaction::<_, NoStream>>::Single(async move { + // Extract the ID. We need to set it in the reply. + let id = message.header().id(); + // We get a Message, but the client transport needs a + // MessageBuilder. Convert. + println!("request {:?}", message); + let mut msg_builder = to_builder(&message).unwrap(); + println!("request {:?}", msg_builder); + let mut query = conn.query(&mut msg_builder).await.unwrap(); + let reply = query.get_result().await.unwrap(); + println!("got reply {:?}", reply); + + // Set the ID + let mut reply: Message> = OctetsFrom::octets_from(reply); + reply.header_mut().set_id(id); + + // We get the reply as Message from the client transport but + // we need to return a StreamTarget. Convert. + let stream = to_stream_target::<_, Vec>(&reply).unwrap(); + Ok(CallResult::new(stream)) + }) + } + + move |message| Ok(query::>(message, conn.clone())) +} + /// Function that returns a Service trait. /// /// This is a trick to capture the Future by an async block into a type. -fn service + Octets + Send + Sync + 'static>( - conn: Connection>, +fn udptcp_service< + RequestOctets: AsRef<[u8]> + Octets + Send + Sync + 'static, +>( + conn: udp_tcp::Connection>, ) -> impl Service where for<'a> &'a RequestOctets: AsRef<[u8]>, @@ -154,7 +219,7 @@ where /// Basic query function for Service. fn query + Octets, ReplyOcts>( message: Message, - conn: Connection>, + conn: udp_tcp::Connection>, ) -> Transaction< impl Future>, ServiceError<()>>>, impl Stream>, ServiceError<()>>>, @@ -197,7 +262,10 @@ struct NoStream { impl Stream for NoStream { type Item = Result, ServiceError<()>>; - fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { todo!() } } @@ -223,7 +291,10 @@ struct VecSingle(Option>>); impl Future for VecSingle { type Output = Result>, ServiceError<()>>; - fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + fn poll( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll { Poll::Ready(Ok(self.0.take().unwrap())) } } @@ -236,19 +307,49 @@ async fn main() { let locport = args.locport.unwrap_or_else(|| "8053".parse().unwrap()); - let udpsocket = UdpSocket::bind(SocketAddr::new("127.0.0.1".parse().unwrap(), locport)).await.unwrap(); + let udpsocket = UdpSocket::bind(SocketAddr::new( + "127.0.0.1".parse().unwrap(), + locport, + )) + .await + .unwrap(); + + if args.do_udp { + let port = args.port.unwrap_or_else(|| "53".parse().unwrap()); - if args.do_tls { + let conn = + udp_tcp::Connection::new(SocketAddr::new(server, port)).unwrap(); + let conn_run = conn.clone(); + + tokio::spawn(async move { + conn_run.run().await; + println!("run terminated"); + }); + + let svc = udptcp_service(conn); + + let buf_source = Arc::new(VecBufSource); + let srv = Arc::new(DgramServer::new( + udpsocket, + buf_source.clone(), + Arc::new(svc), + )); + let udp_join_handle = tokio::spawn(srv.run()); + + udp_join_handle.await.unwrap().unwrap(); + } else if args.do_tls { let port = args.port.unwrap_or_else(|| "853".parse().unwrap()); let mut root_store = rustls::RootCertStore::empty(); - root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { - rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + root_store.add_server_trust_anchors( + webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }), + ); let client_config = Arc::new( ClientConfig::builder() .with_safe_defaults() @@ -262,7 +363,7 @@ async fn main() { SocketAddr::new(server, port), ); - let conn = Connection::new().unwrap(); + let conn = multi_stream::Connection::new().unwrap(); let conn_run = conn.clone(); tokio::spawn(async move { @@ -270,7 +371,7 @@ async fn main() { println!("run terminated"); }); - let svc = service(conn); + let svc = stream_service(conn); let buf_source = Arc::new(VecBufSource); let srv = Arc::new(DgramServer::new( @@ -285,7 +386,7 @@ async fn main() { let port = args.port.unwrap_or_else(|| "53".parse().unwrap()); let tcp_factory = TcpConnFactory::new(SocketAddr::new(server, port)); - let conn = Connection::new().unwrap(); + let conn = multi_stream::Connection::new().unwrap(); let conn_run = conn.clone(); tokio::spawn(async move { @@ -293,7 +394,7 @@ async fn main() { println!("run terminated"); }); - let svc = service(conn); + let svc = stream_service(conn); let buf_source = Arc::new(VecBufSource); let srv = Arc::new(DgramServer::new(