diff --git a/Cargo.lock b/Cargo.lock index 91b5e953..e86183df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -536,7 +536,7 @@ dependencies = [ "bitflags 2.6.0", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.13.0", "log", "prettyplease", "proc-macro2", @@ -1524,7 +1524,7 @@ checksum = "e1b84d32b18d9a256d81e4fec2e4cfd0ab6dde5e5ff49be1713ae0adbd0060c2" dependencies = [ "heck 0.5.0", "indexmap 2.5.0", - "itertools 0.10.5", + "itertools 0.13.0", "proc-macro-crate 3.2.0", "proc-macro2", "quote", @@ -4013,7 +4013,7 @@ checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" dependencies = [ "bytes", "heck 0.5.0", - "itertools 0.10.5", + "itertools 0.13.0", "log", "multimap", "once_cell", @@ -4033,7 +4033,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.13.0", "proc-macro2", "quote", "syn 2.0.77", diff --git a/clash_lib/src/app/dns/config.rs b/clash_lib/src/app/dns/config.rs index 868e3a86..46169593 100644 --- a/clash_lib/src/app/dns/config.rs +++ b/clash_lib/src/app/dns/config.rs @@ -45,7 +45,7 @@ pub struct FallbackFilter { pub domain: Vec, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Clone)] #[serde(rename_all = "kebab-case")] pub struct DoHConfig { pub addr: SocketAddr, @@ -54,7 +54,7 @@ pub struct DoHConfig { pub hostname: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Clone)] #[serde(rename_all = "kebab-case")] pub struct DoH3Config { pub addr: SocketAddr, @@ -63,7 +63,7 @@ pub struct DoH3Config { pub hostname: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Clone)] #[serde(rename_all = "kebab-case")] pub struct DoTConfig { pub addr: SocketAddr, @@ -74,7 +74,7 @@ pub struct DoTConfig { pub type DnsServerKey = Option; pub type DnsServerCert = Option; -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct DNSListenAddr { pub udp: Option, pub tcp: Option, diff --git a/clash_lib/src/app/dns/resolver/enhanced.rs b/clash_lib/src/app/dns/resolver/enhanced.rs index 3426a79c..ea46952e 100644 --- a/clash_lib/src/app/dns/resolver/enhanced.rs +++ b/clash_lib/src/app/dns/resolver/enhanced.rs @@ -84,7 +84,7 @@ impl EnhancedResolver { } pub async fn new( - cfg: &Config, + cfg: Config, store: ThreadSafeCacheFile, mmdb: Arc, ) -> Self { @@ -110,7 +110,7 @@ impl EnhancedResolver { Some(default_resolver.clone()), ) .await, - hosts: cfg.hosts.clone(), + hosts: cfg.hosts, fallback: if !cfg.fallback.is_empty() { Some( make_clients( diff --git a/clash_lib/src/app/dns/resolver/mod.rs b/clash_lib/src/app/dns/resolver/mod.rs index c1378b2a..d5aabf26 100644 --- a/clash_lib/src/app/dns/resolver/mod.rs +++ b/clash_lib/src/app/dns/resolver/mod.rs @@ -17,7 +17,7 @@ use crate::{app::profile::ThreadSafeCacheFile, common::mmdb::Mmdb}; use super::{Config, ThreadSafeDNSResolver}; pub async fn new( - cfg: &Config, + cfg: Config, store: Option, mmdb: Option>, ) -> ThreadSafeDNSResolver { diff --git a/clash_lib/src/app/remote_content_manager/providers/rule_provider/provider.rs b/clash_lib/src/app/remote_content_manager/providers/rule_provider/provider.rs index 2108b904..e26d934e 100644 --- a/clash_lib/src/app/remote_content_manager/providers/rule_provider/provider.rs +++ b/clash_lib/src/app/remote_content_manager/providers/rule_provider/provider.rs @@ -20,7 +20,9 @@ use crate::{ }, router::{map_rule_type, RuleMatcher}, }, - common::{errors::map_io_error, geodata::GeoData, mmdb::Mmdb, trie}, + common::{ + errors::map_io_error, geodata::GeoData, mmdb::Mmdb, succinct_set, trie, + }, config::internal::rule::RuleType, session::Session, Error, @@ -52,7 +54,8 @@ impl Display for RuleSetBehavior { } enum RuleContent { - Domain(trie::StringTrie), + // the left will converted into a right + Domain(succinct_set::DomainSet), Ipcidr(Box), Classical(Vec>), } @@ -91,7 +94,7 @@ impl RuleProviderImpl { let inner = Arc::new(tokio::sync::RwLock::new(Inner { content: match behovior { RuleSetBehavior::Domain => { - RuleContent::Domain(trie::StringTrie::new()) + RuleContent::Domain(succinct_set::DomainSet::default()) } RuleSetBehavior::Ipcidr => { RuleContent::Ipcidr(Box::new(CidrTrie::new())) @@ -150,9 +153,7 @@ impl RuleProvider for RuleProviderImpl { match inner { Ok(inner) => match &inner.content { - RuleContent::Domain(trie) => { - trie.search(&sess.destination.host()).is_some() - } + RuleContent::Domain(set) => set.has(&sess.destination.host()), RuleContent::Ipcidr(trie) => trie.contains( sess.destination .ip() @@ -243,7 +244,8 @@ fn make_rules( ) -> Result { match behavior { RuleSetBehavior::Domain => { - Ok(RuleContent::Domain(make_domain_rules(rules)?)) + let s = make_domain_rules(rules)?; + Ok(RuleContent::Domain(s.into())) } RuleSetBehavior::Ipcidr => { Ok(RuleContent::Ipcidr(Box::new(make_ip_cidr_rules(rules)?))) diff --git a/clash_lib/src/common/mod.rs b/clash_lib/src/common/mod.rs index 661a0ce9..bf787f36 100644 --- a/clash_lib/src/common/mod.rs +++ b/clash_lib/src/common/mod.rs @@ -6,6 +6,7 @@ pub mod geodata; pub mod http; pub mod io; pub mod mmdb; +pub mod succinct_set; pub mod timed_future; pub mod tls; pub mod trie; diff --git a/clash_lib/src/common/succinct_set.rs b/clash_lib/src/common/succinct_set.rs new file mode 100644 index 00000000..5d3df7ff --- /dev/null +++ b/clash_lib/src/common/succinct_set.rs @@ -0,0 +1,424 @@ +//! idea: https://github.com/openacid/succinct +//! impl: https://github.com/MetaCubeX/mihomo/blob/Meta/component/trie/domain_set.go +//! I have not idea what's going on here, just copy the code from above link. + +use super::trie::StringTrie; + +static COMPLEX_WILDCARD: u8 = b'+'; +static WILDCARD: u8 = b'*'; +static DOMAIN_STEP: u8 = b'.'; + +#[derive(Default)] +pub struct DomainSet { + leaves: Vec, + label_bit_map: Vec, + labels: Vec, + ranks: Vec, + selects: Vec, +} + +impl DomainSet { + pub fn has(&self, key: &str) -> bool { + let key = key + .chars() + .rev() + .map(|x| x.to_ascii_lowercase()) + .collect::>(); + let mut node_id = 0; + let mut bm_idx = 0; + + struct Cursor { + bm_idx: usize, + index: usize, + } + + let mut stack = vec![]; + + #[derive(PartialEq)] + enum State { + Restart, + Done, + } + + let mut i: usize = 0; + + while i < key.len() + // i++ + { + let mut state = State::Restart; + + 'ctrl: while state == State::Restart { + state = State::Done; + + let c = key[i]; + loop + // bm_idx++ + { + if get_bit(&self.label_bit_map, bm_idx) { + if !stack.is_empty() { + let cursor: Cursor = stack.pop().unwrap(); + let next_node_id = count_zeros( + &self.label_bit_map, + &self.ranks, + cursor.bm_idx + 1, + ); + let mut next_bm_idx = select_ith_one( + &self.label_bit_map, + &self.ranks, + &self.selects, + next_node_id - 1, + ) + 1; + + let mut j = cursor.index; + while j < key.len() && key[j] != DOMAIN_STEP as char { + j += 1; + } + if j == key.len() { + if get_bit(&self.leaves, next_node_id as isize) { + return true; + } else { + state = State::Restart; + continue 'ctrl; + } + } + + while next_bm_idx - next_node_id < self.labels.len() { + if self.labels[next_bm_idx - next_node_id] + == DOMAIN_STEP + { + bm_idx = next_bm_idx as isize; + node_id = next_node_id; + i = j; + + state = State::Restart; + continue 'ctrl; + } + next_bm_idx += 1; + } + } + return false; + } + + if self.labels[bm_idx as usize - node_id] == COMPLEX_WILDCARD { + return true; + } else if self.labels[bm_idx as usize - node_id] == WILDCARD { + let cursor = Cursor { + bm_idx: bm_idx as usize, + index: i, + }; + stack.push(cursor); + } else if self.labels[bm_idx as usize - node_id] == c as u8 { + break; + } + + bm_idx += 1; + } + + node_id = count_zeros( + &self.label_bit_map, + &self.ranks, + bm_idx as usize + 1, + ); + bm_idx = select_ith_one( + &self.label_bit_map, + &self.ranks, + &self.selects, + node_id - 1, + ) as isize + + 1; + + i += 1; + } + } + + get_bit(&self.leaves, node_id as isize) + } + + #[cfg(test)] + pub fn traverse(&self, mut f: F) + where + F: FnMut(&String) -> bool, + { + self.keys(|x| f(&x.chars().rev().collect::())); + } +} + +impl DomainSet { + fn init(&mut self) { + self.ranks.push(0); + for i in 0..self.label_bit_map.len() { + let n = self.label_bit_map[i].count_ones(); + self.ranks.push(self.ranks.last().unwrap() + n as i32); + } + + let mut n = 0; + for i in 0..self.label_bit_map.len() << 6 { + let z = self.label_bit_map[i >> 6] >> (i & 63) & 1; + if z == 1 && n & 63 == 0 { + self.selects.push(i as i32); + } + n += z; + } + } + + #[cfg(test)] + fn keys(&self, mut f: F) + where + F: FnMut(&String) -> bool, + { + let mut current_key = vec![]; + + fn traverse( + this: &DomainSet, + current_key: &mut Vec, + node_id: isize, + bm_idx: isize, + f: &mut F, + ) -> bool + where + F: FnMut(&String) -> bool, + { + if get_bit(&this.leaves, node_id) && !f(¤t_key.iter().collect()) { + return false; + } + + let mut bm_idx = bm_idx; + + loop { + if get_bit(&this.label_bit_map, bm_idx) { + return true; + } + + let next_label = this.labels[(bm_idx - node_id) as usize]; + current_key.push(next_label as char); + let next_node_id = count_zeros( + &this.label_bit_map, + &this.ranks, + bm_idx as usize + 1, + ); + let next_bm_idx = select_ith_one( + &this.label_bit_map, + &this.ranks, + &this.selects, + next_node_id - 1, + ) + 1; + + if !traverse( + this, + current_key, + next_node_id as isize, + next_bm_idx as isize, + f, + ) { + return false; + } + + current_key.pop(); + + bm_idx += 1; + } + } + + traverse(self, &mut current_key, 0, 0, &mut f); + } +} + +struct QElt { + s: usize, + e: usize, + col: usize, +} + +/// Convert a `StringTrie` to a `DomainSet`. +/// TODO: support loading from a binary file. +/// e.g. the so called 'mrs' file in the MiHoMo project. +impl From> for DomainSet { + fn from(value: StringTrie) -> Self { + let mut keys = vec![]; + value.traverse(|key, _| { + keys.push(key.chars().rev().collect::()); + true + }); + keys.sort(); + + let mut rv = DomainSet::default(); + + let mut l_idx = 0; + + let mut queue = vec![QElt { + s: 0, + e: keys.len(), + col: 0, + }]; + + let mut i = 0; + loop { + let elt = &mut queue[i]; + if elt.col == keys[elt.s].len() { + elt.s += 1; + set_bit(&mut rv.leaves, i, true); + } + + let mut j = elt.s; + let e = elt.e; + let col = elt.col; + while j < e { + let frm = j; + while j < e && keys[j].chars().nth(col) == keys[frm].chars().nth(col) + { + j += 1; + } + + queue.push(QElt { + s: frm, + e: j, + col: col + 1, + }); + rv.labels.push(keys[frm].chars().nth(col).unwrap() as u8); + set_bit(&mut rv.label_bit_map, l_idx, false); + l_idx += 1; + } + + set_bit(&mut rv.label_bit_map, l_idx, true); + l_idx += 1; + + if i == queue.len() - 1 { + break; + } + i += 1; + } + + rv.init(); + + rv + } +} + +fn get_bit(bm: &[u64], i: isize) -> bool { + bm[(i >> 6) as usize] & (1 << (i & 63) as usize) != 0 +} + +fn set_bit(bm: &mut Vec, i: usize, v: bool) { + while i >> 6 >= (bm.len()) { + bm.push(0); + } + bm[i >> 6] |= (v as u64) << (i & 63); +} + +fn count_zeros(bm: &[u64], ranks: &[i32], i: usize) -> usize { + i - ranks[i >> 6] as usize + - (bm[i >> 6] & ((1 << (i & 63)) - 1)).count_ones() as usize +} + +fn select_ith_one(bm: &[u64], ranks: &[i32], selects: &[i32], i: usize) -> usize { + let base = selects[i >> 6] & !63; + let mut find_ith_one = i as isize - ranks[base as usize >> 6] as isize; + for (i, w) in bm.iter().enumerate().skip(base as usize >> 6) { + let mut bit_idx = 0; + let mut w = *w; + while w > 0 { + find_ith_one -= (w & 1) as isize; + if find_ith_one < 0 { + return (i << 6) + bit_idx; + } + + let t0 = (w & !1).trailing_zeros(); + w = w.unbounded_shr(t0); + bit_idx += t0 as usize; + } + } + + unreachable!("invalid data"); +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + #[test] + fn test_domain_set_complex_wildcard() { + let mut tree = super::StringTrie::new(); + let domains = vec![ + "baidu.com", + "google.com", + "www.google.com", + "test.a.net", + "test.a.oc", + "mijia cloud", + ".qq.com", + "+.cn", + ]; + + for d in domains { + tree.insert(d, Arc::new(true)); + } + + let mut key_src = vec![]; + tree.traverse(|key, _| { + key_src.push(key.to_owned()); + true + }); + key_src.sort(); + + let set = super::DomainSet::from(tree); + assert!(set.has("test.cn")); + assert!(set.has("cn")); + assert!(set.has("mijia cloud")); + assert!(set.has("test.a.net")); + assert!(set.has("www.qq.com")); + assert!(set.has("google.com")); + assert!(!set.has("qq.com")); + assert!(!set.has("www.baidu.com")); + + test_dump(&key_src, &set); + } + + #[test] + fn test_domain_set_wildcard() { + let mut tree = super::StringTrie::new(); + let domains = vec![ + "*.*.*.baidu.com", + "www.baidu.*", + "stun.*.*", + "*.*.qq.com", + "test.*.baidu.com", + "*.apple.com", + ]; + + for d in domains { + tree.insert(d, Arc::new(true)); + } + + let mut key_src = vec![]; + tree.traverse(|key, _| { + key_src.push(key.to_owned()); + true + }); + key_src.sort(); + + let set = super::DomainSet::from(tree); + + assert!(set.has("www.baidu.com")); + assert!(set.has("test.test.baidu.com")); + assert!(set.has("test.test.qq.com")); + assert!(set.has("stun.ab.cd")); + assert!(!set.has("test.baidu.com")); + assert!(!set.has("www.google.com")); + assert!(!set.has("a.www.google.com")); + assert!(!set.has("test.qq.com")); + assert!(!set.has("test.test.test.qq.com")); + + test_dump(&key_src, &set); + } + + fn test_dump(data_src: &Vec, set: &super::DomainSet) { + let mut data_set = vec![]; + set.traverse(|key| { + data_set.push(key.to_owned()); + true + }); + data_set.sort(); + + assert_eq!(data_src, &data_set); + } +} diff --git a/clash_lib/src/common/trie.rs b/clash_lib/src/common/trie.rs index 029baa7c..25fade81 100644 --- a/clash_lib/src/common/trie.rs +++ b/clash_lib/src/common/trie.rs @@ -1,5 +1,3 @@ -use std::marker::PhantomData; - use std::{collections::HashMap, sync::Arc}; static DOMAIN_STEP: &str = "."; @@ -7,26 +5,18 @@ static COMPLEX_WILDCARD: &str = "+"; static DOT_WILDCARD: &str = ""; static WILDCARD: &str = "*"; -#[derive(Clone)] -pub struct StringTrie { - root: Node, - __type_holder: PhantomData, -} - -#[derive(Clone)] -pub struct Node { +pub struct Node { children: HashMap>, - // TODO: maybe we only need RefCell here data: Option>, } -impl Default for Node { +impl Default for Node { fn default() -> Self { Self::new() } } -impl Node { +impl Node { pub fn new() -> Self { Node { children: HashMap::new(), @@ -53,20 +43,24 @@ impl Node { pub fn add_child(&mut self, s: &str, child: Node) { self.children.insert(s.to_string(), child); } + + pub fn get_children(&self) -> &HashMap> { + &self.children + } +} +pub struct StringTrie { + root: Node, } -impl Default for StringTrie { +impl Default for StringTrie { fn default() -> Self { Self::new() } } -impl StringTrie { +impl StringTrie { pub fn new() -> Self { - StringTrie { - root: Node::new(), - __type_holder: PhantomData, - } + StringTrie { root: Node::new() } } pub fn insert(&mut self, domain: &str, data: Arc) -> bool { @@ -109,6 +103,53 @@ impl StringTrie { None } + pub fn traverse(&self, mut f: F) + where + F: FnMut(&String, &T) -> bool, + { + for (key, child) in self.root.get_children() { + Self::traverse_inner(&[key], child, &mut f); + if let Some(data) = child.get_data() { + if !f(key, data) { + return; + } + } + } + } + + fn traverse_inner<'a, F>( + keys: &'a [&String], + node: &'a Node, + f: &mut F, + ) -> bool + where + F: FnMut(&String, &T) -> bool, + { + for (key, child) in node.get_children() { + let keys = [&[key], keys].concat(); + + let d = keys.iter().map(|x| x.as_str()).collect::>(); + if let Some(data) = child.get_data() { + let domain = d.join(DOMAIN_STEP); + let key = if domain.starts_with(DOMAIN_STEP) { + COMPLEX_WILDCARD.to_string() + domain.as_str() + } else { + domain + }; + + if !f(&key, data) { + return false; + } + } + + if !Self::traverse_inner(&keys, child, f) { + return false; + } + } + + true + } + fn insert_inner(&mut self, parts: &[&str], data: Arc) { let mut node = &mut self.root; diff --git a/clash_lib/src/lib.rs b/clash_lib/src/lib.rs index 59b05879..eed31a3d 100644 --- a/clash_lib/src/lib.rs +++ b/clash_lib/src/lib.rs @@ -1,5 +1,6 @@ #![feature(ip)] #![feature(sync_unsafe_cell)] +#![feature(unbounded_shifts)] #[macro_use] extern crate anyhow; @@ -356,9 +357,10 @@ async fn create_components( config.profile.store_selected, ); + let dns_listen = config.dns.listen.clone(); debug!("initializing dns resolver"); let dns_resolver = dns::new_resolver( - &config.dns, + config.dns, Some(cache_store.clone()), Some(country_mmdb.clone()), ) @@ -444,7 +446,7 @@ async fn create_components( debug!("initializing dns listener"); let dns_listener = - dns::get_dns_listener(config.dns.listen, dns_resolver.clone(), &cwd).await; + dns::get_dns_listener(dns_listen, dns_resolver.clone(), &cwd).await; Ok(RuntimeComponents { cache_store, diff --git a/clash_lib/src/proxy/utils/mod.rs b/clash_lib/src/proxy/utils/mod.rs index a4dea3f9..399e36ca 100644 --- a/clash_lib/src/proxy/utils/mod.rs +++ b/clash_lib/src/proxy/utils/mod.rs @@ -89,6 +89,18 @@ pub fn get_outbound_interface() -> Option { let priority = ["eth", "en", "pdp_ip"]; all_outbounds.sort_by(|left, right| { + match (left.addr_v6, right.addr_v6) { + (Some(_), None) => return std::cmp::Ordering::Less, + (None, Some(_)) => return std::cmp::Ordering::Greater, + (Some(left), Some(right)) => { + if left.is_unicast_global() && !right.is_unicast_global() { + return std::cmp::Ordering::Less; + } else if !left.is_unicast_global() && right.is_unicast_global() { + return std::cmp::Ordering::Greater; + } + } + _ => {} + } let left = priority .iter() .position(|x| left.name.contains(x)) diff --git a/clash_lib/src/proxy/utils/test_utils/config_helper.rs b/clash_lib/src/proxy/utils/test_utils/config_helper.rs index 458e97eb..8b910ffd 100644 --- a/clash_lib/src/proxy/utils/test_utils/config_helper.rs +++ b/clash_lib/src/proxy/utils/test_utils/config_helper.rs @@ -24,10 +24,7 @@ pub fn test_config_base_dir() -> PathBuf { // load the config from test dir // and return the dns resolver for the proxy -pub async fn load_config() -> anyhow::Result<( - crate::config::internal::config::Config, - Arc, -)> { +pub async fn build_dns_resolver() -> anyhow::Result> { let root = root_dir(); let test_base_dir = test_config_base_dir(); let config_path = test_base_dir.join("ss.yaml").to_str().unwrap().to_owned(); @@ -51,9 +48,9 @@ pub async fn load_config() -> anyhow::Result<( ); let dns_resolver = Arc::new( - dns::EnhancedResolver::new(&config.dns, cache_store.clone(), mmdb.clone()) + dns::EnhancedResolver::new(config.dns, cache_store.clone(), mmdb.clone()) .await, ); - Ok((config, dns_resolver)) + Ok(dns_resolver) } diff --git a/clash_lib/src/proxy/utils/test_utils/mod.rs b/clash_lib/src/proxy/utils/test_utils/mod.rs index bc3ab71f..c54336b3 100644 --- a/clash_lib/src/proxy/utils/test_utils/mod.rs +++ b/clash_lib/src/proxy/utils/test_utils/mod.rs @@ -46,7 +46,7 @@ pub async fn ping_pong_test( ..Default::default() }; - let (_, resolver) = config_helper::load_config().await?; + let resolver = config_helper::build_dns_resolver().await?; let listener = TcpListener::bind(format!("0.0.0.0:{}", port).as_str()).await?; @@ -167,7 +167,7 @@ pub async fn ping_pong_udp_test( ..Default::default() }; - let (_, resolver) = config_helper::load_config().await?; + let resolver = config_helper::build_dns_resolver().await?; let listener = UdpSocket::bind(format!("0.0.0.0:{}", port).as_str()).await?; info!("target local server started at: {}", listener.local_addr()?); @@ -242,7 +242,7 @@ pub async fn ping_pong_udp_test( pub async fn latency_test( handler: Arc, ) -> anyhow::Result<(u16, u16)> { - let (_, resolver) = config_helper::load_config().await?; + let resolver = config_helper::build_dns_resolver().await?; let proxy_manager = ProxyManager::new(resolver.clone()); proxy_manager .url_test(handler, "https://example.com", None) @@ -259,7 +259,7 @@ pub async fn dns_test(handler: Arc) -> anyhow::Result<()> { ..Default::default() }; - let (_, resolver) = config_helper::load_config().await?; + let resolver = config_helper::build_dns_resolver().await?; // we don't need the resolver, so it doesn't matter to create a casual one let stream = handler.connect_datagram(&sess, resolver).await?;