Skip to content

Commit

Permalink
feat: bump some deps to newest version
Browse files Browse the repository at this point in the history
  • Loading branch information
jjeffcaii committed Jan 3, 2025
1 parent 156367e commit 4f420e0
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 42 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "zerodns"
version = "0.1.0-alpha.7"
version = "0.1.0-alpha.8"
edition = "2021"
license = "MIT"
readme = "README.md"
Expand Down Expand Up @@ -55,10 +55,10 @@ regex = "1.10"
hex = "0.4"
strum = { version = "0.26", default-features = false, features = ["strum_macros", "derive"] }
strum_macros = "0.26"
deadpool = "0.10"
deadpool = "0.12"
socket2 = "0.5"
mlua = { version = "0.9.9", features = ["luajit", "vendored", "serialize", "async", "macros", "send", "parking_lot"] }
garde = { version = "0.20", features = ["serde", "derive", "regex"] }
mlua = { version = "0.10", features = ["luajit", "vendored", "serialize", "async", "macros", "send", "anyhow"] }
garde = { version = "0.21", features = ["serde", "derive", "regex"] }
rustls = "0.23"
webpki-roots = "0.26"
tokio-rustls = "0.26"
Expand Down
31 changes: 31 additions & 0 deletions src/client/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,37 @@ mod tests {
pretty_env_logger::try_init_timed().ok();
}

#[tokio::test]
async fn test_wield() -> anyhow::Result<()> {
init();

let req = {
let s = "a4e5010000010000000000000a68747470733a2f2f696d0864696e6774616c6b03636f6d0000010001";
let b = hex::decode(s)?;
Message::from(b)
};

let res = UdpClient::aliyun().request(&req).await?;

info!("questions: {}", res.question_count());
info!("answers: {}", res.answer_count());
info!("additional: {}", res.additional_count());
info!("authority: {}", res.authority_count());

for next in res.answers() {
info!(
"{}.\t{}\t{:?}\t{:?}\t{}",
next.name(),
next.time_to_live(),
next.class(),
next.kind(),
next.rdata().unwrap()
);
}

Ok(())
}

