From 4f420e05af8bff9f1fa3cfc3777695f47977cb99 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Fri, 3 Jan 2025 20:08:35 +0800 Subject: [PATCH] feat: bump some deps to newest version --- Cargo.toml | 8 ++--- src/client/udp.rs | 31 ++++++++++++++++ src/filter/lua.rs | 25 ++++++------- src/misc/tcp.rs | 84 ++++++++++++++++++++++++++++++++++++------- src/misc/tls.rs | 37 ++++++++++++------- src/protocol/frame.rs | 14 ++++++++ 6 files changed, 157 insertions(+), 42 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fcb07a3..2891329 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zerodns" -version = "0.1.0-alpha.7" +version = "0.1.0-alpha.8" edition = "2021" license = "MIT" readme = "README.md" @@ -55,10 +55,10 @@ regex = "1.10" hex = "0.4" strum = { version = "0.26", default-features = false, features = ["strum_macros", "derive"] } strum_macros = "0.26" -deadpool = "0.10" +deadpool = "0.12" socket2 = "0.5" -mlua = { version = "0.9.9", features = ["luajit", "vendored", "serialize", "async", "macros", "send", "parking_lot"] } -garde = { version = "0.20", features = ["serde", "derive", "regex"] } +mlua = { version = "0.10", features = ["luajit", "vendored", "serialize", "async", "macros", "send", "anyhow"] } +garde = { version = "0.21", features = ["serde", "derive", "regex"] } rustls = "0.23" webpki-roots = "0.26" tokio-rustls = "0.26" diff --git a/src/client/udp.rs b/src/client/udp.rs index 386b36d..789ff3f 100644 --- a/src/client/udp.rs +++ b/src/client/udp.rs @@ -302,6 +302,37 @@ mod tests { pretty_env_logger::try_init_timed().ok(); } + #[tokio::test] + async fn test_wield() -> anyhow::Result<()> { + init(); + + let req = { + let s = "a4e5010000010000000000000a68747470733a2f2f696d0864696e6774616c6b03636f6d0000010001"; + let b = hex::decode(s)?; + Message::from(b) + }; + + let res = UdpClient::aliyun().request(&req).await?; + + info!("questions: {}", res.question_count()); + info!("answers: {}", res.answer_count()); + info!("additional: {}", res.additional_count()); + info!("authority: {}", res.authority_count()); + + for next in res.answers() { + info!( + "{}.\t{}\t{:?}\t{:?}\t{}", + next.name(), + next.time_to_live(), + next.class(), + next.kind(), + next.rdata().unwrap() + ); + } + + Ok(()) + } + #[tokio::test] async fn test_request() -> anyhow::Result<()> { init(); diff --git a/src/filter/lua.rs b/src/filter/lua.rs index 1dc6f86..42fe367 100644 --- a/src/filter/lua.rs +++ b/src/filter/lua.rs @@ -4,7 +4,7 @@ use crate::filter::{handle_next, Context, ContextFlags, FilterFactory, Options}; use crate::protocol::{Class, Flags, Kind, Message, DNS}; use async_trait::async_trait; use mlua::prelude::*; -use mlua::{Function, Lua, UserData, UserDataMethods}; +use mlua::{Function, Lua, UserData}; use once_cell::sync::Lazy; use smallvec::SmallVec; use std::net::Ipv4Addr; @@ -24,7 +24,7 @@ static RUNTIME: Lazy = Lazy::new(|| { struct LuaLoggerModule; impl UserData for LuaLoggerModule { - fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("debug", |_lua, _this, msg: LuaString| { debug!("{}", msg.to_string_lossy()); Ok(()) @@ -47,7 +47,7 @@ impl UserData for LuaLoggerModule { struct LuaJsonModule; impl UserData for LuaJsonModule { - fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("encode", |lua, _, value: mlua::Value| { let mut b = SmallVec::<[u8; 512]>::new(); serde_json::to_writer(&mut b, &value).map_err(LuaError::external)?; @@ -55,7 +55,7 @@ impl UserData for LuaJsonModule { }); methods.add_method("decode", |lua, _, input: LuaString| { let s = input.to_str()?; - let v = serde_json::from_str::(s).map_err(LuaError::external)?; + let v = serde_json::from_str::(&s).map_err(LuaError::external)?; lua.to_value(&v) }); } @@ -64,8 +64,8 @@ impl UserData for LuaJsonModule { #[derive(Clone)] struct LuaMessage(Message); -impl<'lua> FromLua<'lua> for LuaMessage { - fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult { +impl FromLua for LuaMessage { + fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult { match value { LuaValue::UserData(data) => Ok(Clone::clone(&*data.borrow::()?)), _ => unreachable!(), @@ -74,7 +74,7 @@ impl<'lua> FromLua<'lua> for LuaMessage { } impl UserData for LuaMessage { - fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("questions_count", |_, this, ()| Ok(this.0.question_count())); methods.add_method("flags", |lua, this, ()| Ok(LuaFlags(this.0.flags()))); methods.add_method("questions", |lua, this, ()| { @@ -138,13 +138,14 @@ impl UserData for LuaMessage { struct LuaContext(*mut Context, *mut Message, *mut Option); impl UserData for LuaContext { - fn add_fields<'lua, F: LuaUserDataFields<'lua, Self>>(fields: &mut F) { + fn add_fields>(fields: &mut F) { fields.add_field_method_get("request", |lua, this| { let msg = LuaMessage(Clone::clone(unsafe { this.1.as_ref().unwrap() })); Ok(msg) }); } - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + + fn add_methods>(methods: &mut M) { methods.add_method("nocache", |_lua, this, ()| { let ctx = unsafe { this.0.as_mut().unwrap() }; ctx.flags.set(ContextFlags::NO_CACHE, true); @@ -201,7 +202,7 @@ impl UserData for LuaContext { struct LuaFlags(Flags); impl UserData for LuaFlags { - fn add_fields<'lua, F: LuaUserDataFields<'lua, Self>>(fields: &mut F) { + fn add_fields>(fields: &mut F) { fields.add_field_method_get("opcode", |_lua, this| Ok(this.0.opcode() as u16)); fields.add_field_method_get("response_code", |_lua, this| { Ok(this.0.response_code() as u8) @@ -239,12 +240,12 @@ impl Filter for LuaFilter { let lua = self.vm.lock().await; let globals = lua.globals(); - let handler = globals.get::<_, Function>("handle"); + let handler = globals.get::("handle"); if let Ok(handler) = handler { lua.scope(|scope| { let uctx = scope.create_userdata(LuaContext(ctx, req, res))?; - let _ = handler.call::<_, Option>(uctx)?; + let _ = handler.call::>(uctx)?; Ok(()) })?; } diff --git a/src/misc/tcp.rs b/src/misc/tcp.rs index d7ca80c..6193a42 100644 --- a/src/misc/tcp.rs +++ b/src/misc/tcp.rs @@ -1,9 +1,11 @@ use crate::Result; use deadpool::managed::{self, Metrics, RecycleError, RecycleResult}; +use futures::future; use hashbrown::HashMap; use once_cell::sync::Lazy; use parking_lot::RwLock; use socket2::{Domain, Protocol, SockAddr, Type}; +use std::future::Future; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -50,14 +52,8 @@ impl Manager { pub fn key(&self) -> Key { self.key } -} -#[async_trait::async_trait] -impl managed::Manager for Manager { - type Type = (u32, TcpStream); - type Error = anyhow::Error; - - async fn create(&self) -> std::result::Result { + fn connect(&self) -> Result { let stream: std::net::TcpStream = { let dst = SockAddr::from(self.key.0); @@ -82,24 +78,86 @@ impl managed::Manager for Manager { }; let socket = TcpStream::from_std(stream)?; - Ok((0, socket)) + + Ok(socket) } +} - async fn recycle(&self, obj: &mut Self::Type, metrics: &Metrics) -> RecycleResult { +#[async_trait::async_trait] +impl managed::Manager for Manager { + type Type = (u32, TcpStream); + type Error = anyhow::Error; + + fn create(&self) -> impl Future> + Send { + match self.connect() { + Ok(socket) => future::ok((0, socket)), + Err(e) => future::err(e), + } + } + + fn recycle( + &self, + obj: &mut Self::Type, + metrics: &Metrics, + ) -> impl Future> + Send { if metrics.created.elapsed() > self.lifetime { - return Err(RecycleError::Backend(anyhow!("exceed max lifetime!"))); + return future::err(RecycleError::Backend(anyhow!("exceed max lifetime!"))); } if obj.0 != 0 { - return Err(RecycleError::Backend(anyhow!("invalid connection!"))); + return future::err(RecycleError::Backend(anyhow!("invalid connection!"))); } if let Err(e) = validate(&obj.1) { - return Err(RecycleError::Backend(e)); + return future::err(RecycleError::Backend(e)); } - Ok(()) + future::ok(()) } + + // async fn create(&self) -> std::result::Result { + // let stream: std::net::TcpStream = { + // let dst = SockAddr::from(self.key.0); + // + // let socket = socket2::Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?; + // socket.set_nodelay(true)?; + // socket.set_keepalive(true)?; + // + // if let Some(source) = self.key.1 { + // socket.set_reuse_address(true)?; + // socket.set_reuse_port(true)?; + // let src = SockAddr::from(source); + // socket + // .bind(&src) + // .map_err(|e| crate::Error::NetworkBindFailure(source, e))?; + // } + // + // socket.connect(&dst)?; + // + // socket.set_nonblocking(true)?; + // + // socket.into() + // }; + // + // let socket = TcpStream::from_std(stream)?; + // Ok((0, socket)) + // } + // + // async fn recycle(&self, obj: &mut Self::Type, metrics: &Metrics) -> RecycleResult { + // if metrics.created.elapsed() > self.lifetime { + // return Err(RecycleError::Backend(anyhow!("exceed max lifetime!"))); + // } + // + // if obj.0 != 0 { + // return Err(RecycleError::Backend(anyhow!("invalid connection!"))); + // } + // + // if let Err(e) = validate(&obj.1) { + // return Err(RecycleError::Backend(e)); + // } + // + // Ok(()) + // } } #[inline] diff --git a/src/misc/tls.rs b/src/misc/tls.rs index 34b55e5..0c82326 100644 --- a/src/misc/tls.rs +++ b/src/misc/tls.rs @@ -1,10 +1,12 @@ use crate::Result; use deadpool::managed; use deadpool::managed::{Metrics, RecycleError, RecycleResult}; +use futures::{future, FutureExt}; use hashbrown::HashMap; use once_cell::sync::Lazy; use parking_lot::RwLock; use rustls::pki_types::ServerName; +use std::future::Future; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -30,35 +32,44 @@ pub(crate) struct Manager { lifetime: Duration, } -#[async_trait::async_trait] -impl managed::Manager for Manager { - type Type = (u32, TlsStream); - type Error = anyhow::Error; - - async fn create(&self) -> std::result::Result { +impl Manager { + #[inline] + async fn connect(&self) -> Result> { let connector = TlsConnector::from(Clone::clone(&*DEFAULT_TLS_CLIENT_CONFIG)); let dnsname = ServerName::try_from(self.key.0.to_string())?; - let stream = TcpStream::connect(self.key.1).await?; let stream = connector.connect(dnsname, stream).await?; + Ok(stream) + } +} + +#[async_trait::async_trait] +impl managed::Manager for Manager { + type Type = (u32, TlsStream); + type Error = anyhow::Error; - Ok((0, stream)) + fn create(&self) -> impl Future> + Send { + self.connect().map(|it| it.map(|it| (0, it))) } - async fn recycle(&self, obj: &mut Self::Type, metrics: &Metrics) -> RecycleResult { + fn recycle( + &self, + obj: &mut Self::Type, + metrics: &Metrics, + ) -> impl Future> + Send { if metrics.created.elapsed() > self.lifetime { - return Err(RecycleError::Backend(anyhow!("exceed max lifetime!"))); + return future::err(RecycleError::Backend(anyhow!("exceed max lifetime!"))); } if obj.0 != 0 { - return Err(RecycleError::Backend(anyhow!("invalid connection!"))); + return future::err(RecycleError::Backend(anyhow!("invalid connection!"))); } if let Err(e) = validate(&obj.1) { - return Err(RecycleError::Backend(e)); + return future::err(RecycleError::Backend(e)); } - Ok(()) + future::ok(()) } } diff --git a/src/protocol/frame.rs b/src/protocol/frame.rs index 2c47f05..e2b9e2a 100644 --- a/src/protocol/frame.rs +++ b/src/protocol/frame.rs @@ -1861,6 +1861,20 @@ mod tests { assert_eq!(RCode::NoError, flags.response_code()); } + #[test] + fn test_addition() { + init(); + + let s = "a4e5808000010000000400120a68747470733a2f2f696d0864696e6774616c6b03636f6d0000010001c017000200010000001e000d036e73360674616f62616fc020c017000200010000001e0006036e7335c039c017000200010000001e0006036e7337c039c017000200010000001e0006036e7334c039c072000100010000001e0004aa211849c072000100010000001e0004aa21184bc072000100010000001e00042f584a21c072000100010000001e00042f584a23c072000100010000001e00042ff1cf0dc072000100010000001e00042ff1cf0fc04e000100010000001e00048ccd7a21c04e000100010000001e00048ccd7a22c035000100010000001e00048ccd7a24c035000100010000001e00048ccd7a23c060000100010000001e00046a0b2996c060000100010000001e00046a0b2319c060000100010000001e00046a0b231ac060000100010000001e00046a0b2995c072001c00010000001e00102401b180410000000000000000000004c04e001c00010000001e00102401b180410000000000000000000005c035001c00010000001e00102401b180410000000000000000000006c060001c00010000001e00102401b180410000000000000000000007"; + let b = hex::decode(s).unwrap(); + let msg = Message::from(b); + + assert_eq!(1, msg.question_count(), "invalid question count"); + assert_eq!(0, msg.answer_count(), "invalid answer count"); + assert_eq!(18, msg.additional_count(), "invalid additional count"); + assert_eq!(4, msg.authority_count(), "invalid authority count"); + } + #[test] fn test_broken() { init();