Skip to content

Commit

Permalink
fix: msg id conflict check
Browse files Browse the repository at this point in the history
  • Loading branch information
jjeffcaii committed Dec 7, 2024
1 parent 29d322c commit 21bd964
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 13 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ a DNS server in Rust, which is inspired from chinadns/dnsmasq.
ZeroDNS provides similar functionality to dig, but supports more DNS protocols. Here are some examples:

```shell
$ # Simple resolve, will read dns server from /etc/resolv.conf
$ # Simple resolve, will read nameserver from /etc/resolv.conf
$ zerodns resolve www.youtube.com
$ # Use short output, similar with 'dig +short ...'
$ zerodns resolve --short www.youtube.com
$ # Resolve over google UDP
$ zerodns resolve -s 8.8.8.8 www.youtube.com
$ # Resolve over google TCP
$ zerodns resolve -s tcp://8.8.8.8 www.youtube.com
$ # Resolve over google DoT
Expand Down
40 changes: 30 additions & 10 deletions src/client/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,28 +173,51 @@ impl MultiplexUdpClient {

async fn request(
&self,
req: Message,
req: &Message,
remote: SocketAddr,
timeout: Duration,
) -> Result<Message> {
let id = req.id();
let origin_id = req.id();
let mut id = origin_id;

let (tx, rx) = oneshot::channel::<Message>();
{
let mut w = self.handlers.lock().await;
if w.contains_key(&(id, remote)) {
// TODO: how to check id conflict???
for seq in 1..u16::MAX {
let key = (seq, remote);
if !w.contains_key(&key) {
id = seq;
break;
}
}
}

w.insert((id, remote), tx);
}

let res: Result<Message> = async move {
self.queue.send((req, remote)).await?;
let mut cloned_req = Clone::clone(req);
if origin_id != id {
cloned_req.set_id(id);
}
let mut res: Result<Message> = async move {
self.queue.send((cloned_req, remote)).await?;
let res = tokio::time::timeout(timeout, rx).await??;
Ok(res)
}
.await;

// clean handler if enqueue failed
if res.is_err() {
self.handlers.lock().await.remove(&(id, remote));
match &mut res {
Ok(v) => {
if origin_id != id {
v.set_id(origin_id);
}
}
Err(_) => {
self.handlers.lock().await.remove(&(id, remote));
}
}

res
Expand All @@ -205,9 +228,7 @@ impl MultiplexUdpClient {
impl Client for UdpClient {
async fn request(&self, req: &Message) -> Result<Message> {
let w = requester().await?;
let res = w
.request(Clone::clone(req), self.addr, self.timeout)
.await?;
let res = w.request(req, self.addr, self.timeout).await?;
Ok(res)
}
}
Expand All @@ -229,7 +250,6 @@ impl UdpClientBuilder {

#[cfg(test)]
mod tests {

use crate::client::Client;
use crate::protocol::{Class, Flags, Kind, Message, OpCode};

Expand Down
4 changes: 2 additions & 2 deletions src/cmds/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ fn print_resolve_result(
req: &Message,
res: &Message,
begin: DateTime<Local>,
) -> anyhow::Result<()> {
) -> Result<()> {
let cost = Local::now() - begin;

println!();
Expand All @@ -114,7 +114,7 @@ fn print_resolve_result(
println!(";; global options: +cmd");
println!(";; Got answer:");
println!(
";; ->>HEADER<<- opcode: {:?}, status: {:?}, id: {}",
";; ->>HEADER<<- opcode: {}, status: {}, id: {}",
res.flags().opcode(),
res.flags().response_code(),
res.id()
Expand Down
18 changes: 18 additions & 0 deletions src/protocol/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ pub enum RCode {
NotZone = 10,
}

impl Display for RCode {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
RCode::NoError => "NOERROR",
RCode::FormatError => "FORMATERROR",
RCode::ServerFailure => "SERVERFAILURE",
RCode::NameError => "NAMEERROR",
RCode::NotImplemented => "NOTIMPLEMENTED",
RCode::Refused => "REFUSED",
RCode::YXDomain => "YXDOMAIN",
RCode::YXRRSet => "YXRRSET",
RCode::NXRRSet => "NXRRSET",
RCode::NotAuth => "NOTAUTH",
RCode::NotZone => "NOTZONE",
})
}
}

/// dns record types, see also https://en.wikipedia.org/wiki/List_of_DNS_record_types
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, EnumIter, Hash)]
Expand Down

0 comments on commit 21bd964

Please sign in to comment.