#[tokio::test]
async fn test_request() -> anyhow::Result<()> {
init();
Expand Down
25 changes: 13 additions & 12 deletions src/filter/lua.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::filter::{handle_next, Context, ContextFlags, FilterFactory, Options};
use crate::protocol::{Class, Flags, Kind, Message, DNS};
use async_trait::async_trait;
use mlua::prelude::*;
use mlua::{Function, Lua, UserData, UserDataMethods};
use mlua::{Function, Lua, UserData};
use once_cell::sync::Lazy;
use smallvec::SmallVec;
use std::net::Ipv4Addr;
Expand All @@ -24,7 +24,7 @@ static RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
struct LuaLoggerModule;

impl UserData for LuaLoggerModule {
fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) {
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
methods.add_method("debug", |_lua, _this, msg: LuaString| {
debug!("{}", msg.to_string_lossy());
Ok(())
Expand All @@ -47,15 +47,15 @@ impl UserData for LuaLoggerModule {
struct LuaJsonModule;

impl UserData for LuaJsonModule {
fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) {
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
methods.add_method("encode", |lua, _, value: mlua::Value| {
let mut b = SmallVec::<[u8; 512]>::new();
serde_json::to_writer(&mut b, &value).map_err(LuaError::external)?;
lua.create_string(&b[..])
});
methods.add_method("decode", |lua, _, input: LuaString| {
let s = input.to_str()?;
let v = serde_json::from_str::<serde_json::Value>(s).map_err(LuaError::external)?;
let v = serde_json::from_str::<serde_json::Value>(&s).map_err(LuaError::external)?;
lua.to_value(&v)
});
}
Expand All @@ -64,8 +64,8 @@ impl UserData for LuaJsonModule {
#[derive(Clone)]
struct LuaMessage(Message);

impl<'lua> FromLua<'lua> for LuaMessage {
fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult<Self> {
impl FromLua for LuaMessage {
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
match value {
LuaValue::UserData(data) => Ok(Clone::clone(&*data.borrow::<Self>()?)),
_ => unreachable!(),
Expand All @@ -74,7 +74,7 @@ impl<'lua> FromLua<'lua> for LuaMessage {
}

impl UserData for LuaMessage {
fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) {
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
methods.add_method("questions_count", |_, this, ()| Ok(this.0.question_count()));
methods.add_method("flags", |lua, this, ()| Ok(LuaFlags(this.0.flags())));
methods.add_method("questions", |lua, this, ()| {
Expand Down Expand Up @@ -138,13 +138,14 @@ impl UserData for LuaMessage {
struct LuaContext(*mut Context, *mut Message, *mut Option<Message>);

impl UserData for LuaContext {
fn add_fields<'lua, F: LuaUserDataFields<'lua, Self>>(fields: &mut F) {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("request", |lua, this| {
let msg = LuaMessage(Clone::clone(unsafe { this.1.as_ref().unwrap() }));
Ok(msg)
});
}
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {

fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
methods.add_method("nocache", |_lua, this, ()| {
let ctx = unsafe { this.0.as_mut().unwrap() };
ctx.flags.set(ContextFlags::NO_CACHE, true);
Expand Down Expand Up @@ -201,7 +202,7 @@ impl UserData for LuaContext {
struct LuaFlags(Flags);

impl UserData for LuaFlags {
fn add_fields<'lua, F: LuaUserDataFields<'lua, Self>>(fields: &mut F) {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("opcode", |_lua, this| Ok(this.0.opcode() as u16));
fields.add_field_method_get("response_code", |_lua, this| {
Ok(this.0.response_code() as u8)
Expand Down Expand Up @@ -239,12 +240,12 @@ impl Filter for LuaFilter {
let lua = self.vm.lock().await;
let globals = lua.globals();

let handler = globals.get::<_, Function>("handle");
let handler = globals.get::<Function>("handle");

if let Ok(handler) = handler {
lua.scope(|scope| {
let uctx = scope.create_userdata(LuaContext(ctx, req, res))?;
let _ = handler.call::<_, Option<LuaValue>>(uctx)?;
let _ = handler.call::<Option<LuaValue>>(uctx)?;
Ok(())
})?;
}
Expand Down
84 changes: 71 additions & 13 deletions src/misc/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::Result;
use deadpool::managed::{self, Metrics, RecycleError, RecycleResult};
use futures::future;
use hashbrown::HashMap;
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use socket2::{Domain, Protocol, SockAddr, Type};
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
Expand Down Expand Up @@ -50,14 +52,8 @@ impl Manager {
pub fn key(&self) -> Key {
self.key
}
}

#[async_trait::async_trait]
impl managed::Manager for Manager {
type Type = (u32, TcpStream);
type Error = anyhow::Error;

async fn create(&self) -> std::result::Result<Self::Type, Self::Error> {
fn connect(&self) -> Result<TcpStream> {
let stream: std::net::TcpStream = {
let dst = SockAddr::from(self.key.0);

Expand All @@ -82,24 +78,86 @@ impl managed::Manager for Manager {
};

let socket = TcpStream::from_std(stream)?;
Ok((0, socket))

Ok(socket)
}
}

async fn recycle(&self, obj: &mut Self::Type, metrics: &Metrics) -> RecycleResult<Self::Error> {
#[async_trait::async_trait]
impl managed::Manager for Manager {
type Type = (u32, TcpStream);
type Error = anyhow::Error;

fn create(&self) -> impl Future<Output = std::result::Result<Self::Type, Self::Error>> + Send {
match self.connect() {
Ok(socket) => future::ok((0, socket)),
Err(e) => future::err(e),
}
}

fn recycle(
&self,
obj: &mut Self::Type,
metrics: &Metrics,
) -> impl Future<Output = RecycleResult<Self::Error>> + Send {
if metrics.created.elapsed() > self.lifetime {
return Err(RecycleError::Backend(anyhow!("exceed max lifetime!")));
return future::err(RecycleError::Backend(anyhow!("exceed max lifetime!")));
}

if obj.0 != 0 {
return Err(RecycleError::Backend(anyhow!("invalid connection!")));
return future::err(RecycleError::Backend(anyhow!("invalid connection!")));
}

if let Err(e) = validate(&obj.1) {
return Err(RecycleError::Backend(e));
return future::err(RecycleError::Backend(e));
}

Ok(())
future::ok(())
}

// async fn create(&self) -> std::result::Result<Self::Type, Self::Error> {
// let stream: std::net::TcpStream = {
// let dst = SockAddr::from(self.key.0);
//
// let socket = socket2::Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?;
// socket.set_nodelay(true)?;
// socket.set_keepalive(true)?;
//
// if let Some(source) = self.key.1 {
// socket.set_reuse_address(true)?;
// socket.set_reuse_port(true)?;
// let src = SockAddr::from(source);
// socket
// .bind(&src)
// .map_err(|e| crate::Error::NetworkBindFailure(source, e))?;
// }
//
// socket.connect(&dst)?;
//
// socket.set_nonblocking(true)?;
//
// socket.into()
// };
//
// let socket = TcpStream::from_std(stream)?;
// Ok((0, socket))
// }
//
// async fn recycle(&self, obj: &mut Self::Type, metrics: &Metrics) -> RecycleResult<Self::Error> {
// if metrics.created.elapsed() > self.lifetime {
// return Err(RecycleError::Backend(anyhow!("exceed max lifetime!")));
// }
//
// if obj.0 != 0 {
// return Err(RecycleError::Backend(anyhow!("invalid connection!")));
// }
//
// if let Err(e) = validate(&obj.1) {
// return Err(RecycleError::Backend(e));
// }
//
// Ok(())
// }
}

#[inline]
Expand Down
37 changes: 24 additions & 13 deletions src/misc/tls.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::Result;
use deadpool::managed;
use deadpool::managed::{Metrics, RecycleError, RecycleResult};
use futures::{future, FutureExt};
use hashbrown::HashMap;
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use rustls::pki_types::ServerName;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -30,35 +32,44 @@ pub(crate) struct Manager {
lifetime: Duration,
}

#[async_trait::async_trait]
impl managed::Manager for Manager {
type Type = (u32, TlsStream<TcpStream>);
type Error = anyhow::Error;

async fn create(&self) -> std::result::Result<Self::Type, Self::Error> {
impl Manager {
#[inline]
async fn connect(&self) -> Result<TlsStream<TcpStream>> {
let connector = TlsConnector::from(Clone::clone(&*DEFAULT_TLS_CLIENT_CONFIG));
let dnsname = ServerName::try_from(self.key.0.to_string())?;

let stream = TcpStream::connect(self.key.1).await?;
let stream = connector.connect(dnsname, stream).await?;
Ok(stream)
}
}

#[async_trait::async_trait]
impl managed::Manager for Manager {
type Type = (u32, TlsStream<TcpStream>);
type Error = anyhow::Error;

Ok((0, stream))
fn create(&self) -> impl Future<Output = std::result::Result<Self::Type, Self::Error>> + Send {
self.connect().map(|it| it.map(|it| (0, it)))
}

async fn recycle(&self, obj: &mut Self::Type, metrics: &Metrics) -> RecycleResult<Self::Error> {
fn recycle(
&self,
obj: &mut Self::Type,
metrics: &Metrics,
) -> impl Future<Output = RecycleResult<Self::Error>> + Send {
if metrics.created.elapsed() > self.lifetime {
return Err(RecycleError::Backend(anyhow!("exceed max lifetime!")));
return future::err(RecycleError::Backend(anyhow!("exceed max lifetime!")));
}

if obj.0 != 0 {
return Err(RecycleError::Backend(anyhow!("invalid connection!")));
return future::err(RecycleError::Backend(anyhow!("invalid connection!")));
}

if let Err(e) = validate(&obj.1) {
return Err(RecycleError::Backend(e));
return future::err(RecycleError::Backend(e));
}

Ok(())
future::ok(())
}
}

Expand Down
14 changes: 14 additions & 0 deletions src/protocol/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1861,6 +1861,20 @@ mod tests {
assert_eq!(RCode::NoError, flags.response_code());
}

#[test]
fn test_addition() {
init();

let s = "a4e5808000010000000400120a68747470733a2f2f696d0864696e6774616c6b03636f6d0000010001c017000200010000001e000d036e73360674616f62616fc020c017000200010000001e0006036e7335c039c017000200010000001e0006036e7337c039c017000200010000001e0006036e7334c039c072000100010000001e0004aa211849c072000100010000001e0004aa21184bc072000100010000001e00042f584a21c072000100010000001e00042f584a23c072000100010000001e00042ff1cf0dc072000100010000001e00042ff1cf0fc04e000100010000001e00048ccd7a21c04e000100010000001e00048ccd7a22c035000100010000001e00048ccd7a24c035000100010000001e00048ccd7a23c060000100010000001e00046a0b2996c060000100010000001e00046a0b2319c060000100010000001e00046a0b231ac060000100010000001e00046a0b2995c072001c00010000001e00102401b180410000000000000000000004c04e001c00010000001e00102401b180410000000000000000000005c035001c00010000001e00102401b180410000000000000000000006c060001c00010000001e00102401b180410000000000000000000007";
let b = hex::decode(s).unwrap();
let msg = Message::from(b);

assert_eq!(1, msg.question_count(), "invalid question count");
assert_eq!(0, msg.answer_count(), "invalid answer count");
assert_eq!(18, msg.additional_count(), "invalid additional count");
assert_eq!(4, msg.authority_count(), "invalid authority count");
}

#[test]
fn test_broken() {
init();
Expand Down

0 comments on commit 4f420e0

Please sign in to comment.