From 21bd9640a636dc905558054169311bf4a2ed3e5b Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Sat, 7 Dec 2024 21:41:47 +0800 Subject: [PATCH] fix: msg id conflict check --- README.md | 4 +++- src/client/udp.rs | 40 ++++++++++++++++++++++++++++++---------- src/cmds/resolve.rs | 4 ++-- src/protocol/frame.rs | 18 ++++++++++++++++++ 4 files changed, 53 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index ec45710..766dd72 100644 --- a/README.md +++ b/README.md @@ -25,10 +25,12 @@ a DNS server in Rust, which is inspired from chinadns/dnsmasq. ZeroDNS provides similar functionality to dig, but supports more DNS protocols. Here are some examples: ```shell -$ # Simple resolve, will read dns server from /etc/resolv.conf +$ # Simple resolve, will read nameserver from /etc/resolv.conf $ zerodns resolve www.youtube.com $ # Use short output, similar with 'dig +short ...' $ zerodns resolve --short www.youtube.com +$ # Resolve over google UDP +$ zerodns resolve -s 8.8.8.8 www.youtube.com $ # Resolve over google TCP $ zerodns resolve -s tcp://8.8.8.8 www.youtube.com $ # Resolve over google DoT diff --git a/src/client/udp.rs b/src/client/udp.rs index 9a7abb2..a6885f3 100644 --- a/src/client/udp.rs +++ b/src/client/udp.rs @@ -173,28 +173,51 @@ impl MultiplexUdpClient { async fn request( &self, - req: Message, + req: &Message, remote: SocketAddr, timeout: Duration, ) -> Result { - let id = req.id(); + let origin_id = req.id(); + let mut id = origin_id; let (tx, rx) = oneshot::channel::(); { let mut w = self.handlers.lock().await; + if w.contains_key(&(id, remote)) { + // TODO: how to check id conflict??? + for seq in 1..u16::MAX { + let key = (seq, remote); + if !w.contains_key(&key) { + id = seq; + break; + } + } + } + w.insert((id, remote), tx); } - let res: Result = async move { - self.queue.send((req, remote)).await?; + let mut cloned_req = Clone::clone(req); + if origin_id != id { + cloned_req.set_id(id); + } + let mut res: Result = async move { + self.queue.send((cloned_req, remote)).await?; let res = tokio::time::timeout(timeout, rx).await??; Ok(res) } .await; // clean handler if enqueue failed - if res.is_err() { - self.handlers.lock().await.remove(&(id, remote)); + match &mut res { + Ok(v) => { + if origin_id != id { + v.set_id(origin_id); + } + } + Err(_) => { + self.handlers.lock().await.remove(&(id, remote)); + } } res @@ -205,9 +228,7 @@ impl MultiplexUdpClient { impl Client for UdpClient { async fn request(&self, req: &Message) -> Result { let w = requester().await?; - let res = w - .request(Clone::clone(req), self.addr, self.timeout) - .await?; + let res = w.request(req, self.addr, self.timeout).await?; Ok(res) } } @@ -229,7 +250,6 @@ impl UdpClientBuilder { #[cfg(test)] mod tests { - use crate::client::Client; use crate::protocol::{Class, Flags, Kind, Message, OpCode}; diff --git a/src/cmds/resolve.rs b/src/cmds/resolve.rs index 80624c6..f320b89 100644 --- a/src/cmds/resolve.rs +++ b/src/cmds/resolve.rs @@ -100,7 +100,7 @@ fn print_resolve_result( req: &Message, res: &Message, begin: DateTime, -) -> anyhow::Result<()> { +) -> Result<()> { let cost = Local::now() - begin; println!(); @@ -114,7 +114,7 @@ fn print_resolve_result( println!(";; global options: +cmd"); println!(";; Got answer:"); println!( - ";; ->>HEADER<<- opcode: {:?}, status: {:?}, id: {}", + ";; ->>HEADER<<- opcode: {}, status: {}, id: {}", res.flags().opcode(), res.flags().response_code(), res.id() diff --git a/src/protocol/frame.rs b/src/protocol/frame.rs index cf50225..5865984 100644 --- a/src/protocol/frame.rs +++ b/src/protocol/frame.rs @@ -52,6 +52,24 @@ pub enum RCode { NotZone = 10, } +impl Display for RCode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + RCode::NoError => "NOERROR", + RCode::FormatError => "FORMATERROR", + RCode::ServerFailure => "SERVERFAILURE", + RCode::NameError => "NAMEERROR", + RCode::NotImplemented => "NOTIMPLEMENTED", + RCode::Refused => "REFUSED", + RCode::YXDomain => "YXDOMAIN", + RCode::YXRRSet => "YXRRSET", + RCode::NXRRSet => "NXRRSET", + RCode::NotAuth => "NOTAUTH", + RCode::NotZone => "NOTZONE", + }) + } +} + /// dns record types, see also https://en.wikipedia.org/wiki/List_of_DNS_record_types #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Copy, Clone, PartialEq, Eq, EnumIter, Hash)]