From 299fc89a2ebd1fe65104965dd77ada3dfbf8a411 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Fri, 3 Jan 2025 20:58:14 +0800 Subject: [PATCH] feat: add cache for default lookup --- src/client/mod.rs | 84 +++++++++++++++++++++++++++++++++----------- src/cmds/resolve.rs | 2 +- src/misc/mod.rs | 5 +-- src/server/helper.rs | 54 ++++++++++++++-------------- src/server/tcp.rs | 33 +++++++++++++++-- src/server/udp.rs | 33 +++++++++++++++-- 6 files changed, 158 insertions(+), 53 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index a507621..646f12d 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,7 +1,12 @@ +use crate::cachestr::Cachestr; use crate::client::doh::DoHClient; use crate::client::dot::DoTClient; +use crate::error::Error; use crate::protocol::*; use crate::Result; +use moka::future::Cache; +use once_cell::sync::Lazy; +use smallvec::SmallVec; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::PathBuf; use std::sync::Arc; @@ -17,6 +22,14 @@ mod udp; static DEFAULT_DNS: OnceCell> = OnceCell::const_new(); +static DEFAULT_LOOKUPS: Lazy = Lazy::new(|| { + let cache = Cache::builder() + .max_capacity(4096) + .time_to_live(Duration::from_secs(30)) + .build(); + LookupCache(cache) +}); + pub async fn default_dns() -> Result> { let v = DEFAULT_DNS .get_or_try_init(|| async { @@ -98,7 +111,7 @@ pub async fn request(dns: &DNS, request: &Message, timeout: Duration) -> Result< } Address::HostAddr(host_addr) => { let domain = &host_addr.host; - let ip = lookup(domain, timeout).await?; + let ip = DEFAULT_LOOKUPS.lookup(domain, timeout).await?; let addr = SocketAddr::new(IpAddr::V4(ip), host_addr.port); let c = DoTClient::builder(addr) .sni(domain.as_ref()) @@ -112,7 +125,7 @@ pub async fn request(dns: &DNS, request: &Message, timeout: Duration) -> Result< Address::SocketAddr(addr) => DoHClient::builder(*addr).https(doh_addr.https), Address::HostAddr(addr) => { let domain = &addr.host; - let ip = lookup(domain, timeout).await?; + let ip = DEFAULT_LOOKUPS.lookup(domain, timeout).await?; let mut bu = DoHClient::builder(SocketAddr::new(IpAddr::V4(ip), addr.port)) .host(domain) .https(doh_addr.https); @@ -130,29 +143,60 @@ pub async fn request(dns: &DNS, request: &Message, timeout: Duration) -> Result< } } -#[inline] -async fn lookup(host: &str, timeout: Duration) -> Result { - // TODO: add cache - let flags = Flags::builder() - .request() - .recursive_query(true) - .opcode(OpCode::StandardQuery) - .build(); - let req0 = Message::builder() - .id(1234) - .flags(flags) - .question(host, Kind::A, Class::IN) - .build()?; +struct LookupCache(Cache>); - let v = DefaultClient.request(&req0).await?; +impl LookupCache { + async fn lookup(&self, host: &str, timeout: Duration) -> Result { + let key = Cachestr::from(host); - for next in v.answers() { - if let Ok(RData::A(a)) = next.rdata() { - return Ok(a.ipaddr()); + let res = self + .0 + .try_get_with(key, Self::lookup_(host, timeout)) + .await + .map_err(|e| anyhow!("lookup failed: {}", e))?; + + if let Some(first) = res.first() { + return Ok(Clone::clone(first)); } + + bail!(Error::ResolveNothing) } - bail!(crate::Error::ResolveNothing) + #[inline] + async fn lookup_(host: &str, timeout: Duration) -> Result> { + let flags = Flags::builder() + .request() + .recursive_query(true) + .opcode(OpCode::StandardQuery) + .build(); + + let id = { + use rand::prelude::*; + + let mut rng = thread_rng(); + rng.gen_range(1..u16::MAX) + }; + + let req0 = Message::builder() + .id(id) + .flags(flags) + .question(host, Kind::A, Class::IN) + .build()?; + + let mut ret = SmallVec::<[Ipv4Addr; 2]>::new(); + let v = DefaultClient.request(&req0).await?; + for next in v.answers() { + if let Ok(RData::A(a)) = next.rdata() { + ret.push(a.ipaddr()); + } + } + + if !ret.is_empty() { + return Ok(ret); + } + + bail!(Error::ResolveNothing) + } } #[cfg(test)] diff --git a/src/cmds/resolve.rs b/src/cmds/resolve.rs index d6dd4c0..16292a7 100644 --- a/src/cmds/resolve.rs +++ b/src/cmds/resolve.rs @@ -151,7 +151,7 @@ fn print_resolve_result( println!(); println!(";; Query time: {} msec", cost.num_milliseconds()); println!(";; SERVER: {}", &dns); - println!(";; WHEN: {}", &begin); + println!(";; WHEN: {}", begin.to_rfc2822()); println!(";; MSG SIZE\trcvd: {}", res.len()); Ok(()) diff --git a/src/misc/mod.rs b/src/misc/mod.rs index c4b2fe2..e0f46a2 100644 --- a/src/misc/mod.rs +++ b/src/misc/mod.rs @@ -9,8 +9,9 @@ pub(crate) fn is_valid_domain(domain: &str) -> bool { return true; } - static RE: Lazy = - Lazy::new(|| regex::Regex::new("^([a-z0-9_-]{1,63})(\\.[a-z0-9_-]{1,63})+\\.?$").unwrap()); + static RE: Lazy = Lazy::new(|| { + regex::Regex::new("^([a-zA-Z0-9_-]{1,63})(\\.[a-zA-Z0-9_-]{1,63})+\\.?$").unwrap() + }); RE.is_match(domain) } diff --git a/src/server/helper.rs b/src/server/helper.rs index 172cc2b..a19017b 100644 --- a/src/server/helper.rs +++ b/src/server/helper.rs @@ -30,8 +30,6 @@ fn convert_error_to_message( err: anyhow::Error, attach_questions: bool, ) -> Message { - error!("failed to handle dns request: {:?}", err); - let rid = request.id(); let rflags = request.flags(); @@ -45,6 +43,25 @@ fn convert_error_to_message( } } + // log those internal server failure: + match rcode { + RCode::ServerFailure => match request.questions().next() { + Some(question) => { + let name = question.name(); + error!("failed to handle dns request {}: {:?}", name, err); + } + None => error!("failed to handle dns request: {:?}", err), + }, + RCode::NameError => match request.questions().next() { + Some(question) => { + let name = question.name(); + warn!("failed to handle dns request {}: {:?}", name, err); + } + None => warn!("failed to handle dns request: {:?}", err), + }, + _ => (), + } + let flags = { let mut bu = Flags::builder() .response() @@ -68,13 +85,17 @@ fn convert_error_to_message( bu.build().unwrap() } -pub(super) async fn handle(mut req: Message, h: Arc, cache: Option>) -> Message +pub(super) async fn handle( + mut req: Message, + h: Arc, + cache: Option>, +) -> (Message, bool) where H: Handler, C: CacheStore, { if let Err(e) = validate_request(&req) { - return convert_error_to_message(&req, e, false); + return (convert_error_to_message(&req, e, false), false); } if let Some(cache) = cache.as_deref() { @@ -85,8 +106,7 @@ where if let Some(mut exist) = cached { exist.set_id(id); - debug!("use dns cache"); - return exist; + return (exist, true); } } @@ -110,29 +130,11 @@ where req.set_id(0); cache.set(&req, &msg).await; req.set_id(id); - - debug!("set dns cache ok"); - } - } - - if msg.answer_count() > 0 { - for next in msg.answers() { - if let Ok(rdata) = next.rdata() { - info!( - "0x{:04x} <- {}.\t{}\t{:?}\t{:?}\t{}", - msg.id(), - next.name(), - next.time_to_live(), - next.class(), - next.kind(), - rdata, - ); - } } } - msg + (msg, false) } - Err(e) => convert_error_to_message(&req, e, true), + Err(e) => (convert_error_to_message(&req, e, true), false), } } diff --git a/src/server/tcp.rs b/src/server/tcp.rs index 114444c..10f4139 100644 --- a/src/server/tcp.rs +++ b/src/server/tcp.rs @@ -90,8 +90,37 @@ where let req = next?; let handler = Clone::clone(&handler); let cache = Clone::clone(&cache); - let msg = super::helper::handle(req, handler, cache).await; - w.send(&msg).await?; + let (res, cached) = super::helper::handle(req, handler, cache).await; + + if res.answer_count() > 0 { + for next in res.answers() { + if let Ok(rdata) = next.rdata() { + if cached { + info!( + "0x{:04x} <- {}.\t{}\t{:?}\t{:?}\t{}\t", + res.id(), + next.name(), + next.time_to_live(), + next.class(), + next.kind(), + rdata, + ); + } else { + info!( + "0x{:04x} <- {}.\t{}\t{:?}\t{:?}\t{}", + res.id(), + next.name(), + next.time_to_live(), + next.class(), + next.kind(), + rdata, + ); + } + } + } + } + + w.send(&res).await?; } Ok(()) diff --git a/src/server/udp.rs b/src/server/udp.rs index 8e5db0e..9f8af16 100644 --- a/src/server/udp.rs +++ b/src/server/udp.rs @@ -41,8 +41,37 @@ where h: Arc, cache: Option>, ) { - let result = helper::handle(req, h, cache).await; - if let Err(e) = socket.send_to(result.as_ref(), peer).await { + let (res, cached) = helper::handle(req, h, cache).await; + + if res.answer_count() > 0 { + for next in res.answers() { + if let Ok(rdata) = next.rdata() { + if cached { + info!( + "0x{:04x} <- {}.\t{}\t{:?}\t{:?}\t{}\t", + res.id(), + next.name(), + next.time_to_live(), + next.class(), + next.kind(), + rdata, + ); + } else { + info!( + "0x{:04x} <- {}.\t{}\t{:?}\t{:?}\t{}", + res.id(), + next.name(), + next.time_to_live(), + next.class(), + next.kind(), + rdata, + ); + } + } + } + } + + if let Err(e) = socket.send_to(res.as_ref(), peer).await { error!("failed to reply dns response: {:?}", e); } }