Skip to content

Commit

Permalink
feat: add cache for default lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
jjeffcaii committed Jan 3, 2025
1 parent 4f420e0 commit 299fc89
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 53 deletions.
84 changes: 64 additions & 20 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use crate::cachestr::Cachestr;
use crate::client::doh::DoHClient;
use crate::client::dot::DoTClient;
use crate::error::Error;
use crate::protocol::*;
use crate::Result;
use moka::future::Cache;
use once_cell::sync::Lazy;
use smallvec::SmallVec;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::PathBuf;
use std::sync::Arc;
Expand All @@ -17,6 +22,14 @@ mod udp;

static DEFAULT_DNS: OnceCell<Arc<dyn Client>> = OnceCell::const_new();

static DEFAULT_LOOKUPS: Lazy<LookupCache> = Lazy::new(|| {
let cache = Cache::builder()
.max_capacity(4096)
.time_to_live(Duration::from_secs(30))
.build();
LookupCache(cache)
});

pub async fn default_dns() -> Result<Arc<dyn Client>> {
let v = DEFAULT_DNS
.get_or_try_init(|| async {
Expand Down Expand Up @@ -98,7 +111,7 @@ pub async fn request(dns: &DNS, request: &Message, timeout: Duration) -> Result<
}
Address::HostAddr(host_addr) => {
let domain = &host_addr.host;
let ip = lookup(domain, timeout).await?;
let ip = DEFAULT_LOOKUPS.lookup(domain, timeout).await?;
let addr = SocketAddr::new(IpAddr::V4(ip), host_addr.port);
let c = DoTClient::builder(addr)
.sni(domain.as_ref())
Expand All @@ -112,7 +125,7 @@ pub async fn request(dns: &DNS, request: &Message, timeout: Duration) -> Result<
Address::SocketAddr(addr) => DoHClient::builder(*addr).https(doh_addr.https),
Address::HostAddr(addr) => {
let domain = &addr.host;
let ip = lookup(domain, timeout).await?;
let ip = DEFAULT_LOOKUPS.lookup(domain, timeout).await?;
let mut bu = DoHClient::builder(SocketAddr::new(IpAddr::V4(ip), addr.port))
.host(domain)
.https(doh_addr.https);
Expand All @@ -130,29 +143,60 @@ pub async fn request(dns: &DNS, request: &Message, timeout: Duration) -> Result<
}
}

#[inline]
async fn lookup(host: &str, timeout: Duration) -> Result<Ipv4Addr> {
// TODO: add cache
let flags = Flags::builder()
.request()
.recursive_query(true)
.opcode(OpCode::StandardQuery)
.build();
let req0 = Message::builder()
.id(1234)
.flags(flags)
.question(host, Kind::A, Class::IN)
.build()?;
struct LookupCache(Cache<Cachestr, SmallVec<[Ipv4Addr; 2]>>);

let v = DefaultClient.request(&req0).await?;
impl LookupCache {
async fn lookup(&self, host: &str, timeout: Duration) -> Result<Ipv4Addr> {
let key = Cachestr::from(host);

for next in v.answers() {
if let Ok(RData::A(a)) = next.rdata() {
return Ok(a.ipaddr());
let res = self
.0
.try_get_with(key, Self::lookup_(host, timeout))
.await
.map_err(|e| anyhow!("lookup failed: {}", e))?;

if let Some(first) = res.first() {
return Ok(Clone::clone(first));
}

bail!(Error::ResolveNothing)
}

bail!(crate::Error::ResolveNothing)
#[inline]
async fn lookup_(host: &str, timeout: Duration) -> Result<SmallVec<[Ipv4Addr; 2]>> {
let flags = Flags::builder()
.request()
.recursive_query(true)
.opcode(OpCode::StandardQuery)
.build();

let id = {
use rand::prelude::*;

let mut rng = thread_rng();
rng.gen_range(1..u16::MAX)
};

let req0 = Message::builder()
.id(id)
.flags(flags)
.question(host, Kind::A, Class::IN)
.build()?;

let mut ret = SmallVec::<[Ipv4Addr; 2]>::new();
let v = DefaultClient.request(&req0).await?;
for next in v.answers() {
if let Ok(RData::A(a)) = next.rdata() {
ret.push(a.ipaddr());
}
}

if !ret.is_empty() {
return Ok(ret);
}

bail!(Error::ResolveNothing)
}
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion src/cmds/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ fn print_resolve_result(
println!();
println!(";; Query time: {} msec", cost.num_milliseconds());
println!(";; SERVER: {}", &dns);
println!(";; WHEN: {}", &begin);
println!(";; WHEN: {}", begin.to_rfc2822());
println!(";; MSG SIZE\trcvd: {}", res.len());

Ok(())
Expand Down
5 changes: 3 additions & 2 deletions src/misc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ pub(crate) fn is_valid_domain(domain: &str) -> bool {
return true;
}

static RE: Lazy<regex::Regex> =
Lazy::new(|| regex::Regex::new("^([a-z0-9_-]{1,63})(\\.[a-z0-9_-]{1,63})+\\.?$").unwrap());
static RE: Lazy<regex::Regex> = Lazy::new(|| {
regex::Regex::new("^([a-zA-Z0-9_-]{1,63})(\\.[a-zA-Z0-9_-]{1,63})+\\.?$").unwrap()
});

RE.is_match(domain)
}
54 changes: 28 additions & 26 deletions src/server/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ fn convert_error_to_message(
err: anyhow::Error,
attach_questions: bool,
) -> Message {
error!("failed to handle dns request: {:?}", err);

let rid = request.id();
let rflags = request.flags();

Expand All @@ -45,6 +43,25 @@ fn convert_error_to_message(
}
}

// log those internal server failure:
match rcode {
RCode::ServerFailure => match request.questions().next() {
Some(question) => {
let name = question.name();
error!("failed to handle dns request {}: {:?}", name, err);
}
None => error!("failed to handle dns request: {:?}", err),
},
RCode::NameError => match request.questions().next() {
Some(question) => {
let name = question.name();
warn!("failed to handle dns request {}: {:?}", name, err);
}
None => warn!("failed to handle dns request: {:?}", err),
},
_ => (),
}

let flags = {
let mut bu = Flags::builder()
.response()
Expand All @@ -68,13 +85,17 @@ fn convert_error_to_message(
bu.build().unwrap()
}

pub(super) async fn handle<H, C>(mut req: Message, h: Arc<H>, cache: Option<Arc<C>>) -> Message
pub(super) async fn handle<H, C>(
mut req: Message,
h: Arc<H>,
cache: Option<Arc<C>>,
) -> (Message, bool)
where
H: Handler,
C: CacheStore,
{
if let Err(e) = validate_request(&req) {
return convert_error_to_message(&req, e, false);
return (convert_error_to_message(&req, e, false), false);
}

if let Some(cache) = cache.as_deref() {
Expand All @@ -85,8 +106,7 @@ where

if let Some(mut exist) = cached {
exist.set_id(id);
debug!("use dns cache");
return exist;
return (exist, true);
}
}

Expand All @@ -110,29 +130,11 @@ where
req.set_id(0);
cache.set(&req, &msg).await;
req.set_id(id);

debug!("set dns cache ok");
}
}

if msg.answer_count() > 0 {
for next in msg.answers() {
if let Ok(rdata) = next.rdata() {
info!(
"0x{:04x} <- {}.\t{}\t{:?}\t{:?}\t{}",
msg.id(),
next.name(),
next.time_to_live(),
next.class(),
next.kind(),
rdata,
);
}
}
}

msg
(msg, false)
}
Err(e) => convert_error_to_message(&req, e, true),
Err(e) => (convert_error_to_message(&req, e, true), false),
}
}
33 changes: 31 additions & 2 deletions src/server/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,37 @@ where
let req = next?;
let handler = Clone::clone(&handler);
let cache = Clone::clone(&cache);
let msg = super::helper::handle(req, handler, cache).await;
w.send(&msg).await?;
let (res, cached) = super::helper::handle(req, handler, cache).await;

if res.answer_count() > 0 {
for next in res.answers() {
if let Ok(rdata) = next.rdata() {
if cached {
info!(
"0x{:04x} <- {}.\t{}\t{:?}\t{:?}\t{}\t<CACHE>",
res.id(),
next.name(),
next.time_to_live(),
next.class(),
next.kind(),
rdata,
);
} else {
info!(
"0x{:04x} <- {}.\t{}\t{:?}\t{:?}\t{}",
res.id(),
next.name(),
next.time_to_live(),
next.class(),
next.kind(),
rdata,
);
}
}
}
}

w.send(&res).await?;
}

Ok(())
Expand Down
33 changes: 31 additions & 2 deletions src/server/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,37 @@ where
h: Arc<H>,
cache: Option<Arc<C>>,
) {
let result = helper::handle(req, h, cache).await;
if let Err(e) = socket.send_to(result.as_ref(), peer).await {
let (res, cached) = helper::handle(req, h, cache).await;

if res.answer_count() > 0 {
for next in res.answers() {
if let Ok(rdata) = next.rdata() {
if cached {
info!(
"0x{:04x} <- {}.\t{}\t{:?}\t{:?}\t{}\t<CACHE>",
res.id(),
next.name(),
next.time_to_live(),
next.class(),
next.kind(),
rdata,
);
} else {
info!(
"0x{:04x} <- {}.\t{}\t{:?}\t{:?}\t{}",
res.id(),
next.name(),
next.time_to_live(),
next.class(),
next.kind(),
rdata,
);
}
}
}
}

if let Err(e) = socket.send_to(res.as_ref(), peer).await {
error!("failed to reply dns response: {:?}", e);
}
}
Expand Down

0 comments on commit 299fc89

Please sign in to comment.