From 659b840c6c6478e141c7c857301b6fe2fa22c7b7 Mon Sep 17 00:00:00 2001 From: Ryan Roberts Date: Wed, 25 Mar 2020 17:17:29 +0000 Subject: [PATCH 1/8] Upgrading to new tokio / futures and rust 2018 edition. --- Cargo.toml | 26 ++- rustfmt.toml | 7 + src/dns_parser/builder.rs | 128 ++++++----- src/dns_parser/enums.rs | 101 +++++---- src/dns_parser/header.rs | 119 +++++----- src/dns_parser/mod.rs | 16 +- src/dns_parser/name.rs | 82 ++++--- src/dns_parser/parser.rs | 446 +++++++++++++++++++++----------------- src/dns_parser/rrdata.rs | 80 +++---- src/dns_parser/structs.rs | 3 +- src/fsm.rs | 107 +++++---- src/lib.rs | 63 +++--- src/services.rs | 4 +- 13 files changed, 652 insertions(+), 530 deletions(-) create mode 100644 rustfmt.toml diff --git a/Cargo.toml b/Cargo.toml index 579f62f..9b28ceb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,23 +2,25 @@ name = "libmdns" version = "0.2.4" authors = ["Will Stott "] - description = "mDNS Responder library for building discoverable LAN services in Rust" repository = "https://github.com/librespot-org/libmdns" readme = "README.md" license = "MIT" +edition = "2018" [dependencies] -byteorder = "1.2" -futures = "0.1" -get_if_addrs = "0.5" -hostname = "0.2" -log = "0.4" -multimap = "0.4" -net2 = "0.2" -rand = "0.5" -tokio-core = "0.1" -quick-error = "1.2" +byteorder = "1.3.4" +futures = "0.3.4" +get_if_addrs = "0.5.3" +hostname = "0.3.1" +log = "0.4.8" +multimap = "0.8.0" +net2 = "0.2.33" +rand = "0.7.3" +tokio = { version = "0.2.13", features = ["udp","stream","io-driver","io-std","macros"] } +quick-error = "1.2.3" +pin-project = "0.4.8" + [dev-dependencies] -env_logger = "0.5" +env_logger = "0.7.1" diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..b16b80e --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,7 @@ +merge_derives=true +merge_imports=true +normalize_comments=true +reorder_impl_items=true +reorder_modules=true +use_try_shorthand=true +use_field_init_shorthand=true \ No newline at end of file diff --git a/src/dns_parser/builder.rs b/src/dns_parser/builder.rs index da81356..d8f5f33 100644 --- a/src/dns_parser/builder.rs +++ b/src/dns_parser/builder.rs @@ -1,16 +1,17 @@ use std::marker::PhantomData; -use byteorder::{ByteOrder, BigEndian, WriteBytesExt}; +use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; -use super::{Opcode, ResponseCode, Header, Name, RRData, QueryType, QueryClass}; +use super::{Header, Name, Opcode, QueryClass, QueryType, RRData, ResponseCode}; pub enum Questions {} pub enum Answers {} -#[allow(dead_code)] pub enum Nameservers {} +#[allow(dead_code)] +pub enum Nameservers {} pub enum Additional {} -pub trait MoveTo { } -impl MoveTo for T {} +pub trait MoveTo {} +impl MoveTo for T {} impl MoveTo for Questions {} @@ -40,7 +41,7 @@ impl Builder { pub fn new_query(id: u16, recursion: bool) -> Builder { let mut buf = Vec::with_capacity(512); let head = Header { - id: id, + id, query: true, opcode: Opcode::StandardQuery, authoritative: false, @@ -55,16 +56,20 @@ impl Builder { }; buf.extend([0u8; 12].iter()); head.write(&mut buf[..12]); - Builder { buf: buf, max_size: Some(512), _state: PhantomData } + Builder { + buf, + max_size: Some(512), + _state: PhantomData, + } } pub fn new_response(id: u16, recursion: bool, authoritative: bool) -> Builder { let mut buf = Vec::with_capacity(512); let head = Header { - id: id, + id, query: false, opcode: Opcode::StandardQuery, - authoritative: authoritative, + authoritative, truncated: false, recursion_desired: recursion, recursion_available: false, @@ -76,14 +81,16 @@ impl Builder { }; buf.extend([0u8; 12].iter()); head.write(&mut buf[..12]); - Builder { buf: buf, max_size: Some(512), _state: PhantomData } + Builder { + buf, + max_size: Some(512), + _state: PhantomData, + } } } -impl Builder { - fn write_rr(&mut self, name: &Name, - cls: QueryClass, ttl: u32, data: &RRData) { - +impl Builder { + fn write_rr(&mut self, name: &Name, cls: QueryClass, ttl: u32, data: &RRData) { name.write_to(&mut self.buf).unwrap(); self.buf.write_u16::(data.typ() as u16).unwrap(); self.buf.write_u16::(cls as u16).unwrap(); @@ -96,7 +103,10 @@ impl Builder { data.write_to(&mut self.buf).unwrap(); let data_size = self.buf.len() - data_offset; - BigEndian::write_u16(&mut self.buf[size_offset..size_offset+2], data_size as u16); + BigEndian::write_u16( + &mut self.buf[size_offset..size_offset + 2], + data_size as u16, + ); } /// Returns the final packet @@ -113,7 +123,7 @@ impl Builder { /// appropriate. // TODO(tailhook) does the truncation make sense for TCP, and how // to treat it for EDNS0? - pub fn build(mut self) -> Result,Vec> { + pub fn build(mut self) -> Result, Vec> { // TODO(tailhook) optimize labels match self.max_size { Some(max_size) if self.buf.len() > max_size => { @@ -124,8 +134,15 @@ impl Builder { } } - pub fn move_to(self) -> Builder where T: MoveTo { - Builder { buf: self.buf, max_size: self.max_size, _state: PhantomData } + pub fn move_to(self) -> Builder + where + T: MoveTo, + { + Builder { + buf: self.buf, + max_size: self.max_size, + _state: PhantomData, + } } pub fn set_max_size(&mut self, max_size: Option) { @@ -133,61 +150,66 @@ impl Builder { } pub fn is_empty(&self) -> bool { - Header::question_count(&self.buf) == 0 && - Header::answer_count(&self.buf) == 0 && - Header::nameserver_count(&self.buf) == 0 && - Header::additional_count(&self.buf) == 0 + Header::question_count(&self.buf) == 0 + && Header::answer_count(&self.buf) == 0 + && Header::nameserver_count(&self.buf) == 0 + && Header::additional_count(&self.buf) == 0 } } -impl > Builder { +impl> Builder { /// Adds a question to the packet /// /// # Panics /// /// * There are already 65535 questions in the buffer. #[allow(dead_code)] - pub fn add_question(self, qname: &Name, - qtype: QueryType, qclass: QueryClass) - -> Builder - { + pub fn add_question( + self, + qname: &Name, + qtype: QueryType, + qclass: QueryClass, + ) -> Builder { let mut builder = self.move_to::(); qname.write_to(&mut builder.buf).unwrap(); builder.buf.write_u16::(qtype as u16).unwrap(); builder.buf.write_u16::(qclass as u16).unwrap(); - Header::inc_questions(&mut builder.buf) - .expect("Too many questions"); + Header::inc_questions(&mut builder.buf).expect("Too many questions"); builder } } -impl > Builder { - pub fn add_answer(self, name: &Name, - cls: QueryClass, ttl: u32, data: &RRData) - -> Builder - { +impl> Builder { + pub fn add_answer( + self, + name: &Name, + cls: QueryClass, + ttl: u32, + data: &RRData, + ) -> Builder { let mut builder = self.move_to::(); builder.write_rr(name, cls, ttl, data); - Header::inc_answers(&mut builder.buf) - .expect("Too many answers"); + Header::inc_answers(&mut builder.buf).expect("Too many answers"); builder } } -impl > Builder { +impl> Builder { #[allow(dead_code)] - pub fn add_nameserver(self, name: &Name, - cls: QueryClass, ttl: u32, data: &RRData) - -> Builder - { + pub fn add_nameserver( + self, + name: &Name, + cls: QueryClass, + ttl: u32, + data: &RRData, + ) -> Builder { let mut builder = self.move_to::(); builder.write_rr(name, cls, ttl, data); - Header::inc_nameservers(&mut builder.buf) - .expect("Too many nameservers"); + Header::inc_nameservers(&mut builder.buf).expect("Too many nameservers"); builder } @@ -195,15 +217,17 @@ impl > Builder { impl Builder { #[allow(dead_code)] - pub fn add_additional(self, name: &Name, - cls: QueryClass, ttl: u32, data: &RRData) - -> Builder - { + pub fn add_additional( + self, + name: &Name, + cls: QueryClass, + ttl: u32, + data: &RRData, + ) -> Builder { let mut builder = self.move_to::(); builder.write_rr(name, cls, ttl, data); - Header::inc_nameservers(&mut builder.buf) - .expect("Too many additional answers"); + Header::inc_nameservers(&mut builder.buf).expect("Too many additional answers"); builder } @@ -211,10 +235,10 @@ impl Builder { #[cfg(test)] mod test { - use super::QueryType as QT; - use super::QueryClass as QC; - use super::Name; use super::Builder; + use super::Name; + use super::QueryClass as QC; + use super::QueryType as QT; #[test] fn build_query() { diff --git a/src/dns_parser/enums.rs b/src/dns_parser/enums.rs index 55f63b3..e2c7b1c 100644 --- a/src/dns_parser/enums.rs +++ b/src/dns_parser/enums.rs @@ -92,7 +92,6 @@ pub enum QueryType { All = 255, } - /// The CLASS value according to RFC 1035 #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] pub enum Class { @@ -171,13 +170,13 @@ impl From for ResponseCode { fn from(code: u8) -> ResponseCode { use self::ResponseCode::*; match code { - 0 => NoError, - 1 => FormatError, - 2 => ServerFailure, - 3 => NameError, - 4 => NotImplemented, - 5 => Refused, - 6...15 => Reserved(code), + 0 => NoError, + 1 => FormatError, + 2 => ServerFailure, + 3 => NameError, + 4 => NotImplemented, + 5 => Refused, + 6..=15 => Reserved(code), x => panic!("Invalid response code {}", x), } } @@ -201,23 +200,23 @@ impl QueryType { pub fn parse(code: u16) -> Result { use self::QueryType::*; match code { - 1 => Ok(A), - 2 => Ok(NS), - 4 => Ok(MF), - 5 => Ok(CNAME), - 6 => Ok(SOA), - 7 => Ok(MB), - 8 => Ok(MG), - 9 => Ok(MR), - 10 => Ok(NULL), - 11 => Ok(WKS), - 12 => Ok(PTR), - 13 => Ok(HINFO), - 14 => Ok(MINFO), - 15 => Ok(MX), - 16 => Ok(TXT), - 28 => Ok(AAAA), - 33 => Ok(SRV), + 1 => Ok(A), + 2 => Ok(NS), + 4 => Ok(MF), + 5 => Ok(CNAME), + 6 => Ok(SOA), + 7 => Ok(MB), + 8 => Ok(MG), + 9 => Ok(MR), + 10 => Ok(NULL), + 11 => Ok(WKS), + 12 => Ok(PTR), + 13 => Ok(HINFO), + 14 => Ok(MINFO), + 15 => Ok(MX), + 16 => Ok(TXT), + 28 => Ok(AAAA), + 33 => Ok(SRV), 252 => Ok(AXFR), 253 => Ok(MAILB), 254 => Ok(MAILA), @@ -231,10 +230,10 @@ impl QueryClass { pub fn parse(code: u16) -> Result { use self::QueryClass::*; match code { - 1 => Ok(IN), - 2 => Ok(CS), - 3 => Ok(CH), - 4 => Ok(HS), + 1 => Ok(IN), + 2 => Ok(CS), + 3 => Ok(CH), + 4 => Ok(HS), 255 => Ok(Any), x => Err(Error::InvalidQueryClass(x)), } @@ -245,24 +244,24 @@ impl Type { pub fn parse(code: u16) -> Result { use self::Type::*; match code { - 1 => Ok(A), - 2 => Ok(NS), - 4 => Ok(MF), - 5 => Ok(CNAME), - 6 => Ok(SOA), - 7 => Ok(MB), - 8 => Ok(MG), - 9 => Ok(MR), - 10 => Ok(NULL), - 11 => Ok(WKS), - 12 => Ok(PTR), - 13 => Ok(HINFO), - 14 => Ok(MINFO), - 15 => Ok(MX), - 16 => Ok(TXT), - 28 => Ok(AAAA), - 33 => Ok(SRV), - 41 => Ok(OPT), + 1 => Ok(A), + 2 => Ok(NS), + 4 => Ok(MF), + 5 => Ok(CNAME), + 6 => Ok(SOA), + 7 => Ok(MB), + 8 => Ok(MG), + 9 => Ok(MR), + 10 => Ok(NULL), + 11 => Ok(WKS), + 12 => Ok(PTR), + 13 => Ok(HINFO), + 14 => Ok(MINFO), + 15 => Ok(MX), + 16 => Ok(TXT), + 28 => Ok(AAAA), + 33 => Ok(SRV), + 41 => Ok(OPT), x => Err(Error::InvalidType(x)), } } @@ -272,10 +271,10 @@ impl Class { pub fn parse(code: u16) -> Result { use self::Class::*; match code { - 1 => Ok(IN), - 2 => Ok(CS), - 3 => Ok(CH), - 4 => Ok(HS), + 1 => Ok(IN), + 2 => Ok(CS), + 3 => Ok(CH), + 4 => Ok(HS), x => Err(Error::InvalidClass(x)), } } diff --git a/src/dns_parser/header.rs b/src/dns_parser/header.rs index ddc908f..263d002 100644 --- a/src/dns_parser/header.rs +++ b/src/dns_parser/header.rs @@ -1,16 +1,16 @@ use byteorder::{BigEndian, ByteOrder}; -use super::{Error, ResponseCode, Opcode}; +use super::{Error, Opcode, ResponseCode}; mod flag { - pub const QUERY: u16 = 0b1000_0000_0000_0000; - pub const OPCODE_MASK: u16 = 0b0111_1000_0000_0000; - pub const AUTHORITATIVE: u16 = 0b0000_0100_0000_0000; - pub const TRUNCATED: u16 = 0b0000_0010_0000_0000; - pub const RECURSION_DESIRED: u16 = 0b0000_0001_0000_0000; + pub const QUERY: u16 = 0b1000_0000_0000_0000; + pub const OPCODE_MASK: u16 = 0b0111_1000_0000_0000; + pub const AUTHORITATIVE: u16 = 0b0000_0100_0000_0000; + pub const TRUNCATED: u16 = 0b0000_0010_0000_0000; + pub const RECURSION_DESIRED: u16 = 0b0000_0001_0000_0000; pub const RECURSION_AVAILABLE: u16 = 0b0000_0000_1000_0000; - pub const RESERVED_MASK: u16 = 0b0000_0000_0111_0000; - pub const RESPONSE_CODE_MASK: u16 = 0b0000_0000_0000_1111; + pub const RESERVED_MASK: u16 = 0b0000_0000_0111_0000; + pub const RESPONSE_CODE_MASK: u16 = 0b0000_0000_0000_1111; } /// Represents parsed header of the packet @@ -42,13 +42,12 @@ impl Header { let header = Header { id: BigEndian::read_u16(&data[..2]), query: flags & flag::QUERY == 0, - opcode: (flags & flag::OPCODE_MASK - >> flag::OPCODE_MASK.trailing_zeros()).into(), + opcode: (flags & flag::OPCODE_MASK >> flag::OPCODE_MASK.trailing_zeros()).into(), authoritative: flags & flag::AUTHORITATIVE != 0, truncated: flags & flag::TRUNCATED != 0, recursion_desired: flags & flag::RECURSION_DESIRED != 0, recursion_available: flags & flag::RECURSION_AVAILABLE != 0, - response_code: From::from((flags&flag::RESPONSE_CODE_MASK) as u8), + response_code: From::from((flags & flag::RESPONSE_CODE_MASK) as u8), questions: BigEndian::read_u16(&data[4..6]), answers: BigEndian::read_u16(&data[6..8]), nameservers: BigEndian::read_u16(&data[8..10]), @@ -66,14 +65,23 @@ impl Header { panic!("Header size is exactly 12 bytes"); } let mut flags = 0u16; - flags |= Into::::into(self.opcode) - << flag::OPCODE_MASK.trailing_zeros(); + flags |= Into::::into(self.opcode) << flag::OPCODE_MASK.trailing_zeros(); flags |= Into::::into(self.response_code) as u16; - if !self.query { flags |= flag::QUERY; } - if self.authoritative { flags |= flag::AUTHORITATIVE; } - if self.recursion_desired { flags |= flag::RECURSION_DESIRED; } - if self.recursion_available { flags |= flag::RECURSION_AVAILABLE; } - if self.truncated { flags |= flag::TRUNCATED; } + if !self.query { + flags |= flag::QUERY; + } + if self.authoritative { + flags |= flag::AUTHORITATIVE; + } + if self.recursion_desired { + flags |= flag::RECURSION_DESIRED; + } + if self.recursion_available { + flags |= flag::RECURSION_AVAILABLE; + } + if self.truncated { + flags |= flag::TRUNCATED; + } BigEndian::write_u16(&mut data[..2], self.id); BigEndian::write_u16(&mut data[2..4], flags); BigEndian::write_u16(&mut data[4..6], self.questions); @@ -105,7 +113,7 @@ impl Header { pub fn inc_questions(data: &mut [u8]) -> Option { let oldq = BigEndian::read_u16(&data[4..6]); if oldq < 65535 { - BigEndian::write_u16(&mut data[4..6], oldq+1); + BigEndian::write_u16(&mut data[4..6], oldq + 1); Some(oldq + 1) } else { None @@ -115,7 +123,7 @@ impl Header { pub fn inc_answers(data: &mut [u8]) -> Option { let oldq = BigEndian::read_u16(&data[6..8]); if oldq < 65535 { - BigEndian::write_u16(&mut data[6..8], oldq+1); + BigEndian::write_u16(&mut data[6..8], oldq + 1); Some(oldq + 1) } else { None @@ -125,7 +133,7 @@ impl Header { pub fn inc_nameservers(data: &mut [u8]) -> Option { let oldq = BigEndian::read_u16(&data[8..10]); if oldq < 65535 { - BigEndian::write_u16(&mut data[8..10], oldq+1); + BigEndian::write_u16(&mut data[8..10], oldq + 1); Some(oldq + 1) } else { None @@ -136,16 +144,17 @@ impl Header { pub fn inc_additional(data: &mut [u8]) -> Option { let oldq = BigEndian::read_u16(&data[10..12]); if oldq < 65535 { - BigEndian::write_u16(&mut data[10..12], oldq+1); + BigEndian::write_u16(&mut data[10..12], oldq + 1); Some(oldq + 1) } else { None } } - pub fn size() -> usize { 12 } + pub fn size() -> usize { + 12 + } } - #[cfg(test)] mod test { @@ -158,20 +167,23 @@ mod test { let query = b"\x06%\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\ \x07example\x03com\x00\x00\x01\x00\x01"; let header = Header::parse(query).unwrap(); - assert_eq!(header, Header { - id: 1573, - query: true, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: false, - response_code: NoError, - questions: 1, - answers: 0, - nameservers: 0, - additional: 0, - }); + assert_eq!( + header, + Header { + id: 1573, + query: true, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: false, + response_code: NoError, + questions: 1, + answers: 0, + nameservers: 0, + additional: 0, + } + ); } #[test] @@ -181,19 +193,22 @@ mod test { \xc0\x0c\x00\x01\x00\x01\x00\x00\x04\xf8\ \x00\x04]\xb8\xd8\""; let header = Header::parse(response).unwrap(); - assert_eq!(header, Header { - id: 1573, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 1, - nameservers: 0, - additional: 0, - }); + assert_eq!( + header, + Header { + id: 1573, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 1, + nameservers: 0, + additional: 0, + } + ); } } diff --git a/src/dns_parser/mod.rs b/src/dns_parser/mod.rs index a05cb31..23234a3 100644 --- a/src/dns_parser/mod.rs +++ b/src/dns_parser/mod.rs @@ -1,15 +1,15 @@ mod error; -pub use self::error::{Error}; +pub use self::error::Error; mod enums; -pub use self::enums::{Type, QueryType, Class, QueryClass, ResponseCode, Opcode}; +pub use self::enums::{Class, Opcode, QueryClass, QueryType, ResponseCode, Type}; mod structs; -pub use self::structs::{Question, ResourceRecord, Packet}; +pub use self::structs::{Packet, Question, ResourceRecord}; mod name; -pub use self::name::{Name}; -mod parser; +pub use self::name::Name; mod header; -pub use self::header::{Header}; +mod parser; +pub use self::header::Header; mod rrdata; -pub use self::rrdata::{RRData}; +pub use self::rrdata::RRData; mod builder; -pub use self::builder::{Builder, Questions, Answers}; +pub use self::builder::{Answers, Builder, Questions}; diff --git a/src/dns_parser/name.rs b/src/dns_parser/name.rs index eec408e..2869feb 100644 --- a/src/dns_parser/name.rs +++ b/src/dns_parser/name.rs @@ -1,13 +1,13 @@ -use std::io; +use std::borrow::Cow; use std::fmt; use std::fmt::Write; -use std::str::from_utf8; -use std::borrow::Cow; use std::hash; +use std::io; +use std::str::from_utf8; use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; -use super::{Error}; +use super::Error; /// The DNS name as stored in the original packet /// @@ -26,7 +26,7 @@ pub enum Name<'a> { } impl<'a> Name<'a> { - pub fn scan(data: &'a[u8], original: &'a[u8]) -> Result<(Name<'a>, usize), Error> { + pub fn scan(data: &'a [u8], original: &'a [u8]) -> Result<(Name<'a>, usize), Error> { let mut pos = 0; loop { if data.len() <= pos { @@ -34,22 +34,34 @@ impl<'a> Name<'a> { } let byte = data[pos]; if byte == 0 { - return Ok((Name::FromPacket { labels: &data[..pos+1], original: original }, pos + 1)); + return Ok(( + Name::FromPacket { + labels: &data[..pos + 1], + original, + }, + pos + 1, + )); } else if byte & 0b1100_0000 == 0b1100_0000 { - if data.len() < pos+2 { + if data.len() < pos + 2 { return Err(Error::UnexpectedEOF); } - let off = (BigEndian::read_u16(&data[pos..pos+2]) - & !0b1100_0000_0000_0000) as usize; + let off = + (BigEndian::read_u16(&data[pos..pos + 2]) & !0b1100_0000_0000_0000) as usize; if off >= original.len() { return Err(Error::UnexpectedEOF); } // Validate referred to location - try!(Name::scan(&original[off..], original)); - return Ok((Name::FromPacket { labels: &data[..pos+2], original: original }, pos + 2)); + Name::scan(&original[off..], original)?; + return Ok(( + Name::FromPacket { + labels: &data[..pos + 2], + original, + }, + pos + 2, + )); } else if byte & 0b1100_0000 == 0 { let end = pos + byte as usize + 1; - if from_utf8(&data[pos+1..end]).is_err() { + if from_utf8(&data[pos + 1..end]).is_err() { return Err(Error::LabelIsNotAscii); } pos = end; @@ -74,15 +86,18 @@ impl<'a> Name<'a> { loop { let byte = labels[pos]; if byte == 0 { - try!(writer.write_u8(0)); + writer.write_u8(0)?; return Ok(()); } else if byte & 0b1100_0000 == 0b1100_0000 { - let off = (BigEndian::read_u16(&labels[pos..pos+2]) - & !0b1100_0000_0000_0000) as usize; - return Name::scan(&original[off..], original).unwrap().0.write_to(writer) + let off = (BigEndian::read_u16(&labels[pos..pos + 2]) + & !0b1100_0000_0000_0000) as usize; + return Name::scan(&original[off..], original) + .unwrap() + .0 + .write_to(writer); } else if byte & 0b1100_0000 == 0 { let end = pos + byte as usize + 1; - try!(writer.write(&labels[pos..end])); + writer.write_all(&labels[pos..end])?; pos = end; continue; } else { @@ -95,10 +110,10 @@ impl<'a> Name<'a> { for part in name.split('.') { assert!(part.len() < 63); let ln = part.len() as u8; - try!(writer.write_u8(ln)); - try!(writer.write(part.as_bytes())); + writer.write_u8(ln)?; + writer.write_all(part.as_bytes())?; } - try!(writer.write_u8(0)); + writer.write_u8(0)?; Ok(()) } @@ -116,19 +131,21 @@ impl<'a> fmt::Display for Name<'a> { if byte == 0 { return Ok(()); } else if byte & 0b1100_0000 == 0b1100_0000 { - let off = (BigEndian::read_u16(&labels[pos..pos+2]) - & !0b1100_0000_0000_0000) as usize; + let off = (BigEndian::read_u16(&labels[pos..pos + 2]) + & !0b1100_0000_0000_0000) as usize; if pos != 0 { - try!(fmt.write_char('.')); + fmt.write_char('.')?; } return fmt::Display::fmt( - &Name::scan(&original[off..], original).unwrap().0, fmt) + &Name::scan(&original[off..], original).unwrap().0, + fmt, + ); } else if byte & 0b1100_0000 == 0 { if pos != 0 { - try!(fmt.write_char('.')); + fmt.write_char('.')?; } let end = pos + byte as usize + 1; - try!(fmt.write_str(from_utf8(&labels[pos+1..end]).unwrap())); + fmt.write_str(from_utf8(&labels[pos + 1..end]).unwrap())?; pos = end; continue; } else { @@ -137,20 +154,23 @@ impl<'a> fmt::Display for Name<'a> { } } - Name::FromStr(ref name) => fmt.write_str(&name) + Name::FromStr(ref name) => fmt.write_str(&name), } } } -impl <'a> hash::Hash for Name<'a> { - fn hash(&self, state: &mut H) where H: hash::Hasher { +impl<'a> hash::Hash for Name<'a> { + fn hash(&self, state: &mut H) + where + H: hash::Hasher, + { let mut buffer = Vec::new(); self.write_to(&mut buffer).unwrap(); hash::Hash::hash(&buffer, state) } } -impl <'a> PartialEq for Name<'a> { +impl<'a> PartialEq for Name<'a> { fn eq(&self, other: &Name) -> bool { let mut buffer = Vec::new(); self.write_to(&mut buffer).unwrap(); @@ -162,4 +182,4 @@ impl <'a> PartialEq for Name<'a> { } } -impl <'a> Eq for Name<'a> {} +impl<'a> Eq for Name<'a> {} diff --git a/src/dns_parser/parser.rs b/src/dns_parser/parser.rs index d264f14..48e1cb0 100644 --- a/src/dns_parser/parser.rs +++ b/src/dns_parser/parser.rs @@ -2,49 +2,47 @@ use std::i32; use byteorder::{BigEndian, ByteOrder}; -use super::{Header, Packet, Error, Question, Name, QueryType, QueryClass}; -use super::{Type, Class, ResourceRecord, RRData}; - +use super::{Class, RRData, ResourceRecord, Type}; +use super::{Error, Header, Name, Packet, QueryClass, QueryType, Question}; impl<'a> Packet<'a> { pub fn parse(data: &[u8]) -> Result { - let header = try!(Header::parse(data)); + let header = Header::parse(data)?; let mut offset = Header::size(); let mut questions = Vec::with_capacity(header.questions as usize); for _ in 0..header.questions { - let (name, name_size) = try!(Name::scan(&data[offset..], data)); + let (name, name_size) = Name::scan(&data[offset..], data)?; offset += name_size; if offset + 4 > data.len() { return Err(Error::UnexpectedEOF); } - let qtype = try!(QueryType::parse( - BigEndian::read_u16(&data[offset..offset+2]))); + let qtype = QueryType::parse(BigEndian::read_u16(&data[offset..offset + 2]))?; offset += 2; - let qclass_qu = BigEndian::read_u16(&data[offset..offset+2]); - let qclass = try!(QueryClass::parse(qclass_qu & 0x7fff)); + let qclass_qu = BigEndian::read_u16(&data[offset..offset + 2]); + let qclass = QueryClass::parse(qclass_qu & 0x7fff)?; let qu = (qclass_qu & 0x8000) != 0; offset += 2; questions.push(Question { qname: name, - qtype: qtype, - qclass: qclass, - qu: qu, + qtype, + qclass, + qu, }); } let mut answers = Vec::with_capacity(header.answers as usize); for _ in 0..header.answers { - answers.push(try!(parse_record(data, &mut offset))); + answers.push(parse_record(data, &mut offset)?); } let mut nameservers = Vec::with_capacity(header.nameservers as usize); for _ in 0..header.nameservers { - nameservers.push(try!(parse_record(data, &mut offset))); + nameservers.push(parse_record(data, &mut offset)?); } Ok(Packet { - header: header, - questions: questions, - answers: answers, - nameservers: nameservers, + header, + questions, + answers, + nameservers, additional: Vec::new(), // TODO(tailhook) }) } @@ -52,69 +50,69 @@ impl<'a> Packet<'a> { // Generic function to parse answer, nameservers, and additional records. fn parse_record<'a>(data: &'a [u8], offset: &mut usize) -> Result, Error> { - let (name, name_size) = try!(Name::scan(&data[*offset..], data)); + let (name, name_size) = Name::scan(&data[*offset..], data)?; *offset += name_size; if *offset + 10 > data.len() { return Err(Error::UnexpectedEOF); } - let typ = try!(Type::parse( - BigEndian::read_u16(&data[*offset..*offset+2]))); + let typ = Type::parse(BigEndian::read_u16(&data[*offset..*offset + 2]))?; *offset += 2; - let cls = try!(Class::parse( - BigEndian::read_u16(&data[*offset..*offset+2]) & 0x7fff )); + let cls = Class::parse(BigEndian::read_u16(&data[*offset..*offset + 2]) & 0x7fff)?; *offset += 2; - let mut ttl = BigEndian::read_u32(&data[*offset..*offset+4]); + let mut ttl = BigEndian::read_u32(&data[*offset..*offset + 4]); if ttl > i32::MAX as u32 { ttl = 0; } *offset += 4; - let rdlen = BigEndian::read_u16(&data[*offset..*offset+2]) as usize; + let rdlen = BigEndian::read_u16(&data[*offset..*offset + 2]) as usize; *offset += 2; if *offset + rdlen > data.len() { return Err(Error::UnexpectedEOF); } - let data = try!(RRData::parse(typ, - &data[*offset..*offset+rdlen], data)); + let data = RRData::parse(typ, &data[*offset..*offset + rdlen], data)?; *offset += rdlen; Ok(ResourceRecord { - name: name, - cls: cls, - ttl: ttl, - data: data, + name, + cls, + ttl, + data, }) } #[cfg(test)] mod test { - use std::net::{Ipv4Addr, Ipv6Addr}; - use {super::Packet, super::Header}; use super::super::Opcode::*; use super::super::ResponseCode::NoError; - use super::QueryType as QT; - use super::QueryClass as QC; use super::Class as C; + use super::QueryClass as QC; + use super::QueryType as QT; use super::RRData; + use std::net::{Ipv4Addr, Ipv6Addr}; + use {super::Header, super::Packet}; #[test] fn parse_example_query() { let query = b"\x06%\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\ \x07example\x03com\x00\x00\x01\x00\x01"; let packet = Packet::parse(query).unwrap(); - assert_eq!(packet.header, Header { - id: 1573, - query: true, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: false, - response_code: NoError, - questions: 1, - answers: 0, - nameservers: 0, - additional: 0, - }); + assert_eq!( + packet.header, + Header { + id: 1573, + query: true, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: false, + response_code: NoError, + questions: 1, + answers: 0, + nameservers: 0, + additional: 0, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::A); assert_eq!(packet.questions[0].qclass, QC::IN); @@ -129,20 +127,23 @@ mod test { \xc0\x0c\x00\x01\x00\x01\x00\x00\x04\xf8\ \x00\x04]\xb8\xd8\""; let packet = Packet::parse(response).unwrap(); - assert_eq!(packet.header, Header { - id: 1573, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 1, - nameservers: 0, - additional: 0, - }); + assert_eq!( + packet.header, + Header { + id: 1573, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 1, + nameservers: 0, + additional: 0, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::A); assert_eq!(packet.questions[0].qclass, QC::IN); @@ -170,46 +171,49 @@ mod test { \xc0\x42\x00\x02\x00\x01\x00\x01\xd5\xd3\x00\x11\ \x01\x67\x0c\x67\x74\x6c\x64\x2d\x73\x65\x72\x76\x65\x72\x73\ \xc0\x42"; - let packet = Packet::parse(response).unwrap(); - assert_eq!(packet.header, Header { - id: 19184, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 1, - nameservers: 1, - additional: 0, - }); - assert_eq!(packet.questions.len(), 1); - assert_eq!(packet.questions[0].qtype, QT::A); - assert_eq!(packet.questions[0].qclass, QC::IN); - assert_eq!(&packet.questions[0].qname.to_string()[..], "www.skype.com"); - assert_eq!(packet.answers.len(), 1); - assert_eq!(&packet.answers[0].name.to_string()[..], "www.skype.com"); - assert_eq!(packet.answers[0].cls, C::IN); - assert_eq!(packet.answers[0].ttl, 3600); - match packet.answers[0].data { - RRData::CNAME(ref cname) => { - assert_eq!(&cname.to_string()[..], "livecms.trafficmanager.net"); - } - ref x => panic!("Wrong rdata {:?}", x), - } - assert_eq!(packet.nameservers.len(), 1); - assert_eq!(&packet.nameservers[0].name.to_string()[..], "net"); - assert_eq!(packet.nameservers[0].cls, C::IN); - assert_eq!(packet.nameservers[0].ttl, 120275); - match packet.nameservers[0].data { - RRData::NS(ref ns) => { - assert_eq!(&ns.to_string()[..], "g.gtld-servers.net"); - } - ref x => panic!("Wrong rdata {:?}", x), - } - } + let packet = Packet::parse(response).unwrap(); + assert_eq!( + packet.header, + Header { + id: 19184, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 1, + nameservers: 1, + additional: 0, + } + ); + assert_eq!(packet.questions.len(), 1); + assert_eq!(packet.questions[0].qtype, QT::A); + assert_eq!(packet.questions[0].qclass, QC::IN); + assert_eq!(&packet.questions[0].qname.to_string()[..], "www.skype.com"); + assert_eq!(packet.answers.len(), 1); + assert_eq!(&packet.answers[0].name.to_string()[..], "www.skype.com"); + assert_eq!(packet.answers[0].cls, C::IN); + assert_eq!(packet.answers[0].ttl, 3600); + match packet.answers[0].data { + RRData::CNAME(ref cname) => { + assert_eq!(&cname.to_string()[..], "livecms.trafficmanager.net"); + } + ref x => panic!("Wrong rdata {:?}", x), + } + assert_eq!(packet.nameservers.len(), 1); + assert_eq!(&packet.nameservers[0].name.to_string()[..], "net"); + assert_eq!(packet.nameservers[0].cls, C::IN); + assert_eq!(packet.nameservers[0].ttl, 120275); + match packet.nameservers[0].data { + RRData::NS(ref ns) => { + assert_eq!(&ns.to_string()[..], "g.gtld-servers.net"); + } + ref x => panic!("Wrong rdata {:?}", x), + } + } #[test] fn parse_multiple_answers() { @@ -224,20 +228,23 @@ mod test { \xe9\xa4e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\xef\ \x00\x04@\xe9\xa4\x8a"; let packet = Packet::parse(response).unwrap(); - assert_eq!(packet.header, Header { - id: 40425, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 6, - nameservers: 0, - additional: 0, - }); + assert_eq!( + packet.header, + Header { + id: 40425, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 6, + nameservers: 0, + additional: 0, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::A); assert_eq!(packet.questions[0].qclass, QC::IN); @@ -269,25 +276,30 @@ mod test { let query = b"[\xd9\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\ \x0c_xmpp-server\x04_tcp\x05gmail\x03com\x00\x00!\x00\x01"; let packet = Packet::parse(query).unwrap(); - assert_eq!(packet.header, Header { - id: 23513, - query: true, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: false, - response_code: NoError, - questions: 1, - answers: 0, - nameservers: 0, - additional: 0, - }); + assert_eq!( + packet.header, + Header { + id: 23513, + query: true, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: false, + response_code: NoError, + questions: 1, + answers: 0, + nameservers: 0, + additional: 0, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::SRV); assert_eq!(packet.questions[0].qclass, QC::IN); - assert_eq!(&packet.questions[0].qname.to_string()[..], - "_xmpp-server._tcp.gmail.com"); + assert_eq!( + &packet.questions[0].qname.to_string()[..], + "_xmpp-server._tcp.gmail.com" + ); assert_eq!(packet.answers.len(), 0); } @@ -306,25 +318,30 @@ mod test { \xc0\x0c\x00!\x00\x01\x00\x00\x03\x84\x00%\x00\x14\x00\x00\ \x14\x95\x04alt4\x0bxmpp-server\x01l\x06google\x03com\x00"; let packet = Packet::parse(response).unwrap(); - assert_eq!(packet.header, Header { - id: 23513, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 5, - nameservers: 0, - additional: 0, - }); + assert_eq!( + packet.header, + Header { + id: 23513, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 5, + nameservers: 0, + additional: 0, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::SRV); assert_eq!(packet.questions[0].qclass, QC::IN); - assert_eq!(&packet.questions[0].qname.to_string()[..], - "_xmpp-server._tcp.gmail.com"); + assert_eq!( + &packet.questions[0].qname.to_string()[..], + "_xmpp-server._tcp.gmail.com" + ); assert_eq!(packet.answers.len(), 5); let items = vec![ (5, 0, 5269, "xmpp-server.l.google.com"), @@ -334,12 +351,19 @@ mod test { (20, 0, 5269, "alt4.xmpp-server.l.google.com"), ]; for i in 0..5 { - assert_eq!(&packet.answers[i].name.to_string()[..], - "_xmpp-server._tcp.gmail.com"); + assert_eq!( + &packet.answers[i].name.to_string()[..], + "_xmpp-server._tcp.gmail.com" + ); assert_eq!(packet.answers[i].cls, C::IN); assert_eq!(packet.answers[i].ttl, 900); match *&packet.answers[i].data { - RRData::SRV { priority, weight, port, ref target } => { + RRData::SRV { + priority, + weight, + port, + ref target, + } => { assert_eq!(priority, items[i].0); assert_eq!(weight, items[i].1); assert_eq!(port, items[i].2); @@ -361,40 +385,44 @@ mod test { \x00\x04|\x00\t\x00\x14\x04alt2\xc0)\xc0\x0c\x00\x0f\ \x00\x01\x00\x00\x04|\x00\t\x00\x1e\x04alt3\xc0)"; let packet = Packet::parse(response).unwrap(); - assert_eq!(packet.header, Header { - id: 58344, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 5, - nameservers: 0, - additional: 0, - }); + assert_eq!( + packet.header, + Header { + id: 58344, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 5, + nameservers: 0, + additional: 0, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::MX); assert_eq!(packet.questions[0].qclass, QC::IN); - assert_eq!(&packet.questions[0].qname.to_string()[..], - "gmail.com"); + assert_eq!(&packet.questions[0].qname.to_string()[..], "gmail.com"); assert_eq!(packet.answers.len(), 5); let items = vec![ - ( 5, "gmail-smtp-in.l.google.com"), + (5, "gmail-smtp-in.l.google.com"), (10, "alt1.gmail-smtp-in.l.google.com"), (40, "alt4.gmail-smtp-in.l.google.com"), (20, "alt2.gmail-smtp-in.l.google.com"), (30, "alt3.gmail-smtp-in.l.google.com"), ]; for i in 0..5 { - assert_eq!(&packet.answers[i].name.to_string()[..], - "gmail.com"); + assert_eq!(&packet.answers[i].name.to_string()[..], "gmail.com"); assert_eq!(packet.answers[i].cls, C::IN); assert_eq!(packet.answers[i].ttl, 1148); match *&packet.answers[i].data { - RRData::MX { preference, ref exchange } => { + RRData::MX { + preference, + ref exchange, + } => { assert_eq!(preference, items[i].0); assert_eq!(exchange.to_string(), (items[i].1).to_string()); } @@ -410,20 +438,23 @@ mod test { \x00\x8b\x00\x10*\x00\x14P@\t\x08\x12\x00\x00\x00\x00\x00\x00 \x0e"; let packet = Packet::parse(response).unwrap(); - assert_eq!(packet.header, Header { - id: 43481, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 1, - nameservers: 0, - additional: 0, - }); + assert_eq!( + packet.header, + Header { + id: 43481, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 1, + nameservers: 0, + additional: 0, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::AAAA); @@ -435,8 +466,9 @@ mod test { assert_eq!(packet.answers[0].ttl, 139); match packet.answers[0].data { RRData::AAAA(addr) => { - assert_eq!(addr, Ipv6Addr::new( - 0x2A00, 0x1450, 0x4009, 0x812, 0, 0, 0, 0x200e) + assert_eq!( + addr, + Ipv6Addr::new(0x2A00, 0x1450, 0x4009, 0x812, 0, 0, 0, 0x200e) ); } ref x => panic!("Wrong rdata {:?}", x), @@ -458,25 +490,31 @@ mod test { \x00\x99L\x00\x04\xad\xf5;\x04"; let packet = Packet::parse(response).unwrap(); - assert_eq!(packet.header, Header { - id: 64669, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 6, - nameservers: 2, - additional: 2, - }); + assert_eq!( + packet.header, + Header { + id: 64669, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 6, + nameservers: 2, + additional: 2, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::A); assert_eq!(packet.questions[0].qclass, QC::IN); - assert_eq!(&packet.questions[0].qname.to_string()[..], "cdn.sstatic.net"); + assert_eq!( + &packet.questions[0].qname.to_string()[..], + "cdn.sstatic.net" + ); assert_eq!(packet.answers.len(), 6); assert_eq!(&packet.answers[0].name.to_string()[..], "cdn.sstatic.net"); assert_eq!(packet.answers[0].cls, C::IN); @@ -501,7 +539,7 @@ mod test { assert_eq!(packet.answers[i].ttl, 102); match packet.answers[i].data { RRData::A(addr) => { - assert_eq!(addr, ips[i-1]); + assert_eq!(addr, ips[i - 1]); } ref x => panic!("Wrong rdata {:?}", x), } diff --git a/src/dns_parser/rrdata.rs b/src/dns_parser/rrdata.rs index 67ca2f8..f77cc35 100644 --- a/src/dns_parser/rrdata.rs +++ b/src/dns_parser/rrdata.rs @@ -3,8 +3,7 @@ use std::net::{Ipv4Addr, Ipv6Addr}; use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; -use super::{Name, Type, Error}; - +use super::{Error, Name, Type}; /// The enumeration that represents known types of DNS resource records data #[derive(Debug, Clone)] @@ -14,11 +13,22 @@ pub enum RRData<'a> { PTR(Name<'a>), A(Ipv4Addr), AAAA(Ipv6Addr), - SRV { priority: u16, weight: u16, port: u16, target: Name<'a> }, - MX { preference: u16, exchange: Name<'a> }, + SRV { + priority: u16, + weight: u16, + port: u16, + target: Name<'a>, + }, + MX { + preference: u16, + exchange: Name<'a>, + }, TXT(&'a [u8]), // Anything that can't be parsed yet - Unknown { typ: Type, data: &'a [u8] }, + Unknown { + typ: Type, + data: &'a [u8], + }, } impl<'a> RRData<'a> { @@ -38,26 +48,34 @@ impl<'a> RRData<'a> { pub fn write_to(&self, writer: &mut T) -> io::Result<()> { match *self { - RRData::CNAME(ref name) | - RRData::NS(ref name) | - RRData::PTR(ref name) => name.write_to(writer), + RRData::CNAME(ref name) | RRData::NS(ref name) | RRData::PTR(ref name) => { + name.write_to(writer) + } RRData::A(ip) => writer.write_u32::(ip.into()), RRData::AAAA(ip) => { - for segment in ip.segments().into_iter() { - try!(writer.write_u16::(*segment)); + for segment in ip.segments().iter() { + writer.write_u16::(*segment)?; } Ok(()) } - RRData::SRV { priority, weight, port, ref target } => { - try!(writer.write_u16::(priority)); - try!(writer.write_u16::(weight)); - try!(writer.write_u16::(port)); + RRData::SRV { + priority, + weight, + port, + ref target, + } => { + writer.write_u16::(priority)?; + writer.write_u16::(weight)?; + writer.write_u16::(port)?; target.write_to(writer) } - RRData::MX { preference, ref exchange } => { - try!(writer.write_u16::(preference)); + RRData::MX { + preference, + ref exchange, + } => { + writer.write_u16::(preference)?; exchange.write_to(writer) } RRData::TXT(data) => writer.write_all(data), @@ -65,16 +83,13 @@ impl<'a> RRData<'a> { } } - pub fn parse(typ: Type, rdata: &'a [u8], original: &'a [u8]) - -> Result, Error> - { + pub fn parse(typ: Type, rdata: &'a [u8], original: &'a [u8]) -> Result, Error> { match typ { Type::A => { if rdata.len() != 4 { return Err(Error::WrongRdataLength); } - Ok(RRData::A( - Ipv4Addr::from(BigEndian::read_u32(rdata)))) + Ok(RRData::A(Ipv4Addr::from(BigEndian::read_u32(rdata)))) } Type::AAAA => { if rdata.len() != 16 { @@ -91,22 +106,16 @@ impl<'a> RRData<'a> { BigEndian::read_u16(&rdata[14..16]), ))) } - Type::CNAME => { - Ok(RRData::CNAME(try!(Name::scan(rdata, original)).0)) - } - Type::NS => { - Ok(RRData::NS(try!(Name::scan(rdata, original)).0)) - } - Type::PTR => { - Ok(RRData::PTR(try!(Name::scan(rdata, original)).0)) - } + Type::CNAME => Ok(RRData::CNAME(Name::scan(rdata, original)?.0)), + Type::NS => Ok(RRData::NS(Name::scan(rdata, original)?.0)), + Type::PTR => Ok(RRData::PTR(Name::scan(rdata, original)?.0)), Type::MX => { if rdata.len() < 3 { return Err(Error::WrongRdataLength); } Ok(RRData::MX { preference: BigEndian::read_u16(&rdata[..2]), - exchange: try!(Name::scan(&rdata[2..], original)).0, + exchange: Name::scan(&rdata[2..], original)?.0, }) } Type::SRV => { @@ -117,16 +126,11 @@ impl<'a> RRData<'a> { priority: BigEndian::read_u16(&rdata[..2]), weight: BigEndian::read_u16(&rdata[2..4]), port: BigEndian::read_u16(&rdata[4..6]), - target: try!(Name::scan(&rdata[6..], original)).0, + target: Name::scan(&rdata[6..], original)?.0, }) } Type::TXT => Ok(RRData::TXT(rdata)), - typ => { - Ok(RRData::Unknown { - typ: typ, - data: rdata - }) - } + typ => Ok(RRData::Unknown { typ, data: rdata }), } } } diff --git a/src/dns_parser/structs.rs b/src/dns_parser/structs.rs index ac0d125..50bdd44 100644 --- a/src/dns_parser/structs.rs +++ b/src/dns_parser/structs.rs @@ -1,5 +1,4 @@ -use super::{QueryType, QueryClass, Name, Class, Header, RRData}; - +use super::{Class, Header, Name, QueryClass, QueryType, RRData}; /// Parsed DNS packet #[derive(Debug)] diff --git a/src/fsm.rs b/src/fsm.rs index 429d703..7e23b91 100644 --- a/src/fsm.rs +++ b/src/fsm.rs @@ -1,18 +1,20 @@ -use dns_parser::{self, Name, QueryClass, QueryType, RRData}; -use futures::sync::mpsc; -use futures::{Async, Future, Poll, Stream}; +use crate::dns_parser::{self, Name, QueryClass, QueryType, RRData}; +use futures::channel::mpsc; +use futures::future::*; use get_if_addrs::get_if_addrs; +use std::pin::Pin; +use std::task::*; + use std::collections::VecDeque; use std::io; -use std::io::ErrorKind::WouldBlock; use std::marker::PhantomData; use std::net::{IpAddr, SocketAddr}; use tokio::net::UdpSocket; -use tokio::reactor::Handle; use super::{DEFAULT_TTL, MDNS_PORT}; -use address_family::AddressFamily; -use services::{ServiceData, Services}; + +use crate::address_family::AddressFamily; +use crate::services::{ServiceData, Services}; pub type AnswerBuilder = dns_parser::Builder; @@ -26,6 +28,7 @@ pub enum Command { Shutdown, } +#[pin_project::pin_project] pub struct FSM { socket: UdpSocket, services: Services, @@ -34,17 +37,15 @@ pub struct FSM { _af: PhantomData, } + impl FSM { - pub fn new( - handle: &Handle, - services: &Services, - ) -> io::Result<(FSM, mpsc::UnboundedSender)> { + pub fn new(services: &Services) -> io::Result<(FSM, mpsc::UnboundedSender)> { let std_socket = AF::bind()?; - let socket = UdpSocket::from_socket(std_socket, handle)?; + let socket = UdpSocket::from_std(std_socket)?; let (tx, rx) = mpsc::unbounded(); let fsm = FSM { - socket: socket, + socket, services: services.clone(), commands: rx, outgoing: VecDeque::new(), @@ -53,24 +54,31 @@ impl FSM { Ok((fsm, tx)) } +} + + +#[pin_project::project] +impl FSM { - fn recv_packets(&mut self) -> io::Result<()> { + fn recv_packets( + &mut self, + ctx: &mut std::task::Context, + ) -> io::Result<()> { let mut buf = [0u8; 4096]; - loop { - let (bytes, addr) = match self.socket.recv_from(&mut buf) { - Ok((bytes, addr)) => (bytes, addr), - Err(ref ioerr) if ioerr.kind() == WouldBlock => break, - Err(err) => return Err(err), - }; - - if bytes >= buf.len() { - warn!("buffer too small for packet from {:?}", addr); - continue; - } - self.handle_packet(&buf[..bytes], addr); + match self.socket.poll_recv_from(ctx,&mut buf) { + Poll::Ready(Ok((bytes, addr))) => { + if bytes >= buf.len() { + warn!("buffer too small for packet from {:?}", addr); + Ok(()) + } else { + self.handle_packet(&buf[..bytes], addr); + Ok(()) + } + } + Poll::Ready(Err(err)) => Err(err), + _ => Ok(()) } - Ok(()) } fn handle_packet(&mut self, buffer: &[u8], addr: SocketAddr) { @@ -216,15 +224,13 @@ impl FSM { self.outgoing.push_back((response, addr)); } } -} -impl Future for FSM { - type Item = (); - type Error = io::Error; - fn poll(&mut self) -> Poll<(), io::Error> { - while let Async::Ready(cmd) = self.commands.poll().unwrap() { + + fn poll_project(&mut self, ctx: &mut std::task::Context) -> Poll> { + + while let Ok(cmd) = self.commands.try_next() { match cmd { - Some(Command::Shutdown) => return Ok(Async::Ready(())), + Some(Command::Shutdown) => return Poll::Ready(Ok(())), Some(Command::SendUnsolicited { svc, ttl, @@ -234,31 +240,38 @@ impl Future for FSM { } None => { warn!("responder disconnected without shutdown"); - return Ok(Async::Ready(())); + return Poll::Ready(Ok(())); } } } - while let Async::Ready(()) = self.socket.poll_read() { - self.recv_packets()?; - } + self.recv_packets(ctx)?; loop { if let Some(&(ref response, ref addr)) = self.outgoing.front() { trace!("sending packet to {:?}", addr); - match self.socket.send_to(response, addr) { - Ok(_) => (), - Err(ref ioerr) if ioerr.kind() == WouldBlock => break, - Err(err) => warn!("error sending packet {:?}", err), + match self.socket.poll_send_to(ctx, response, addr) { + Poll::Ready(Ok(_)) => { + self.outgoing.pop_front(); + } + Poll::Ready(Err(err)) => { + warn!("error sending packet {:?}", err); + self.outgoing.pop_front(); + } + Poll::Pending => break, } - } else { - break; } - - self.outgoing.pop_front(); } - Ok(Async::NotReady) + Poll::Pending + } +} + +impl Future for FSM { + type Output = Result<(), io::Error>; + + fn poll(self: Pin<&mut Self>, ctx: &mut std::task::Context) -> Poll { + self.project().poll_project(ctx) } } diff --git a/src/lib.rs b/src/lib.rs index 89e1ee1..4131d0a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,26 +11,26 @@ extern crate hostname; extern crate multimap; extern crate net2; extern crate rand; -extern crate tokio_core as tokio; +extern crate tokio; -use futures::sync::mpsc; -use futures::Future; +use futures::channel::mpsc; +use futures::future::{BoxFuture, FutureExt, TryFutureExt}; use std::cell::RefCell; use std::io; use std::sync::{Arc, RwLock}; use std::thread; -use tokio::reactor::{Core, Handle}; +use tokio::runtime::{Handle, Runtime}; mod dns_parser; -use dns_parser::Name; +use crate::dns_parser::Name; mod address_family; mod fsm; mod services; -use address_family::{Inet, Inet6}; -use fsm::{Command, FSM}; -use services::{ServiceData, Services, ServicesInner}; +use crate::address_family::{Inet, Inet6}; +use crate::fsm::{Command, FSM}; +use crate::services::{ServiceData, Services, ServicesInner}; const DEFAULT_TTL: u32 = 60; const MDNS_PORT: u16 = 5353; @@ -48,15 +48,16 @@ pub struct Service { _shutdown: Arc, } -type ResponderTask = Box + Send>; +type ResponderTask = BoxFuture<'static, Result<(), io::Error>>; impl Responder { - fn setup_core() -> io::Result<(Core, ResponderTask, Responder)> { - let core = Core::new()?; - let (responder, task) = Self::with_handle(&core.handle())?; - Ok((core, task, responder)) + fn setup_core() -> io::Result<(Runtime, ResponderTask, Responder)> { + let rt = Runtime::new()?; + let (responder, task) = Self::with_handle()?; + Ok((rt, task, responder)) } + //TODO: this should not create threads etc pub fn new() -> io::Result { let (tx, rx) = std::sync::mpsc::sync_channel(0); thread::Builder::new() @@ -64,7 +65,7 @@ impl Responder { .spawn(move || match Self::setup_core() { Ok((mut core, task, responder)) => { tx.send(Ok(responder)).expect("tx responder channel closed"); - core.run(task).expect("mdns thread failed"); + core.block_on(task).expect("mdns thread failed"); } Err(err) => { tx.send(Err(err)).expect("tx responder channel closed"); @@ -75,15 +76,14 @@ impl Responder { } pub fn spawn(handle: &Handle) -> io::Result { - let (responder, task) = Responder::with_handle(handle)?; + let (responder, task) = Responder::with_handle()?; handle.spawn(task.map_err(|e| { warn!("mdns error {:?}", e); - () })); Ok(responder) } - pub fn with_handle(handle: &Handle) -> io::Result<(Responder, ResponderTask)> { + pub fn with_handle() -> io::Result<(Responder, ResponderTask)> { let mut hostname = match hostname::get() { Ok(s) => match s.into_string() { Ok(s) => s, @@ -102,20 +102,21 @@ impl Responder { let services = Arc::new(RwLock::new(ServicesInner::new(hostname))); - let v4 = FSM::::new(handle, &services); - let v6 = FSM::::new(handle, &services); + let v4 = FSM::::new(&services); + let v6 = FSM::::new(&services); let (task, commands): (ResponderTask, _) = match (v4, v6) { (Ok((v4_task, v4_command)), Ok((v6_task, v6_command))) => { - let task = v4_task.join(v6_task).map(|((), ())| ()); - let task = Box::new(task); + let task = futures::future::join(v4_task, v6_task) + .map(|_| Ok(())) + .boxed(); let commands = vec![v4_command, v6_command]; (task, commands) } (Ok((v4_task, v4_command)), Err(err)) => { warn!("Failed to register IPv6 receiver: {:?}", err); - (Box::new(v4_task), vec![v4_command]) + (v4_task.boxed(), vec![v4_command]) } (Err(err), _) => return Err(err), @@ -123,7 +124,7 @@ impl Responder { let commands = CommandSender(commands); let responder = Responder { - services: services, + services, commands: RefCell::new(commands.clone()), shutdown: Arc::new(Shutdown(commands)), }; @@ -137,13 +138,13 @@ impl Responder { let txt = if txt.is_empty() { vec![0] } else { - txt.into_iter() + txt.iter() .flat_map(|entry| { let entry = entry.as_bytes(); if entry.len() > 255 { panic!("{:?} is too long for a TXT record", entry); } - std::iter::once(entry.len() as u8).chain(entry.into_iter().cloned()) + std::iter::once(entry.len() as u8).chain(entry.iter().cloned()) }) .collect() }; @@ -151,8 +152,8 @@ impl Responder { let svc = ServiceData { typ: Name::from_str(format!("{}.local", svc_type)).unwrap(), name: Name::from_str(format!("{}.{}.local", svc_name, svc_type)).unwrap(), - port: port, - txt: txt, + port, + txt, }; self.commands @@ -162,7 +163,7 @@ impl Responder { let id = self.services.write().unwrap().register(svc); Service { - id: id, + id, commands: self.commands.borrow().clone(), services: self.services.clone(), _shutdown: self.shutdown.clone(), @@ -197,9 +198,9 @@ impl CommandSender { fn send_unsolicited(&mut self, svc: ServiceData, ttl: u32, include_ip: bool) { self.send(Command::SendUnsolicited { - svc: svc, - ttl: ttl, - include_ip: include_ip, + svc, + ttl, + include_ip, }); } diff --git a/src/services.rs b/src/services.rs index 11fb264..a6bb03b 100644 --- a/src/services.rs +++ b/src/services.rs @@ -1,4 +1,4 @@ -use dns_parser::{self, Name, QueryClass, RRData}; +use crate::dns_parser::{self, Name, QueryClass, RRData}; use multimap::MultiMap; use rand::{thread_rng, Rng}; use std::collections::HashMap; @@ -43,7 +43,7 @@ impl ServicesInner { FindByType { services: self, - ids: ids, + ids, } } From 5ce1a0ef990c37877482fdfa5a9ddcf2077f3b39 Mon Sep 17 00:00:00 2001 From: Ryan Roberts Date: Fri, 27 Mar 2020 14:00:08 +0000 Subject: [PATCH 2/8] Remove handle stuff from interface --- examples/register.rs | 15 ++++++--------- src/lib.rs | 39 +++++++++------------------------------ 2 files changed, 15 insertions(+), 39 deletions(-) diff --git a/examples/register.rs b/examples/register.rs index 851b99b..219873e 100644 --- a/examples/register.rs +++ b/examples/register.rs @@ -1,18 +1,15 @@ -extern crate env_logger; -extern crate libmdns; +use env_logger; +use libmdns; -pub fn main() { +#[tokio::main] +pub async fn main() { env_logger::init(); - let responder = libmdns::Responder::new().unwrap(); - let _svc = responder.register( + let responder = libmdns::Responder::new().await.unwrap(); + let svc = responder.register( "_http._tcp".to_owned(), "libmdns Web Server".to_owned(), 80, &["path=/"], ); - - loop { - ::std::thread::sleep(::std::time::Duration::from_secs(10)); - } } diff --git a/src/lib.rs b/src/lib.rs index 4131d0a..6df64d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,12 +14,11 @@ extern crate rand; extern crate tokio; use futures::channel::mpsc; -use futures::future::{BoxFuture, FutureExt, TryFutureExt}; +use futures::future::{BoxFuture, FutureExt}; +use futures::*; use std::cell::RefCell; use std::io; use std::sync::{Arc, RwLock}; -use std::thread; -use tokio::runtime::{Handle, Runtime}; mod dns_parser; use crate::dns_parser::Name; @@ -51,36 +50,16 @@ pub struct Service { type ResponderTask = BoxFuture<'static, Result<(), io::Error>>; impl Responder { - fn setup_core() -> io::Result<(Runtime, ResponderTask, Responder)> { - let rt = Runtime::new()?; - let (responder, task) = Self::with_handle()?; - Ok((rt, task, responder)) - } - //TODO: this should not create threads etc - pub fn new() -> io::Result { - let (tx, rx) = std::sync::mpsc::sync_channel(0); - thread::Builder::new() - .name("mdns-responder".to_owned()) - .spawn(move || match Self::setup_core() { - Ok((mut core, task, responder)) => { - tx.send(Ok(responder)).expect("tx responder channel closed"); - core.block_on(task).expect("mdns thread failed"); - } - Err(err) => { - tx.send(Err(err)).expect("tx responder channel closed"); - } - })?; + pub async fn new() -> io::Result { + let (mut tx, mut rx) = futures::channel::mpsc::channel(0); - rx.recv().expect("rx responder channel closed") - } + let (responder,task) = Self::with_handle()?; + + tokio::spawn(task); - pub fn spawn(handle: &Handle) -> io::Result { - let (responder, task) = Responder::with_handle()?; - handle.spawn(task.map_err(|e| { - warn!("mdns error {:?}", e); - })); - Ok(responder) + tx.send(Ok(responder)).await.expect("tx responder channel closed"); + rx.next().await.expect("rx responder channel closed") } pub fn with_handle() -> io::Result<(Responder, ResponderTask)> { From a875d22cf513610f4d1f0780cddf936ef74c2d63 Mon Sep 17 00:00:00 2001 From: Ryan Roberts Date: Fri, 27 Mar 2020 14:49:40 +0000 Subject: [PATCH 3/8] Fix examples --- Cargo.toml | 2 +- examples/register.rs | 2 ++ src/fsm.rs | 8 +++----- src/lib.rs | 5 +---- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9b28ceb..89e2df5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ log = "0.4.8" multimap = "0.8.0" net2 = "0.2.33" rand = "0.7.3" -tokio = { version = "0.2.13", features = ["udp","stream","io-driver","io-std","macros"] } +tokio = { version = "0.2.13", features = ["udp","stream","io-driver","io-std","macros","signal"] } quick-error = "1.2.3" pin-project = "0.4.8" diff --git a/examples/register.rs b/examples/register.rs index 219873e..a7ee36c 100644 --- a/examples/register.rs +++ b/examples/register.rs @@ -12,4 +12,6 @@ pub async fn main() { 80, &["path=/"], ); + + tokio::signal::ctrl_c().await.unwrap(); } diff --git a/src/fsm.rs b/src/fsm.rs index 7e23b91..d26161a 100644 --- a/src/fsm.rs +++ b/src/fsm.rs @@ -56,7 +56,6 @@ impl FSM { } } - #[pin_project::project] impl FSM { @@ -227,7 +226,7 @@ impl FSM { fn poll_project(&mut self, ctx: &mut std::task::Context) -> Poll> { - + debug!("Poll"); while let Ok(cmd) = self.commands.try_next() { match cmd { Some(Command::Shutdown) => return Poll::Ready(Ok(())), @@ -247,8 +246,8 @@ impl FSM { self.recv_packets(ctx)?; - loop { - if let Some(&(ref response, ref addr)) = self.outgoing.front() { + + while let Some(&(ref response, ref addr)) = self.outgoing.front() { trace!("sending packet to {:?}", addr); match self.socket.poll_send_to(ctx, response, addr) { @@ -261,7 +260,6 @@ impl FSM { } Poll::Pending => break, } - } } Poll::Pending diff --git a/src/lib.rs b/src/lib.rs index 6df64d7..ba1cef4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,14 +52,11 @@ type ResponderTask = BoxFuture<'static, Result<(), io::Error>>; impl Responder { pub async fn new() -> io::Result { - let (mut tx, mut rx) = futures::channel::mpsc::channel(0); - let (responder,task) = Self::with_handle()?; tokio::spawn(task); - tx.send(Ok(responder)).await.expect("tx responder channel closed"); - rx.next().await.expect("rx responder channel closed") + Ok(responder) } pub fn with_handle() -> io::Result<(Responder, ResponderTask)> { From 456c010b7d0385b4f86026103691431686534b9e Mon Sep 17 00:00:00 2001 From: Ryan Roberts Date: Fri, 27 Mar 2020 14:49:40 +0000 Subject: [PATCH 4/8] Fix examples --- src/fsm.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fsm.rs b/src/fsm.rs index d26161a..a05970b 100644 --- a/src/fsm.rs +++ b/src/fsm.rs @@ -226,7 +226,6 @@ impl FSM { fn poll_project(&mut self, ctx: &mut std::task::Context) -> Poll> { - debug!("Poll"); while let Ok(cmd) = self.commands.try_next() { match cmd { Some(Command::Shutdown) => return Poll::Ready(Ok(())), From 0a6b8847263d3200e210a55f135b68eba13c3311 Mon Sep 17 00:00:00 2001 From: Ryan Roberts Date: Fri, 27 Mar 2020 15:56:32 +0000 Subject: [PATCH 5/8] No need for this to be async other than to force runtime --- src/fsm.rs | 37 +++++++++++++++---------------------- src/lib.rs | 5 ++--- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/src/fsm.rs b/src/fsm.rs index a05970b..46e3fea 100644 --- a/src/fsm.rs +++ b/src/fsm.rs @@ -37,7 +37,6 @@ pub struct FSM { _af: PhantomData, } - impl FSM { pub fn new(services: &Services) -> io::Result<(FSM, mpsc::UnboundedSender)> { let std_socket = AF::bind()?; @@ -58,14 +57,10 @@ impl FSM { #[pin_project::project] impl FSM { - - fn recv_packets( - &mut self, - ctx: &mut std::task::Context, - ) -> io::Result<()> { + fn recv_packets(&mut self, ctx: &mut std::task::Context) -> io::Result<()> { let mut buf = [0u8; 4096]; - match self.socket.poll_recv_from(ctx,&mut buf) { + match self.socket.poll_recv_from(ctx, &mut buf) { Poll::Ready(Ok((bytes, addr))) => { if bytes >= buf.len() { warn!("buffer too small for packet from {:?}", addr); @@ -76,7 +71,7 @@ impl FSM { } } Poll::Ready(Err(err)) => Err(err), - _ => Ok(()) + _ => Ok(()), } } @@ -224,8 +219,7 @@ impl FSM { } } - - fn poll_project(&mut self, ctx: &mut std::task::Context) -> Poll> { + fn poll_project(&mut self, ctx: &mut std::task::Context) -> Poll> { while let Ok(cmd) = self.commands.try_next() { match cmd { Some(Command::Shutdown) => return Poll::Ready(Ok(())), @@ -245,20 +239,19 @@ impl FSM { self.recv_packets(ctx)?; + while let Some(&(ref response, ref addr)) = self.outgoing.front() { + trace!("sending packet to {:?}", addr); - while let Some(&(ref response, ref addr)) = self.outgoing.front() { - trace!("sending packet to {:?}", addr); - - match self.socket.poll_send_to(ctx, response, addr) { - Poll::Ready(Ok(_)) => { - self.outgoing.pop_front(); - } - Poll::Ready(Err(err)) => { - warn!("error sending packet {:?}", err); - self.outgoing.pop_front(); - } - Poll::Pending => break, + match self.socket.poll_send_to(ctx, response, addr) { + Poll::Ready(Ok(_)) => { + self.outgoing.pop_front(); + } + Poll::Ready(Err(err)) => { + warn!("error sending packet {:?}", err); + self.outgoing.pop_front(); } + Poll::Pending => break, + } } Poll::Pending diff --git a/src/lib.rs b/src/lib.rs index ba1cef4..b7192c9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,9 +50,8 @@ pub struct Service { type ResponderTask = BoxFuture<'static, Result<(), io::Error>>; impl Responder { - - pub async fn new() -> io::Result { - let (responder,task) = Self::with_handle()?; + pub fn new() -> io::Result { + let (responder, task) = Self::with_handle()?; tokio::spawn(task); From 5a8b5b302d143aaedb08a94cf8392bc4cf0098d7 Mon Sep 17 00:00:00 2001 From: Ryan Roberts Date: Wed, 1 Apr 2020 12:13:25 +0100 Subject: [PATCH 6/8] Add builder, conditional v4 / v6 --- examples/register.rs | 2 +- src/lib.rs | 79 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 72 insertions(+), 9 deletions(-) diff --git a/examples/register.rs b/examples/register.rs index a7ee36c..e34fbee 100644 --- a/examples/register.rs +++ b/examples/register.rs @@ -5,7 +5,7 @@ use libmdns; pub async fn main() { env_logger::init(); - let responder = libmdns::Responder::new().await.unwrap(); + let responder = libmdns::Responder::new().unwrap(); let svc = responder.register( "_http._tcp".to_owned(), "libmdns Web Server".to_owned(), diff --git a/src/lib.rs b/src/lib.rs index b7192c9..701231f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,7 +15,6 @@ extern crate tokio; use futures::channel::mpsc; use futures::future::{BoxFuture, FutureExt}; -use futures::*; use std::cell::RefCell; use std::io; use std::sync::{Arc, RwLock}; @@ -34,6 +33,35 @@ use crate::services::{ServiceData, Services, ServicesInner}; const DEFAULT_TTL: u32 = 60; const MDNS_PORT: u16 = 5353; + +pub struct ResponderBuilder { + v4: bool, + v6: bool +} + +impl ResponderBuilder { + pub fn new() -> Self { + ResponderBuilder { + v4:true, + v6:true, + } + } + + pub fn use_v6(mut self,use_v6: bool) -> Self { + self.v6 = use_v6; + self + } + + pub fn use_v4(mut self,use_v4: bool) -> Self { + self.v4 = use_v4; + self + } + + pub fn build(self) -> io::Result { + Responder::start(self) + } +} + pub struct Responder { services: Services, commands: RefCell, @@ -51,14 +79,21 @@ type ResponderTask = BoxFuture<'static, Result<(), io::Error>>; impl Responder { pub fn new() -> io::Result { - let (responder, task) = Self::with_handle()?; + Self::builder().build() + } + pub fn builder() -> ResponderBuilder { + ResponderBuilder::new() + } + + fn start(builder: ResponderBuilder) -> io::Result { + let (responder, task) = Self::create_task(builder)?; tokio::spawn(task); Ok(responder) } - pub fn with_handle() -> io::Result<(Responder, ResponderTask)> { + fn create_task(builder: ResponderBuilder) -> io::Result<(Responder, ResponderTask)> { let mut hostname = match hostname::get() { Ok(s) => match s.into_string() { Ok(s) => s, @@ -77,11 +112,24 @@ impl Responder { let services = Arc::new(RwLock::new(ServicesInner::new(hostname))); - let v4 = FSM::::new(&services); - let v6 = FSM::::new(&services); + let v4 = { + if builder.v4 { + Some(FSM::::new(&services)) + } else { + None + } + }; + + let v6 = { + if builder.v6 { + Some(FSM::::new(&services)) + } else { + None + } + }; let (task, commands): (ResponderTask, _) = match (v4, v6) { - (Ok((v4_task, v4_command)), Ok((v6_task, v6_command))) => { + (Some(Ok((v4_task, v4_command))), Some(Ok((v6_task, v6_command)))) => { let task = futures::future::join(v4_task, v6_task) .map(|_| Ok(())) .boxed(); @@ -89,12 +137,27 @@ impl Responder { (task, commands) } - (Ok((v4_task, v4_command)), Err(err)) => { + (Some(Ok((v4_task, v4_command))), Some(Err(err))) => { warn!("Failed to register IPv6 receiver: {:?}", err); (v4_task.boxed(), vec![v4_command]) } - (Err(err), _) => return Err(err), + (Some(Err(err)),Some(Ok((v6_task, v6_command)))) => { + warn!("Failed to register IPv4 receiver: {:?}", err); + (v6_task.boxed(), vec![v6_command]) + } + + (None,Some(Ok((v6_task, v6_command)))) => { + (v6_task.boxed(), vec![v6_command]) + } + + (Some(Ok((v4_task, v4_command))),None) => { + (v4_task.boxed(), vec![v4_command]) + } + + (_, Some(Err(err))) => return Err(err), + (Some(Err(err)), _) => return Err(err), + (None,None) => panic!("No v4 or v6 responder configured") }; let commands = CommandSender(commands); From c36531a17584986e83b79ce582d99986342f550e Mon Sep 17 00:00:00 2001 From: Ryan Roberts Date: Wed, 1 Apr 2020 15:07:08 +0100 Subject: [PATCH 7/8] Log fsm init, remove v6 from example --- examples/register.rs | 8 ++++++-- src/fsm.rs | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/register.rs b/examples/register.rs index e34fbee..463eff9 100644 --- a/examples/register.rs +++ b/examples/register.rs @@ -5,8 +5,12 @@ use libmdns; pub async fn main() { env_logger::init(); - let responder = libmdns::Responder::new().unwrap(); - let svc = responder.register( + let responder = libmdns::Responder::builder() + .use_v6(false) + .build() + .unwrap(); + + let _svc = responder.register( "_http._tcp".to_owned(), "libmdns Web Server".to_owned(), 80, diff --git a/src/fsm.rs b/src/fsm.rs index 46e3fea..9b9ecbc 100644 --- a/src/fsm.rs +++ b/src/fsm.rs @@ -51,6 +51,8 @@ impl FSM { _af: PhantomData, }; + info!("Started {} fsm",if AF::v6() {"v6"} else {"v4"}); + Ok((fsm, tx)) } } From 1ef93ec0d2eaae6a1e270dec38b3c380f1363c30 Mon Sep 17 00:00:00 2001 From: Ryan Roberts Date: Fri, 3 Apr 2020 16:06:27 +0100 Subject: [PATCH 8/8] Edge triggered, so must poll recv until done, oops --- examples/register.rs | 5 +---- src/fsm.rs | 38 ++++++++++++++++++++++++-------------- src/lib.rs | 26 +++++++++----------------- 3 files changed, 34 insertions(+), 35 deletions(-) diff --git a/examples/register.rs b/examples/register.rs index 463eff9..cb7026a 100644 --- a/examples/register.rs +++ b/examples/register.rs @@ -5,10 +5,7 @@ use libmdns; pub async fn main() { env_logger::init(); - let responder = libmdns::Responder::builder() - .use_v6(false) - .build() - .unwrap(); + let responder = libmdns::Responder::builder().use_v6(false).build().unwrap(); let _svc = responder.register( "_http._tcp".to_owned(), diff --git a/src/fsm.rs b/src/fsm.rs index 9b9ecbc..ce0af67 100644 --- a/src/fsm.rs +++ b/src/fsm.rs @@ -41,6 +41,13 @@ impl FSM { pub fn new(services: &Services) -> io::Result<(FSM, mpsc::UnboundedSender)> { let std_socket = AF::bind()?; let socket = UdpSocket::from_std(std_socket)?; + + if AF::v6() { + socket.set_multicast_loop_v6(true)?; + } else { + socket.set_multicast_loop_v4(true)?; + } + let (tx, rx) = mpsc::unbounded(); let fsm = FSM { @@ -51,7 +58,7 @@ impl FSM { _af: PhantomData, }; - info!("Started {} fsm",if AF::v6() {"v6"} else {"v4"}); + info!("Started {} fsm", if AF::v6() { "v6" } else { "v4" }); Ok((fsm, tx)) } @@ -59,21 +66,24 @@ impl FSM { #[pin_project::project] impl FSM { - fn recv_packets(&mut self, ctx: &mut std::task::Context) -> io::Result<()> { + fn recv_packets(&mut self, ctx: &mut std::task::Context) { let mut buf = [0u8; 4096]; - match self.socket.poll_recv_from(ctx, &mut buf) { - Poll::Ready(Ok((bytes, addr))) => { - if bytes >= buf.len() { - warn!("buffer too small for packet from {:?}", addr); - Ok(()) - } else { - self.handle_packet(&buf[..bytes], addr); - Ok(()) + loop { + match self.socket.poll_recv_from(ctx, &mut buf) { + Poll::Ready(Ok((bytes, addr))) => { + if bytes >= buf.len() { + warn!("buffer too small for packet from {:?}", addr); + } else { + trace!("Handle packets"); + self.handle_packet(&buf[..bytes], addr); + } + } + Poll::Ready(Err(err)) => { + error!("Recv error {}", err); } + _ => break, } - Poll::Ready(Err(err)) => Err(err), - _ => Ok(()), } } @@ -89,7 +99,7 @@ impl FSM { }; if !packet.header.query { - trace!("received packet from {:?} with no query", addr); + trace!("received packet {:?} from {:?} with no query", packet, addr); return; } @@ -239,7 +249,7 @@ impl FSM { } } - self.recv_packets(ctx)?; + self.recv_packets(ctx); while let Some(&(ref response, ref addr)) = self.outgoing.front() { trace!("sending packet to {:?}", addr); diff --git a/src/lib.rs b/src/lib.rs index 701231f..c82d3d1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,26 +33,22 @@ use crate::services::{ServiceData, Services, ServicesInner}; const DEFAULT_TTL: u32 = 60; const MDNS_PORT: u16 = 5353; - pub struct ResponderBuilder { v4: bool, - v6: bool + v6: bool, } -impl ResponderBuilder { +impl ResponderBuilder { pub fn new() -> Self { - ResponderBuilder { - v4:true, - v6:true, - } + ResponderBuilder { v4: true, v6: true } } - pub fn use_v6(mut self,use_v6: bool) -> Self { + pub fn use_v6(mut self, use_v6: bool) -> Self { self.v6 = use_v6; self } - pub fn use_v4(mut self,use_v4: bool) -> Self { + pub fn use_v4(mut self, use_v4: bool) -> Self { self.v4 = use_v4; self } @@ -142,22 +138,18 @@ impl Responder { (v4_task.boxed(), vec![v4_command]) } - (Some(Err(err)),Some(Ok((v6_task, v6_command)))) => { + (Some(Err(err)), Some(Ok((v6_task, v6_command)))) => { warn!("Failed to register IPv4 receiver: {:?}", err); (v6_task.boxed(), vec![v6_command]) } - (None,Some(Ok((v6_task, v6_command)))) => { - (v6_task.boxed(), vec![v6_command]) - } + (None, Some(Ok((v6_task, v6_command)))) => (v6_task.boxed(), vec![v6_command]), - (Some(Ok((v4_task, v4_command))),None) => { - (v4_task.boxed(), vec![v4_command]) - } + (Some(Ok((v4_task, v4_command))), None) => (v4_task.boxed(), vec![v4_command]), (_, Some(Err(err))) => return Err(err), (Some(Err(err)), _) => return Err(err), - (None,None) => panic!("No v4 or v6 responder configured") + (None, None) => panic!("No v4 or v6 responder configured"), }; let commands = CommandSender(commands);