diff --git a/Cargo.toml b/Cargo.toml index fb66cbd9b..047f22127 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,8 @@ futures = { version = "0.3.22", optional = true } # Force futures to at l futures-util = { version = "0.3", optional = true } heapless = { version = "0.8", optional = true } hex = { version = "0.4", optional = true } -libc = { version = "0.2.79", default-features = false, optional = true } # 0.2.79 is the first version that has IP_PMTUDISC_OMIT +libc = { version = "0.2.153", default-features = false, optional = true } # 0.2.79 is the first version that has IP_PMTUDISC_OMIT +parking_lot = { version = "0.11.2", optional = true } moka = { version = "0.12.3", optional = true, features = ["future"] } proc-macro2 = { version = "1.0.69", optional = true } # Force proc-macro2 to at least 1.0.69 for minimal-version build ring = { version = "0.17", optional = true } @@ -45,7 +46,7 @@ mock_instant = { version = "0.3.2", optional = true, features = ["sync"] } [target.'cfg(macos)'.dependencies] # specifying this overrides minimum-version mio's 0.2.69 libc dependency, which allows the build to work -libc = { version = "0.2.71", default-features = false, optional = true } +libc = { version = "0.2.153", default-features = false, optional = true } [features] default = ["std", "rand"] @@ -65,6 +66,7 @@ zonefile = ["bytes", "serde", "std"] # Unstable features unstable-client-transport = [ "moka", "tracing" ] unstable-server-transport = ["arc-swap", "chrono/clock", "hex", "libc", "tracing"] +unstable-zonetree = ["futures", "parking_lot", "serde", "tokio", "tracing"] # Test features # Commented out as using --all-features to build would cause mock time to also @@ -92,14 +94,13 @@ tokio-tfo = { version = "0.2.0" } lazy_static = { version = "1.4.0" } # Force lazy_static to > 1.0.0 for https://github.com/rust-lang-nursery/lazy-static.rs/pull/107 tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +# For the "mysql-zone" example +#sqlx = { version = "0.6", features = [ "runtime-tokio-native-tls", "mysql" ] } + [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] -[[example]] -name = "readzone" -required-features = ["zonefile"] - [[example]] name = "download-rust-lang" required-features = ["resolv"] @@ -123,3 +124,26 @@ required-features = ["net", "unstable-client-transport"] [[example]] name = "server-transports" required-features = ["net", "unstable-server-transport"] + +[[example]] +name = "read-zone" +required-features = ["zonefile"] + +[[example]] +name = "query-zone" +required-features = ["zonefile", "unstable-zonetree"] + +[[example]] +name = "serve-zone" +required-features = ["zonefile", "net", "unstable-server-transport", "unstable-zonetree"] + +# This example is commented out because it is difficult, if not impossible, +# when including the sqlx dependency, to make the dependency tree compatible +# with both `cargo +nightly update -Z minimal versions` and the crate minimum +# supported Rust version (1.67 at the time of writing), both of which are +# tested by our CI setup. To try this example, uncomment the lines below and +# the sqlx dependency above, then run `cargo run --example mysql-zone`. +#[[example]] +#name = "mysql-zone" +#path = "examples/other/mysql-zone.rs" +#required-features = ["zonefile", "net", "unstable-server-transport", "unstable-zonetree"] diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 6f3ab9522..682f033c3 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -76,7 +76,7 @@ async fn main() { // Get the reply println!("Wating for UDP+TCP reply"); let reply = request.get_response().await; - println!("UDP+TCP reply: {:?}", reply); + println!("UDP+TCP reply: {reply:?}"); // The query may have a reference to the connection. Drop the query // when it is no longer needed. @@ -94,7 +94,7 @@ async fn main() { // Get the reply println!("Wating for cache reply"); let reply = request.get_response().await; - println!("Cache reply: {:?}", reply); + println!("Cache reply: {reply:?}"); // Send the request message again. let mut request = cache.send_request(req.clone()); @@ -102,7 +102,7 @@ async fn main() { // Get the reply println!("Wating for cached reply"); let reply = request.get_response().await; - println!("Cached reply: {:?}", reply); + println!("Cached reply: {reply:?}"); // Create a new TCP connections object. Pass the destination address and // port as parameter. @@ -130,7 +130,7 @@ async fn main() { println!("Wating for multi TCP reply"); let reply = timeout(Duration::from_millis(500), request.get_response()).await; - println!("multi TCP reply: {:?}", reply); + println!("multi TCP reply: {reply:?}"); drop(request); @@ -181,7 +181,7 @@ async fn main() { println!("Wating for TLS reply"); let reply = timeout(Duration::from_millis(500), request.get_response()).await; - println!("TLS reply: {:?}", reply); + println!("TLS reply: {reply:?}"); drop(request); @@ -205,7 +205,7 @@ async fn main() { let mut request = redun.send_request(req.clone()); let reply = request.get_response().await; if i == 2 { - println!("redundant connection reply: {:?}", reply); + println!("redundant connection reply: {reply:?}"); } } @@ -224,7 +224,7 @@ async fn main() { // // Get the reply let reply = request.get_response().await; - println!("Dgram reply: {:?}", reply); + println!("Dgram reply: {reply:?}"); // Create a single TCP transport connection. This is usefull for a // single request or a small burst of requests. @@ -232,8 +232,7 @@ async fn main() { Ok(conn) => conn, Err(err) => { println!( - "TCP Connection to {} failed: {}, exiting", - server_addr, err + "TCP Connection to {server_addr} failed: {err}, exiting", ); return; } @@ -250,7 +249,7 @@ async fn main() { // Get the reply let reply = request.get_response().await; - println!("TCP reply: {:?}", reply); + println!("TCP reply: {reply:?}"); drop(tcp); } diff --git a/examples/client.rs b/examples/client.rs index 508c75103..1b3cb90cf 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -92,9 +92,9 @@ fn main() { for option in response.opt().unwrap().opt().iter::>() { let opt = option.unwrap(); match opt { - AllOptData::Nsid(nsid) => println!("{}", nsid), + AllOptData::Nsid(nsid) => println!("{nsid}"), AllOptData::ExtendedError(extendederror) => { - println!("{}", extendederror) + println!("{extendederror}") } _ => println!("NO OPT!"), } diff --git a/examples/common/serve-utils.rs b/examples/common/serve-utils.rs new file mode 100644 index 000000000..c1becae82 --- /dev/null +++ b/examples/common/serve-utils.rs @@ -0,0 +1,119 @@ +use bytes::Bytes; +use domain::base::{Dname, Message, MessageBuilder, ParsedDname, Rtype}; +use domain::rdata::ZoneRecordData; +use domain::zonetree::Answer; + +pub fn generate_wire_query( + qname: &Dname, + qtype: Rtype, +) -> Message> { + let query = MessageBuilder::new_vec(); + let mut query = query.question(); + query.push((qname, qtype)).unwrap(); + query.into() +} + +pub fn generate_wire_response( + wire_query: &Message>, + zone_answer: Answer, +) -> Message> { + let builder = MessageBuilder::new_vec(); + let response = zone_answer.to_message(wire_query, builder); + response.into() +} + +pub fn print_dig_style_response( + query: &Message>, + response: &Message>, + short: bool, +) { + if !short { + let qh = query.header(); + let rh = response.header(); + println!("; (1 server found)"); + println!(";; global options:"); + println!(";; Got answer:"); + println!( + ";; ->>HEADER<<- opcode: {}, status: {}, id: {}", + qh.opcode(), + rh.rcode(), + rh.id() + ); + print!(";; flags: "); + if rh.aa() { + print!("aa "); + } + if rh.ad() { + print!("ad "); + } + if rh.cd() { + print!("cd "); + } + if rh.qr() { + print!("qr "); + } + if rh.ra() { + print!("ra "); + } + if rh.rd() { + print!("rd "); + } + if rh.tc() { + print!("tc "); + } + let counts = response.header_counts(); + println!( + "; QUERY: {}, ANSWER: {}, AUTHORITY: {}, ADDITIONAL: {}", + counts.qdcount(), + counts.ancount(), + counts.arcount(), + counts.adcount() + ); + + // TODO: add OPT PSEUDOSECTION + + if let Ok(question) = query.sole_question() { + println!(";; QUESTION SECTION:"); + println!( + ";{} {} {}", + question.qname(), + question.qclass(), + question.qtype() + ); + println!(); + } + } + + let sections = [ + ("ANSWER", response.answer()), + ("AUTHORITY", response.authority()), + ("ADDITIONAL", response.additional()), + ]; + for (name, section) in sections { + if let Ok(section) = section { + if section.count() > 0 { + if !short { + println!(";; {name} SECTION:"); + } + + for record in section { + let record = record + .unwrap() + .into_record::>>() + .unwrap() + .unwrap(); + + if short { + println!("{}", record.data()); + } else { + println!("{record}"); + } + } + + if !short { + println!(); + } + } + } + } +} diff --git a/examples/download-rust-lang.rs b/examples/download-rust-lang.rs index 3b5a75dff..06d126b25 100644 --- a/examples/download-rust-lang.rs +++ b/examples/download-rust-lang.rs @@ -16,7 +16,7 @@ async fn main() { { Ok(addr) => addr, Err(err) => { - eprintln!("DNS query failed: {}", err); + eprintln!("DNS query failed: {err}"); return; } }; @@ -30,7 +30,7 @@ async fn main() { let mut socket = match TcpStream::connect(&addr).await { Ok(socket) => socket, Err(err) => { - eprintln!("Failed to connect to {}: {}", addr, err); + eprintln!("Failed to connect to {addr}: {err}"); return; } }; @@ -45,12 +45,12 @@ async fn main() { ) .await { - eprintln!("Failed to send request: {}", err); + eprintln!("Failed to send request: {err}"); return; }; let mut response = Vec::new(); if let Err(err) = socket.read_to_end(&mut response).await { - eprintln!("Failed to read response: {}", err); + eprintln!("Failed to read response: {err}"); return; } diff --git a/examples/lookup.rs b/examples/lookup.rs index 81b5d95af..898d6fb23 100644 --- a/examples/lookup.rs +++ b/examples/lookup.rs @@ -20,14 +20,14 @@ async fn forward(resolver: &StubResolver, name: UncertainDname>) { } let canon = answer.canonical_name(); if canon != answer.qname() { - println!("{} is an alias for {}", answer.qname(), canon); + println!("{} is an alias for {canon}", answer.qname()); } for addr in answer.iter() { - println!("{} has address {}", canon, addr); + println!("{canon} has address {addr}"); } } Err(err) => { - println!("Query failed: {}", err); + println!("Query failed: {err}"); } } } @@ -36,10 +36,10 @@ async fn reverse(resolver: &StubResolver, addr: IpAddr) { match resolver.lookup_addr(addr).await { Ok(answer) => { for name in answer.iter() { - println!("Host {} has domain name pointer {}", addr, name); + println!("Host {addr} has domain name pointer {name}"); } } - Err(err) => println!("Query failed: {}", err), + Err(err) => println!("Query failed: {err}"), } } @@ -58,7 +58,7 @@ async fn main() { } else if let Ok(name) = UncertainDname::from_str(&name) { forward(&resolver, name).await; } else { - println!("Not a domain name: {}", name); + println!("Not a domain name: {name}"); } } } diff --git a/examples/other/mysql-zone.rs b/examples/other/mysql-zone.rs new file mode 100644 index 000000000..c85e7c96f --- /dev/null +++ b/examples/other/mysql-zone.rs @@ -0,0 +1,332 @@ +//! MySQL backed zone serving minimal proof of concept. +// +// This example extends `domain` with a new `ZoneStore` impl adding support for +// MySQL backed zones. This demonstration only implements the `ReadableZone` +// trait, it doesn't implement the `WritableZone` trait, so database access is +// read-only. Write access could be implemented, it just isn't in this +// example. The same approach can be used to implement access for any kind of +// backed, e.g. invoking shell commands to get the answers even ;-) +// +// Warning: This example needs a lot of setup and has several prerequisites. +// +// =========================================================================== +// A big shout out to PowerDNS as this example uses their MySQL database +// schema and their zone2sql tool. And also to the sqlx project for making +// database access via Rust so easy. +// +// For more information about the PowerDNS MySQL support see: +// https://doc.powerdns.com/authoritative/backends/generic-mysql.html +// +// For more information about SQLX see: https://github.com/launchbadge/sqlx +// =========================================================================== +// +// # Prerequisites +// +// You need: +// - A Linux machine (the instructions below have only been tested on Fedora +// 39). +// - A MySQL server (tested with "Ver 8.0.35 for Linux on x86_64). +// - A MySQL user with the right to create a database. Note: You may also +// need sufficient rights to disable restrictions concerning maximum +// column length. +// - The PowerDNS zone2sql command line tool for converting a zone file to +// SQL insert operations compatible with the PowerDNS MySQL schema. +// - The sqlx-cli command line tool, for automating database and table +// creation and data import. +// +// # Database access +// +// Connecting to the database users settings provided in an environment +// variable called DATABASE_URL. When using a connection URL the password has +// to be URL encoded. The environment variable value must have the following +// format. +// +// DATABASE_URL='mysql://:>@[:]/' +// +// Note: The PowerDNS MySQL schema uses large column sizes. If you see an +// error like "Column length too big for column" when running the initial sqlx +// database migration step to create the schema, disable the default MySQL +// restrictions with a command similar to the following in the MySQL shell: +// +// $ mysql -u root -p mysql> SET GLOBAL sql_mode = ''; +// +// A quick tip for viewing the MySQL queries issued by this example: +// +// $ mysql -u root -p mysql> SET GLOBAL log_output = 'table'; mysql> SET +// GLOBAL general_log = 'on'; mysql> SELECT CONVERT(argument USING utf8) +// FROM mysql.general_log; +// +// # Preparation +// +// Note: dnf is the Fedora package manager, and pdns is the name of the +// Fedora PowerDNS package. Adjust the commands and values below to match +// your O/S. +// +// - cargo install sqlx-cli +// +// - sudo dnf install -y pdns +// +// - export DATABASE_URL='.....' Make sure the user specified in +// DATABASE_URL has the right to create databases. +// +// - cargo sqlx database create +// +// - cargo sqlx migrate add make_tables Note: This will output a +// migrations/..._make_tables.sql path. We will refer to this below as +// MAKE_TABLES_PATH +// +// - wget -O${MAKE_TABLES_PATH} +// https://raw.githubusercontent.com/PowerDNS/pdns/master/modules/gmysqlbackend/schema.mysql.sql +// +// - cargo sqlx migrate run +// +// - cargo sqlx migrate add import_data Note: This will output a +// migrations/..._import_data.sql path. We will refer to this below as +// IMPORT_DATA_PATH +// +// - zone2sql --gmysql --zone=test-data/zonefiles/nsd-example.txt > +// ${IMPORT_DATA_PATH} +// +// - cargo sqlx migrate run +// +// - cargo sqlx prepare -- --example mysql-zone --features +// zonefile,net,unstable-server-transport +// +// # Running the example +// +// Now you can run the example with the following command and should see +// output similar to that shown below: +// +// $ cargo run --example mysql-zone --features +// zonefile,net,unstable-server-transport ... ; (1 server found) ;; global +// options: ;; Got answer: ;; ->>HEADER<<- opcode: QUERY, status: NOERROR, +// id: 0 ;; flags: qr ; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0 ;; +// QUESTION SECTION: ;example.com IN A +// +// ;; ANSWER SECTION: example.com. 86400 IN A 192.0.2.1 +// +// # A note about SQLX and Rust versions +// +// The database query strings used and handling of query results in this code +// would be simpler if SQLX >= v7 were used (via its `query!` macro). We don't +// use SQLX >= v7 because it requires Rust 1.75.0 which exceeds our project +// Rust MSRV of 1.67.0. We can't downgrade to older SQLX because SQLX < 7.0 +// don't correctly respect saved queries in .sqlx/ (generated by `cargo sqlx +// prepare``) and so compilation fails, but sqlx >= 7.0 has a Rust MSRV of +// 1.75.0 which exceeds our Rust MSRV of 1.67.0. + +use std::{future::Future, pin::Pin, str::FromStr, sync::Arc}; + +use bytes::Bytes; +use domain::base::iana::{Class, Rcode}; +use domain::base::scan::IterScanner; +use domain::base::{Dname, Rtype, Ttl}; +use domain::rdata::ZoneRecordData; +use domain::zonetree::{ + Answer, ReadableZone, Rrset, SharedRrset, StoredDname, WalkOp, + WritableZone, Zone, ZoneStore, ZoneTree, +}; +use sqlx::Row; +use sqlx::{mysql::MySqlConnectOptions, MySqlPool}; +use domain::zonefile::error::OutOfZone; + +#[path = "../common/serve-utils.rs"] +mod common; + +#[tokio::main] +async fn main() { + // Create a zone whose queries will be satisfied by querying the database + // defined by the DATABASE_URL environment variable. + let mut zones = ZoneTree::new(); + let db_zone = DatabaseZoneBuilder::mk_test_zone("example.com").await; + zones.insert_zone(db_zone).unwrap(); + + // Setup a mock query. + let qname = Dname::bytes_from_str("example.com").unwrap(); + let qclass = Class::IN; + let qtype = Rtype::A; + + // Execute the query. The steps we take are: + // 1. Find the zone in the zone set that matches the query name. + // 2. Get a read interface to it via `.read()`. + // 3. Query the zone, synchronously or asynchronously, based on what + // the zone says it supports. For stock `domain` zones the + // `.is_async()` call will return false, but for our MySQL backed + // zone it returns true, as the DB calls are asynchronous. + let zone = zones.find_zone(&qname, qclass).unwrap().read(); + let zone_answer = match zone.is_async() { + true => zone.query_async(qname.clone(), qtype).await.unwrap(), + false => zone.query(qname.clone(), qtype).unwrap(), + }; + + // Render the response in dig style output. + let wire_query = common::generate_wire_query(&qname, qtype); + let wire_response = + common::generate_wire_response(&wire_query, zone_answer); + common::print_dig_style_response(&wire_query, &wire_response, false); +} + +//----------- DatbaseZoneBuilder --------------------------------------------- + +pub struct DatabaseZoneBuilder; + +impl DatabaseZoneBuilder { + pub async fn mk_test_zone(apex_name: &str) -> Zone { + let opts: MySqlConnectOptions = + std::env::var("DATABASE_URL").unwrap().parse().unwrap(); + let pool = MySqlPool::connect_with(opts).await.unwrap(); + let apex_name = StoredDname::from_str(apex_name).unwrap(); + let node = DatabaseNode::new(pool, apex_name); + Zone::new(node) + } +} + +//----------- DatbaseNode ---------------------------------------------------- + +#[derive(Debug)] +struct DatabaseNode { + db_pool: sqlx::MySqlPool, + apex_name: StoredDname, +} + +impl DatabaseNode { + fn new(db_pool: sqlx::MySqlPool, apex_name: StoredDname) -> Self { + Self { db_pool, apex_name } + } +} + +//--- impl ZoneStore + +impl ZoneStore for DatabaseNode { + fn class(&self) -> Class { + Class::IN + } + + fn apex_name(&self) -> &StoredDname { + &self.apex_name + } + + fn read(self: Arc) -> Box { + Box::new(DatabaseReadZone::new( + self.db_pool.clone(), + self.apex_name.clone(), + )) + } + + fn write( + self: Arc, + ) -> Pin>>> { + todo!() + } +} + +//----------- DatbaseReadZone ------------------------------------------------ + +struct DatabaseReadZone { + db_pool: sqlx::MySqlPool, + apex_name: StoredDname, +} + +impl DatabaseReadZone { + fn new(db_pool: sqlx::MySqlPool, apex_name: StoredDname) -> Self { + Self { db_pool, apex_name } + } +} + +//--- impl ReadableZone + +impl ReadableZone for DatabaseReadZone { + fn is_async(&self) -> bool { + true + } + + fn query_async( + &self, + qname: Dname, + qtype: Rtype, + ) -> Pin> + Send>> { + let db_pool = self.db_pool.clone(); + let apex_name = self.apex_name.to_string(); + let fut = async move { + let answer = if let Ok(row) = sqlx::query( + r#"SELECT R.content, R.ttl FROM domains D, records R WHERE D.name = ? AND D.id = R.domain_id AND R.name = ? AND R.type = ? LIMIT 1"#) + .bind(apex_name) + .bind(qname.to_string()) + .bind(qtype.to_string()) + .fetch_one(&db_pool) + .await + { + let mut answer = Answer::new(Rcode::NOERROR); + let ttl = row.try_get("ttl").unwrap(); + let mut rrset = Rrset::new(qtype, Ttl::from_secs(ttl)); + let content: String = row.try_get("content").unwrap(); + let content_strings = content.split_ascii_whitespace().collect::>(); + let mut scanner = IterScanner::new(&content_strings); + match ZoneRecordData::scan(qtype, &mut scanner) { + Ok(data) => { + rrset.push_data(data); + let rrset = SharedRrset::new(rrset); + answer.add_answer(rrset); + answer + } + Err(err) => { + eprintln!("Unable to parse DB record of type {qtype}: {err}"); + Answer::new(Rcode::SERVFAIL) + } + } + } else { + Answer::new(Rcode::NXDOMAIN) + }; + Ok(answer) + }; + Box::pin(fut) + } + + fn walk_async( + &self, + op: WalkOp, + ) -> Pin + Send>> { + let db_pool = self.db_pool.clone(); + let apex_name = self.apex_name.to_string(); + let fut = async move { + for row in sqlx::query( + r#"SELECT R.name, R.type AS rtype, R.content, R.ttl FROM domains D, records R WHERE D.name = ? AND D.id = R.domain_id"#) + .bind(apex_name) + .fetch_all(&db_pool) + .await + .unwrap() { + let owner: String = row.try_get("name").unwrap(); + let owner = Dname::bytes_from_str(&owner).unwrap(); + let rtype: String = row.try_get("rtype").unwrap(); + let rtype = Rtype::from_str(&rtype).unwrap(); + let ttl = row.try_get("ttl").unwrap(); + let mut rrset = Rrset::new(rtype, Ttl::from_secs(ttl)); + let content: String = row.try_get("content").unwrap(); + let content_strings = content.split_ascii_whitespace().collect::>(); + let mut scanner = IterScanner::new(&content_strings); + match ZoneRecordData::scan(rtype, &mut scanner) { + Ok(data) => { + rrset.push_data(data); + op(owner, &rrset); + } + Err(err) => { + eprintln!("Unable to parse DB record of type {rtype}: {err}"); + } + } + }; + }; + Box::pin(fut) + } + + fn query( + &self, + _qname: Dname, + _qtype: Rtype, + ) -> Result { + unimplemented!() + } + + fn walk(&self, _walkop: WalkOp) { + unimplemented!() + } +} diff --git a/examples/query-zone.rs b/examples/query-zone.rs new file mode 100644 index 000000000..2313f5d92 --- /dev/null +++ b/examples/query-zone.rs @@ -0,0 +1,210 @@ +//! Reads a zone file into memory and queries it. +//! Command line argument and response style emulate that of dig. + +use std::env; +use std::fs::File; +use std::{process::exit, str::FromStr}; + +use bytes::Bytes; +use domain::base::iana::{Class, Rcode}; +use domain::base::record::ComposeRecord; +use domain::base::{Dname, ParsedDname, Rtype}; +use domain::base::{ParsedRecord, Record}; +use domain::rdata::ZoneRecordData; +use domain::zonefile::inplace; +use domain::zonetree::{Answer, Rrset}; +use domain::zonetree::{Zone, ZoneTree}; +use octseq::Parser; +use tracing_subscriber::EnvFilter; + +#[path = "common/serve-utils.rs"] +mod common; + +#[derive(PartialEq, Eq)] +enum Verbosity { + Quiet, + Normal, + Verbose(u8), +} + +fn main() { + // Initialize tracing based logging. Override with env var RUST_LOG, e.g. + // RUST_LOG=trace. + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .with_thread_ids(true) + .without_time() + .try_init() + .ok(); + + let mut args = env::args(); + let prog_name = args.next().unwrap(); // SAFETY: O/S always passes our name as the first argument. + let usage = format!( + "Usage: {prog_name} [-q|--quiet|-v|--verbose] [+short] [ ..] ", + ); + + // Process command line arguments. + let (verbosity, zone_files, qtype, qname, short) = + process_dig_style_args(args).unwrap_or_else(|err| { + eprintln!("{usage}"); + eprintln!("{err}"); + exit(2); + }); + + // Go! + let mut zones = ZoneTree::new(); + + for (zone_file_path, mut zone_file) in zone_files { + if verbosity != Verbosity::Quiet { + println!("Reading zone file '{zone_file_path}'..."); + } + let reader = inplace::Zonefile::load(&mut zone_file).unwrap(); + + if verbosity != Verbosity::Quiet { + println!("Constructing zone..."); + } + let zone = Zone::try_from(reader).unwrap_or_else(|err| { + eprintln!("Error while constructing zone: {err}"); + exit(1); + }); + + if verbosity != Verbosity::Quiet { + println!( + "Inserting zone for {} class {}...", + zone.apex_name(), + zone.class() + ); + } + zones.insert_zone(zone).unwrap_or_else(|err| { + eprintln!("Error while inserting zone: {err}"); + exit(1); + }); + } + + if let Verbosity::Verbose(level) = verbosity { + for zone in zones.iter_zones() { + println!( + "Dumping zone {} class {}...", + zone.apex_name(), + zone.class() + ); + zone.read().walk(Box::new(move |owner, rrset| { + dump_rrset(owner, rrset); + })); + println!("Dump complete."); + + if level > 0 { + println!("Debug dumping zone..."); + dbg!(zone); + } + } + } + + // Find the zone to query + let qclass = Class::IN; + if verbosity != Verbosity::Quiet { + println!("Finding zone for qname {qname} class {qclass}..."); + } + let zone_answer = if let Some(zone) = zones.find_zone(&qname, qclass) { + // Query the built zone for the requested records. + if verbosity != Verbosity::Quiet { + println!("Querying zone {} class {} for qname {qname} with qtype {qtype}...", zone.apex_name(), zone.class()); + } + zone.read().query(qname.clone(), qtype).unwrap() + } else { + Answer::new(Rcode::NXDOMAIN) + }; + + // Emulate a DIG style response by generating a complete DNS wire response + // from the zone answer, which requires that we fake a DNS wire query to + // respond to. + if verbosity != Verbosity::Quiet { + println!("Preparing dig style response...\n"); + } + let wire_query = common::generate_wire_query(&qname, qtype); + let wire_response = + common::generate_wire_response(&wire_query, zone_answer); + common::print_dig_style_response(&wire_query, &wire_response, short); +} + +#[allow(clippy::type_complexity)] +fn process_dig_style_args( + args: env::Args, +) -> Result<(Verbosity, Vec<(String, File)>, Rtype, Dname, bool), String> +{ + let mut abort_with_usage = false; + let mut verbosity = Verbosity::Normal; + let mut short = false; + let mut zone_files = vec![]; + + let args: Vec<_> = args + .filter(|arg| { + if arg.starts_with(['-', '+']) { + match arg.as_str() { + "-q" | "--quiet" => verbosity = Verbosity::Quiet, + "-v" | "--verbose" => { + if let Verbosity::Verbose(level) = verbosity { + verbosity = Verbosity::Verbose(level + 1) + } else { + verbosity = Verbosity::Verbose(0) + } + } + "+short" => { + short = true; + if verbosity == Verbosity::Normal { + verbosity = Verbosity::Quiet + } + } + _ => abort_with_usage = true, + } + false // discard the argument + } else { + true // keep the argument + } + }) + .collect(); + + if args.len() >= 3 { + let mut i = 0; + while i < args.len() - 2 { + let zone_file = File::open(&args[i]).map_err(|err| { + format!("Cannot open zone file '{}': {err}", args[i]) + })?; + zone_files.push((args[i].to_string(), zone_file)); + i += 1; + } + + let qtype = Rtype::from_str(&args[i]) + .map_err(|err| format!("Cannot parse qtype: {err}"))?; + i += 1; + + let qname = Dname::::from_str(&args[i]) + .map_err(|err| format!("Cannot parse qname: {err}"))?; + + Ok((verbosity, zone_files, qtype, qname, short)) + } else { + Err("Insufficient arguments".to_string()) + } +} + +fn dump_rrset(owner: Dname, rrset: &Rrset) { + // + // The following code renders an owner + rrset (IN class, TTL, RDATA) + // into zone presentation format. This can be used for diagnostic + // dumping. + // + let mut target = Vec::::new(); + for item in rrset.data() { + let record = Record::new(owner.clone(), Class::IN, rrset.ttl(), item); + if record.compose_record(&mut target).is_ok() { + let mut parser = Parser::from_ref(&target); + if let Ok(parsed_record) = ParsedRecord::parse(&mut parser) { + if let Ok(Some(record)) = parsed_record + .into_record::>>() + { + println!("> {record}"); + } + } + } + } +} diff --git a/examples/read-zone.rs b/examples/read-zone.rs new file mode 100644 index 000000000..110aaa804 --- /dev/null +++ b/examples/read-zone.rs @@ -0,0 +1,73 @@ +//! Reads a zone file. + +use std::env; +use std::fs::File; +use std::process::exit; +use std::time::SystemTime; + +use domain::zonefile::inplace::Entry; +use domain::zonefile::inplace::Zonefile; + +fn main() { + let mut args = env::args(); + let prog_name = args.next().unwrap(); // SAFETY: O/S always passes our name as the first argument. + let zone_files: Vec<_> = args.collect(); + + if zone_files.is_empty() { + eprintln!("Usage: {prog_name} [, , , , , ...]"); + exit(2); + } + + for zone_file in zone_files { + print!("Processing {zone_file}: "); + let start = SystemTime::now(); + let mut reader = + Zonefile::load(&mut File::open(&zone_file).unwrap()).unwrap(); + println!( + "Data loaded ({:.03}s).", + start.elapsed().unwrap().as_secs_f32() + ); + + let mut i = 0; + let mut last_entry = None; + loop { + match reader.next_entry() { + Ok(entry) if entry.is_some() => { + last_entry = entry; + } + Ok(_) => break, // EOF + Err(err) => { + eprintln!( + "\nAn error occurred while reading {zone_file}:" + ); + eprintln!(" Error: {err}"); + if let Some(entry) = &last_entry { + if let Entry::Record(record) = &entry { + eprintln!( + "\nThe last record read was:\n{record}." + ); + } else { + eprintln!("\nThe last record read was:\n{last_entry:#?}."); + } + eprintln!("\nTry commenting out the line after that record with a leading ; (semi-colon) character.") + } + exit(1); + } + } + i += 1; + if i % 100_000_000 == 0 { + println!( + "Processed {}M records ({:.03}s)", + i / 1_000_000, + start.elapsed().unwrap().as_secs_f32() + ); + } + } + + println!( + "Complete with {} records ({:.03}s)\n", + i, + start.elapsed().unwrap().as_secs_f32() + ); + } +} diff --git a/examples/readzone.rs b/examples/readzone.rs deleted file mode 100644 index e07cafc89..000000000 --- a/examples/readzone.rs +++ /dev/null @@ -1,34 +0,0 @@ -//! Reads a zone file. - -fn main() { - use domain::zonefile::inplace::Zonefile; - use std::env; - use std::fs::File; - use std::time::SystemTime; - - for arg in env::args().skip(1) { - print!("Processing {}: ", arg); - let start = SystemTime::now(); - let mut zone = Zonefile::load(&mut File::open(arg).unwrap()).unwrap(); - println!( - "Data loaded ({:.03}s).", - start.elapsed().unwrap().as_secs_f32() - ); - let mut i = 0; - while zone.next_entry().unwrap().is_some() { - i += 1; - if i % 100_000_000 == 0 { - eprintln!( - "Processed {}M records ({:.03}s)", - i / 1_000_000, - start.elapsed().unwrap().as_secs_f32() - ); - } - } - eprintln!( - "Complete with {} records ({:.03}s)\n", - i, - start.elapsed().unwrap().as_secs_f32() - ); - } -} diff --git a/examples/resolv-sync.rs b/examples/resolv-sync.rs index 6d2fdffa9..eed325d3b 100644 --- a/examples/resolv-sync.rs +++ b/examples/resolv-sync.rs @@ -26,6 +26,6 @@ fn main() { let res = res.answer().unwrap().limit_to::>(); for record in res { let record = record.unwrap(); - println!("{}", record); + println!("{record}"); } } diff --git a/examples/serve-zone.rs b/examples/serve-zone.rs new file mode 100644 index 000000000..dbe082d02 --- /dev/null +++ b/examples/serve-zone.rs @@ -0,0 +1,318 @@ +//! Loads a zone file and serves it over localhost UDP and TCP. +//! +//! Try queries such as: +//! +//! dig @127.0.0.1 -p 8053 NS example.com +//! dig @127.0.0.1 -p 8053 A example.com +//! dig @127.0.0.1 -p 8053 AAAA example.com +//! dig @127.0.0.1 -p 8053 CNAME example.com +//! +//! Also try with TCP, e.g.: +//! +//! dig @127.0.0.1 -p 8053 +tcp A example.com +//! +//! Also try AXFR, e.g.: +//! +//! dig @127.0.0.1 -p 8053 AXFR example.com + +use domain::base::iana::{Opcode, Rcode}; +use domain::base::message_builder::AdditionalBuilder; +use domain::base::{Dname, Message, Rtype, ToDname}; +use domain::net::server::buf::VecBufSource; +use domain::net::server::dgram::DgramServer; +use domain::net::server::message::Request; +use domain::net::server::service::{ + CallResult, ServiceError, Transaction, TransactionStream, +}; +use domain::net::server::stream::StreamServer; +use domain::net::server::util::{mk_builder_for_target, service_fn}; +use domain::zonefile::inplace; +use domain::zonetree::{Answer, Rrset}; +use domain::zonetree::{Zone, ZoneTree}; +use octseq::OctetsBuilder; +use std::future::{pending, ready, Future}; +use std::io::BufReader; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use tokio::net::{TcpListener, UdpSocket}; +use tracing_subscriber::EnvFilter; + +#[tokio::main()] +async fn main() { + // Initialize tracing based logging. Override with env var RUST_LOG, e.g. + // RUST_LOG=trace. + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .with_thread_ids(true) + .without_time() + .try_init() + .ok(); + + // Populate a zone tree with test data + let mut zones = ZoneTree::new(); + let zone_bytes = include_bytes!("../test-data/zonefiles/nsd-example.txt"); + let mut zone_bytes = BufReader::new(&zone_bytes[..]); + + // We're reading from static data so this cannot fail due to I/O error. + // Don't handle errors that shouldn't happen, keep the example focused + // on what we want to demonstrate. + let reader = inplace::Zonefile::load(&mut zone_bytes).unwrap(); + let zone = Zone::try_from(reader).unwrap(); + zones.insert_zone(zone).unwrap(); + let zones = Arc::new(zones); + + let addr = "127.0.0.1:8053"; + let svc = Arc::new(service_fn(my_service, zones)); + + let sock = UdpSocket::bind(addr).await.unwrap(); + let sock = Arc::new(sock); + let mut udp_metrics = vec![]; + let num_cores = std::thread::available_parallelism().unwrap().get(); + for _i in 0..num_cores { + let udp_srv = + DgramServer::new(sock.clone(), VecBufSource, svc.clone()); + let metrics = udp_srv.metrics(); + udp_metrics.push(metrics); + tokio::spawn(async move { udp_srv.run().await }); + } + + let sock = TcpListener::bind(addr).await.unwrap(); + let tcp_srv = StreamServer::new(sock, VecBufSource, svc); + let tcp_metrics = tcp_srv.metrics(); + + tokio::spawn(async move { tcp_srv.run().await }); + + tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_millis(5000)).await; + for (i, metrics) in udp_metrics.iter().enumerate() { + eprintln!( + "Server status: UDP[{i}]: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}", + metrics.num_connections(), + metrics.num_inflight_requests(), + metrics.num_pending_writes(), + metrics.num_received_requests(), + metrics.num_sent_responses(), + ); + } + eprintln!( + "Server status: TCP: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}", + tcp_metrics.num_connections(), + tcp_metrics.num_inflight_requests(), + tcp_metrics.num_pending_writes(), + tcp_metrics.num_received_requests(), + tcp_metrics.num_sent_responses(), + ); + } + }); + + pending::<()>().await; +} + +#[allow(clippy::type_complexity)] +fn my_service( + request: Request>, + zones: Arc, +) -> Result< + Transaction< + Vec, + impl Future>, ServiceError>> + Send, + >, + ServiceError, +> { + let qtype = request.message().sole_question().unwrap().qtype(); + match qtype { + Rtype::AXFR if request.transport_ctx().is_non_udp() => { + let fut = handle_axfr_request(request, zones); + Ok(Transaction::stream(Box::pin(fut))) + } + _ => { + let fut = handle_non_axfr_request(request, zones); + Ok(Transaction::single(fut)) + } + } +} + +async fn handle_non_axfr_request( + request: Request>, + zones: Arc, +) -> Result>, ServiceError> { + let question = request.message().sole_question().unwrap(); + let zone = zones + .find_zone(question.qname(), question.qclass()) + .map(|zone| zone.read()); + let answer = match zone { + Some(zone) => { + let qname = question.qname().to_bytes(); + let qtype = question.qtype(); + zone.query(qname, qtype).unwrap() + } + None => Answer::new(Rcode::NXDOMAIN), + }; + + let builder = mk_builder_for_target(); + let additional = answer.to_message(request.message(), builder); + Ok(CallResult::new(additional)) +} + +async fn handle_axfr_request( + request: Request>, + zones: Arc, +) -> TransactionStream>, ServiceError>> { + let mut stream = TransactionStream::default(); + + // Look up the zone for the queried name. + let question = request.message().sole_question().unwrap(); + let zone = zones + .find_zone(question.qname(), question.qclass()) + .map(|zone| zone.read()); + + // If not found, return an NXDOMAIN error response. + let Some(zone) = zone else { + let answer = Answer::new(Rcode::NXDOMAIN); + add_to_stream(answer, request.message(), &mut stream); + return stream; + }; + + // https://datatracker.ietf.org/doc/html/rfc5936#section-2.2 + // 2.2: AXFR Response + // + // "An AXFR response that is transferring the zone's contents + // will consist of a series (which could be a series of + // length 1) of DNS messages. In such a series, the first + // message MUST begin with the SOA resource record of the + // zone, and the last message MUST conclude with the same SOA + // resource record. Intermediate messages MUST NOT contain + // the SOA resource record. The AXFR server MUST copy the + // Question section from the corresponding AXFR query message + // into the first response message's Question section. For + // subsequent messages, it MAY do the same or leave the + // Question section empty." + + // Get the SOA record as AXFR transfers must start and end with the SOA + // record. If not found, return a SERVFAIL error response. + let qname = question.qname().to_bytes(); + let Ok(soa_answer) = zone.query(qname, Rtype::SOA) else { + let answer = Answer::new(Rcode::SERVFAIL); + add_to_stream(answer, request.message(), &mut stream); + return stream; + }; + + // Push the begin SOA response message into the stream + add_to_stream(soa_answer.clone(), request.message(), &mut stream); + + // "The AXFR protocol treats the zone contents as an unordered + // collection (or to use the mathematical term, a "set") of + // RRs. Except for the requirement that the transfer must + // begin and end with the SOA RR, there is no requirement to + // send the RRs in any particular order or grouped into + // response messages in any particular way. Although servers + // typically do attempt to send related RRs (such as the RRs + // forming an RRset, and the RRsets of a name) as a + // contiguous group or, when message space allows, in the + // same response message, they are not required to do so, and + // clients MUST accept any ordering and grouping of the + // non-SOA RRs. Each RR SHOULD be transmitted only once, and + // AXFR clients MUST ignore any duplicate RRs received. + // + // Each AXFR response message SHOULD contain a sufficient + // number of RRs to reasonably amortize the per-message + // overhead, up to the largest number that will fit within a + // DNS message (taking the required content of the other + // sections into account, as described below). + // + // Some old AXFR clients expect each response message to + // contain only a single RR. To interoperate with such + // clients, the server MAY restrict response messages to a + // single RR. As there is no standard way to automatically + // detect such clients, this typically requires manual + // configuration at the server." + + let stream = Arc::new(Mutex::new(stream)); + let cloned_stream = stream.clone(); + let cloned_msg = request.message().clone(); + + let op = Box::new(move |owner: Dname<_>, rrset: &Rrset| { + if rrset.rtype() != Rtype::SOA { + let builder = mk_builder_for_target(); + let mut answer = + builder.start_answer(&cloned_msg, Rcode::NOERROR).unwrap(); + for item in rrset.data() { + answer.push((owner.clone(), rrset.ttl(), item)).unwrap(); + } + + let additional = answer.additional(); + let mut stream = cloned_stream.lock().unwrap(); + add_additional_to_stream(additional, &cloned_msg, &mut stream); + } + }); + zone.walk(op); + + let mutex = Arc::try_unwrap(stream).unwrap(); + let mut stream = mutex.into_inner().unwrap(); + + // Push the end SOA response message into the stream + add_to_stream(soa_answer, request.message(), &mut stream); + + stream +} + +#[allow(clippy::type_complexity)] +fn add_to_stream( + answer: Answer, + msg: &Message>, + stream: &mut TransactionStream>, ServiceError>>, +) { + let builder = mk_builder_for_target(); + let additional = answer.to_message(msg, builder); + add_additional_to_stream(additional, msg, stream); +} + +#[allow(clippy::type_complexity)] +fn add_additional_to_stream( + mut additional: AdditionalBuilder>>, + msg: &Message>, + stream: &mut TransactionStream>, ServiceError>>, +) { + set_axfr_header(msg, &mut additional); + stream.push(ready(Ok(CallResult::new(additional)))); +} + +fn set_axfr_header( + msg: &Message>, + additional: &mut AdditionalBuilder, +) where + Target: AsMut<[u8]>, + Target: OctetsBuilder, +{ + // https://datatracker.ietf.org/doc/html/rfc5936#section-2.2.1 + // 2.2.1: Header Values + // + // "These are the DNS message header values for AXFR responses. + // + // ID MUST be copied from request -- see Note a) + // + // QR MUST be 1 (Response) + // + // OPCODE MUST be 0 (Standard Query) + // + // Flags: + // AA normally 1 -- see Note b) + // TC MUST be 0 (Not truncated) + // RD RECOMMENDED: copy request's value; MAY be set to 0 + // RA SHOULD be 0 -- see Note c) + // Z "mbz" -- see Note d) + // AD "mbz" -- see Note d) + // CD "mbz" -- see Note d)" + let header = additional.header_mut(); + header.set_id(msg.header().id()); + header.set_qr(true); + header.set_opcode(Opcode::QUERY); + header.set_aa(true); + header.set_tc(false); + header.set_rd(msg.header().rd()); + header.set_ra(false); + header.set_z(false); + header.set_ad(false); + header.set_cd(false); +} diff --git a/examples/server-transports.rs b/examples/server-transports.rs index 073e714a4..7d62993b2 100644 --- a/examples/server-transports.rs +++ b/examples/server-transports.rs @@ -78,8 +78,8 @@ struct MyService; /// This example shows how to implement the [`Service`] trait directly. /// -/// See [`query()`] and [`name_to_ip()`] for ways of implementing the -/// [`Service`] trait for a function instead of a struct. +/// See [`query`] and [`name_to_ip`] for ways of implementing the [`Service`] +/// trait for a function instead of a struct. impl Service> for MyService { type Target = Vec; type Future = Ready, ServiceError>>; @@ -87,13 +87,7 @@ impl Service> for MyService { fn call( &self, request: Request>, - ) -> Result< - Transaction< - Result, ServiceError>, - Self::Future, - >, - ServiceError, - > { + ) -> Result, ServiceError> { let builder = mk_builder_for_target(); let additional = mk_answer(&request, builder)?; let item = ready(Ok(CallResult::new(additional))); @@ -108,13 +102,13 @@ impl Service> for MyService { /// function signature required by the [`Service`] trait. /// /// The function signature is slightly more complex than when using -/// [`service_fn()`] (see the [`query()`] example below). +/// [`service_fn`] (see the [`query`] example below). #[allow(clippy::type_complexity)] fn name_to_ip( request: Request>, ) -> Result< Transaction< - Result, ServiceError>, + Target, impl Future, ServiceError>> + Send, >, ServiceError, @@ -166,10 +160,10 @@ where //--- query() /// This function shows how to implement [`Service`] logic by matching the -/// function signature required by [`service_fn()`]. +/// function signature required by [`service_fn`]. /// /// The function signature is slightly simpler to write than when not using -/// [`service_fn()`] and supports passing in meta data without any extra +/// [`service_fn`] and supports passing in meta data without any extra /// boilerplate. #[allow(clippy::type_complexity)] fn query( @@ -177,7 +171,7 @@ fn query( count: Arc, ) -> Result< Transaction< - Result>, ServiceError>, + Vec, impl Future>, ServiceError>> + Send, >, ServiceError, @@ -188,9 +182,9 @@ fn query( }) .unwrap(); - // This fn blocks the server until it returns. By returning a future - // that handles the request we allow the server to execute the future - // in the background without blocking the server. + // This fn blocks the server until it returns. By returning a future that + // handles the request we allow the server to execute the future in the + // background without blocking the server. let fut = async move { eprintln!("Sleeping for 100ms"); tokio::time::sleep(Duration::from_millis(100)).await; @@ -229,7 +223,8 @@ impl DoubleListener { } } -/// Combine two streams into one by interleaving the output of both as it is produced. +/// Combine two streams into one by interleaving the output of both as it is +/// produced. impl AsyncAccept for DoubleListener { type Error = io::Error; type StreamType = TcpStream; @@ -684,8 +679,8 @@ async fn main() { let listener = BufferedTcpListener(listener); let count = Arc::new(AtomicU8::new(5)); - // Make our service from the `query()` function with the help of the - // `service_fn()` function. + // Make our service from the `query` function with the help of the + // `service_fn` function. let fn_svc = service_fn(query, count); // Show that we don't have to use the same middleware with every server by @@ -760,7 +755,7 @@ async fn main() { let mut interval = tokio::time::interval(Duration::from_secs(15)); loop { interval.tick().await; - println!("Statistics report: {}", stats); + println!("Statistics report: {stats}"); } }); diff --git a/src/base/header.rs b/src/base/header.rs index 29f1bb261..4c4de11d5 100644 --- a/src/base/header.rs +++ b/src/base/header.rs @@ -71,8 +71,8 @@ impl Header { /// Creates a new header. /// /// The new header has all fields as either zero or false. Thus, the - /// opcode will be [`Opcode::Query`] and the response code will be - /// [`Rcode::NoError`]. + /// opcode will be [`Opcode::QUERY`] and the response code will be + /// [`Rcode::NOERROR`]. #[must_use] pub fn new() -> Self { Self::default() @@ -155,7 +155,7 @@ impl Header { /// /// This field specifies the kind of query a message contains. See /// the [`Opcode`] type for more information on the possible values and - /// their meaning. Normal queries have the variant [`Opcode::Query`] + /// their meaning. Normal queries have the variant [`Opcode::QUERY`] /// which is also the default value when creating a new header. #[must_use] pub fn opcode(self) -> Opcode { diff --git a/src/base/iana/rcode.rs b/src/base/iana/rcode.rs index 0a4b52005..96a83c530 100644 --- a/src/base/iana/rcode.rs +++ b/src/base/iana/rcode.rs @@ -301,10 +301,10 @@ impl<'de> serde::Deserialize<'de> for Rcode { /// This type offers several functions to ease working with the separate parts /// and the combined value of an extended RCODE: /// -/// - [`OptRcode::rcode()`]: the RFC 1035 header RCODE part. -/// - [`OptRcode::ext()`]`: the RFC 6891 ENDS OPT extended RCODE part. -/// - [`OptRcode::to_parts()`]`: to access both parts at once. -/// - [`OptRcode::to_int()`]`: the IANA number for the RCODE combining both +/// - [`OptRcode::rcode`]: the RFC 1035 header RCODE part. +/// - [`OptRcode::ext`]`: the RFC 6891 ENDS OPT extended RCODE part. +/// - [`OptRcode::to_parts`]`: to access both parts at once. +/// - [`OptRcode::to_int`]`: the IANA number for the RCODE combining both /// parts. /// /// [Rcode]: enum.Rcode.html diff --git a/src/base/message_builder.rs b/src/base/message_builder.rs index 5f01492ed..a6831100f 100644 --- a/src/base/message_builder.rs +++ b/src/base/message_builder.rs @@ -1703,15 +1703,15 @@ impl<'a, Target: Composer + ?Sized> OptBuilder<'a, Target> { /// A builder target for sending messages on stream transports. /// /// TODO: Rename this type and adjust the doc comments as it is usable both -/// for datagram AND stream transports via [`as_dgram_slice()`] and -/// [`as_stream_slice()`]. +/// for datagram AND stream transports via [`as_dgram_slice`] and +/// [`as_stream_slice`]. /// /// When messages are sent over stream-oriented transports such as TCP, a DNS -/// message is preceded by a 16 bit length value in order to determine the -/// end of a message. This type transparently adds this length value as the -/// first two octets of an octets builder and itself presents an octets -/// builder interface for building the actual message. Whenever data is pushed -/// to that builder interface, the type will update the length value. +/// message is preceded by a 16 bit length value in order to determine the end +/// of a message. This type transparently adds this length value as the first +/// two octets of an octets builder and itself presents an octets builder +/// interface for building the actual message. Whenever data is pushed to that +/// builder interface, the type will update the length value. /// /// Because the length is 16 bits long, the assembled message can be at most /// 65536 octets long, independently of the maximum length the underlying diff --git a/src/base/name/dname.rs b/src/base/name/dname.rs index 145ac261c..0505914c4 100644 --- a/src/base/name/dname.rs +++ b/src/base/name/dname.rs @@ -472,16 +472,16 @@ impl + ?Sized> Dname { /// /// # Panics /// - /// The method panics if either position is not the start of a label or - /// is out of bounds. + /// The method panics if either position is not the start of a label or is + /// out of bounds. /// /// Because the returned domain name is relative, the method will also - /// panic if the end is equal to the length of the name. If you - /// want to slice the entire end of the name including the final root - /// label, you can use [`slice_from()`] instead. + /// panic if the end is equal to the length of the name. If you want to + /// slice the entire end of the name including the final root label, you + /// can use [`slice_from`] instead. /// /// [`range`]: #method.range - /// [`slice_from()`]: #method.slice_from + /// [`slice_from`]: #method.slice_from pub fn slice( &self, range: impl RangeBounds, @@ -528,9 +528,9 @@ impl + ?Sized> Dname { /// Because the returned domain name is relative, the method will also /// panic if the end is equal to the length of the name. If you /// want to slice the entire end of the name including the final root - /// label, you can use [`range_from()`] instead. + /// label, you can use [`range_from`] instead. /// - /// [`range_from()`]: #method.range_from + /// [`range_from`]: #method.range_from pub fn range( &self, range: impl RangeBounds, diff --git a/src/base/opt/cookie.rs b/src/base/opt/cookie.rs index ee08c79a4..ce012b7f2 100644 --- a/src/base/opt/cookie.rs +++ b/src/base/opt/cookie.rs @@ -274,7 +274,7 @@ impl<'a, Target: Composer> OptBuilder<'a, Target> { /// [`from_octets`][ClientCookie::from_octets]. Similarly, the `Default` /// implementation will create a random cookie and is thus only available if /// the `rand` feature is enabled. -#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct ClientCookie([u8; 8]); impl ClientCookie { diff --git a/src/base/scan.rs b/src/base/scan.rs index 0cb89cc4a..62605c586 100644 --- a/src/base/scan.rs +++ b/src/base/scan.rs @@ -703,6 +703,7 @@ impl Symbol { Symbol::Char(ch) => { ch != ' ' && ch != '\t' + && ch != '\r' && ch != '\n' && ch != '(' && ch != ')' @@ -1055,7 +1056,7 @@ impl SymbolCharsError { #[must_use] pub fn as_str(self) -> &'static str { match self.0 { - SymbolCharsEnum::BadEscape => "illegale escape sequence", + SymbolCharsEnum::BadEscape => "illegal escape sequence", SymbolCharsEnum::ShortInput => "unexpected end of input", } } diff --git a/src/lib.rs b/src/lib.rs index 9025d50d5..448486db4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,8 +45,11 @@ //! Experimental support for DNSSEC validation. #![cfg_attr(feature = "zonefile", doc = "* [zonefile]:")] #![cfg_attr(not(feature = "zonefile"), doc = "* zonefile:")] -//! Experimental reading and writing of zone files, i.e., the textual +//! Experimental reading and writing of zone files, i.e. the textual //! representation of DNS data. +#![cfg_attr(feature = "unstable-zonetree", doc = "* [zonetree]:")] +#![cfg_attr(not(feature = "unstable-zonetree"), doc = "* zonetree:")] +//! Experimental storing and querying of zone trees. //! //! Finally, the [dep] module contains re-exports of some important //! dependencies to help avoid issues with multiple versions of a crate. @@ -129,6 +132,8 @@ //! a client perspective; primarily the `net::client` module. //! * `unstable-server-transport`: receiving and sending DNS messages from //! a server perspective; primarily the `net::server` module. +//! * `unstable-zonetree`: building & querying zone trees; primarily the +//! `zonetree` module. //! //! Note: Some functionality is currently informally marked as //! “experimental” since it was introduced before adoption of the concept @@ -160,3 +165,4 @@ pub mod tsig; pub mod utils; pub mod validate; pub mod zonefile; +pub mod zonetree; diff --git a/src/net/server/connection.rs b/src/net/server/connection.rs index eb4b4b952..173079fc1 100644 --- a/src/net/server/connection.rs +++ b/src/net/server/connection.rs @@ -90,11 +90,7 @@ const MAX_QUEUED_RESPONSES: DefMinMax = DefMinMax::new(10, 0, 1024); //----------- Config --------------------------------------------------------- /// Configuration for a stream server connection. -pub struct Config -where - RequestOctets: Octets, - Target: Composer + Default, -{ +pub struct Config { /// Limit on the amount of time to allow between client requests. /// /// This setting can be overridden on a per connection basis by a @@ -155,15 +151,15 @@ where /// /// # Reconfigure /// - /// On [`StreamServer::reconfigure()`] the current idle period will NOT - /// be affected. Subsequent idle periods (after the next message is - /// received or response is sent, assuming that happens within the current - /// idle period) will use the new timeout value. + /// On [`StreamServer::reconfigure`] the current idle period will NOT be + /// affected. Subsequent idle periods (after the next message is received + /// or response is sent, assuming that happens within the current idle + /// period) will use the new timeout value. /// /// [RFC 7766]: /// https://datatracker.ietf.org/doc/html/rfc7766#section-6.2.3 /// - /// [`StreamServer::reconfigure()`]: + /// [`StreamServer::reconfigure`]: /// super::stream::StreamServer::reconfigure() #[allow(dead_code)] pub fn set_idle_timeout(&mut self, value: Duration) { @@ -180,11 +176,11 @@ where /// /// # Reconfigure /// - /// On [`StreamServer::reconfigure()`] any responses currently being + /// On [`StreamServer::reconfigure`] any responses currently being /// written will NOT use the new timeout, it will only apply to responses /// that start being sent after the timeout is changed. /// - /// [`StreamServer::reconfigure()`]: + /// [`StreamServer::reconfigure`]: /// super::stream::StreamServer::reconfigure() #[allow(dead_code)] pub fn set_response_write_timeout(&mut self, value: Duration) { @@ -202,11 +198,11 @@ where /// /// # Reconfigure /// - /// On [`StreamServer::reconfigure()`] only new connections created after + /// On [`StreamServer::reconfigure`] only new connections created after /// this setting is changed will use the new value, existing connections /// will continue to use their exisitng queue at its existing size. /// - /// [`StreamServer::reconfigure()`]: + /// [`StreamServer::reconfigure`]: /// super::stream::StreamServer::reconfigure() #[allow(dead_code)] pub fn set_max_queued_responses(&mut self, value: usize) { @@ -218,12 +214,12 @@ where /// /// # Reconfigure /// - /// On [`StreamServer::reconfigure()`] only new connections created after + /// On [`StreamServer::reconfigure`] only new connections created after /// this setting is changed will use the new value, existing connections /// and in-flight requests (and their responses) will continue to use /// their current middleware chain. /// - /// [`StreamServer::reconfigure()`]: + /// [`StreamServer::reconfigure`]: /// super::stream::StreamServer::reconfigure() pub fn set_middleware_chain( &mut self, @@ -252,11 +248,7 @@ where //--- Clone -impl Clone for Config -where - RequestOctets: Octets, - Target: Composer + Default, -{ +impl Clone for Config { fn clone(&self) -> Self { Self { idle_timeout: self.idle_timeout, @@ -272,10 +264,8 @@ where /// A handler for a single stream connection between client and server. pub struct Connection where - Stream: AsyncRead + AsyncWrite + Send + Sync + 'static, - Buf: BufSource + Send + Sync + Clone + 'static, - Buf::Output: Octets + Send + Sync, - Svc: Service + Send + Sync + Clone + 'static, + Buf: BufSource, + Svc: Service, { /// Flag used by the Drop impl to track if the metric count has to be /// decreased or not. @@ -324,10 +314,11 @@ where /// impl Connection where - Stream: AsyncRead + AsyncWrite + Send + Sync + 'static, - Buf: BufSource + Send + Sync + Clone + 'static, - Buf::Output: Octets + Send + Sync, - Svc: Service + Send + Sync + Clone + 'static, + Stream: AsyncRead + AsyncWrite, + Buf: BufSource, + Buf::Output: Octets, + Svc: Service, + Svc::Target: Composer + Default, { /// Creates a new handler for an accepted stream connection. #[must_use] @@ -394,9 +385,10 @@ where impl Connection where Stream: AsyncRead + AsyncWrite + Send + Sync + 'static, - Buf: BufSource + Send + Sync + 'static + Clone, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static + Clone, + Buf: BufSource + Send + Sync + Clone + 'static, + Buf::Output: Octets + Send + Sync, + Svc: Service + Send + Sync + 'static, + Svc::Target: Send + Composer + Default, { /// Start reading requests and writing responses to the stream. /// @@ -431,9 +423,11 @@ where impl Connection where Stream: AsyncRead + AsyncWrite + Send + Sync + 'static, - Buf: BufSource + Send + Sync + 'static + Clone, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static + Clone, + Buf: BufSource + Send + Sync + Clone + 'static, + Buf::Output: Octets + Send + Sync, + Svc: Service + Send + Sync + 'static, + Svc::Future: Send, + Svc::Target: Send + Composer + Default, { /// Connection handler main loop. async fn run_until_error( @@ -441,12 +435,10 @@ where mut command_rx: watch::Receiver< ServerCommand>, >, - ) where - Svc::Future: Send, - { + ) { // SAFETY: This unwrap is safe because we always put a Some value into - // self.stream_rx in [`Self::with_config()`] above (and thus also in - // [`Self::new()`] which calls [`Self::with_config()`]). + // self.stream_rx in [`Self::with_config`] above (and thus also in + // [`Self::new`] which calls [`Self::with_config`]). let stream_rx = self.stream_rx.take().unwrap(); let mut dns_msg_receiver = @@ -721,7 +713,10 @@ where async fn process_read_request( &mut self, res: Result, - ) -> Result<(), ConnectionEvent> { + ) -> Result<(), ConnectionEvent> + where + Svc::Future: Send, + { res.and_then(|msg| { let received_at = Instant::now(); @@ -754,10 +749,8 @@ where impl Drop for Connection where - Stream: AsyncRead + AsyncWrite + Send + Sync, - Buf: BufSource + Send + Sync + Clone, - Buf::Output: Octets + Send + Sync, - Svc: Service + Send + Sync + Clone, + Buf: BufSource, + Svc: Service, { fn drop(&mut self) { if self.active { @@ -772,10 +765,10 @@ where impl CommonMessageFlow for Connection where - Stream: AsyncRead + AsyncWrite + Send + Sync, - Buf: BufSource + Send + Sync + Clone, - Buf::Output: Octets + Send + Sync, - Svc: Service + Send + Sync + Clone, + Buf: BufSource, + Buf::Output: Octets + Send + Sync + 'static, + Svc: Service + Send + Sync + 'static, + Svc::Target: Send, { type Meta = Sender>; @@ -814,12 +807,12 @@ where Err(TrySendError::Closed(_msg)) => { // TODO: How should we properly communicate this to the operator? - error!("StreamServer: Unable to queue message for sending: server is shutting down."); + error!("Unable to queue message for sending: server is shutting down."); } Err(TrySendError::Full(_msg)) => { // TODO: How should we properly communicate this to the operator? - error!("StreamServer: Unable to queue message for sending: queue is full."); + error!("Unable to queue message for sending: queue is full."); } } } @@ -849,11 +842,7 @@ enum Status { /// ensures that any part of the request already received is not lost if the /// read operation is cancelled by Tokio and then a new read operation is /// started. -struct DnsMessageReceiver -where - Stream: AsyncRead + AsyncWrite + Send + Sync + 'static, - Buf: BufSource + Send + Sync + 'static + Clone, -{ +struct DnsMessageReceiver { /// A buffer to record the total expected size of the message currently /// being received. DNS TCP streams preceed the DNS message by bytes /// indicating the length of the message that follows. @@ -957,7 +946,6 @@ where io::ErrorKind::UnexpectedEof => { // The client disconnected. Per RFC 7766 6.2.4 pending // responses MUST NOT be sent to the client. - error!("I/O error: {}", err); ControlFlow::Break(ConnectionEvent::DisconnectWithoutFlush) } io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => { diff --git a/src/net/server/dgram.rs b/src/net/server/dgram.rs index 83598cd8a..59356c36d 100644 --- a/src/net/server/dgram.rs +++ b/src/net/server/dgram.rs @@ -84,11 +84,7 @@ const MAX_RESPONSE_SIZE: DefMinMax = DefMinMax::new(1232, 512, 4096); /// Configuration for a datagram server. #[derive(Debug)] -pub struct Config -where - RequestOctets: Octets, - Target: Composer + Default, -{ +pub struct Config { /// Limit suggested to [`Service`] on maximum response size to create. max_response_size: Option, @@ -123,9 +119,9 @@ where /// /// # Reconfigure /// - /// On [`DgramServer::reconfigure()`]` any change to this setting will - /// only affect requests received after the setting is changed, in - /// progress requests will be unaffected. + /// On [`DgramServer::reconfigure`]` any change to this setting will only + /// affect requests received after the setting is changed, in progress + /// requests will be unaffected. /// /// [2020 DNS Flag Day]: http://www.dnsflagday.net/2020/ /// [RFC 6891]: @@ -134,15 +130,16 @@ where self.max_response_size = value; } - /// Sets the time to wait for a complete message to be written to the client. + /// Sets the time to wait for a complete message to be written to the + /// client. /// /// The value has to be between 1ms and 60 seconds. The default value is 5 /// seconds. /// /// # Reconfigure /// - /// On [`DgramServer::reconfigure()`]` any change to this setting will - /// only affect responses sent after the setting is changed, in-flight + /// On [`DgramServer::reconfigure`]` any change to this setting will only + /// affect responses sent after the setting is changed, in-flight /// responses will be unaffected. pub fn set_write_timeout(&mut self, value: Duration) { self.write_timeout = value; @@ -153,9 +150,9 @@ where /// /// # Reconfigure /// - /// On [`DgramServer::reconfigure()`]` any change to this setting will - /// only affect requests (and their responses) received after the setting - /// is changed, in progress requests will be unaffected. + /// On [`DgramServer::reconfigure`]` any change to this setting will only + /// affect requests (and their responses) received after the setting is + /// changed, in progress requests will be unaffected. pub fn set_middleware_chain( &mut self, value: MiddlewareChain, @@ -226,7 +223,7 @@ type CommandReceiver = watch::Receiver>; /// /// A socket is anything that implements the [`AsyncDgramSock`] trait. This /// crate provides an implementation for [`tokio::net::UdpSocket`]. When -/// wrapped inside an [`Arc`] the same `UdpSocket` can be [`Arc::clone()`]d to +/// wrapped inside an [`Arc`] the same `UdpSocket` can be [`Arc::clone`]d to /// multiple instances of [`DgramServer`] potentially increasing throughput. /// /// # Examples @@ -255,8 +252,7 @@ type CommandReceiver = watch::Receiver>; /// /// fn my_service(msg: Request>, _meta: ()) /// -> Result< -/// Transaction< -/// Result>, ServiceError>, +/// Transaction, /// Pin>, ServiceError> /// > + Send>>, @@ -300,8 +296,9 @@ pub struct DgramServer where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync + 'static, - Buf::Output: Octets + Send + Sync + 'static, + Buf::Output: Octets + Send + Sync, Svc: Service + Send + Sync + 'static, + Svc::Target: Send + Composer + Default, { /// The configuration of the server. config: Arc>>, @@ -335,14 +332,15 @@ where /// impl DgramServer where - Sock: AsyncDgramSock + Send + Sync + 'static, - Buf: BufSource + Send + Sync + 'static + Clone, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static, + Sock: AsyncDgramSock + Send + Sync, + Buf: BufSource + Send + Sync, + Buf::Output: Octets + Send + Sync, + Svc: Service + Send + Sync, + Svc::Target: Send + Composer + Default, { /// Constructs a new [`DgramServer`] with default configuration. /// - /// See [`Self::with_config()`]. + /// See [`Self::with_config`]. #[must_use] pub fn new(sock: Sock, buf: Buf, service: Svc) -> Self { Self::with_config(sock, buf, service, Config::default()) @@ -354,12 +352,13 @@ where /// - A socket which must implement [`AsyncDgramSock`] and is responsible /// receiving new messages and send responses back to the client. /// - A [`BufSource`] for creating buffers on demand. - /// - A [`Service`] for handling received requests and generating responses. + /// - A [`Service`] for handling received requests and generating + /// responses. /// - A [`Config`] with settings to control the server behaviour. /// - /// Invoke [`run()`] to receive and process incoming messages. + /// Invoke [`run`] to receive and process incoming messages. /// - /// [`run()`]: Self::run() + /// [`run`]: Self::run() #[must_use] pub fn with_config( sock: Sock, @@ -392,7 +391,7 @@ where Buf: BufSource + Send + Sync + 'static, Buf::Output: Octets + Send + Sync + 'static + Debug, Svc: Service + Send + Sync + 'static, - Svc::Target: Debug, + Svc::Target: Send + Composer + Debug + Default, { /// Get a reference to the network source being used to receive messages. #[must_use] @@ -412,23 +411,24 @@ where impl DgramServer where Sock: AsyncDgramSock + Send + Sync + 'static, - Buf: BufSource + Send + Sync + 'static, + Buf: BufSource + Send + Sync, Buf::Output: Octets + Send + Sync + 'static, Svc: Service + Send + Sync + 'static, + Svc::Target: Send + Composer + Default, { /// Start the server. /// /// # Drop behaviour /// - /// When dropped [`shutdown()`] will be invoked. + /// When dropped [`shutdown`] will be invoked. /// - /// [`shutdown()`]: Self::shutdown + /// [`shutdown`]: Self::shutdown pub async fn run(&self) where Svc::Future: Send, { if let Err(err) = self.run_until_error().await { - error!("DgramServer: {err}"); + error!("Server stopped due to error: {err}"); } } @@ -453,11 +453,10 @@ where /// socket that was given to the server when it was created remains /// operational. /// - /// [`Self::is_shutdown()`] can be used to dertermine if shutdown is + /// [`Self::is_shutdown`] can be used to dertermine if shutdown is /// complete. /// - /// [`Self::await_shutdown()`] can be used to wait for shutdown to - /// complete. + /// [`Self::await_shutdown`] can be used to wait for shutdown to complete. pub fn shutdown(&self) -> Result<(), Error> { self.command_tx .lock() @@ -480,7 +479,7 @@ where /// Returns true if the server shutdown in the given time period, false /// otherwise. /// - /// To start the shutdown process first call [`Self::shutdown()`] then use + /// To start the shutdown process first call [`Self::shutdown`] then use /// this method to wait for the shutdown process to complete. pub async fn await_shutdown(&self, duration: Duration) -> bool { timeout(duration, async { @@ -500,9 +499,10 @@ where impl DgramServer where Sock: AsyncDgramSock + Send + Sync + 'static, - Buf: BufSource + Send + Sync + 'static, + Buf: BufSource + Send + Sync, Buf::Output: Octets + Send + Sync + 'static, Svc: Service + Send + Sync + 'static, + Svc::Target: Send + Composer + Default, { /// Receive incoming messages until shutdown or fatal error. async fn run_until_error(&self) -> Result<(), String> @@ -643,7 +643,7 @@ where /// Helper function to package references to key parts of our server state /// into a [`RequestState`] ready for passing through the /// [`CommonMessageFlow`] call chain and ultimately back to ourselves at - /// [`process_call_reusult()`]. + /// [`process_call_reusult`]. fn mk_state_for_request( &self, ) -> RequestState { @@ -664,6 +664,7 @@ where Buf: BufSource + Send + Sync + 'static, Buf::Output: Octets + Send + Sync + 'static, Svc: Service + Send + Sync + 'static, + Svc::Target: Send + Composer + Default, { type Meta = RequestState; @@ -740,8 +741,9 @@ impl Drop for DgramServer where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync + 'static, - Buf::Output: Octets + Send + Sync + 'static, + Buf::Output: Octets + Send + Sync, Svc: Service + Send + Sync + 'static, + Svc::Target: Send + Composer + Default, { fn drop(&mut self) { // Shutdown the DgramServer. Don't handle the failure case here as @@ -753,13 +755,9 @@ where //----------- RequestState --------------------------------------------------- -/// Data needed by [`DgramServer::process_call_result()`] which needs to be +/// Data needed by [`DgramServer::process_call_result`] which needs to be /// passed through the [`CommonMessageFlow`] call chain. -pub struct RequestState -where - RequestOctets: Octets, - Target: Composer + Default, -{ +pub struct RequestState { /// The network socket over which this request was received and over which /// the response should be sent. sock: Arc, @@ -774,11 +772,7 @@ where write_timeout: Duration, } -impl RequestState -where - RequestOctets: Octets, - Target: Composer + Default, -{ +impl RequestState { /// Creates a new instance of [`RequestState`]. fn new( sock: Arc, @@ -797,9 +791,6 @@ where impl Clone for RequestState -where - RequestOctets: Octets, - Target: Composer + Default, { fn clone(&self) -> Self { Self { diff --git a/src/net/server/message.rs b/src/net/server/message.rs index ece6f100b..bf1d71dbc 100644 --- a/src/net/server/message.rs +++ b/src/net/server/message.rs @@ -17,6 +17,7 @@ use crate::net::server::middleware::chain::MiddlewareChain; use super::service::{CallResult, Service, ServiceError, Transaction}; use super::util::start_reply; +use crate::base::wire::Composer; //------------ UdpTransportContext ------------------------------------------- @@ -233,7 +234,7 @@ impl> Clone for Request { /// Servers implement this trait to benefit from the common processing /// required while still handling aspects specific to the server themselves. /// -/// Processing starts at [`process_request()`]. +/// Processing starts at [`process_request`]. /// ///
/// @@ -244,12 +245,12 @@ impl> Clone for Request { /// ///
/// -/// [`process_request()`]: Self::process_request() +/// [`process_request`]: Self::process_request() pub trait CommonMessageFlow where - Buf: BufSource + Send + Sync + 'static, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static, + Buf: BufSource, + Buf::Output: Octets + Send + Sync, + Svc: Service + Send + Sync, { /// Server-specific data that it chooses to pass along with the request in /// order that it may receive it when `process_call_result()` is @@ -260,7 +261,7 @@ where /// /// This function consumes the given message buffer and processes the /// contained message, if any, to completion, possibly resulting in a - /// response being passed to [`Self::process_call_result()`]. + /// response being passed to [`Self::process_call_result`]. /// /// The request message is a given as a seqeuence of bytes in `buf` /// originating from client address `addr`. @@ -286,7 +287,10 @@ where meta: Self::Meta, ) -> Result<(), ServiceError> where + Svc: 'static, + Svc::Target: Send + Composer + Default, Svc::Future: Send, + Buf::Output: 'static, { boomerang( self, @@ -344,16 +348,22 @@ fn boomerang( meta: Server::Meta, ) -> Result<(), ServiceError> where - Buf: BufSource + Send + Sync + 'static, + Buf: BufSource, Buf::Output: Octets + Send + Sync + 'static, Svc: Service + Send + Sync + 'static, + Svc::Future: Send, + Svc::Target: Send + Composer + Default, Server: CommonMessageFlow + ?Sized, { - let (request, preprocessing_result) = do_middleware_preprocessing( - server, - buf, - received_at, - addr, + let message = Message::from_octets(buf).map_err(|err| { + warn!("Failed while parsing request message: {err}"); + ServiceError::InternalError + })?; + + let request = server.add_context_to_request(message, received_at, addr); + + let preprocessing_result = do_middleware_preprocessing::( + &request, &middleware_chain, &metrics, )?; @@ -375,28 +385,23 @@ where /// Pass a pre-processed request to the [`Service`] to handle. /// -/// If [`Service::call()`] returns an error this function will produce a DNS +/// If [`Service::call`] returns an error this function will produce a DNS /// ServFail error response. If the returned error is /// [`ServiceError::InternalError`] it will also be logged. #[allow(clippy::type_complexity)] fn do_service_call( preprocessing_result: ControlFlow<( - Transaction< - Result, ServiceError>, - Svc::Future, - >, + Transaction, usize, )>, request: &Request<::Output>, svc: &Svc, -) -> ( - Transaction, ServiceError>, Svc::Future>, - Option, -) +) -> (Transaction, Option) where - Buf: BufSource + Send + Sync + 'static, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static, + Buf: BufSource, + Buf::Output: Octets, + Svc: Service, + Svc::Target: Composer + Default, { match preprocessing_result { ControlFlow::Continue(()) => { @@ -439,47 +444,28 @@ where /// pre-processing it via any supplied [`MiddlewareChain`]. /// /// On success the result is an immutable request message and a -/// [`ControlFlow`] decision about whether to continue with further -/// processing or to break early with a possible response. If processing -/// failed the result will be a [`ServiceError`]. +/// [`ControlFlow`] decision about whether to continue with further processing +/// or to break early with a possible response. If processing failed the +/// result will be a [`ServiceError`]. /// -/// On break the result will be one ([`Transaction::single()`]) or more -/// ([`Transaction::stream()`]) to post-process. +/// On break the result will be one ([`Transaction::single`]) or more +/// ([`Transaction::stream`]) to post-process. #[allow(clippy::type_complexity)] -fn do_middleware_preprocessing( - server: &Server, - buf: Buf::Output, - received_at: Instant, - addr: SocketAddr, +fn do_middleware_preprocessing( + request: &Request, middleware_chain: &MiddlewareChain, metrics: &Arc, ) -> Result< - ( - Request, - ControlFlow<( - Transaction< - Result, ServiceError>, - Svc::Future, - >, - usize, - )>, - ), + ControlFlow<(Transaction, usize)>, ServiceError, > where - Buf: BufSource + Send + Sync + 'static, + Buf: BufSource, Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static, - Server: CommonMessageFlow + ?Sized, + Svc: Service + Send + Sync, + Svc::Future: Send, + Svc::Target: Send + Composer + Default + 'static, { - let message = Message::from_octets(buf).map_err(|err| { - warn!("Failed while parsing request message: {err}"); - ServiceError::InternalError - })?; - - let mut request = - server.add_context_to_request(message, received_at, addr); - let span = info_span!("pre-process", msg_id = request.message().header().id(), client = %request.client_addr(), @@ -488,37 +474,34 @@ where metrics.inc_num_inflight_requests(); - let pp_res = middleware_chain.preprocess(&mut request); + let pp_res = middleware_chain.preprocess(request); - Ok((request, pp_res)) + Ok(pp_res) } /// Post-process a response in the context of its originating request. /// -/// Each response is post-processed in its own Tokio task. Note that there -/// is no guarantee about the order in which responses will be -/// post-processed. If the order of a seqence of responses is important it -/// should be provided as a [`Transaction::stream()`] rather than -/// [`Transaction::single()`]. +/// Each response is post-processed in its own Tokio task. Note that there is +/// no guarantee about the order in which responses will be post-processed. If +/// the order of a seqence of responses is important it should be provided as +/// a [`Transaction::stream`] rather than [`Transaction::single`]. /// -/// Responses are first post-processed by the [`MiddlewareChain`] -/// provided, if any, then passed to [`Self::process_call_result()`] for -/// final processing. +/// Responses are first post-processed by the [`MiddlewareChain`] provided, if +/// any, then passed to [`Self::process_call_result`] for final processing. #[allow(clippy::type_complexity)] fn do_middleware_postprocessing( request: Request, meta: Server::Meta, middleware_chain: MiddlewareChain, - mut response_txn: Transaction< - Result, ServiceError>, - Svc::Future, - >, + mut response_txn: Transaction, last_processor_id: Option, metrics: Arc, ) where - Buf: BufSource + Send + Sync + 'static, + Buf: BufSource, Buf::Output: Octets + Send + Sync + 'static, Svc: Service + Send + Sync + 'static, + Svc::Future: Send, + Svc::Target: Send + Composer + Default, Server: CommonMessageFlow + ?Sized, { tokio::spawn(async move { diff --git a/src/net/server/middleware/builder.rs b/src/net/server/middleware/builder.rs index 5e1f75723..91b8b64b0 100644 --- a/src/net/server/middleware/builder.rs +++ b/src/net/server/middleware/builder.rs @@ -17,16 +17,12 @@ use super::processors::mandatory::MandatoryMiddlewareProcessor; /// [`MiddlewareProcessor`] at a time. /// /// This builder allows you to add [`MiddlewareProcessor`]s sequentially using -/// [`push()`] before finally calling [`build()`] to turn the builder into an +/// [`push`] before finally calling [`build`] to turn the builder into an /// immutable [`MiddlewareChain`]. /// -/// [`push()`]: Self::push() -/// [`build()`]: Self::build() -pub struct MiddlewareBuilder, Target = Vec> -where - RequestOctets: Octets, - Target: Composer + Default, -{ +/// [`push`]: Self::push() +/// [`build`]: Self::build() +pub struct MiddlewareBuilder, Target = Vec> { /// The ordered set of processors which will pre-process requests and then /// in reverse order will post-process responses. processors: Vec< @@ -49,12 +45,12 @@ where ///
Warning: /// /// When building a standards compliant DNS server you should probably use - /// [`MiddlewareBuilder::minimal()`] or [`MiddlewareBuilder::standard()`] + /// [`MiddlewareBuilder::minimal`] or [`MiddlewareBuilder::standard`] /// instead. ///
/// - /// [`MiddlewareBuilder::minimal()`]: Self::minimal() - /// [`MiddlewareBuilder::standard()`]: Self::standard() + /// [`MiddlewareBuilder::minimal`]: Self::minimal() + /// [`MiddlewareBuilder::standard`]: Self::standard() #[must_use] pub fn new() -> Self { Self { processors: vec![] } @@ -131,12 +127,12 @@ where impl Default for MiddlewareBuilder where - RequestOctets: AsRef<[u8]> + Octets, + RequestOctets: Octets, Target: Composer + Default, { /// Create a middleware builder with default, aka "standard", processors. /// - /// See [`Self::standard()`]. + /// See [`Self::standard`]. fn default() -> Self { Self::standard() } diff --git a/src/net/server/middleware/chain.rs b/src/net/server/middleware/chain.rs index 9b87f8acc..57d4c9956 100644 --- a/src/net/server/middleware/chain.rs +++ b/src/net/server/middleware/chain.rs @@ -25,11 +25,7 @@ use super::processor::MiddlewareProcessor; /// A [`MiddlewareChain`] is immutable. Requests should not be post-processed /// by a different or modified chain than they were pre-processed by. #[derive(Default)] -pub struct MiddlewareChain -where - RequestOctets: AsRef<[u8]>, - Target: Composer + Default, -{ +pub struct MiddlewareChain { /// The ordered set of processors which will pre-process requests and then /// in reverse order will post-process responses. processors: Arc< @@ -39,11 +35,7 @@ where >, } -impl MiddlewareChain -where - RequestOctets: AsRef<[u8]>, - Target: Composer + Default, -{ +impl MiddlewareChain { /// Create a new _empty_ chain of processors. /// ///
Warning: @@ -56,11 +48,11 @@ where /// perform such processing yourself. /// /// Most users should **NOT** use this function but should instead use - /// [`MiddlewareBuilder::default()`] which constructs a chain that starts + /// [`MiddlewareBuilder::default`] which constructs a chain that starts /// with [`MandatoryMiddlewareProcessor`]. ///
/// - /// [`MiddlewareBuilder::default()`]: + /// [`MiddlewareBuilder::default`]: /// super::builder::MiddlewareBuilder::default() /// [`MandatoryMiddlewareProcessor`]: /// super::processors::mandatory::MandatoryMiddlewareProcessor @@ -93,9 +85,9 @@ where /// a pre-processor decided to terminate processing of the request. /// /// On [`ControlFlow::Break`] the caller should pass the given result to - /// [`postprocess()`][Self::postprocess]. If processing terminated early - /// the result includes the index of the pre-processor which terminated - /// the processing. + /// [`postprocess`][Self::postprocess]. If processing terminated early the + /// result includes the index of the pre-processor which terminated the + /// processing. /// /// # Performance /// @@ -107,11 +99,8 @@ where #[allow(clippy::type_complexity)] pub fn preprocess( &self, - request: &mut Request, - ) -> ControlFlow<( - Transaction, ServiceError>, Future>, - usize, - )> + request: &Request, + ) -> ControlFlow<(Transaction, usize)> where Future: std::future::Future< Output = Result, ServiceError>, @@ -149,7 +138,7 @@ where /// was recieved. /// /// The optional `last_processor_idx` value should come from an earlier - /// call to [`preprocess()`][Self::preprocess]. Post-processing will start + /// call to [`preprocess`][Self::preprocess]. Post-processing will start /// with this processor and walk backward from there, post-processors /// further down the chain will not be invoked. pub fn postprocess( @@ -172,11 +161,7 @@ where //--- Clone -impl Clone for MiddlewareChain -where - RequestOctets: AsRef<[u8]>, - Target: Composer + Default, -{ +impl Clone for MiddlewareChain { fn clone(&self) -> Self { Self { processors: self.processors.clone(), @@ -186,11 +171,7 @@ where //--- Debug -impl Debug for MiddlewareChain -where - RequestOctets: AsRef<[u8]>, - Target: Composer + Default, -{ +impl Debug for MiddlewareChain { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("MiddlewareChain") .field("processors", &self.processors.len()) diff --git a/src/net/server/middleware/mod.rs b/src/net/server/middleware/mod.rs index 2d0b408cf..2f5176866 100644 --- a/src/net/server/middleware/mod.rs +++ b/src/net/server/middleware/mod.rs @@ -13,16 +13,16 @@ //! //! Mandatory functionality and logic required by all standards compliant DNS //! servers can be incorporated into your server by building a middleware -//! chain starting from [`MiddlewareBuilder::default()`]. +//! chain starting from [`MiddlewareBuilder::default`]. //! //! A selection of additional functionality relating to server behaviour and //! DNS standards (as opposed to your own application logic) is provided which -//! you can incorporate into your DNS server via -//! [`MiddlewareBuilder::push()`]. See the various implementations of -//! [`MiddlewareProcessor`] for more information. +//! you can incorporate into your DNS server via [`MiddlewareBuilder::push`]. +//! See the various implementations of [`MiddlewareProcessor`] for more +//! information. //! -//! [`MiddlewareBuilder::default()`]: builder::MiddlewareBuilder::default() -//! [`MiddlewareBuilder::push()`]: builder::MiddlewareBuilder::push() +//! [`MiddlewareBuilder::default`]: builder::MiddlewareBuilder::default() +//! [`MiddlewareBuilder::push`]: builder::MiddlewareBuilder::push() //! [`MiddlewareChain`]: chain::MiddlewareChain //! [`MiddlewareProcessor`]: processor::MiddlewareProcessor //! [`Service`]: crate::net::server::service::Service diff --git a/src/net/server/middleware/processor.rs b/src/net/server/middleware/processor.rs index a04b64b19..1b40d85da 100644 --- a/src/net/server/middleware/processor.rs +++ b/src/net/server/middleware/processor.rs @@ -2,7 +2,6 @@ use core::ops::ControlFlow; use crate::base::message_builder::AdditionalBuilder; -use crate::base::wire::Composer; use crate::base::StreamTarget; use crate::net::server::message::Request; @@ -14,13 +13,13 @@ use crate::net::server::message::Request; pub trait MiddlewareProcessor where RequestOctets: AsRef<[u8]>, - Target: Composer + Default, { /// Apply middleware pre-processing rules to a request. /// - /// See [`MiddlewareChain::preprocess()`] for more information. + /// See [`MiddlewareChain::preprocess`] for more information. /// - /// [`MiddlewareChain::preprocess()`]: crate::net::server::middleware::chain::MiddlewareChain::preprocess() + /// [`MiddlewareChain::preprocess`]: + /// crate::net::server::middleware::chain::MiddlewareChain::preprocess() fn preprocess( &self, request: &Request, @@ -28,9 +27,10 @@ where /// Apply middleware post-processing rules to a response. /// - /// See [`MiddlewareChain::postprocess()`] for more information. + /// See [`MiddlewareChain::postprocess`] for more information. /// - /// [`MiddlewareChain::postprocess()`]: crate::net::server::middleware::chain::MiddlewareChain::postprocess() + /// [`MiddlewareChain::postprocess`]: + /// crate::net::server::middleware::chain::MiddlewareChain::postprocess() fn postprocess( &self, request: &Request, diff --git a/src/net/server/middleware/processors/cookies.rs b/src/net/server/middleware/processors/cookies.rs index 2425746d9..6f0a245a4 100644 --- a/src/net/server/middleware/processors/cookies.rs +++ b/src/net/server/middleware/processors/cookies.rs @@ -4,7 +4,7 @@ use core::ops::ControlFlow; use std::net::IpAddr; use std::vec::Vec; -use octseq::{Octets, OctetsBuilder}; +use octseq::Octets; use rand::RngCore; use tracing::{debug, trace, warn}; @@ -129,7 +129,7 @@ impl CookiesMiddlewareProcessor { ) -> AdditionalBuilder> where RequestOctets: Octets, - Target: Composer + OctetsBuilder + Default, + Target: Composer + Default, { let mut additional = start_reply(request).additional(); @@ -169,7 +169,7 @@ impl CookiesMiddlewareProcessor { ) -> AdditionalBuilder> where RequestOctets: Octets, - Target: Composer + OctetsBuilder + Default, + Target: Composer + Default, { // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.3 // "If the server responds [ed: by sending a BADCOOKIE error @@ -189,7 +189,7 @@ impl CookiesMiddlewareProcessor { ) -> AdditionalBuilder> where RequestOctets: Octets, - Target: Composer + OctetsBuilder + Default, + Target: Composer + Default, { // https://datatracker.ietf.org/doc/html/rfc7873#section-5.4 // Querying for a Server Cookie: @@ -253,7 +253,7 @@ impl MiddlewareProcessor for CookiesMiddlewareProcessor where RequestOctets: Octets, - Target: Composer + OctetsBuilder + Default, + Target: Composer + Default, { fn preprocess( &self, diff --git a/src/net/server/middleware/processors/edns.rs b/src/net/server/middleware/processors/edns.rs index 081b1f82c..54ce92f44 100644 --- a/src/net/server/middleware/processors/edns.rs +++ b/src/net/server/middleware/processors/edns.rs @@ -280,10 +280,10 @@ where // using the edns-tcp-keepalive EDNS(0) option [RFC7828]." if let TransportSpecificContext::NonUdp(ctx) = request.transport_ctx() { - if let Ok(additional) = request.message().additional() { - let mut iter = additional.limit_to::>(); - if iter.next().is_some() { - if let Some(idle_timeout) = ctx.idle_timeout() { + if let Some(idle_timeout) = ctx.idle_timeout() { + if let Ok(additional) = request.message().additional() { + let mut iter = additional.limit_to::>(); + if iter.next().is_some() { match IdleTimeout::try_from(idle_timeout) { Ok(timeout) => { // Request has an OPT RR and server idle diff --git a/src/net/server/middleware/processors/mandatory.rs b/src/net/server/middleware/processors/mandatory.rs index 674e9ff20..ac0bdc366 100644 --- a/src/net/server/middleware/processors/mandatory.rs +++ b/src/net/server/middleware/processors/mandatory.rs @@ -92,8 +92,8 @@ impl MandatoryMiddlewareProcessor { response: &mut AdditionalBuilder>, ) -> Result<(), TruncateError> where - RequestOctets: Octets, Target: Composer + Default, + RequestOctets: AsRef<[u8]>, { if let TransportSpecificContext::Udp(ctx) = request.transport_ctx() { // https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1 diff --git a/src/net/server/mod.rs b/src/net/server/mod.rs index 8f21b725a..13c8c34a7 100644 --- a/src/net/server/mod.rs +++ b/src/net/server/mod.rs @@ -44,7 +44,7 @@ //! - Tune the server behaviour via builder functions such as //! `with_middleware()`. //! - `run()` the server. -//! - `shutdown()` the server, explicitly or on [`drop()`]. +//! - `shutdown()` the server, explicitly or on [`drop`]. //! //! See [`DgramServer`] and [`StreamServer`] for example code to help you get //! started. @@ -95,11 +95,11 @@ //! With Middleware mandatory functionality and logic required by all //! standards compliant DNS servers can be incorporated into your server by //! building a [`MiddlewareChain`] starting from -//! [`MiddlewareBuilder::default()`]. +//! [`MiddlewareBuilder::default`]. //! //! You can also opt to incorporate additional behaviours into your DNS server //! from a selection of pre-supplied implementations via -//! [`MiddlewareBuilder::push()`]. See the various implementations of +//! [`MiddlewareBuilder::push`]. See the various implementations of //! [`MiddlewareProcessor`] for more information. //! //! And if the existing middleware processors don't meet your needs, maybe you @@ -132,18 +132,18 @@ //! ## Performance //! //! Both [`DgramServer`] and [`StreamServer`] use [`CommonMessageFlow`] to -//! pre-process the request, invoke [`Service::call()`], and post-process the +//! pre-process the request, invoke [`Service::call`], and post-process the //! response. //! -//! - Pre-processing and [`Service::call()`] invocation are done from the +//! - Pre-processing and [`Service::call`] invocation are done from the //! Tokio task handling the request. For [`DgramServer`] this is the main //! task that receives incoming messages. For [`StreamServer`] this is a //! dedicated task per accepted connection. //! - Post-processing is done in a new task request within which each future -//! resulting from invoking [`Service::call()`] is awaited and the -//! resulting response is post-processed. +//! resulting from invoking [`Service::call`] is awaited and the resulting +//! response is post-processed. //! -//! The initial work done by [`Service::call()`] should therefore complete as +//! The initial work done by [`Service::call`] should therefore complete as //! quickly as possible, delegating as much of the work as it can to the //! future(s) it returns. Until then it blocks the server from receiving new //! messages, or in the case of [`StreamServer`], new messages for the @@ -173,25 +173,25 @@ //! | # | Difficulty | Summary | Description | //! |---|------------|---------|-------------| //! | 1 | Easy | `#[derive(Clone)]` | Add `#[derive(Clone)]` to your [`Service`] impl. If your [`Service`] impl has no state that needs to be shared amongst instances of itself then this may be good enough for you. | -//! | 2 | Medium | [`Arc`] wrapper | Wrap your [`Service`] impl instance inside an [`Arc`] via [`Arc::new()`]. This crate implements the [`Service`] trait for `Arc` so you can pass an `Arc` to both [`DgramServer`] and [`StreamServer`] and they will [`Clone`] the [`Arc`] rather than the [`Service`] instance itself. | +//! | 2 | Medium | [`Arc`] wrapper | Wrap your [`Service`] impl instance inside an [`Arc`] via [`Arc::new`]. This crate implements the [`Service`] trait for `Arc` so you can pass an `Arc` to both [`DgramServer`] and [`StreamServer`] and they will [`Clone`] the [`Arc`] rather than the [`Service`] instance itself. | //! | 3 | Hard | Do it yourself | Manually implement [`Clone`] and/or your own locking and interior mutability strategy for your [`Service`] impl, giving you complete control over how state is shared by your server instances. | //! //! [`Arc`]: std::sync::Arc -//! [`Arc::new()`]: std::sync::Arc::new() +//! [`Arc::new`]: std::sync::Arc::new() //! [`AsyncAccept`]: sock::AsyncAccept //! [`AsyncDgramSock`]: sock::AsyncDgramSock //! [`BufSource`]: buf::BufSource //! [`DgramServer`]: dgram::DgramServer //! [`CommonMessageFlow`]: message::CommonMessageFlow //! [Middleware]: middleware -//! [`MiddlewareBuilder::default()`]: +//! [`MiddlewareBuilder::default`]: //! middleware::builder::MiddlewareBuilder::default() -//! [`MiddlewareBuilder::push()`]: +//! [`MiddlewareBuilder::push`]: //! middleware::builder::MiddlewareBuilder::push() //! [`MiddlewareChain`]: middleware::chain::MiddlewareChain //! [`MiddlewareProcessor`]: middleware::processor::MiddlewareProcessor //! [`Service`]: service::Service -//! [`Service::call()`]: service::Service::call() +//! [`Service::call`]: service::Service::call() //! [`StreamServer`]: stream::StreamServer //! [`TcpServer`]: stream::TcpServer //! [`UdpServer`]: dgram::UdpServer diff --git a/src/net/server/service.rs b/src/net/server/service.rs index 852b9131a..57625025f 100644 --- a/src/net/server/service.rs +++ b/src/net/server/service.rs @@ -1,8 +1,8 @@ //! The application logic of a DNS server. //! -//! The [`Service::call()`] function defines how the service should respond to -//! a given DNS request. resulting in a [`Transaction`] containing a -//! transaction that yields one or more future DNS responses, and/or a +//! The [`Service::call`] function defines how the service should respond to a +//! given DNS request. resulting in a [`Transaction`] containing a transaction +//! that yields one or more future DNS responses, and/or a //! [`ServiceFeedback`]. use core::fmt::Display; use core::ops::Deref; @@ -16,14 +16,13 @@ use std::vec::Vec; use futures_util::stream::FuturesOrdered; use futures_util::{FutureExt, StreamExt}; -use octseq::{OctetsBuilder, ShortBuf}; +use super::message::Request; use crate::base::iana::Rcode; use crate::base::message_builder::{AdditionalBuilder, PushError}; -use crate::base::wire::{Composer, ParseError}; +use crate::base::wire::ParseError; use crate::base::StreamTarget; - -use super::message::Request; +use octseq::{OctetsBuilder, ShortBuf}; //------------ Service ------------------------------------------------------- @@ -37,9 +36,9 @@ use super::message::Request; /// For an overview of how services fit into the total flow of request and /// response handling see the [net::server module documentation]. /// -/// Each [`Service`] implementation defines a [`call()`] function which takes -/// a [`Request`] DNS request as input and returns either a -/// [`Transaction`] on success, or a [`ServiceError`] on failure, as output. +/// Each [`Service`] implementation defines a [`call`] function which takes a +/// [`Request`] DNS request as input and returns either a [`Transaction`] on +/// success, or a [`ServiceError`] on failure, as output. /// /// Each [`Transaction`] contains either a single DNS response message, or a /// stream of DNS response messages (e.g. for a zone transfer). Each response @@ -52,12 +51,12 @@ use super::message::Request; /// /// 1. Implement the [`Service`] trait on a struct. /// 2. Define a function compatible with the [`Service`] trait. -/// 3. Define a function compatible with [`service_fn()`]. +/// 3. Define a function compatible with [`service_fn`]. /// ///
/// /// Whichever approach you choose it is important to minimize the work done -/// before returning from [`Service::call()`], as time spent here blocks the +/// before returning from [`Service::call`], as time spent here blocks the /// caller. Instead as much work as possible should be delegated to the /// futures returned as a [`Transaction`]. /// @@ -102,13 +101,7 @@ use super::message::Request; /// fn call( /// &self, /// msg: Request>, -/// ) -> Result< -/// Transaction< -/// Result, ServiceError>, -/// Self::Future, -/// >, -/// ServiceError, -/// > { +/// ) -> Result, ServiceError> { /// let builder = mk_builder_for_target(); /// let additional = mk_answer(&msg, builder)?; /// let item = ready(Ok(CallResult::new(additional))); @@ -138,8 +131,7 @@ use super::message::Request; /// fn name_to_ip( /// msg: Request>, /// ) -> Result< -/// Transaction< -/// Result, ServiceError>, +/// Transaction, ServiceError> /// > + Send, @@ -195,39 +187,34 @@ use super::message::Request; /// let srv = DgramServer::new(sock, buf, name_to_ip); /// ``` /// -/// # Define a function compatible with [`service_fn()`] +/// # Define a function compatible with [`service_fn`] /// -/// See [`service_fn()`] for an example of how to use it to create a -/// [`Service`] impl from a funciton. +/// See [`service_fn`] for an example of how to use it to create a [`Service`] +/// impl from a funciton. /// -/// [`MiddlewareChain`]: crate::net::server::middleware::chain::MiddlewareChain +/// [`MiddlewareChain`]: +/// crate::net::server::middleware::chain::MiddlewareChain /// [`DgramServer`]: crate::net::server::dgram::DgramServer /// [`StreamServer`]: crate::net::server::stream::StreamServer /// [net::server module documentation]: crate::net::server -/// [`call()`]: Self::call() -/// [`service_fn()`]: crate::net::server::util::service_fn() +/// [`call`]: Self::call() +/// [`service_fn`]: crate::net::server::util::service_fn() pub trait Service = Vec> { /// The type of buffer in which response messages are stored. - type Target: Composer + Default + Send + Sync + 'static; + type Target; - /// The type of future returned by [`Service::call()`] via - /// [`Transaction::single()`]. + /// The type of future returned by [`Service::call`] via + /// [`Transaction::single`]. type Future: std::future::Future< - Output = Result, ServiceError>, - > + Send; + Output = Result, ServiceError>, + >; /// Generate a response to a fully pre-processed request. #[allow(clippy::type_complexity)] fn call( &self, request: Request, - ) -> Result< - Transaction< - Result, ServiceError>, - Self::Future, - >, - ServiceError, - >; + ) -> Result, ServiceError>; } /// Helper trait impl to treat an [`Arc`] as a [`Service`]. @@ -240,13 +227,7 @@ impl, T: Service> fn call( &self, request: Request, - ) -> Result< - Transaction< - Result, ServiceError>, - Self::Future, - >, - ServiceError, - > { + ) -> Result, ServiceError> { Arc::deref(self).call(request) } } @@ -256,14 +237,11 @@ impl Service for F where F: Fn( Request, - ) -> Result< - Transaction, ServiceError>, Future>, - ServiceError, - >, + ) -> Result, ServiceError>, RequestOctets: AsRef<[u8]>, - Target: Composer + Default + Send + Sync + 'static, - Future: std::future::Future, ServiceError>> - + Send, + Future: std::future::Future< + Output = Result, ServiceError>, + >, { type Target = Target; type Future = Future; @@ -271,13 +249,7 @@ where fn call( &self, request: Request, - ) -> Result< - Transaction< - Result, ServiceError>, - Self::Future, - >, - ServiceError, - > { + ) -> Result, ServiceError> { (*self)(request) } } @@ -357,15 +329,15 @@ pub enum ServiceFeedback { //------------ CallResult ---------------------------------------------------- -/// The result of processing a DNS request via [`Service::call()`]. +/// The result of processing a DNS request via [`Service::call`]. /// /// Directions to a server on how to respond to a request. /// /// In most cases a [`CallResult`] will be a DNS response message. /// /// If needed a [`CallResult`] can instead, or additionally, contain a -/// [`ServiceFeedback`] directing the server or connection handler handling the -/// request to adjust its own configuration, or even to terminate the +/// [`ServiceFeedback`] directing the server or connection handler handling +/// the request to adjust its own configuration, or even to terminate the /// connection. #[derive(Clone, Debug)] pub struct CallResult { @@ -444,26 +416,32 @@ where /// # Usage /// /// Either: -/// - Construct a transaction for a [`single()`] response future, OR -/// - Construct a transaction [`stream()`] and [`push()`] response futures -/// into it. +/// - Construct a transaction for a [`single`] response future, OR +/// - Construct a transaction [`stream`] and [`push`] response futures into +/// it. /// -/// Then iterate over the response futures one at a time using [`next()`]. +/// Then iterate over the response futures one at a time using [`next`]. /// -/// [`single()`]: Self::single() -/// [`stream()`]: Self::stream() -/// [`push()`]: TransactionStream::push() -/// [`next()`]: Self::next() -pub struct Transaction(TransactionInner) +/// [`single`]: Self::single() +/// [`stream`]: Self::stream() +/// [`push`]: TransactionStream::push() +/// [`next`]: Self::next() +pub struct Transaction(TransactionInner) where - Future: std::future::Future + Send; + Future: std::future::Future< + Output = Result, ServiceError>, + >; -impl Transaction +impl Transaction where - Future: std::future::Future + Send, + Future: std::future::Future< + Output = Result, ServiceError>, + >, { /// Construct a transaction for a single immediate response. - pub(crate) fn immediate(item: Item) -> Self { + pub(crate) fn immediate( + item: Result, ServiceError>, + ) -> Self { Self(TransactionInner::Immediate(Some(item))) } @@ -480,17 +458,14 @@ where /// Construct a transaction for a future stream of response futures. /// /// The given future should build the stream of response futures that will - /// eventually be resolved by [`Self::next()`]. + /// eventually be resolved by [`Self::next`]. /// /// This takes a future instead of a [`TransactionStream`] because the /// caller may not yet know how many futures they need to push into the /// stream and we don't want them to block us while they work that out. pub fn stream( fut: Pin< - Box< - dyn std::future::Future> - + Send, - >, + Box> + Send>, >, ) -> Self { Self(TransactionInner::PendingStream(fut)) @@ -503,7 +478,9 @@ where /// /// Returns None if there are no (more) responses to take, Some(future) /// otherwise. - pub async fn next(&mut self) -> Option { + pub async fn next( + &mut self, + ) -> Option, ServiceError>> { match &mut self.0 { TransactionInner::Immediate(item) => item.take(), @@ -533,35 +510,36 @@ where /// [`Service`] impl should return, and (b) to control the interface offered /// to consumers of this type and avoid them having to work with the enum /// variants directly. -enum TransactionInner +enum TransactionInner where - Future: std::future::Future + Send, + Future: std::future::Future< + Output = Result, ServiceError>, + >, { /// The transaction will result in a single immediate response. /// /// This variant is for internal use only when aborting Middleware /// processing early. - Immediate(Option), + Immediate(Option, ServiceError>>), /// The transaction will result in at most a single response future. Single(Option), /// The transaction will result in stream of multiple response futures. PendingStream( - Pin< - Box< - dyn std::future::Future> - + Send, - >, - >, + Pin> + Send>>, ), /// The transaction is a stream of multiple response futures. - Stream(TransactionStream), + Stream(Stream), } //------------ TransacationStream -------------------------------------------- +/// A [`TransactionStream`] of [`Service`] results. +type Stream = + TransactionStream, ServiceError>>; + /// A stream of zero or more DNS response futures relating to a single DNS request. pub struct TransactionStream { /// An ordered sequence of futures that will resolve to responses to be @@ -591,3 +569,9 @@ impl Default for TransactionStream { } } } + +impl std::fmt::Debug for TransactionStream { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("TransactionStream").finish() + } +} diff --git a/src/net/server/sock.rs b/src/net/server/sock.rs index 729114bc4..ee322260e 100644 --- a/src/net/server/sock.rs +++ b/src/net/server/sock.rs @@ -7,7 +7,7 @@ use std::boxed::Box; use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::ReadBuf; use tokio::net::{TcpListener, TcpStream, UdpSocket}; //------------ AsyncDgramSock ------------------------------------------------ @@ -17,7 +17,7 @@ use tokio::net::{TcpListener, TcpStream, UdpSocket}; /// Must be implemented by "network source"s to be used with a /// [`DgramServer`]. /// -/// When reading the server will wait until [`Self::readable()`] succeeds and +/// When reading the server will wait until [`Self::readable`] succeeds and /// then call `try_recv_buf_from()`. /// /// # Design notes @@ -25,10 +25,10 @@ use tokio::net::{TcpListener, TcpStream, UdpSocket}; /// When the underlying socket implementation is [`tokio::net::UdpSocket`] /// this pattern scales better than using `poll_recv_from()` as the latter /// causes the socket to be locked for exclusive access even if it was -/// [`Arc::clone()`]d. +/// [`Arc::clone`]d. /// /// With the `readable()` then `try_recv_buf_from()` pattern one can -/// [`Arc::clone()`] the socket and use it with multiple server instances at +/// [`Arc::clone`] the socket and use it with multiple server instances at /// once for greater throughput without any such locking occurring. /// /// [`DgramServer`]: crate::net::server::stream::DgramServer. @@ -122,14 +122,15 @@ impl AsyncDgramSock for Arc { /// [`StreamServer`]: crate::net::server::stream::StreamServer. pub trait AsyncAccept { /// The type of error that the trait impl produces. - type Error: Send; + type Error; /// The type of stream that the trait impl consumes. - type StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static; + type StreamType; /// The type of [`std::future::Future`] that the trait impl returns. - type Future: std::future::Future> - + Send; + type Future: std::future::Future< + Output = Result, + >; /// Polls to accept a new incoming connection to this listener. /// diff --git a/src/net/server/stream.rs b/src/net/server/stream.rs index 4fd7add2c..09d813e00 100644 --- a/src/net/server/stream.rs +++ b/src/net/server/stream.rs @@ -40,6 +40,7 @@ use super::buf::VecBufSource; use super::connection::{self, Connection}; use super::ServerCommand; use crate::base::wire::Composer; +use tokio::io::{AsyncRead, AsyncWrite}; // TODO: Should this crate also provide a TLS listener implementation? @@ -72,11 +73,7 @@ const MAX_CONCURRENT_TCP_CONNECTIONS: DefMinMax = //----------- Config --------------------------------------------------------- /// Configuration for a stream server. -pub struct Config -where - RequestOctets: Octets, - Target: Composer + Default, -{ +pub struct Config { /// Limit on the number of concurrent TCP connections that can be handled /// by the server. max_concurrent_connections: usize, @@ -131,7 +128,7 @@ where /// /// # Reconfigure /// - /// On [`StreamServer::reconfigure()`] if there are more connections + /// On [`StreamServer::reconfigure`] if there are more connections /// currently than the new limit the exceess connections will be allowed /// to complete normally, connections will NOT be terminated. pub fn set_max_concurrent_connections(&mut self, value: usize) { @@ -251,8 +248,7 @@ type CommandReceiver = /// /// fn my_service(msg: Request>, _meta: ()) /// -> Result< -/// Transaction< -/// Result>, ServiceError>, +/// Transaction, /// Pin>, ServiceError> /// > + Send>>, @@ -296,10 +292,11 @@ type CommandReceiver = /// https://docs.rs/tokio/latest/tokio/net/struct.TcpListener.html pub struct StreamServer where - Listener: AsyncAccept + Send + Sync + 'static, - Buf: BufSource + Send + Sync + 'static + Clone, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static + Clone, + Listener: AsyncAccept + Send + Sync, + Buf: BufSource + Send + Sync + Clone, + Buf::Output: Octets + Send + Sync, + Svc: Service + Send + Sync + Clone, + Svc::Target: Composer + Default + 'static, { /// The configuration of the server. config: Arc>>, @@ -340,10 +337,11 @@ where /// impl StreamServer where - Listener: AsyncAccept + Send + Sync + 'static, - Buf: BufSource + Send + Sync + 'static + Clone, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static + Clone, + Listener: AsyncAccept + Send + Sync, + Buf: BufSource + Send + Sync + Clone, + Buf::Output: Octets + Send + Sync, + Svc: Service + Send + Sync + Clone, + Svc::Target: Composer + Default, { /// Creates a new [`StreamServer`] instance. /// @@ -419,11 +417,11 @@ where /// impl StreamServer where - Listener: AsyncAccept + Send + Sync + 'static, - Buf: BufSource + Send + Sync + 'static + Clone, - Buf::Output: Octets + Debug + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static + Clone, - Svc::Target: Debug, + Listener: AsyncAccept + Send + Sync, + Buf: BufSource + Send + Sync + Clone, + Buf::Output: Octets + Debug + Send + Sync, + Svc: Service + Send + Sync + Clone, + Svc::Target: Composer + Default, { /// Get a reference to the source for this server. #[must_use] @@ -442,24 +440,32 @@ where /// impl StreamServer where - Listener: AsyncAccept + Send + Sync + 'static, - Buf: BufSource + Send + Sync + 'static + Clone, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static + Clone, + Listener: AsyncAccept + Send + Sync, + Buf: BufSource + Send + Sync + Clone, + Buf::Output: Octets + Send + Sync, + Svc: Service + Send + Sync + Clone, + Svc::Target: Composer + Default + 'static, { /// Start the server. /// /// # Drop behaviour /// - /// When dropped [`shutdown()`] will be invoked. + /// When dropped [`shutdown`] will be invoked. /// - /// [`shutdown()`]: Self::shutdown + /// [`shutdown`]: Self::shutdown pub async fn run(&self) where + Buf: 'static, + Buf::Output: 'static, + Listener::Error: Send, + Listener::Future: Send + 'static, + Listener::StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static, + Svc: 'static, + Svc::Target: Send + Sync, Svc::Future: Send, { if let Err(err) = self.run_until_error().await { - error!("StreamServer: {err}"); + error!("Server stopped due to error: {err}"); } } @@ -486,11 +492,10 @@ where /// be written as long as the client side of connection remains remains /// operational. /// - /// [`Self::is_shutdown()`] can be used to dertermine if shutdown is + /// [`Self::is_shutdown`] can be used to dertermine if shutdown is /// complete. /// - /// [`Self::await_shutdown()`] can be used to wait for shutdown to - /// complete. + /// [`Self::await_shutdown`] can be used to wait for shutdown to complete. pub fn shutdown(&self) -> Result<(), Error> { self.command_tx .lock() @@ -513,7 +518,7 @@ where /// Returns true if the server shutdown in the given time period, false /// otherwise. /// - /// To start the shutdown process first call [`Self::shutdown()`] then use + /// To start the shutdown process first call [`Self::shutdown`] then use /// this method to wait for the shutdown process to complete. pub async fn await_shutdown(&self, duration: Duration) -> bool { timeout(duration, async { @@ -532,15 +537,23 @@ where impl StreamServer where - Listener: AsyncAccept + Send + Sync + 'static, - Buf: BufSource + Send + Sync + 'static + Clone, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static + Clone, + Listener: AsyncAccept + Send + Sync, + Buf: BufSource + Send + Sync + Clone, + Buf::Output: Octets + Send + Sync, + Svc: Service + Send + Sync + Clone, + Svc::Target: Composer + Default, { /// Accept stream connections until shutdown or fatal error. async fn run_until_error(&self) -> Result<(), String> where + Buf: 'static, + Buf::Output: 'static, + Listener::Error: Send, + Listener::Future: Send + 'static, + Listener::StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static, + Svc: 'static, Svc::Future: Send, + Svc::Target: Send + Sync + 'static, { let mut command_rx = self.command_rx.clone(); @@ -656,8 +669,14 @@ where stream: Listener::Future, addr: SocketAddr, ) where - Buf::Output: Octets, + Buf: 'static, + Buf::Output: Octets + 'static, + Listener::Error: Send, + Listener::Future: Send + 'static, + Listener::StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static, + Svc: 'static, Svc::Future: Send, + Svc::Target: Send + Sync + 'static, { // Work around the compiler wanting to move self to the async block by // preparing only those pieces of information from self for the new @@ -718,10 +737,11 @@ where impl Drop for StreamServer where - Listener: AsyncAccept + Send + Sync + 'static, - Buf: BufSource + Send + Sync + 'static + Clone, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static + Clone, + Listener: AsyncAccept + Send + Sync, + Buf: BufSource + Send + Sync + Clone, + Buf::Output: Octets + Send + Sync, + Svc: Service + Send + Sync + Clone, + Svc::Target: Composer + Default + 'static, { fn drop(&mut self) { // Shutdown the StreamServer. Don't handle the failure case here as diff --git a/src/net/server/tests.rs b/src/net/server/tests.rs index d1c43920e..335740bec 100644 --- a/src/net/server/tests.rs +++ b/src/net/server/tests.rs @@ -303,16 +303,9 @@ impl Service> for MyService { fn call( &self, - _request: Request>, - ) -> Result< - Transaction< - Result, ServiceError>, - Self::Future, - >, - ServiceError, - > { + _msg: Request>, + ) -> Result, ServiceError> { Ok(Transaction::single(MySingle)) - // Err(ServiceError::ShuttingDown) } } diff --git a/src/net/server/util.rs b/src/net/server/util.rs index 504336a34..5a38d1acc 100644 --- a/src/net/server/util.rs +++ b/src/net/server/util.rs @@ -22,7 +22,7 @@ use crate::base::iana::Rcode; /// Helper for creating a [`MessageBuilder`] for a `Target`. pub fn mk_builder_for_target() -> MessageBuilder> where - Target: Composer + OctetsBuilder + Default, + Target: Composer + Default, { let target = StreamTarget::new(Target::default()) .map_err(|_| ()) @@ -71,11 +71,10 @@ where /// req: Request>, /// _meta: MyMeta, /// ) -> Result< -/// Transaction< -/// Result>, ServiceError>, +/// Transaction, /// Pin>, ServiceError> -/// > + Send>>, +/// >>>, /// >, /// ServiceError, /// > { @@ -94,28 +93,26 @@ where /// Above we see the outline of what we need to do: /// - Define a function that implements our request handling logic for our /// service. -/// - Call [`service_fn()`] to wrap it in an actual [`Service`] impl. +/// - Call [`service_fn`] to wrap it in an actual [`Service`] impl. /// /// [`Vec`]: std::vec::Vec /// [`CallResult`]: crate::net::server::service::CallResult -/// [`Result::Ok()`]: std::result::Result::Ok +/// [`Result::Ok`]: std::result::Result::Ok pub fn service_fn( request_handler: T, metadata: Metadata, ) -> impl Service + Clone where RequestOctets: AsRef<[u8]>, - Target: Composer + Default + Send + Sync + 'static, - Future: std::future::Future, ServiceError>> - + Send, + Future: std::future::Future< + Output = Result, ServiceError>, + >, Metadata: Clone, T: Fn( Request, Metadata, - ) -> Result< - Transaction, ServiceError>, Future>, - ServiceError, - > + Clone, + ) -> Result, ServiceError> + + Clone, { move |request| request_handler(request, metadata.clone()) } @@ -165,7 +162,7 @@ pub fn start_reply( ) -> QuestionBuilder> where RequestOctets: Octets, - Target: Composer + OctetsBuilder + Default, + Target: Composer + Default, { let builder = mk_builder_for_target(); diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index e05c5f0c9..fbc69377a 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -58,21 +58,21 @@ pub mod conf; /// /// This type collects all information making it possible to start DNS /// queries. You can create a new resolver using the system’s configuration -/// using the [`new()`] associate function or using your own configuration -/// with [`from_conf()`]. +/// using the [`new`] associate function or using your own configuration with +/// [`from_conf`]. /// /// Stub resolver values can be cloned relatively cheaply as they keep all /// information behind an arc. /// /// If you want to run a single query or lookup on a resolver synchronously, -/// you can do so simply by using the [`run()`] or [`run_with_conf()`] -/// associated functions. +/// you can do so simply by using the [`run`] or [`run_with_conf`] associated +/// functions. /// -/// [`new()`]: #method.new -/// [`from_conf()`]: #method.from_conf -/// [`query()`]: #method.query -/// [`run()`]: #method.run -/// [`run_with_conf()`]: #method.run_with_conf +/// [`new`]: #method.new +/// [`from_conf`]: #method.from_conf +/// [`query`]: #method.query +/// [`run`]: #method.run +/// [`run_with_conf`]: #method.run_with_conf #[derive(Debug)] pub struct StubResolver { transport: Mutex>>>>, @@ -281,10 +281,10 @@ impl StubResolver { /// Synchronously perform a DNS operation atop a configured resolver. /// - /// This is like [`run()`] but also takes a resolver configuration for + /// This is like [`run`] but also takes a resolver configuration for /// tailor-making your own resolver. /// - /// [`run()`]: #method.run + /// [`run`]: #method.run pub fn run_with_conf(conf: ResolvConf, op: F) -> R::Output where R: Future> + Send + 'static, diff --git a/src/zonefile/error.rs b/src/zonefile/error.rs new file mode 100644 index 000000000..eed516f63 --- /dev/null +++ b/src/zonefile/error.rs @@ -0,0 +1,187 @@ +//! Zone related errors. + +//------------ ZoneCutError -------------------------------------------------- + +use std::fmt::Display; +use std::vec::Vec; + +use crate::base::Rtype; +use crate::zonetree::{StoredDname, StoredRecord}; + +use super::inplace; + +/// A zone cut is not valid with respect to the zone's apex. +#[derive(Clone, Copy, Debug)] +pub enum ZoneCutError { + OutOfZone, + ZoneCutAtApex, +} + +impl From for ZoneCutError { + fn from(_: OutOfZone) -> ZoneCutError { + ZoneCutError::OutOfZone + } +} + +impl Display for ZoneCutError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + ZoneCutError::OutOfZone => write!(f, "Out of zone"), + ZoneCutError::ZoneCutAtApex => write!(f, "Zone cut at apex"), + } + } +} + +//----------- CnameError ----------------------------------------------------- + +/// A CNAME is not valid with respect to the zone's apex. +#[derive(Clone, Copy, Debug)] +pub enum CnameError { + OutOfZone, + CnameAtApex, +} + +impl From for CnameError { + fn from(_: OutOfZone) -> CnameError { + CnameError::OutOfZone + } +} + +impl Display for CnameError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + CnameError::OutOfZone => write!(f, "Out of zone"), + CnameError::CnameAtApex => write!(f, "CNAME at apex"), + } + } +} + +//----------- OutOfZone ------------------------------------------------------ + +/// A domain name is not under the zone’s apex. +#[derive(Clone, Copy, Debug)] +pub struct OutOfZone; + +//------------ RecordError --------------------------------------------------- + +#[derive(Clone, Debug)] +pub enum RecordError { + /// The class of the record does not match the class of the zone. + ClassMismatch(StoredRecord), + + /// Attempted to add zone cut records where there is no zone cut. + IllegalZoneCut(StoredRecord), + + /// Attempted to add a normal record to a zone cut or CNAME. + IllegalRecord(StoredRecord), + + /// Attempted to add a CNAME record where there are other records. + IllegalCname(StoredRecord), + + /// Attempted to add multiple CNAME records for an owner. + MultipleCnames(StoredRecord), + + /// The record could not be parsed. + MalformedRecord(inplace::Error), + + /// The record is parseable but not valid. + InvalidRecord(ZoneErrors), + + /// The SOA record was not found. + MissingSoa(StoredRecord), +} + +impl Display for RecordError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + RecordError::ClassMismatch(rec) => { + write!(f, "ClassMismatch: {rec}") + } + RecordError::IllegalZoneCut(rec) => { + write!(f, "IllegalZoneCut: {rec}") + } + RecordError::IllegalRecord(rec) => { + write!(f, "IllegalRecord: {rec}") + } + RecordError::IllegalCname(rec) => { + write!(f, "IllegalCname: {rec}") + } + RecordError::MultipleCnames(rec) => { + write!(f, "MultipleCnames: {rec}") + } + RecordError::MalformedRecord(err) => { + write!(f, "MalformedRecord: {err}") + } + RecordError::InvalidRecord(err) => { + write!(f, "InvalidRecord: {err}") + } + RecordError::MissingSoa(rec) => write!(f, "MissingSoa: {rec}"), + } + } +} + +//------------ ZoneErrors ---------------------------------------------------- + +/// A set of problems relating to a zone. +#[derive(Clone, Debug, Default)] +pub struct ZoneErrors { + errors: Vec<(StoredDname, OwnerError)>, +} + +impl ZoneErrors { + pub fn add_error(&mut self, name: StoredDname, error: OwnerError) { + self.errors.push((name, error)) + } + + pub fn into_result(self) -> Result<(), Self> { + if self.errors.is_empty() { + Ok(()) + } else { + Err(self) + } + } +} + +impl Display for ZoneErrors { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "Zone file errors: [")?; + for err in &self.errors { + write!(f, "'{}': {},", err.0, err.1)?; + } + write!(f, "]") + } +} + +//------------ OwnerError --------------------------------------------------- + +#[derive(Clone, Debug)] +pub enum OwnerError { + /// A NS RRset is missing at a zone cut. + /// + /// (This happens if there is only a DS RRset.) + MissingNs, + + /// A zone cut appeared where it shouldn’t have. + InvalidZonecut(ZoneCutError), + + /// A CNAME appeared where it shouldn’t have. + InvalidCname(CnameError), + + /// A record is out of zone. + OutOfZone(Rtype), +} + +impl Display for OwnerError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + OwnerError::MissingNs => write!(f, "Missing NS"), + OwnerError::InvalidZonecut(err) => { + write!(f, "Invalid zone cut: {err}") + } + OwnerError::InvalidCname(err) => { + write!(f, "Invalid CNAME: {err}") + } + OwnerError::OutOfZone(err) => write!(f, "Out of zone: {err}"), + } + } +} diff --git a/src/zonefile/inplace.rs b/src/zonefile/inplace.rs index ef9bc8040..e1464ceab 100644 --- a/src/zonefile/inplace.rs +++ b/src/zonefile/inplace.rs @@ -1,7 +1,7 @@ //! A zonefile scanner keeping data in place. //! //! The zonefile scanner provided by this module reads the entire zonefile -//! into memory and tries as much as possible to modify re-use this memory +//! into memory and tries as much as possible to modify/re-use this memory //! when scanning data. It uses the `Bytes` family of types for safely //! storing, manipulating, and returning the data and thus requires the //! `bytes` feature to be enabled. @@ -161,9 +161,12 @@ unsafe impl BufMut for Zonefile { impl Zonefile { /// Sets the origin of the zonefile. /// - /// The origin is append to relative domain names encountered in the - /// data. Ininitally, there is no origin set. If relative names are - /// encountered, an error happenes. + /// The origin is append to relative domain names encountered in the data. + /// Ininitally, there is no origin set. It will be set if an $ORIGIN + /// directive is encountered while iterating over the zone. If a zone name + /// is not provided via this function or via an $ORIGIN directive, then + /// any relative names encountered will cause iteration to terminate with + /// a missing origin error. pub fn set_origin(&mut self, origin: Dname) { self.origin = Some(origin) } @@ -189,7 +192,7 @@ impl Zonefile { } /// Returns the origin name of the zonefile. - fn get_origin(&self) -> Result, EntryError> { + pub fn origin(&self) -> Result, EntryError> { self.origin .as_ref() .cloned() @@ -637,7 +640,7 @@ impl<'a> Scanner for EntryScanner<'a> { self.zonefile.buf.next_item()?; if start == 0 { return RelativeDname::empty_bytes() - .chain(self.zonefile.get_origin()?) + .chain(self.zonefile.origin()?) .map_err(|_| EntryError::bad_dname()); } else { return unsafe { @@ -676,7 +679,7 @@ impl<'a> Scanner for EntryScanner<'a> { RelativeDname::from_octets_unchecked( self.zonefile.buf.split_to(write).freeze(), ) - .chain(self.zonefile.get_origin()?) + .chain(self.zonefile.origin()?) .map_err(|_| EntryError::bad_dname()) }; } @@ -1420,8 +1423,8 @@ enum ItemCat { //------------ EntryError ---------------------------------------------------- /// An error returned by the entry scanner. -#[derive(Debug)] -struct EntryError(&'static str); +#[derive(Clone, Debug)] +pub struct EntryError(&'static str); impl EntryError { fn bad_symbol(_err: SymbolOctetsError) -> Self { @@ -1498,7 +1501,7 @@ impl std::error::Error for EntryError {} //------------ Error --------------------------------------------------------- -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Error { err: EntryError, line: usize, @@ -1548,7 +1551,7 @@ mod test { }); } - test(" unquoted\n", b"unquoted"); + test(" unquoted\r\n", b"unquoted"); test(" unquoted ", b"unquoted"); test("unquoted ", b"unquoted"); test("unqu\\oted ", b"unquoted"); diff --git a/src/zonefile/mod.rs b/src/zonefile/mod.rs index c208be2af..6cf90f22b 100644 --- a/src/zonefile/mod.rs +++ b/src/zonefile/mod.rs @@ -2,4 +2,7 @@ #![cfg(feature = "zonefile")] #![cfg_attr(docsrs, doc(cfg(feature = "zonefile")))] +pub mod error; pub mod inplace; +#[cfg(feature = "unstable-zonetree")] +pub mod parsed; diff --git a/src/zonefile/parsed.rs b/src/zonefile/parsed.rs new file mode 100644 index 000000000..a5ab05a8b --- /dev/null +++ b/src/zonefile/parsed.rs @@ -0,0 +1,403 @@ +//! Importing from and (in future) exporting to a zonefiles. + +use std::collections::{BTreeMap, HashMap}; +use std::vec::Vec; + +use tracing::trace; + +use super::error::{OwnerError, RecordError, ZoneErrors}; +use super::inplace::{self, Entry}; + +use crate::base::iana::{Class, Rtype}; +use crate::base::name::FlattenInto; +use crate::base::ToDname; +use crate::rdata::ZoneRecordData; +use crate::zonetree::ZoneBuilder; +use crate::zonetree::{Rrset, SharedRr, StoredDname, StoredRecord}; + +//------------ Zonefile ------------------------------------------------------ + +/// A parsed sanity checked representation of a zonefile. +/// +/// This type eases creation of a [`ZoneBuilder`] from a collection of +/// [`StoredRecord`]s, e.g. and accepts only records that are valid within +/// the zone. +/// +/// The zone origin and class may be specified explicitly or be derived from +/// the SOA record when inserted. The relationship of each resource record +/// with the zone is classified on insert, similar to that described by [RFC +/// 1034 4.2.1]. +/// +/// Getter functions provide insight into the classification results. +/// +/// When ready the [`ZoneBuilder::try_from`] function can be used to convert +/// the parsed zonefile into a pre-populated [`ZoneBuilder`]. +/// +/// # Usage +/// +/// See the [zonetree] module docs for example usage. +/// +/// [RFC 1034 4.2.1]: +/// https://datatracker.ietf.org/doc/html/rfc1034#section-4.2.1 +/// [zonetree]: crate::zonetree +#[derive(Clone, Default)] +pub struct Zonefile { + /// The name of the apex of the zone. + origin: Option, + + /// The class of the zone. + class: Option, + + /// The records for names that have regular RRsets attached to them. + normal: Owners, + + /// The records for names that are zone cuts. + zone_cuts: Owners, + + /// The records for names that are CNAMEs. + cnames: Owners, + + /// Out of zone records. + out_of_zone: Owners, +} + +impl Zonefile { + pub fn new(apex: StoredDname, class: Class) -> Self { + Zonefile { + origin: Some(apex), + class: Some(class), + ..Default::default() + } + } +} + +impl Zonefile { + pub fn set_origin(&mut self, origin: StoredDname) { + self.origin = Some(origin) + } + + /// Inserts the record into the zone file. + pub fn insert( + &mut self, + record: StoredRecord, + ) -> Result<(), RecordError> { + // If a zone apex and class were not provided via [`Self::new`], i.e. + // we were created by [`Self::default`], require the first record to + // be a SOA record and use its owner name and class as the zone apex + // name and class. + if self.origin.is_none() { + if record.rtype() != Rtype::SOA { + return Err(RecordError::MissingSoa(record)); + } else { + let apex = record.owner().to_dname(); + self.class = Some(record.class()); + self.origin = Some(apex); + } + } + + let (zone_apex, zone_class) = + (self.origin().unwrap(), self.class().unwrap()); + + if record.class() != zone_class { + return Err(RecordError::ClassMismatch(record)); + } + + if !record.owner().ends_with(zone_apex) { + self.out_of_zone + .entry(record.owner().clone()) + .insert(record); + Ok(()) + } else { + match record.rtype() { + // An Name Server (NS) record at the apex is a nameserver RR + // that indicates a server for the zone. An NS record is only + // an indication of a zone cut when it is NOT at the apex. + // + // A Delegation Signer (DS) record can only appear within the + // parent zone and refer to a child zone, a DS record cannot + // therefore appear at the apex. + Rtype::NS | Rtype::DS if record.owner() != zone_apex => { + if self.normal.contains(record.owner()) + || self.cnames.contains(record.owner()) + { + return Err(RecordError::IllegalZoneCut(record)); + } + self.zone_cuts + .entry(record.owner().clone()) + .insert(record); + Ok(()) + } + Rtype::CNAME => { + if self.normal.contains(record.owner()) + || self.zone_cuts.contains(record.owner()) + { + return Err(RecordError::IllegalCname(record)); + } + if self.cnames.contains(record.owner()) { + return Err(RecordError::MultipleCnames(record)); + } + self.cnames.insert(record.owner().clone(), record.into()); + Ok(()) + } + _ => { + if self.zone_cuts.contains(record.owner()) + || self.cnames.contains(record.owner()) + { + return Err(RecordError::IllegalRecord(record)); + } + self.normal.entry(record.owner().clone()).insert(record); + Ok(()) + } + } + } + } +} + +impl Zonefile { + pub fn origin(&self) -> Option<&StoredDname> { + self.origin.as_ref() + } + + pub fn class(&self) -> Option { + self.class + } + + pub fn normal(&self) -> &Owners { + &self.normal + } + + pub fn zone_cuts(&self) -> &Owners { + &self.zone_cuts + } + + pub fn cnames(&self) -> &Owners { + &self.cnames + } + + pub fn out_of_zone(&self) -> &Owners { + &self.out_of_zone + } +} + +impl TryFrom for ZoneBuilder { + type Error = ZoneErrors; + + fn try_from(mut zonefile: Zonefile) -> Result { + let mut builder = ZoneBuilder::new( + zonefile.origin.unwrap(), + zonefile.class.unwrap(), + ); + let mut zone_err = ZoneErrors::default(); + + // Insert all the zone cuts first. Fish out potential glue records + // from the normal or out-of-zone records. + for (name, cut) in zonefile.zone_cuts.into_iter() { + let ns = match cut.ns { + Some(ns) => ns.into_shared(), + None => { + zone_err.add_error(name, OwnerError::MissingNs); + continue; + } + }; + let ds = cut.ds.map(Rrset::into_shared); + let mut glue = vec![]; + for rdata in ns.data() { + if let ZoneRecordData::Ns(ns) = rdata { + glue.append( + &mut zonefile.normal.collect_glue(ns.nsdname()), + ); + } + } + + if let Err(err) = builder.insert_zone_cut(&name, ns, ds, glue) { + zone_err.add_error(name, OwnerError::InvalidZonecut(err)) + } + } + + // Now insert all the CNAMEs. + for (name, rrset) in zonefile.cnames.into_iter() { + if let Err(err) = builder.insert_cname(&name, rrset) { + zone_err.add_error(name, OwnerError::InvalidCname(err)) + } + } + + // Finally, all the normal records. + for (name, rrsets) in zonefile.normal.into_iter() { + for (rtype, rrset) in rrsets.into_iter() { + if builder.insert_rrset(&name, rrset.into_shared()).is_err() { + zone_err.add_error( + name.clone(), + OwnerError::OutOfZone(rtype), + ); + } + } + } + + // If there are out-of-zone records left, we will error to avoid + // surprises. + for (name, rrsets) in zonefile.out_of_zone.into_iter() { + for (rtype, _) in rrsets.into_iter() { + zone_err + .add_error(name.clone(), OwnerError::OutOfZone(rtype)); + } + } + + zone_err.into_result().map(|_| builder) + } +} + +//--- TryFrom + +impl TryFrom for Zonefile { + type Error = RecordError; + + fn try_from(source: inplace::Zonefile) -> Result { + let mut zonefile = Zonefile::default(); + + for res in source { + match res.map_err(RecordError::MalformedRecord)? { + Entry::Record(r) => zonefile.insert(r.flatten_into())?, + entry => { + trace!("Skipping unsupported zone file entry: {entry:?}") + } + } + } + + Ok(zonefile) + } +} + +//------------ Owners -------------------------------------------------------- + +#[derive(Clone)] +pub struct Owners { + owners: BTreeMap, +} + +impl Owners { + fn contains(&self, name: &StoredDname) -> bool { + self.owners.contains_key(name) + } + + fn insert(&mut self, name: StoredDname, content: Content) -> bool { + use std::collections::btree_map::Entry; + + match self.owners.entry(name) { + Entry::Occupied(_) => false, + Entry::Vacant(vacant) => { + vacant.insert(content); + true + } + } + } + + fn entry(&mut self, name: StoredDname) -> &mut Content + where + Content: Default, + { + self.owners.entry(name).or_default() + } + + fn into_iter(self) -> impl Iterator { + self.owners.into_iter() + } +} + +impl Owners { + fn collect_glue(&mut self, name: &StoredDname) -> Vec { + let mut glue_records = vec![]; + + // https://www.rfc-editor.org/rfc/rfc9471.html + // 2.1. Glue for In-Domain Name Servers + + // For each NS delegation find the names of the nameservers the NS + // records point to, and then see if the A/AAAA records for this names + // are defined in the authoritative (normal) data for this zone, and + // if so extract them. + if let Some(normal) = self.owners.get(name) { + // Now see if A/AAAA records exists for the name in + // this zone. + for (_rtype, rrset) in + normal.records.iter().filter(|(&rtype, _)| { + rtype == Rtype::A || rtype == Rtype::AAAA + }) + { + for rdata in rrset.data() { + let glue_record = StoredRecord::new( + name.clone(), + Class::IN, + rrset.ttl(), + rdata.clone(), + ); + glue_records.push(glue_record); + } + } + } + + glue_records + } +} + +impl Default for Owners { + fn default() -> Self { + Owners { + owners: Default::default(), + } + } +} + +//------------ Normal -------------------------------------------------------- + +#[derive(Clone, Default)] +pub struct Normal { + records: HashMap, +} + +impl Normal { + fn insert(&mut self, record: StoredRecord) { + use std::collections::hash_map::Entry; + + match self.records.entry(record.rtype()) { + Entry::Occupied(mut occupied) => { + occupied.get_mut().push_record(record) + } + Entry::Vacant(vacant) => { + vacant.insert(record.into()); + } + } + } + + fn into_iter(self) -> impl Iterator { + self.records.into_iter() + } +} + +//------------ ZoneCut ------------------------------------------------------- + +#[derive(Clone, Default)] +pub struct ZoneCut { + ns: Option, + ds: Option, +} + +impl ZoneCut { + fn insert(&mut self, record: StoredRecord) { + match record.rtype() { + Rtype::NS => { + if let Some(ns) = self.ns.as_mut() { + ns.push_record(record) + } else { + self.ns = Some(record.into()) + } + } + Rtype::DS => { + if let Some(ds) = self.ds.as_mut() { + ds.push_record(record) + } else { + self.ds = Some(record.into()) + } + } + _ => panic!("inserting wrong rtype to zone cut"), + } + } +} diff --git a/src/zonetree/answer.rs b/src/zonetree/answer.rs new file mode 100644 index 000000000..1fa34afd8 --- /dev/null +++ b/src/zonetree/answer.rs @@ -0,0 +1,228 @@ +//! Answers to zone tree queries. + +//------------ Answer -------------------------------------------------------- + +use super::{SharedRr, SharedRrset, StoredDname}; +use crate::base::iana::Rcode; +use crate::base::message_builder::AdditionalBuilder; +use crate::base::wire::Composer; +use crate::base::Message; +use crate::base::MessageBuilder; +use octseq::Octets; + +/// A DNS answer to a query against a [`Zone`]. +/// +/// [`Answer`] is the type returned by [`ReadableZone::query`]. +/// +/// Callers of [`ReadableZone::query`] will likely only ever need to use the +/// [`Self::to_message`] function. Alternatively, for complete control use the +/// getter functions on [`Answer`] instead and construct a response message +/// yourself using [`MessageBuilder`]. +/// +/// Implementers of alternate backing stores for [`Zone`]s will need to use +/// one of the various `Answer` constructor functions when +/// [`ReadableZone::query`] is invoked for your zone content in order to +/// tailor the DNS message produced by [`Self::to_message`] based on the +/// outcome of the query. +/// +/// [`Zone`]: crate::zonetree::Zone +/// [`ReadableZone::query`]: crate::zonetree::traits::ReadableZone::query() +#[derive(Clone)] +pub struct Answer { + /// The response code of the answer. + rcode: Rcode, + + /// The content of the answer. + content: AnswerContent, + + /// The optional authority section to be included in the answer. + authority: Option, +} + +impl Answer { + /// Creates an "empty" answer. + /// + /// The answer, authority and additinal sections will be empty. + pub fn new(rcode: Rcode) -> Self { + Answer { + rcode, + content: AnswerContent::NoData, + authority: Default::default(), + } + } + + /// Creates a new message with a populated authority section. + /// + /// The answer and additional sections will be empty. + pub fn with_authority(rcode: Rcode, authority: AnswerAuthority) -> Self { + Answer { + rcode, + content: AnswerContent::NoData, + authority: Some(authority), + } + } + + /// Creates a new [Rcode::REFUSED] answer. + /// + /// This is equivalent to calling [`Answer::new(Rcode::Refused)`]. + pub fn refused() -> Self { + Answer::new(Rcode::REFUSED) + } + + /// Adds a CNAME to the answer section. + pub fn add_cname(&mut self, cname: SharedRr) { + self.content = AnswerContent::Cname(cname); + } + + /// Adds an RRset to the answer section. + pub fn add_answer(&mut self, answer: SharedRrset) { + self.content = AnswerContent::Data(answer); + } + + /// Sets the content of the authority section. + pub fn set_authority(&mut self, authority: AnswerAuthority) { + self.authority = Some(authority) + } + + /// Generate a DNS response [`Message`] for this answer. + /// + /// The response [Rcode], question, answer and authority sections of the + /// produced [`AdditionalBuilder`] will be populated based on the + /// properties of this [`Answer`] as determined by the constructor and + /// add/set functions called prior to calling this function. + /// + ///
+ /// This function does NOT currently set the + /// AA + /// flag on the produced message. + ///
+ /// + /// See also: [`MessageBuilder::start_answer`] + pub fn to_message( + &self, + message: &Message, + builder: MessageBuilder, + ) -> AdditionalBuilder { + let question = message.sole_question().unwrap(); + let qname = question.qname(); + let qclass = question.qclass(); + let mut builder = builder.start_answer(message, self.rcode).unwrap(); + + match self.content { + AnswerContent::Data(ref answer) => { + for item in answer.data() { + // TODO: This will panic if too many answers were given, + // rather than give the caller a way to push the rest into + // another message. + builder + .push((qname, qclass, answer.ttl(), item)) + .unwrap(); + } + } + AnswerContent::Cname(ref cname) => builder + .push((qname, qclass, cname.ttl(), cname.data())) + .unwrap(), + AnswerContent::NoData => {} + } + + let mut builder = builder.authority(); + if let Some(authority) = self.authority.as_ref() { + if let Some(soa) = authority.soa.as_ref() { + builder + .push(( + authority.owner.clone(), + qclass, + soa.ttl(), + soa.data(), + )) + .unwrap(); + } + if let Some(ns) = authority.ns.as_ref() { + for item in ns.data() { + builder + .push(( + authority.owner.clone(), + qclass, + ns.ttl(), + item, + )) + .unwrap() + } + } + if let Some(ref ds) = authority.ds { + for item in ds.data() { + builder + .push(( + authority.owner.clone(), + qclass, + ds.ttl(), + item, + )) + .unwrap() + } + } + } + + builder.additional() + } + + /// Gets the [`Rcode`] for this answer. + pub fn rcode(&self) -> Rcode { + self.rcode + } + + /// Gets the answer section content for this answer. + pub fn content(&self) -> &AnswerContent { + &self.content + } + + /// Gets the authority section content for this answer. + pub fn authority(&self) -> Option<&AnswerAuthority> { + self.authority.as_ref() + } +} + +//------------ AnswerContent ------------------------------------------------- + +/// The content of the answer. +#[derive(Clone)] +pub enum AnswerContent { + /// An answer consisting of an RRSET. + Data(SharedRrset), + + /// An answer consisting of a CNAME RR. + Cname(SharedRr), + + /// An empty answer. + NoData, +} + +//------------ AnswerAuthority ----------------------------------------------- + +/// The authority section of a query answer. +#[derive(Clone)] +pub struct AnswerAuthority { + /// The owner name of the record sets in the authority section. + owner: StoredDname, + + /// The SOA record if it should be included. + soa: Option, + + /// The NS record set if it should be included. + ns: Option, + + /// The DS record set if it should be included.. + ds: Option, +} + +impl AnswerAuthority { + /// Creates a new representation of an authority section. + pub fn new( + owner: StoredDname, + soa: Option, + ns: Option, + ds: Option, + ) -> Self { + AnswerAuthority { owner, soa, ns, ds } + } +} diff --git a/src/zonetree/in_memory/builder.rs b/src/zonetree/in_memory/builder.rs new file mode 100644 index 000000000..8e17363ea --- /dev/null +++ b/src/zonetree/in_memory/builder.rs @@ -0,0 +1,172 @@ +//! Builders for in-memory zones. + +use std::sync::Arc; +use std::vec::Vec; + +use crate::base::iana::Class; +use crate::base::name::{Label, ToDname}; +use crate::zonefile::error::{CnameError, OutOfZone, ZoneCutError}; +use crate::zonetree::types::ZoneCut; +use crate::zonetree::{ + SharedRr, SharedRrset, StoredDname, StoredRecord, Zone, +}; + +use super::nodes::{Special, ZoneApex, ZoneNode}; +use super::versioned::Version; + +//------------ ZoneBuilder --------------------------------------------------- + +/// A builder of in-memory [`Zone`]s. +/// +/// `ZoneBuilder` is used to specify the content of a single zone one, +/// resource record or RRset at a time, and to then turn that specification +/// into a populated in-memory [`Zone`]. +/// +///
+/// +/// Already have a zonefile in [presentation format]? +/// +/// Check out the example [module docs] which shows how to use +/// [`inplace::Zonefile`], [`parsed::Zonefile`] and `ZoneBuilder` together +/// without having to manually insert each resource record into the +/// `ZoneBuilder` yourself. +/// +///
+/// +/// Each `ZoneBuilder` builds a single zone with a named apex and a single +/// [`Class`]. All resource records within the zone are considered to have the +/// specified class. +/// +/// `ZoneBuilder` has dedicated functions for inserting certain kinds of +/// resource record properly into the zone in order to cater to RR types that +/// require or benefit from special handling when [`ReadableZone::query`] is +/// invoked for the zone. +/// +/// # Usage +/// +/// To use `ZoneBuilder`: +/// - Call [`ZoneBuilder::new`] to create a new builder. +/// - Call the various `insert_()` functions to add as many resource records +/// as needed. +/// - Call [`ZoneBuilder::build`] to exchange the builder for a populated +/// [`Zone`]. +/// +/// [module docs]: crate::zonetree +/// [`inplace::Zonefile`]: crate::zonefile::inplace::Zonefile +/// [`parsed::Zonefile`]: crate::zonefile::parsed::Zonefile +/// [presentation format]: +/// https://datatracker.ietf.org/doc/html/rfc9499#section-2-1.16.1.6.1.3 +/// [`ReadableZone::query`]: crate::zonetree::ReadableZone::query() +pub struct ZoneBuilder { + apex: ZoneApex, +} + +impl ZoneBuilder { + /// Creates a new builder for the specified apex name and class. + /// + /// All resource records in the zone will be considered to have the + /// specified [`Class`]. + #[must_use] + pub fn new(apex_name: StoredDname, class: Class) -> Self { + ZoneBuilder { + apex: ZoneApex::new(apex_name, class), + } + } + + /// Builds an in-memory [`Zone`] from this builder. + /// + /// Calling this function consumes the [`ZoneBuilder`]. The returned + /// `Zone` will be populated with the resource records that were inserted + /// into the builder. + #[must_use] + pub fn build(self) -> Zone { + Zone::new(self.apex) + } + + /// Inserts a related set of resource records. + /// + /// Inserts a [`SharedRrset`] for the given owner name. + pub fn insert_rrset( + &mut self, + name: &impl ToDname, + rrset: SharedRrset, + ) -> Result<(), OutOfZone> { + match self.get_node(self.apex.prepare_name(name)?) { + Ok(node) => node.rrsets().update(rrset, Version::default()), + Err(apex) => apex.rrsets().update(rrset, Version::default()), + } + Ok(()) + } + + /// Insert one or more resource records that represent a zone cut. + /// + /// A zone cut is the _"delimitation point between two zones where the + /// origin of one of the zones is the child of the other zone"_ ([RFC 9499 + /// section 7.2.13]). + /// + /// Several different resource record types may appear at a zone cut and + /// may be inserted into the `ZoneBuilder` using this function: + /// + /// - [Ns] records + /// - [Ds] records + /// - Glue records _(see [RFC 9499 section 7.2.30])_ + /// + /// [Ns]: crate::rdata::rfc1035::Ns + /// [Ds]: crate::rdata::dnssec::Ds + /// [RFC 9499 section 7.2.13]: + /// https://datatracker.ietf.org/doc/html/rfc9499#section-7-2.13 + /// [delegation point]: + /// https://datatracker.ietf.org/doc/html/rfc4033#section-2 + pub fn insert_zone_cut( + &mut self, + name: &impl ToDname, + ns: SharedRrset, + ds: Option, + glue: Vec, + ) -> Result<(), ZoneCutError> { + let node = self.get_node(self.apex.prepare_name(name)?)?; + let cut = ZoneCut { + name: name.to_bytes(), + ns, + ds, + glue, + }; + node.update_special(Version::default(), Some(Special::Cut(cut))); + Ok(()) + } + + /// Inserts a CNAME resource record. + /// + /// See: [`Cname`] + /// + /// [`Cname`]: crate::rdata::rfc1035::Cname + pub fn insert_cname( + &mut self, + name: &impl ToDname, + cname: SharedRr, + ) -> Result<(), CnameError> { + let node = self.get_node(self.apex.prepare_name(name)?)?; + node.update_special(Version::default(), Some(Special::Cname(cname))); + Ok(()) + } + + fn get_node<'l>( + &self, + mut name: impl Iterator, + ) -> Result, &ZoneApex> { + let label = match name.next() { + Some(label) => label, + None => return Err(&self.apex), + }; + let mut node = self + .apex + .children() + .with_or_default(label, |node, _| node.clone()); + for label in name { + node = node + .children() + .with_or_default(label, |node, _| node.clone()); + } + Ok(node) + } +} diff --git a/src/zonetree/in_memory/mod.rs b/src/zonetree/in_memory/mod.rs new file mode 100644 index 000000000..d7b50db5b --- /dev/null +++ b/src/zonetree/in_memory/mod.rs @@ -0,0 +1,7 @@ +mod builder; +mod nodes; +mod read; +mod versioned; +mod write; + +pub use builder::ZoneBuilder; diff --git a/src/zonetree/in_memory/nodes.rs b/src/zonetree/in_memory/nodes.rs new file mode 100644 index 000000000..e7d789589 --- /dev/null +++ b/src/zonetree/in_memory/nodes.rs @@ -0,0 +1,404 @@ +//! The nodes in a zone tree. + +use std::boxed::Box; +use std::collections::{hash_map, HashMap}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use parking_lot::{ + RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard, +}; +use tokio::sync::Mutex; + +use crate::base::iana::{Class, Rtype}; +use crate::base::name::{Label, OwnedLabel, ToDname, ToLabelIter}; +use crate::zonefile::error::{CnameError, OutOfZone, ZoneCutError}; +use crate::zonetree::types::ZoneCut; +use crate::zonetree::walk::WalkState; +use crate::zonetree::{ + ReadableZone, SharedRr, SharedRrset, StoredDname, WritableZone, ZoneStore, +}; + +use super::read::ReadZone; +use super::versioned::{Version, Versioned}; +use super::write::{WriteZone, ZoneVersions}; + +//------------ ZoneApex ------------------------------------------------------ + +#[derive(Debug)] +pub struct ZoneApex { + apex_name: StoredDname, + class: Class, + rrsets: NodeRrsets, + children: NodeChildren, + update_lock: Arc>, + versions: Arc>, +} + +impl ZoneApex { + /// Creates a new apex. + pub fn new(apex_name: StoredDname, class: Class) -> Self { + ZoneApex { + apex_name, + class, + rrsets: Default::default(), + children: Default::default(), + update_lock: Default::default(), + versions: Default::default(), + } + } + + /// Creates a new apex. + pub fn from_parts( + apex_name: StoredDname, + class: Class, + rrsets: NodeRrsets, + children: NodeChildren, + versions: ZoneVersions, + ) -> Self { + ZoneApex { + apex_name, + class, + rrsets, + children, + update_lock: Default::default(), + versions: Arc::new(RwLock::new(versions)), + } + } + + pub fn prepare_name<'l>( + &self, + qname: &'l impl ToDname, + ) -> Result + Clone, OutOfZone> { + let mut qname = qname.iter_labels().rev(); + for apex_label in self.name().iter_labels().rev() { + let qname_label = qname.next(); + if Some(apex_label) != qname_label { + return Err(OutOfZone); + } + } + Ok(qname) + } + + /// Returns the RRsets of this node. + pub fn rrsets(&self) -> &NodeRrsets { + &self.rrsets + } + + /// Returns the SOA record for the given version if available. + pub fn get_soa(&self, version: Version) -> Option { + self.rrsets() + .get(Rtype::SOA, version) + .and_then(|rrset| rrset.first()) + } + + /// Returns the children. + pub fn children(&self) -> &NodeChildren { + &self.children + } + + pub fn rollback(&self, version: Version) { + self.rrsets.rollback(version); + self.children.rollback(version); + } + + pub fn clean(&self, version: Version) { + self.rrsets.clean(version); + self.children.clean(version); + } + + pub fn versions(&self) -> &RwLock { + &self.versions + } + + pub fn name(&self) -> &StoredDname { + &self.apex_name + } +} + +//--- impl ZoneStore + +impl ZoneStore for ZoneApex { + fn class(&self) -> Class { + self.class + } + + fn apex_name(&self) -> &StoredDname { + &self.apex_name + } + + fn read(self: Arc) -> Box { + let (version, marker) = self.versions().read().current().clone(); + Box::new(ReadZone::new(self, version, marker)) + } + + fn write( + self: Arc, + ) -> Pin>>> { + Box::pin(async move { + let lock = self.update_lock.clone().lock_owned().await; + let version = self.versions().read().current().0.next(); + let zone_versions = self.versions.clone(); + Box::new(WriteZone::new(self, lock, version, zone_versions)) + as Box + }) + } +} + +//--- impl From<&'a ZoneApex> + +impl<'a> From<&'a ZoneApex> for CnameError { + fn from(_: &'a ZoneApex) -> CnameError { + CnameError::CnameAtApex + } +} + +//--- impl From<&'a ZoneApex> + +impl<'a> From<&'a ZoneApex> for ZoneCutError { + fn from(_: &'a ZoneApex) -> ZoneCutError { + ZoneCutError::ZoneCutAtApex + } +} + +//------------ ZoneNode ------------------------------------------------------ + +#[derive(Default, Debug)] +pub struct ZoneNode { + /// The RRsets of the node. + rrsets: NodeRrsets, + + /// The special functions of the node. + special: RwLock>>, + + /// The child nodes of the node. + children: NodeChildren, +} + +impl ZoneNode { + /// Returns the RRsets of this node. + pub fn rrsets(&self) -> &NodeRrsets { + &self.rrsets + } + + /// Returns whether the node is NXDomain for a version. + pub fn is_nx_domain(&self, version: Version) -> bool { + self.with_special(version, |special| { + matches!(special, Some(Special::NxDomain)) + }) + } + + pub fn with_special( + &self, + version: Version, + op: impl FnOnce(Option<&Special>) -> R, + ) -> R { + op(self.special.read().get(version).and_then(Option::as_ref)) + } + + /// Updates the special. + pub fn update_special(&self, version: Version, special: Option) { + self.special.write().update(version, special) + } + + /// Returns the children. + pub fn children(&self) -> &NodeChildren { + &self.children + } + + pub fn rollback(&self, version: Version) { + self.rrsets.rollback(version); + self.special.write().rollback(version); + self.children.rollback(version); + } + + pub fn clean(&self, version: Version) { + self.rrsets.clean(version); + self.special.write().clean(version); + self.children.clean(version); + } +} + +//------------ NodeRrsets ---------------------------------------------------- + +#[derive(Default, Debug)] +pub struct NodeRrsets { + rrsets: RwLock>, +} + +impl NodeRrsets { + /// Returns whether there are no RRsets for the given version. + pub fn is_empty(&self, version: Version) -> bool { + let rrsets = self.rrsets.read(); + if rrsets.is_empty() { + return true; + } + for value in self.rrsets.read().values() { + if value.get(version).is_some() { + return false; + } + } + true + } + + /// Returns the RRset for a given record type. + pub fn get(&self, rtype: Rtype, version: Version) -> Option { + self.rrsets + .read() + .get(&rtype) + .and_then(|rrsets| rrsets.get(version)) + .cloned() + } + + /// Updates an RRset. + pub fn update(&self, rrset: SharedRrset, version: Version) { + self.rrsets + .write() + .entry(rrset.rtype()) + .or_default() + .update(rrset, version) + } + + /// Removes the RRset for the given type. + pub fn remove(&self, rtype: Rtype, version: Version) { + self.rrsets + .write() + .entry(rtype) + .or_default() + .remove(version) + } + + pub fn rollback(&self, version: Version) { + self.rrsets + .write() + .values_mut() + .for_each(|rrset| rrset.rollback(version)); + } + + pub fn clean(&self, version: Version) { + self.rrsets + .write() + .values_mut() + .for_each(|rrset| rrset.clean(version)); + } + + pub(super) fn iter(&self) -> NodeRrsetsIter { + NodeRrsetsIter::new(self.rrsets.read()) + } +} + +//------------ NodeRrsetIter ------------------------------------------------- + +pub(super) struct NodeRrsetsIter<'a> { + guard: RwLockReadGuard<'a, HashMap>, +} + +impl<'a> NodeRrsetsIter<'a> { + fn new(guard: RwLockReadGuard<'a, HashMap>) -> Self { + Self { guard } + } + + pub fn iter(&self) -> hash_map::Iter<'_, Rtype, NodeRrset> { + self.guard.iter() + } +} + +//------------ NodeRrset ----------------------------------------------------- + +#[derive(Default, Debug)] +pub(crate) struct NodeRrset { + /// The RRsets for the various versions. + rrsets: Versioned, +} + +impl NodeRrset { + pub fn get(&self, version: Version) -> Option<&SharedRrset> { + self.rrsets.get(version) + } + + fn update(&mut self, rrset: SharedRrset, version: Version) { + self.rrsets.update(version, rrset) + } + + fn remove(&mut self, version: Version) { + self.rrsets.clean(version) + } + + pub fn rollback(&mut self, version: Version) { + self.rrsets.rollback(version); + } + + pub fn clean(&mut self, version: Version) { + self.rrsets.rollback(version); + } +} + +//------------ Special ------------------------------------------------------- + +#[derive(Clone, Debug)] +pub enum Special { + Cut(ZoneCut), + Cname(SharedRr), + NxDomain, +} + +//------------ NodeChildren -------------------------------------------------- + +#[derive(Debug, Default)] +pub struct NodeChildren { + children: RwLock>>, +} + +impl NodeChildren { + pub fn with( + &self, + label: &Label, + op: impl FnOnce(Option<&Arc>) -> R, + ) -> R { + op(self.children.read().get(label)) + } + + /// Executes a closure for a child, creating a new child if necessary. + /// + /// The closure receives a reference to the node and a boolean expressing + /// whether the child was created. + pub fn with_or_default( + &self, + label: &Label, + op: impl FnOnce(&Arc, bool) -> R, + ) -> R { + let lock = self.children.upgradable_read(); + if let Some(node) = lock.get(label) { + return op(node, false); + } + let mut lock = RwLockUpgradableReadGuard::upgrade(lock); + lock.insert(label.into(), Default::default()); + let lock = RwLockWriteGuard::downgrade(lock); + op(lock.get(label).unwrap(), true) + } + + fn rollback(&self, version: Version) { + self.children + .read() + .values() + .for_each(|item| item.rollback(version)) + } + + fn clean(&self, version: Version) { + self.children + .read() + .values() + .for_each(|item| item.clean(version)) + } + + pub(super) fn walk( + &self, + walk: WalkState, + op: impl Fn(WalkState, (&OwnedLabel, &Arc)), + ) { + for child in self.children.read().iter() { + (op)(walk.clone(), child) + } + } +} diff --git a/src/zonetree/in_memory/read.rs b/src/zonetree/in_memory/read.rs new file mode 100644 index 000000000..afd480486 --- /dev/null +++ b/src/zonetree/in_memory/read.rs @@ -0,0 +1,366 @@ +//! Quering for zone data. +use core::iter; + +use std::sync::Arc; + +use bytes::Bytes; + +use crate::base::iana::{Rcode, Rtype}; +use crate::base::name::Label; +use crate::base::Dname; +use crate::zonefile::error::OutOfZone; +use crate::zonetree::answer::{Answer, AnswerAuthority}; +use crate::zonetree::types::ZoneCut; +use crate::zonetree::walk::WalkState; +use crate::zonetree::{ReadableZone, Rrset, SharedRr, SharedRrset, WalkOp}; + +use super::nodes::{NodeChildren, NodeRrsets, Special, ZoneApex, ZoneNode}; +use super::versioned::Version; +use super::versioned::VersionMarker; + +//------------ ReadZone ------------------------------------------------------ + +#[derive(Clone)] +pub struct ReadZone { + apex: Arc, + version: Version, + _version_marker: Arc, +} + +impl ReadZone { + pub(super) fn new( + apex: Arc, + version: Version, + _version_marker: Arc, + ) -> Self { + ReadZone { + apex, + version, + _version_marker, + } + } +} + +impl ReadZone { + fn query_below_apex<'l>( + &self, + label: &Label, + qname: impl Iterator + Clone, + qtype: Rtype, + walk: WalkState, + ) -> NodeAnswer { + self.query_children(self.apex.children(), label, qname, qtype, walk) + } + + fn query_node<'l>( + &self, + node: &ZoneNode, + mut qname: impl Iterator + Clone, + qtype: Rtype, + walk: WalkState, + ) -> NodeAnswer { + if walk.enabled() { + // Make sure we visit everything when walking the tree. + self.query_rrsets(node.rrsets(), qtype, walk.clone()); + self.query_node_here_and_below( + node, + Label::root(), + qname, + qtype, + walk, + ) + } else if let Some(label) = qname.next() { + self.query_node_here_and_below(node, label, qname, qtype, walk) + } else { + self.query_node_here_but_not_below(node, qtype, walk) + } + } + + fn query_node_here_and_below<'l>( + &self, + node: &ZoneNode, + label: &Label, + qname: impl Iterator + Clone, + qtype: Rtype, + walk: WalkState, + ) -> NodeAnswer { + node.with_special(self.version, |special| match special { + Some(Special::Cut(ref cut)) => { + let answer = NodeAnswer::authority(AnswerAuthority::new( + cut.name.clone(), + None, + Some(cut.ns.clone()), + cut.ds.as_ref().cloned(), + )); + + walk.op(&cut.ns); + if let Some(ds) = &cut.ds { + walk.op(ds); + } + + answer + } + Some(Special::NxDomain) => NodeAnswer::nx_domain(), + Some(Special::Cname(_)) | None => self.query_children( + node.children(), + label, + qname, + qtype, + walk, + ), + }) + } + + fn query_node_here_but_not_below( + &self, + node: &ZoneNode, + qtype: Rtype, + walk: WalkState, + ) -> NodeAnswer { + node.with_special(self.version, |special| match special { + Some(Special::Cut(cut)) => { + let answer = self.query_at_cut(cut, qtype); + if walk.enabled() { + walk.op(&cut.ns); + if let Some(ds) = &cut.ds { + walk.op(ds); + } + } + answer + } + Some(Special::Cname(cname)) => { + let answer = NodeAnswer::cname(cname.clone()); + if walk.enabled() { + let mut rrset = Rrset::new(Rtype::CNAME, cname.ttl()); + rrset.push_data(cname.data().clone()); + walk.op(&rrset); + } + answer + } + Some(Special::NxDomain) => NodeAnswer::nx_domain(), + None => self.query_rrsets(node.rrsets(), qtype, walk), + }) + } + + fn query_rrsets( + &self, + rrsets: &NodeRrsets, + qtype: Rtype, + walk: WalkState, + ) -> NodeAnswer { + if walk.enabled() { + // Walk the zone, don't match by qtype. + let guard = rrsets.iter(); + for (_rtype, rrset) in guard.iter() { + if let Some(shared_rrset) = rrset.get(self.version) { + walk.op(shared_rrset); + } + } + NodeAnswer::no_data() + } else if qtype == Rtype::ANY { + // https://datatracker.ietf.org/doc/html/rfc8482#section-4.2 + // 4. Behavior of DNS Responders + // + // "Below are the three different modes of behavior by DNS + // responders when processing queries with QNAMEs that exist, + // QCLASS=IN, and QTYPE=ANY. Operators and implementers are + // free to choose whichever mechanism best suits their + // environment. + // + // 1. A DNS responder can choose to select one or a larger + // subset of the available RRsets at the QNAME. + // + // 2. A DNS responder can return a synthesized HINFO resource + // record. See Section 6 for discussion of the use of HINFO. + // + // 3. A resolver can try to give out the most likely records + // the requester wants. This is not always possible, and the + // result might well be a large response. + // + // Except as described below in this section, the DNS responder + // MUST follow the standard algorithms when constructing a + // response." + // + // We choose for option 1 because option 2 would create lots of + // extra work in the offline signing case (because lots of HFINO + // records would need to be synthesized prior to signing) and + // option 3 as stated may still result in a large response. + let guard = rrsets.iter(); + guard + .iter() + .next() + .and_then(|(_rtype, rrset)| rrset.get(self.version)) + .map(|rrset| NodeAnswer::data(rrset.clone())) + .unwrap_or_else(NodeAnswer::no_data) + } else { + match rrsets.get(qtype, self.version) { + Some(rrset) => NodeAnswer::data(rrset), + None => NodeAnswer::no_data(), + } + } + } + + fn query_at_cut(&self, cut: &ZoneCut, qtype: Rtype) -> NodeAnswer { + match qtype { + Rtype::DS => { + if let Some(rrset) = cut.ds.as_ref() { + NodeAnswer::data(rrset.clone()) + } else { + NodeAnswer::no_data() + } + } + _ => NodeAnswer::authority(AnswerAuthority::new( + cut.name.clone(), + None, + Some(cut.ns.clone()), + cut.ds.as_ref().cloned(), + )), + } + } + + fn query_children<'l>( + &self, + children: &NodeChildren, + label: &Label, + qname: impl Iterator + Clone, + qtype: Rtype, + walk: WalkState, + ) -> NodeAnswer { + if walk.enabled() { + children.walk(walk, |walk, (label, node)| { + walk.push(*label); + self.query_node( + node, + std::iter::empty(), + qtype, + walk.clone(), + ); + walk.pop(); + }); + return NodeAnswer::no_data(); + } + + // Step 1: See if we have a non-terminal child for label. If so, + // continue there. + let answer = children.with(label, |node| { + node.map(|node| self.query_node(node, qname, qtype, walk.clone())) + }); + if let Some(answer) = answer { + return answer; + } + + // Step 2: Now see if we have an asterisk label. If so, query that + // node. + children.with(Label::wildcard(), |node| match node { + Some(node) => { + self.query_node_here_but_not_below(node, qtype, walk) + } + None => NodeAnswer::nx_domain(), + }) + } +} + +//--- impl ReadableZone + +impl ReadableZone for ReadZone { + fn is_async(&self) -> bool { + false + } + + fn query( + &self, + qname: Dname, + qtype: Rtype, + ) -> Result { + let mut qname = self.apex.prepare_name(&qname)?; + + let answer = if let Some(label) = qname.next() { + self.query_below_apex(label, qname, qtype, WalkState::DISABLED) + } else { + self.query_rrsets(self.apex.rrsets(), qtype, WalkState::DISABLED) + }; + + Ok(answer.into_answer(self)) + } + + fn walk(&self, op: WalkOp) { + // https://datatracker.ietf.org/doc/html/rfc8482 notes that the ANY + // query type is problematic and should be answered as minimally as + // possible. Rather than use ANY internally here to achieve a walk, as + // specific behaviour may actually be wanted for ANY we instead use + // the presence of a callback `op` to indicate that walking mode is + // requested. We still have to pass an Rtype but it won't be used for + // matching when in walk mode, so we set it to Any as it most closely + // matches our intent and will be ignored anyway. + let walk = WalkState::new(op); + self.query_rrsets(self.apex.rrsets(), Rtype::ANY, walk.clone()); + self.query_below_apex(Label::root(), iter::empty(), Rtype::ANY, walk); + } +} + +//------------ NodeAnswer ---------------------------------------------------- + +/// An answer that includes instructions to the apex on what it needs to do. +#[derive(Clone)] +struct NodeAnswer { + /// The actual answer. + answer: Answer, + + /// Does the apex need to add the SOA RRset to the answer? + add_soa: bool, +} + +impl NodeAnswer { + fn data(rrset: SharedRrset) -> Self { + let mut answer = Answer::new(Rcode::NOERROR); + answer.add_answer(rrset); + NodeAnswer { + answer, + add_soa: false, + } + } + + fn no_data() -> Self { + NodeAnswer { + answer: Answer::new(Rcode::NOERROR), + add_soa: true, + } + } + + fn cname(rr: SharedRr) -> Self { + let mut answer = Answer::new(Rcode::NOERROR); + answer.add_cname(rr); + NodeAnswer { + answer, + add_soa: false, + } + } + + fn nx_domain() -> Self { + NodeAnswer { + answer: Answer::new(Rcode::NXDOMAIN), + add_soa: true, + } + } + + fn authority(authority: AnswerAuthority) -> Self { + NodeAnswer { + answer: Answer::with_authority(Rcode::NOERROR, authority), + add_soa: false, + } + } + + fn into_answer(mut self, zone: &ReadZone) -> Answer { + if self.add_soa { + if let Some(soa) = zone.apex.get_soa(zone.version) { + self.answer.set_authority(AnswerAuthority::new( + zone.apex.name().clone(), + Some(soa), + None, + None, + )) + } + } + self.answer + } +} diff --git a/src/zonetree/in_memory/versioned.rs b/src/zonetree/in_memory/versioned.rs new file mode 100644 index 000000000..ce4ce2b38 --- /dev/null +++ b/src/zonetree/in_memory/versioned.rs @@ -0,0 +1,85 @@ +use crate::base::serial::Serial; +use serde::{Deserialize, Serialize}; +use std::vec::Vec; + +//------------ Version ------------------------------------------------------- + +#[derive( + Clone, + Copy, + Debug, + Deserialize, + Eq, + Hash, + PartialEq, + PartialOrd, + Serialize, +)] +pub struct Version(Serial); + +impl Version { + pub fn next(self) -> Version { + Version(self.0.add(1)) + } +} + +impl Default for Version { + fn default() -> Self { + Version(0.into()) + } +} + +//------------ Versioned ----------------------------------------------------- + +#[derive(Clone, Debug)] +pub struct Versioned { + data: Vec<(Version, Option)>, +} + +impl Versioned { + pub fn new() -> Self { + Versioned { data: Vec::new() } + } + + pub fn get(&self, version: Version) -> Option<&T> { + self.data.iter().rev().find_map(|item| { + if item.0 <= version { + item.1.as_ref() + } else { + None + } + }) + } + + pub fn update(&mut self, version: Version, value: T) { + if let Some(last) = self.data.last_mut() { + if last.0 == version { + last.1 = Some(value); + return; + } + } + self.data.push((version, Some(value))) + } + + /// Drops the last version if it is `version`. + pub fn rollback(&mut self, version: Version) { + if self.data.last().map(|item| item.0) == Some(version) { + self.data.pop(); + } + } + + pub fn clean(&mut self, version: Version) { + self.data.retain(|item| item.0 >= version) + } +} + +impl Default for Versioned { + fn default() -> Self { + Self::new() + } +} + +//------------ VersionMarker ------------------------------------------------- + +#[derive(Debug)] +pub struct VersionMarker; diff --git a/src/zonetree/in_memory/write.rs b/src/zonetree/in_memory/write.rs new file mode 100644 index 000000000..9b2db6620 --- /dev/null +++ b/src/zonetree/in_memory/write.rs @@ -0,0 +1,409 @@ +//! Write access to zones. + +use core::future::ready; +use std::boxed::Box; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::Weak; +use std::vec::Vec; +use std::{fmt, io}; + +use futures::future::Either; +use parking_lot::RwLock; +use tokio::sync::OwnedMutexGuard; + +use crate::base::iana::Rtype; +use crate::base::name::Label; +use crate::zonetree::types::ZoneCut; +use crate::zonetree::SharedRr; +use crate::zonetree::{SharedRrset, WritableZone, WritableZoneNode}; + +use super::nodes::{Special, ZoneApex, ZoneNode}; +use super::versioned::{Version, VersionMarker}; + +//------------ WriteZone ----------------------------------------------------- + +pub struct WriteZone { + apex: Arc, + _lock: Option>, + version: Version, + dirty: bool, + zone_versions: Arc>, +} + +impl WriteZone { + pub(super) fn new( + apex: Arc, + _lock: OwnedMutexGuard<()>, + version: Version, + zone_versions: Arc>, + ) -> Self { + WriteZone { + apex, + _lock: Some(_lock), + version, + dirty: false, + zone_versions, + } + } +} + +//--- impl Clone + +impl Clone for WriteZone { + fn clone(&self) -> Self { + Self { + apex: self.apex.clone(), + _lock: None, + version: self.version, + dirty: self.dirty, + zone_versions: self.zone_versions.clone(), + } + } +} + +//--- impl Drop + +impl Drop for WriteZone { + fn drop(&mut self) { + if self.dirty { + self.apex.rollback(self.version); + self.dirty = false; + } + } +} + +//--- impl WritableZone + +impl WritableZone for WriteZone { + #[allow(clippy::type_complexity)] + fn open( + &self, + ) -> Pin< + Box< + dyn Future, io::Error>>, + >, + > { + let res = WriteNode::new_apex(self.clone()) + .map(|node| Box::new(node) as Box) + .map_err(|err| { + io::Error::new( + io::ErrorKind::Other, + format!("Open error: {err}"), + ) + }); + Box::pin(ready(res)) + } + + fn commit( + &mut self, + ) -> Pin>>> { + let marker = self.zone_versions.write().update_current(self.version); + self.zone_versions + .write() + .push_version(self.version, marker); + + // Start the next version. + self.version = self.version.next(); + self.dirty = false; + + Box::pin(ready(Ok(()))) + } +} + +//------------ WriteNode ------------------------------------------------------ + +pub struct WriteNode { + /// The writer for the zone we are working with. + zone: WriteZone, + + /// The node we are updating. + node: Either, Arc>, +} + +impl WriteNode { + fn new_apex(zone: WriteZone) -> Result { + let apex = zone.apex.clone(); + Ok(WriteNode { + zone, + node: Either::Left(apex), + }) + } + fn update_child(&self, label: &Label) -> Result { + let children = match self.node { + Either::Left(ref apex) => apex.children(), + Either::Right(ref node) => node.children(), + }; + let (node, created) = children + .with_or_default(label, |node, created| (node.clone(), created)); + let node = WriteNode { + zone: self.zone.clone(), + node: Either::Right(node), + }; + if created { + node.make_regular()?; + } + + Ok(node) + } + + fn update_rrset(&self, rrset: SharedRrset) -> Result<(), io::Error> { + let rrsets = match self.node { + Either::Right(ref apex) => apex.rrsets(), + Either::Left(ref node) => node.rrsets(), + }; + rrsets.update(rrset, self.zone.version); + self.check_nx_domain()?; + Ok(()) + } + + fn remove_rrset(&self, rtype: Rtype) -> Result<(), io::Error> { + let rrsets = match self.node { + Either::Left(ref apex) => apex.rrsets(), + Either::Right(ref node) => node.rrsets(), + }; + rrsets.remove(rtype, self.zone.version); + self.check_nx_domain()?; + Ok(()) + } + + fn make_regular(&self) -> Result<(), io::Error> { + if let Either::Right(ref node) = self.node { + node.update_special(self.zone.version, None); + self.check_nx_domain()?; + } + Ok(()) + } + + fn make_zone_cut(&self, cut: ZoneCut) -> Result<(), io::Error> { + match self.node { + Either::Left(_) => Err(WriteApexError::NotAllowed), + Either::Right(ref node) => { + node.update_special( + self.zone.version, + Some(Special::Cut(cut)), + ); + Ok(()) + } + } + .map_err(|err| { + io::Error::new( + io::ErrorKind::Other, + format!("Write apex error: {err}"), + ) + }) + } + + fn make_cname(&self, cname: SharedRr) -> Result<(), io::Error> { + match self.node { + Either::Left(_) => Err(WriteApexError::NotAllowed), + Either::Right(ref node) => { + node.update_special( + self.zone.version, + Some(Special::Cname(cname)), + ); + Ok(()) + } + } + .map_err(|err| { + io::Error::new( + io::ErrorKind::Other, + format!("Write apex error: {err}"), + ) + }) + } + + /// Makes sure a NXDomain special is set or removed as necesssary. + fn check_nx_domain(&self) -> Result<(), io::Error> { + let node = match self.node { + Either::Left(_) => return Ok(()), + Either::Right(ref node) => node, + }; + let opt_new_nxdomain = + node.with_special(self.zone.version, |special| match special { + Some(Special::NxDomain) => { + if !node.rrsets().is_empty(self.zone.version) { + Some(false) + } else { + None + } + } + None => { + if node.rrsets().is_empty(self.zone.version) { + Some(true) + } else { + None + } + } + _ => None, + }); + if let Some(new_nxdomain) = opt_new_nxdomain { + if new_nxdomain { + node.update_special( + self.zone.version, + Some(Special::NxDomain), + ); + } else { + node.update_special(self.zone.version, None); + } + } + Ok(()) + } +} + +//--- impl WritableZoneNode + +impl WritableZoneNode for WriteNode { + #[allow(clippy::type_complexity)] + fn update_child( + &self, + label: &Label, + ) -> Pin< + Box< + dyn Future, io::Error>>, + >, + > { + let node = self + .update_child(label) + .map(|node| Box::new(node) as Box); + Box::pin(ready(node)) + } + + fn update_rrset( + &self, + rrset: SharedRrset, + ) -> Pin>>> { + Box::pin(ready(self.update_rrset(rrset))) + } + + fn remove_rrset( + &self, + rtype: Rtype, + ) -> Pin>>> { + Box::pin(ready(self.remove_rrset(rtype))) + } + + fn make_regular( + &self, + ) -> Pin>>> { + Box::pin(ready(self.make_regular())) + } + + fn make_zone_cut( + &self, + cut: ZoneCut, + ) -> Pin>>> { + Box::pin(ready(self.make_zone_cut(cut))) + } + + fn make_cname( + &self, + cname: SharedRr, + ) -> Pin>>> { + Box::pin(ready(self.make_cname(cname))) + } +} + +//------------ WriteApexError ------------------------------------------------ + +/// The requested operation is not allowed at the apex of a zone. +#[derive(Debug)] +pub enum WriteApexError { + /// This operation is not allowed at the apex. + NotAllowed, + + /// An IO error happened while processing the operation. + Io(io::Error), +} + +impl From for WriteApexError { + fn from(src: io::Error) -> WriteApexError { + WriteApexError::Io(src) + } +} + +impl From for io::Error { + fn from(src: WriteApexError) -> io::Error { + match src { + WriteApexError::NotAllowed => io::Error::new( + io::ErrorKind::Other, + "operation not allowed at apex", + ), + WriteApexError::Io(err) => err, + } + } +} + +impl fmt::Display for WriteApexError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + WriteApexError::NotAllowed => { + f.write_str("operation not allowed") + } + WriteApexError::Io(ref err) => err.fmt(f), + } + } +} + +//------------ ZoneVersions -------------------------------------------------- + +#[derive(Debug)] +pub struct ZoneVersions { + current: (Version, Arc), + all: Vec<(Version, Weak)>, +} + +impl ZoneVersions { + #[allow(unused)] + pub fn update_current(&mut self, version: Version) -> Arc { + let marker = Arc::new(VersionMarker); + self.current = (version, marker.clone()); + marker + } + + #[allow(unused)] + pub fn push_version( + &mut self, + version: Version, + marker: Arc, + ) { + self.all.push((version, Arc::downgrade(&marker))) + } + + #[allow(unused)] + pub fn clean_versions(&mut self) -> Option { + let mut max_version = None; + self.all.retain(|item| { + if item.1.strong_count() > 0 { + true + } else { + match max_version { + Some(old) => { + if item.0 > old { + max_version = Some(item.0) + } + } + None => max_version = Some(item.0), + } + false + } + }); + max_version + } + + pub fn current(&self) -> &(Version, Arc) { + &self.current + } +} + +impl Default for ZoneVersions { + fn default() -> Self { + let marker = Arc::new(VersionMarker); + let weak_marker = Arc::downgrade(&marker); + ZoneVersions { + current: (Version::default(), marker), + all: vec![(Version::default(), weak_marker)], + } + } +} diff --git a/src/zonetree/mod.rs b/src/zonetree/mod.rs new file mode 100644 index 000000000..efadcee74 --- /dev/null +++ b/src/zonetree/mod.rs @@ -0,0 +1,101 @@ +#![cfg(feature = "unstable-zonetree")] +#![cfg_attr(docsrs, doc(cfg(feature = "unstable-zonetree")))] +#![warn(missing_docs)] +//! Experimental storing and querying of zone trees. +//! +//! A [`ZoneTree`] is a multi-rooted hierarchy of [`Zone`]s, each root being +//! the apex of a subtree for a distinct [`Class`]. +//! +//! Individual `Zone`s within the tree can be looked up by containing or exact +//! name, and then one can [`query`] the found `Zone` by [`Class`], [`Rtype`] and +//! [`Dname`] to produce an [`Answer`], which in turn can be used to produce a +//! response [`Message`] for serving to a DNS client. +//! +//! Trees can also be iterated over to inspect or export their content. +//! +//! The `Zone`s that a tree is comprised of can be created by feeding +//! zonefiles or individual resource records into [`ZoneBuilder`] and then +//! inserted into a `ZoneTree`. +//! +//! By default `Zone`s are stored in memory only. Zones with other types of +//! backing store can be created by implementing the [`ZoneStore`] trait and +//! passing an instance of the implementing struct to [`Zone::new`]. Zones +//! with different backing store types can be mixed and matched within the +//! same tree. +//! +//! The example below shows how to populate a `ZoneTree` from a zonefile. For +//! more examples of using `Zone`s and `ZoneTree`s including implementing an +//! alternate zone backing store for your `Zone`s, see the [examples in the +//! GitHub +//! repository](https://github.com/NLnetLabs/domain/tree/main/examples). +//! +//! # Usage +//! +//! The following example builds and queries a [`ZoneTree`] containing a +//! single in-memory [`Zone`]. +//! +//! ``` +//! use domain::base::iana::{Class, Rcode, Rtype}; +//! use domain::base::name::Dname; +//! use domain::zonefile::{inplace, parsed}; +//! use domain::zonetree::{Answer, Zone, ZoneBuilder, ZoneTree}; +//! +//! // Prepare some zone file bytes to demonstrate with. +//! let zone_file = include_bytes!("../../test-data/zonefiles/nsd-example.txt"); +//! let mut zone_bytes = std::io::BufReader::new(&zone_file[..]); +//! +//! // Read, parse and build a zone. +//! let reader = inplace::Zonefile::load(&mut zone_bytes).unwrap(); +//! let parsed = parsed::Zonefile::try_from(reader).unwrap(); +//! let builder = ZoneBuilder::try_from(parsed).unwrap(); +//! +//! // Turn the builder into a zone. +//! let zone = Zone::from(builder); +//! +//! // Equivalent but shorter: +//! let mut zone_bytes = std::io::BufReader::new(&zone_file[..]); +//! let reader = inplace::Zonefile::load(&mut zone_bytes).unwrap(); +//! let zone = Zone::try_from(reader).unwrap(); +//! +//! // Insert the zone into a zone tree. +//! let mut tree = ZoneTree::new(); +//! tree.insert_zone(zone).unwrap(); +//! +//! // Query the zone tree. +//! let qname = Dname::bytes_from_str("example.com").unwrap(); +//! let qtype = Rtype::A; +//! let found_zone = tree.find_zone(&qname, Class::IN).unwrap(); +//! let res: Answer = found_zone.read().query(qname, qtype).unwrap(); +//! +//! // Verify that we found a result. +//! assert_eq!(res.rcode(), Rcode::NOERROR); +//! ``` +//! +//! [`query`]: crate::zonetree::ReadableZone::query +//! [`Class`]: crate::base::iana::Class +//! [`Rtype`]: crate::base::iana::Rtype +//! [`Dname`]: crate::base::name::Dname +//! [`Message`]: crate::base::Message +//! [`NoError`]: crate::base::iana::code::Rcode::NOERROR +//! [`NxDomain`]: crate::base::iana::code::Rcode::NXDOMAIN +//! [`ZoneBuilder`]: in_memory::ZoneBuilder + +mod answer; +mod in_memory; +mod traits; +mod tree; +mod types; +mod walk; +mod zone; + +pub use self::answer::{Answer, AnswerAuthority, AnswerContent}; +pub use self::in_memory::ZoneBuilder; +pub use self::traits::{ + ReadableZone, WritableZone, WritableZoneNode, ZoneStore, +}; +pub use self::tree::ZoneTree; +pub use self::types::{ + Rrset, SharedRr, SharedRrset, StoredDname, StoredRecord, +}; +pub use self::walk::WalkOp; +pub use self::zone::Zone; diff --git a/src/zonetree/traits.rs b/src/zonetree/traits.rs new file mode 100644 index 000000000..6d26afa4c --- /dev/null +++ b/src/zonetree/traits.rs @@ -0,0 +1,192 @@ +//! Traits for abstracting away the backing store of a [`ZoneTree`]. +//! +//!
+//! +//! These interfaces are unstable and are likely to change in future. +//! +//!
+use bytes::Bytes; +use core::future::ready; +use core::pin::Pin; +use std::boxed::Box; +use std::fmt::Debug; +use std::future::Future; +use std::io; +use std::sync::Arc; + +use crate::base::iana::Class; +use crate::base::name::Label; +use crate::base::{Dname, Rtype}; +use crate::zonefile::error::OutOfZone; + +use super::answer::Answer; +use super::types::ZoneCut; +use super::{SharedRr, SharedRrset, StoredDname, WalkOp}; + +//------------ ZoneStore ----------------------------------------------------- + +/// A [`Zone`] storage interface. +/// +/// A [`ZoneStore`] provides a way to read [`Zone`]s from and write `Zone`s to +/// a particular backing store implementation. +/// +/// [`Zone`]: super::Zone +pub trait ZoneStore: Debug + Sync + Send { + /// Returns the class of the zone. + fn class(&self) -> Class; + + /// Returns the apex name of the zone. + fn apex_name(&self) -> &StoredDname; + + /// Get a read interface to this store. + fn read(self: Arc) -> Box; + + /// Get a write interface to this store. + fn write( + self: Arc, + ) -> Pin>>>; +} + +//------------ ReadableZone -------------------------------------------------- + +/// A read interface to a [`Zone`]. +/// +/// A [`ReadableZone`] mplementation provides (a)synchronous read access to +/// the [`ZoneStore`] backing storage for a [`Zone`]. +/// +/// [`Zone`]: super::Zone +pub trait ReadableZone: Send { + /// Returns true if ths `_async` variants of the functions offered by this + /// trait should be used by callers instead of the non-`_async` + /// equivalents. + fn is_async(&self) -> bool { + true + } + + //--- Sync variants + + /// Lookup an [`Answer`] in the zone for a given QNAME and QTYPE. + /// + /// This function performs a synchronous query against the zone it + /// provides access to, for a given QNAME and QTYPE. In combination with + /// having first looked the zone up by CLASS this function enables a + /// caller to obtain an [`Answer`] for an [RFC 1034 section 3.7.1] + /// "Standard query". + /// + /// [RFC 1034 section 3.7.1]: + /// https://www.rfc-editor.org/rfc/rfc1034#section-3.7.1 + fn query( + &self, + _qname: Dname, + _qtype: Rtype, + ) -> Result; + + /// Iterate over the entire contents of the zone. + /// + /// This function visits every node in the tree, synchronously, invoking + /// the given callback function at every leaf node found. + fn walk(&self, _op: WalkOp); + + //--- Async variants + + /// Asynchronous variant of `query()`. + fn query_async( + &self, + qname: Dname, + qtype: Rtype, + ) -> Pin> + Send>> { + Box::pin(ready(self.query(qname, qtype))) + } + + /// Asynchronous variant of `walk()`. + fn walk_async( + &self, + op: WalkOp, + ) -> Pin + Send>> { + self.walk(op); + Box::pin(ready(())) + } +} + +//------------ WritableZone -------------------------------------------------- + +/// An asynchronous write interface to a [`Zone`]. +/// +/// [`Zone`]: super::Zone +pub trait WritableZone { + /// Start a write operation for the zone. + #[allow(clippy::type_complexity)] + fn open( + &self, + ) -> Pin< + Box< + dyn Future, io::Error>>, + >, + >; + + /// Complete a write operation for the zone. + /// + /// This function commits the changes accumulated since [`open`] was + /// invoked. Clients who obtain a [`ReadableZone`] interface to this zone + /// _before_ this function has been called will not see any of the changes + /// made since the last commit. Only clients who obtain a [`ReadableZone`] + /// _after_ invoking this function will be able to see the changes made + /// since [`open`] was called. called. + fn commit( + &mut self, + ) -> Pin>>>; +} + +//------------ WritableZoneNode ---------------------------------------------- + +/// An asynchronous write interface to a particular node in a [`ZoneTree`]. +/// +/// [`ZoneTree`]: super::ZoneTree +pub trait WritableZoneNode { + /// Get a write interface to a child node of this node. + #[allow(clippy::type_complexity)] + fn update_child( + &self, + label: &Label, + ) -> Pin< + Box< + dyn Future, io::Error>>, + >, + >; + + /// Replace the RRset at this node with the given RRset. + fn update_rrset( + &self, + rrset: SharedRrset, + ) -> Pin>>>; + + /// Remove an RRset of the given type at this node, if any. + fn remove_rrset( + &self, + rtype: Rtype, + ) -> Pin>>>; + + /// Mark this node as a regular node. + /// + /// If this node has zone cut or CNAME data, calling this + /// function will erase that data. + fn make_regular( + &self, + ) -> Pin>>>; + + /// Mark this node as a zone cut. + /// + /// Any "regular" or CNAME data at this node will be lost. + fn make_zone_cut( + &self, + cut: ZoneCut, + ) -> Pin>>>; + + /// Mark this node as a CNAME. + /// + /// Any "regular" or zone cut data at this node will be lost. + fn make_cname( + &self, + cname: SharedRr, + ) -> Pin>>>; +} diff --git a/src/zonetree/tree.rs b/src/zonetree/tree.rs new file mode 100644 index 000000000..ed9f00758 --- /dev/null +++ b/src/zonetree/tree.rs @@ -0,0 +1,303 @@ +//! The known set of zones. + +use super::zone::Zone; +use crate::base::iana::Class; +use crate::base::name::{Label, OwnedLabel, ToDname, ToLabelIter}; +use std::collections::hash_map; +use std::collections::HashMap; +use std::fmt::Display; +use std::io; +use std::vec::Vec; + +//------------ ZoneTree ------------------------------------------------------ + +/// A multi-rooted [`Zone`] hierarchy. +/// +/// [`Zone`]: crate::zonetree::Zone. +#[derive(Default)] +pub struct ZoneTree { + roots: Roots, +} + +impl ZoneTree { + /// Creates an empty [`ZoneTree`]. + pub fn new() -> Self { + Default::default() + } + + /// Gets a [`Zone`] for the given apex name and CLASS, if any. + pub fn get_zone( + &self, + apex_name: &impl ToDname, + class: Class, + ) -> Option<&Zone> { + self.roots + .get(class)? + .get_zone(apex_name.iter_labels().rev()) + } + + /// Inserts the given [`Zone`]. + /// + /// Returns a [`ZoneTreeModificationError`] if a zone with the same apex + /// and CLASS already exists in the tree. + pub fn insert_zone( + &mut self, + zone: Zone, + ) -> Result<(), ZoneTreeModificationError> { + self.roots.get_or_insert(zone.class()).insert_zone( + &mut zone.apex_name().clone().iter_labels().rev(), + zone, + ) + } + + /// Gets the closest matching [`Zone`] for the given QNAME and CLASS, if + /// any. + pub fn find_zone( + &self, + qname: &impl ToDname, + class: Class, + ) -> Option<&Zone> { + self.roots.get(class)?.find_zone(qname.iter_labels().rev()) + } + + /// Returns an iterator over all of the [`Zone`]s in the tree. + pub fn iter_zones(&self) -> ZoneSetIter { + ZoneSetIter::new(self) + } + + /// Removes the specified [`Zone`], if any. + pub fn remove_zone( + &mut self, + apex_name: &impl ToDname, + class: Class, + ) -> Result<(), ZoneTreeModificationError> { + if let Some(root) = self.roots.get_mut(class) { + root.remove_zone(apex_name.iter_labels().rev()) + } else { + Err(ZoneTreeModificationError::ZoneDoesNotExist) + } + } +} + +//------------ Roots --------------------------------------------------------- + +#[derive(Default)] +struct Roots { + in_: ZoneSetNode, + others: HashMap, +} + +impl Roots { + pub fn get(&self, class: Class) -> Option<&ZoneSetNode> { + if class == Class::IN { + Some(&self.in_) + } else { + self.others.get(&class) + } + } + + pub fn get_mut(&mut self, class: Class) -> Option<&mut ZoneSetNode> { + if class == Class::IN { + Some(&mut self.in_) + } else { + self.others.get_mut(&class) + } + } + + pub fn get_or_insert(&mut self, class: Class) -> &mut ZoneSetNode { + if class == Class::IN { + &mut self.in_ + } else { + self.others.entry(class).or_default() + } + } +} + +//------------ ZoneSetNode --------------------------------------------------- + +#[derive(Default)] +struct ZoneSetNode { + zone: Option, + children: HashMap, +} + +impl ZoneSetNode { + fn get_zone<'l>( + &self, + mut apex_name: impl Iterator, + ) -> Option<&Zone> { + match apex_name.next() { + Some(label) => self.children.get(label)?.get_zone(apex_name), + None => self.zone.as_ref(), + } + } + + pub fn find_zone<'l>( + &self, + mut qname: impl Iterator, + ) -> Option<&Zone> { + if let Some(label) = qname.next() { + if let Some(node) = self.children.get(label) { + if let Some(zone) = node.find_zone(qname) { + return Some(zone); + } + } + } + self.zone.as_ref() + } + + fn insert_zone<'l>( + &mut self, + mut apex_name: impl Iterator, + zone: Zone, + ) -> Result<(), ZoneTreeModificationError> { + if let Some(label) = apex_name.next() { + self.children + .entry(label.into()) + .or_default() + .insert_zone(apex_name, zone) + } else if self.zone.is_some() { + Err(ZoneTreeModificationError::ZoneExists) + } else { + self.zone = Some(zone); + Ok(()) + } + } + + fn remove_zone<'l>( + &mut self, + mut apex_name: impl Iterator, + ) -> Result<(), ZoneTreeModificationError> { + match apex_name.next() { + Some(label) => { + if self.children.remove(label).is_none() { + return Err(ZoneTreeModificationError::ZoneDoesNotExist); + } + } + None => { + self.zone = None; + } + } + Ok(()) + } +} + +//------------ ZoneSetIter --------------------------------------------------- + +pub struct ZoneSetIter<'a> { + roots: hash_map::Values<'a, Class, ZoneSetNode>, + nodes: NodesIter<'a>, +} + +impl<'a> ZoneSetIter<'a> { + fn new(set: &'a ZoneTree) -> Self { + ZoneSetIter { + roots: set.roots.others.values(), + nodes: NodesIter::new(&set.roots.in_), + } + } +} + +impl<'a> Iterator for ZoneSetIter<'a> { + type Item = &'a Zone; + + fn next(&mut self) -> Option { + loop { + if let Some(node) = self.nodes.next() { + if let Some(zone) = node.zone.as_ref() { + return Some(zone); + } else { + continue; + } + } + self.nodes = NodesIter::new(self.roots.next()?); + } + } +} + +//------------ NodesIter ----------------------------------------------------- + +struct NodesIter<'a> { + root: Option<&'a ZoneSetNode>, + stack: Vec>, +} + +impl<'a> NodesIter<'a> { + fn new(node: &'a ZoneSetNode) -> Self { + NodesIter { + root: Some(node), + stack: Vec::new(), + } + } + + fn next_node(&mut self) -> Option<&'a ZoneSetNode> { + if let Some(node) = self.root.take() { + return Some(node); + } + loop { + if let Some(iter) = self.stack.last_mut() { + if let Some(node) = iter.next() { + return Some(node); + } + } else { + return None; + } + let _ = self.stack.pop(); + } + } +} + +impl<'a> Iterator for NodesIter<'a> { + type Item = &'a ZoneSetNode; + + fn next(&mut self) -> Option { + let node = self.next_node()?; + self.stack.push(node.children.values()); + Some(node) + } +} + +//============ Error Types =================================================== + +#[derive(Debug)] +pub enum ZoneTreeModificationError { + ZoneExists, + ZoneDoesNotExist, + Io(io::Error), +} + +impl From for ZoneTreeModificationError { + fn from(src: io::Error) -> Self { + ZoneTreeModificationError::Io(src) + } +} + +impl From for io::Error { + fn from(src: ZoneTreeModificationError) -> Self { + match src { + ZoneTreeModificationError::Io(err) => err, + ZoneTreeModificationError::ZoneDoesNotExist => { + io::Error::new(io::ErrorKind::Other, "zone does not exist") + } + ZoneTreeModificationError::ZoneExists => { + io::Error::new(io::ErrorKind::Other, "zone exists") + } + } + } +} + +impl Display for ZoneTreeModificationError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + ZoneTreeModificationError::ZoneExists => { + write!(f, "Zone already exists") + } + ZoneTreeModificationError::ZoneDoesNotExist => { + write!(f, "Zone does not exist") + } + ZoneTreeModificationError::Io(err) => { + write!(f, "Io error: {err}") + } + } + } +} diff --git a/src/zonetree/types.rs b/src/zonetree/types.rs new file mode 100644 index 000000000..26dd2e072 --- /dev/null +++ b/src/zonetree/types.rs @@ -0,0 +1,238 @@ +use crate::base::name::Dname; +use crate::base::rdata::RecordData; +use crate::base::record::Record; +use crate::base::{iana::Rtype, Ttl}; +use crate::rdata::ZoneRecordData; +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use std::ops; +use std::sync::Arc; +use std::vec::Vec; + +//------------ Type Aliases -------------------------------------------------- + +/// A [`Bytes`] backed [`Dname`]. +pub type StoredDname = Dname; + +/// A [`Bytes`] backed [`ZoneRecordData`]. +pub type StoredRecordData = ZoneRecordData; + +/// A [`Bytes`] backed [`Record`].` +pub type StoredRecord = Record; + +//------------ SharedRr ------------------------------------------------------ + +/// A cheaply clonable resource record. +/// +/// A [`Bytes`] backed resource record which is cheap to [`Clone`] because +/// [`Bytes`] is cheap to clone. +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +pub struct SharedRr { + ttl: Ttl, + data: StoredRecordData, +} + +impl SharedRr { + /// Create a new [`SharedRr`] instance. + pub fn new(ttl: Ttl, data: StoredRecordData) -> Self { + SharedRr { ttl, data } + } + + /// Gets the type of this resource record. + pub fn rtype(&self) -> Rtype { + self.data.rtype() + } + + /// Gets the TTL of this resource record. + pub fn ttl(&self) -> Ttl { + self.ttl + } + + /// Gets a reference to the data of this resource record. + pub fn data(&self) -> &StoredRecordData { + &self.data + } +} + +impl From for SharedRr { + fn from(record: StoredRecord) -> Self { + SharedRr { + ttl: record.ttl(), + data: record.into_data(), + } + } +} + +//------------ Rrset --------------------------------------------------------- + +/// A set of related resource records for use with [`Zone`]s. +/// +/// This type should be used to create and edit one or more resource records +/// for use with a [`Zone`]. RRset records should all have the same type and +/// TTL but differing data, as defined by [RFC 9499 section 5.1.3]. +/// +/// [`Zone`]: crate::zonetree::Zone +/// [RFC 9499 section 5.1.3]: +/// https://datatracker.ietf.org/doc/html/rfc9499#section-5-1.3 +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +pub struct Rrset { + rtype: Rtype, + ttl: Ttl, + data: Vec, +} + +impl Rrset { + /// Creates a new RRset. + pub fn new(rtype: Rtype, ttl: Ttl) -> Self { + Rrset { + rtype, + ttl, + data: Vec::new(), + } + } + + /// Gets the common type of each record in the RRset. + pub fn rtype(&self) -> Rtype { + self.rtype + } + + /// Gets the common TTL of each record in the RRset. + pub fn ttl(&self) -> Ttl { + self.ttl + } + + /// Gets the data for each record in the RRset. + pub fn data(&self) -> &[StoredRecordData] { + &self.data + } + + /// Returns true if this RRset has no resource records, false otherwise. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Gets the first RRset record, if any. + pub fn first(&self) -> Option { + self.data.first().map(|data| SharedRr { + ttl: self.ttl, + data: data.clone(), + }) + } + + /// Changesthe TTL of every record in the RRset. + pub fn set_ttl(&mut self, ttl: Ttl) { + self.ttl = ttl; + } + + /// Limits the TTL of every record in the RRSet. + /// + /// If the TTL currently exceeds the given limit it will be set to the + /// limit. + pub fn limit_ttl(&mut self, ttl: Ttl) { + if self.ttl > ttl { + self.ttl = ttl + } + } + + /// Adds a resource record to the RRset. + /// + /// # Panics + /// + /// This function will panic if the provided record data is for a + /// different type than the RRset. + pub fn push_data(&mut self, data: StoredRecordData) { + assert_eq!(data.rtype(), self.rtype); + self.data.push(data); + } + + /// Adds a resource record to the RRset, limiting the TTL to that of the + /// new record. + /// + /// See [`Self::limit_ttl`] and [`Self::push_data`]. + pub fn push_record(&mut self, record: StoredRecord) { + self.limit_ttl(record.ttl()); + self.push_data(record.into_data()); + } + + /// Converts this [`Rrset`] to an [`SharedRrset`]. + pub fn into_shared(self) -> SharedRrset { + SharedRrset::new(self) + } +} + +impl From for Rrset { + fn from(record: StoredRecord) -> Self { + Rrset { + rtype: record.rtype(), + ttl: record.ttl(), + data: vec![record.into_data()], + } + } +} + +//------------ SharedRrset --------------------------------------------------- + +/// An RRset behind an [`Arc`] for use with [`Zone`]s. +/// +/// See [`Rrset`] for more information. +/// +/// [`Zone`]: crate::zonetree::Zone. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct SharedRrset(Arc); + +impl SharedRrset { + /// Creates a new RRset. + pub fn new(rrset: Rrset) -> Self { + SharedRrset(Arc::new(rrset)) + } + + /// Gets a reference to the inner [`Rrset`]. + pub fn as_rrset(&self) -> &Rrset { + self.0.as_ref() + } +} + +//--- Deref, AsRef, Borrow + +impl ops::Deref for SharedRrset { + type Target = Rrset; + + fn deref(&self) -> &Self::Target { + self.as_rrset() + } +} + +impl AsRef for SharedRrset { + fn as_ref(&self) -> &Rrset { + self.as_rrset() + } +} + +//--- Deserialize and Serialize + +impl<'de> Deserialize<'de> for SharedRrset { + fn deserialize>( + deserializer: D, + ) -> Result { + Rrset::deserialize(deserializer).map(SharedRrset::new) + } +} + +impl Serialize for SharedRrset { + fn serialize( + &self, + serializer: S, + ) -> Result { + self.as_rrset().serialize(serializer) + } +} + +//------------ ZoneCut ------------------------------------------------------- + +#[derive(Clone, Debug)] +pub struct ZoneCut { + pub name: StoredDname, + pub ns: SharedRrset, + pub ds: Option, + pub glue: Vec, +} diff --git a/src/zonetree/walk.rs b/src/zonetree/walk.rs new file mode 100644 index 000000000..101c872dd --- /dev/null +++ b/src/zonetree/walk.rs @@ -0,0 +1,72 @@ +use std::boxed::Box; +use std::sync::{Arc, Mutex}; +use std::vec::Vec; + +use bytes::Bytes; + +use super::Rrset; +use crate::base::name::OwnedLabel; +use crate::base::{Dname, DnameBuilder}; + +/// A callback function invoked for each leaf node visited while walking a +/// [`Zone`]. +/// +/// [`Zone`]: super::Zone +pub type WalkOp = Box, &Rrset) + Send + Sync>; + +struct WalkStateInner { + op: WalkOp, + label_stack: Mutex>, +} + +impl WalkStateInner { + fn new(op: WalkOp) -> Self { + Self { + op, + label_stack: Default::default(), + } + } +} + +#[derive(Clone)] +pub(super) struct WalkState { + inner: Option>, +} + +impl WalkState { + pub(super) const DISABLED: WalkState = WalkState { inner: None }; + + pub(super) fn new(op: WalkOp) -> Self { + Self { + inner: Some(Arc::new(WalkStateInner::new(op))), + } + } + + pub(super) fn enabled(&self) -> bool { + self.inner.is_some() + } + + pub(super) fn op(&self, rrset: &Rrset) { + if let Some(inner) = &self.inner { + let labels = inner.label_stack.lock().unwrap(); + let mut dname = DnameBuilder::new_bytes(); + for label in labels.iter().rev() { + dname.append_label(label.as_slice()).unwrap(); + } + let owner = dname.into_dname().unwrap(); + (inner.op)(owner, rrset); + } + } + + pub(super) fn push(&self, label: OwnedLabel) { + if let Some(inner) = &self.inner { + inner.label_stack.lock().unwrap().push(label); + } + } + + pub(super) fn pop(&self) { + if let Some(inner) = &self.inner { + inner.label_stack.lock().unwrap().pop(); + } + } +} diff --git a/src/zonetree/zone.rs b/src/zonetree/zone.rs new file mode 100644 index 000000000..d160b14b3 --- /dev/null +++ b/src/zonetree/zone.rs @@ -0,0 +1,82 @@ +use std::boxed::Box; +use std::fmt::Debug; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use crate::base::iana::Class; +use crate::zonefile::error::{RecordError, ZoneErrors}; +use crate::zonefile::{inplace, parsed}; + +use super::in_memory::ZoneBuilder; +use super::traits::WritableZone; +use super::{ReadableZone, StoredDname, ZoneStore}; + +//------------ Zone ---------------------------------------------------------- + +/// A single DNS zone. +#[derive(Debug)] +pub struct Zone { + store: Arc, +} + +impl Zone { + /// Creates a new [`Zone`] instance with the given data. + pub fn new(data: impl ZoneStore + 'static) -> Self { + Zone { + store: Arc::new(data), + } + } + + /// Gets the CLASS of this zone. + pub fn class(&self) -> Class { + self.store.class() + } + + /// Gets the apex name of this zone. + pub fn apex_name(&self) -> &StoredDname { + self.store.apex_name() + } + + /// Gets a read interface to this zone. + pub fn read(&self) -> Box { + self.store.clone().read() + } + + /// Gets a write interface to this zone. + pub fn write( + &self, + ) -> Pin>>> { + self.store.clone().write() + } +} + +//--- TryFrom + +impl TryFrom for Zone { + type Error = RecordError; + + fn try_from(source: inplace::Zonefile) -> Result { + parsed::Zonefile::try_from(source)? + .try_into() + .map_err(Self::Error::InvalidRecord) + } +} + +//--- TryFrom + +impl From for Zone { + fn from(builder: ZoneBuilder) -> Self { + builder.build() + } +} + +//--- TryFrom + +impl TryFrom for Zone { + type Error = ZoneErrors; + + fn try_from(source: parsed::Zonefile) -> Result { + Ok(Zone::from(ZoneBuilder::try_from(source)?)) + } +} diff --git a/test-data/zonefiles/nsd-example.txt b/test-data/zonefiles/nsd-example.txt new file mode 100644 index 000000000..06ba2b8a6 --- /dev/null +++ b/test-data/zonefiles/nsd-example.txt @@ -0,0 +1,25 @@ +$ORIGIN example.com. ; 'default' domain as FQDN for this zone +$TTL 86400 ; default time-to-live for this zone + +example.com. IN SOA ns.example.com. noc.dns.icann.org. ( + 2020080302 ;Serial + 7200 ;Refresh + 3600 ;Retry + 1209600 ;Expire + 3600 ;Negative response caching TTL +) + +; The nameserver that are authoritative for this zone. + NS example.com. + +; these A records below are equivalent +example.com. A 192.0.2.1 +@ A 192.0.2.1 + A 192.0.2.1 + +@ AAAA 2001:db8::3 + +; A CNAME redirect from www.exmaple.com to example.com +www CNAME example.com. + +mail MX 10 example.com. diff --git a/tests/net-server.rs b/tests/net-server.rs index b353c50f9..c046b45cc 100644 --- a/tests/net-server.rs +++ b/tests/net-server.rs @@ -102,6 +102,8 @@ fn mk_servers( ) where Svc: Service + Send + Sync + 'static, + Svc::Future: Send, + Svc::Target: Composer + Default + Send + Sync, { // Prepare middleware to be used by the DNS servers to pre-process // received requests and post-process created responses. @@ -248,7 +250,7 @@ fn test_service( zonefile: Zonefile, ) -> Result< Transaction< - Result>, ServiceError>, + Vec, impl Future>, ServiceError>> + Send, >, ServiceError, diff --git a/tests/net/stelline/channel.rs b/tests/net/stelline/channel.rs index 4e4889131..c4498a4ad 100644 --- a/tests/net/stelline/channel.rs +++ b/tests/net/stelline/channel.rs @@ -1,8 +1,6 @@ // Using tokio::io::duplex() seems appealing but it can only create a channel // between two ends, it isn't possible to create additional client ends for a // single server end for example. -use core::future::pending; - use std::collections::HashMap; use std::future::ready; use std::future::Future; @@ -400,12 +398,7 @@ impl AsyncDgramSock for ClientServerChannel { fn readable( &self, ) -> Pin> + '_ + Send>> { - let server_socket = self.server.lock().unwrap(); - let rx = &server_socket.rx; - match !rx.is_empty() { - true => Box::pin(ready(Ok(()))), - false => Box::pin(pending()), - } + Box::pin(ClientServerChannelReadableFut(self.server.clone())) } fn try_recv_buf_from( @@ -436,6 +429,32 @@ impl AsyncDgramSock for ClientServerChannel { } } +pub struct ClientServerChannelReadableFut(Arc>); + +impl Future for ClientServerChannelReadableFut { + type Output = io::Result<()>; + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll { + let server_socket = self.0.lock().unwrap(); + let rx = &server_socket.rx; + trace!("ReadableFut {} in dgram server channel", !rx.is_empty()); + match !rx.is_empty() { + true => Poll::Ready(Ok(())), + false => { + let waker = cx.waker().clone(); + std::thread::spawn(move || { + std::thread::yield_now(); + waker.wake(); + }); + Poll::Pending + } + } + } +} + //--- AsyncAccept // // Stream connection establishment diff --git a/tests/net/stelline/parse_query.rs b/tests/net/stelline/parse_query.rs index 3b91e962d..9c510410d 100644 --- a/tests/net/stelline/parse_query.rs +++ b/tests/net/stelline/parse_query.rs @@ -193,9 +193,11 @@ pub enum Entry { #[allow(dead_code)] Include { /// The path to the file to be included. + #[allow(dead_code)] path: ScannedString, /// The initial origin name of the included file, if provided. + #[allow(dead_code)] origin: Option>, }, }