diff --git a/TODO.md b/TODO.md index 70f0909..61bce67 100644 --- a/TODO.md +++ b/TODO.md @@ -4,11 +4,10 @@ Not all of them are desirable or necessary. - Add a GUI. -- Split off more functionality from gday. +- Confirm gday server works properly on both ipv4 and ipv6. Maybe add a test. -- Test the contact sharing and hole-punching. +- Confirm that TLS close is now properly sent, and no errors are logged. -- Make sure that gday server still lists the file name during file error. ## Abandoned ideas @@ -21,8 +20,7 @@ Not all of them are desirable or necessary. the peer's device is acting as some sort of server. - Allow sending a simple text string instead of only files. - Though, I don't think this is a common use case, so will only - add if I get requests. + Though, I don't think this is a common use case. - Let the client select a source port, to utilize port forwarding. However, turns out port forwarding works for inbound connections, @@ -35,3 +33,6 @@ Not all of them are desirable or necessary. - Make file transfer response msg list file names, instead of going by index? I don't really see the advantage to doing this. + +- Support a shared secret longer than u64 in peer code? But then again, + users can just send their own struct in this case. \ No newline at end of file diff --git a/gday/src/main.rs b/gday/src/main.rs index b6d4310..2b97b7b 100644 --- a/gday/src/main.rs +++ b/gday/src/main.rs @@ -101,10 +101,8 @@ fn run(args: crate::Args) -> Result<(), Box> { // get the server port let port = if let Some(port) = args.port { port - } else if args.unencrypted { - server_connector::DEFAULT_TCP_PORT } else { - server_connector::DEFAULT_TLS_PORT + server_connector::DEFAULT_PORT }; // use custom server if the user provided one, @@ -242,5 +240,8 @@ fn run(args: crate::Args) -> Result<(), Box> { } } } + + server_connection.notify_tls_close()?; + Ok(()) } diff --git a/gday_contact_exchange_protocol/src/lib.rs b/gday_contact_exchange_protocol/src/lib.rs index a46e711..9db9bb8 100644 --- a/gday_contact_exchange_protocol/src/lib.rs +++ b/gday_contact_exchange_protocol/src/lib.rs @@ -81,13 +81,9 @@ use std::{ }; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -/// The port that contact exchange servers -/// using unencrypted TCP should listen on. -pub const DEFAULT_TCP_PORT: u16 = 2310; - /// The port that contact exchange servers /// using encrypted TLS should listen on. -pub const DEFAULT_TLS_PORT: u16 = 2311; +pub const DEFAULT_PORT: u16 = 2311; /// A message from client to server. #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone, Copy)] @@ -193,26 +189,27 @@ impl Display for ServerMsg { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::RoomCreated => write!(f, "Room in server created successfully."), - Self::ReceivedAddr => write!(f, "Server received your address."), + Self::ReceivedAddr => write!(f, "Server recorded your public address."), Self::ClientContact(c) => write!(f, "The server says your contact is {c}."), Self::PeerContact(c) => write!(f, "The server says your peer's contact is {c}."), Self::ErrorRoomTaken => write!( f, - "Can't create room with this id, because it was already created." + "Can't create room with this code, because it was already created." ), Self::ErrorPeerTimedOut => write!( f, "Timed out while waiting for peer to finish sending their address." ), - Self::ErrorNoSuchRoomCode => write!( + Self::ErrorNoSuchRoomCode => write!(f, "No room with this room code has been created."), + Self::ErrorUnexpectedMsg => write!( f, - "Can't join room with this id, because it hasn't been created." + "Server received RecordPublicAddr message after a ReadyToShare message. \ + Maybe someone else tried to join this room with your identity?" ), - Self::ErrorUnexpectedMsg => write!( + Self::ErrorTooManyRequests => write!( f, - "Server received RecordPublicAddr message after a ReadyToShare message." + "Exceeded request limit from this IP address. Try again in a minute." ), - Self::ErrorTooManyRequests => write!(f, "Too many requests from this IP address."), Self::ErrorSyntax => write!(f, "Server couldn't parse message syntax from client."), Self::ErrorConnection => write!(f, "Connection error to client."), Self::ErrorInternal => write!(f, "Internal server error."), diff --git a/gday_hole_punch/src/peer_code.rs b/gday_hole_punch/src/peer_code.rs index dab3f6a..e396dc6 100644 --- a/gday_hole_punch/src/peer_code.rs +++ b/gday_hole_punch/src/peer_code.rs @@ -1,10 +1,11 @@ -use serde::{Deserialize, Serialize}; - use crate::Error; +use serde::{Deserialize, Serialize}; +use std::fmt::Display; +use std::str::FromStr; /// Info that 2 peers must share before they can exchange contacts. /// -/// Use `PeerCode.to_string()` and `PeerCode::from_str()` +/// Use [`PeerCode::fmt()`] and [`PeerCode::try_from()`] /// to convert to and from a short human-readable code. #[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)] pub struct PeerCode { @@ -12,7 +13,6 @@ pub struct PeerCode { /// that the peers will connect to. /// /// Use `0` to indicate a custom server. - /// /// Pass to [`crate::server_connector::connect_to_server_id()`] /// to connect to the server. pub server_id: u64, @@ -24,7 +24,7 @@ pub struct PeerCode { pub room_code: u64, /// The shared secret that the peers will use to confirm - /// each other's identity during hole-punching. + /// each other's identity. /// /// Pass to [`crate::try_connect_to_peer()`] to authenticate /// the other peer when hole-punching. @@ -39,7 +39,7 @@ impl PeerCode { } } -impl std::fmt::Display for PeerCode { +impl Display for PeerCode { /// Display as `"server_id.room_code.shared_secret.checksum"` /// where each field is in hexadecimal form. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -107,7 +107,7 @@ impl TryFrom<&str> for PeerCode { /// /// The checksum is optional. fn try_from(str: &str) -> Result { - str.parse() + Self::from_str(str) } } diff --git a/gday_hole_punch/src/server_connector.rs b/gday_hole_punch/src/server_connector.rs index b082928..7cd377c 100644 --- a/gday_hole_punch/src/server_connector.rs +++ b/gday_hole_punch/src/server_connector.rs @@ -13,8 +13,7 @@ use std::{ time::Duration, }; -pub use gday_contact_exchange_protocol::DEFAULT_TCP_PORT; -pub use gday_contact_exchange_protocol::DEFAULT_TLS_PORT; +pub use gday_contact_exchange_protocol::DEFAULT_PORT; /// List of default public Gday servers. /// @@ -62,7 +61,7 @@ pub enum ServerStream { impl ServerStream { /// Returns the local socket address of this stream. - fn local_addr(&self) -> std::io::Result { + pub fn local_addr(&self) -> std::io::Result { match self { Self::TCP(tcp) => tcp.local_addr(), Self::TLS(tls) => tls.get_ref().local_addr(), @@ -189,6 +188,23 @@ impl ServerConnection { Ok(contact) } + + /// Sends a `close_notify` warning over TLS. + /// Does nothing for TCP connections. + /// + /// This should be called before dropping + /// [`ServerConnection`]. + pub fn notify_tls_close(&mut self) -> std::io::Result<()> { + if let Some(ServerStream::TLS(tls)) = &mut self.v4 { + tls.conn.send_close_notify(); + tls.conn.complete_io(&mut tls.sock)?; + } + if let Some(ServerStream::TLS(tls)) = &mut self.v6 { + tls.conn.send_close_notify(); + tls.conn.complete_io(&mut tls.sock)?; + } + Ok(()) + } } /// In random order, sequentially try connecting to the given `servers`. @@ -226,7 +242,7 @@ pub fn connect_to_server_id( let Some(server) = servers.iter().find(|server| server.id == server_id) else { return Err(Error::ServerIDNotFound(server_id)); }; - connect_tls(server.domain_name.to_string(), DEFAULT_TLS_PORT, timeout) + connect_tls(server.domain_name.to_string(), DEFAULT_PORT, timeout) } /// In random order, sequentially tries connecting to the given `domain_names`. @@ -249,11 +265,11 @@ pub fn connect_to_random_domain_name( for i in indices { let server = domain_names[i]; - let streams = match connect_tls(server.to_string(), DEFAULT_TLS_PORT, timeout) { + let streams = match connect_tls(server.to_string(), DEFAULT_PORT, timeout) { Ok(streams) => streams, Err(err) => { recent_error = err; - warn!("Couldn't connect to \"{server}:{DEFAULT_TLS_PORT}\": {recent_error}"); + warn!("Couldn't connect to \"{server}:{DEFAULT_PORT}\": {recent_error}"); continue; } }; diff --git a/gday_hole_punch/tests/test_integration.rs b/gday_hole_punch/tests/test_integration.rs index 992cf95..059ca08 100644 --- a/gday_hole_punch/tests/test_integration.rs +++ b/gday_hole_punch/tests/test_integration.rs @@ -14,13 +14,14 @@ async fn test_integration() { key: None, certificate: None, unencrypted: true, - address: Some("[::]:0".to_string()), + addresses: vec!["0.0.0.0:0".parse().unwrap(), "[::]:0".parse().unwrap()], timeout: 3600, request_limit: 10, verbosity: log::LevelFilter::Off, }; - let (server_future, server_addr) = gday_server::start_server(args).unwrap(); - let handle = tokio::spawn(server_future); + let (server_addrs, _joinset) = gday_server::start_server(args).unwrap(); + + let server_addr_1 = server_addrs[0]; tokio::task::spawn_blocking(move || { let timeout = std::time::Duration::from_secs(5); @@ -39,7 +40,7 @@ async fn test_integration() { // Connect to the server let mut server_connection = - server_connector::connect_tcp(server_addr, timeout).unwrap(); + server_connector::connect_tcp(server_addr_1, timeout).unwrap(); // Create a room in the server, and get my contact from it let (contact_sharer, my_contact) = @@ -77,7 +78,7 @@ async fn test_integration() { let peer_code = PeerCode::from_str(&received_code).unwrap(); // Connect to the same server as Peer 1 - let mut server_connection = server_connector::connect_tcp(server_addr, timeout).unwrap(); + let mut server_connection = server_connector::connect_tcp(server_addr_1, timeout).unwrap(); // Join the same room in the server, and get my local contact let (contact_sharer, my_contact) = @@ -107,7 +108,4 @@ async fn test_integration() { }) .await .unwrap(); - - // stop the server - handle.abort(); } diff --git a/gday_server/src/lib.rs b/gday_server/src/lib.rs index 10d1be8..dea574a 100644 --- a/gday_server/src/lib.rs +++ b/gday_server/src/lib.rs @@ -10,18 +10,17 @@ mod state; use clap::Parser; use connection_handler::handle_connection; -use gday_contact_exchange_protocol::{DEFAULT_TCP_PORT, DEFAULT_TLS_PORT}; use log::{debug, error, info, warn}; -use socket2::{SockRef, TcpKeepalive}; +use socket2::{Domain, Protocol, TcpKeepalive, Type}; use state::State; -use std::net::{SocketAddr, ToSocketAddrs}; +use std::net::SocketAddr; use std::{ - fmt::Display, io::{BufReader, ErrorKind}, path::{Path, PathBuf}, sync::Arc, time::Duration, }; +use tokio::task::JoinSet; use tokio_rustls::{ rustls::{self, pki_types::CertificateDer}, TlsAcceptor, @@ -42,18 +41,19 @@ pub struct Args { #[arg(short, long, conflicts_with_all(["key", "certificate"]))] pub unencrypted: bool, - /// Custom socket address on which to listen. - /// [default: `[::]:2311` for TLS, `[::]:2310` when --unencrypted] - #[arg(short, long)] - pub address: Option, + /// Socket addresses on which to listen. + #[arg(short, long, default_values = ["0.0.0.0:2311", "[::]:2311"])] + pub addresses: Vec, /// Number of seconds before a new room is deleted - #[arg(short, long, default_value = "3600")] + #[arg(short, long, default_value = "600")] pub timeout: u64, - /// Max number of requests an IP address can - /// send in a minute before they're rejected - #[arg(short, long, default_value = "60")] + /// Max number of create room requests and + /// requests with an invalid room code + /// an IP address can send per minute + /// before they're rejected. + #[arg(short, long, default_value = "10")] pub request_limit: u32, /// Log verbosity. (trace, debug, info, warn, error) @@ -61,13 +61,13 @@ pub struct Args { pub verbosity: log::LevelFilter, } -/// Run a gday server. +/// Spawns a tokio server in the background. /// -/// `server_started` will send as soon as the server is ready to accept -/// requests. -pub fn start_server( - args: Args, -) -> Result<(impl std::future::Future, SocketAddr), Error> { +/// Returns the addresses that the server is listening on and +/// a [`JoinSet`] of tasks, one for each address being listened on. +/// +/// Must be called from a tokio async context. +pub fn start_server(args: Args) -> Result<(Vec, JoinSet<()>), Error> { // set the log level according to the command line argument if let Err(err) = env_logger::builder() .filter_level(args.verbosity) @@ -76,20 +76,22 @@ pub fn start_server( error!("Non-fatal error. Couldn't initialize logger: {err}") } - let addr = if let Some(addr) = args.address { - addr - } else if args.unencrypted { - format!("[::]:{DEFAULT_TCP_PORT}") - } else { - format!("[::]:{DEFAULT_TLS_PORT}") - }; + // get TCP listeners + let tcp_listeners: Result, Error> = + args.addresses.into_iter().map(get_tcp_listener).collect(); + let tcp_listeners = tcp_listeners?; - // get tcp listener - let tcp_listener = get_tcp_listener(addr)?; + // get the addresses that we've actually bound to + let addresses: std::io::Result> = + tcp_listeners.iter().map(|l| l.local_addr()).collect(); + let addresses = addresses.map_err(|source| Error { + msg: "Couldn't determine local address".to_string(), + source, + })?; // get the TLS acceptor if applicable - let tls_acceptor = if let (Some(k), Some(c)) = (args.key, args.certificate) { - Some(get_tls_acceptor(&k, &c)?) + let tls_acceptor = if let (Some(key), Some(cert)) = (args.key, args.certificate) { + Some(get_tls_acceptor(&key, &cert)?) } else { None }; @@ -100,29 +102,34 @@ pub fn start_server( std::time::Duration::from_secs(args.timeout), ); - // log starting information - let local_addr = tcp_listener.local_addr().map_err(|source| Error { - msg: "Couldn't determine local address".to_string(), - source, - })?; - info!("Listening on {local_addr}.",); + // log the addresses being listened on + info!("Listening on these addresses: {addresses:?}"); + info!("Is encrypted?: {}", tls_acceptor.is_some()); info!( - "Requests per minute per IP address limit: {}", + "Critical requests per minute per IP address limit: {}", args.request_limit ); info!( "Number of seconds before a new room is deleted: {}", args.timeout ); - info!("Server started."); + info!("Server is now running."); - let server = run_server(state, tcp_listener, tls_acceptor); + let mut joinset = JoinSet::new(); - Ok((server, local_addr)) + for tcp_listener in tcp_listeners { + joinset.spawn(run_single_server( + state.clone(), + tcp_listener, + tls_acceptor.clone(), + )); + } + + Ok((addresses, joinset)) } -async fn run_server( +async fn run_single_server( state: State, tcp_listener: tokio::net::TcpListener, tls_acceptor: Option, @@ -147,28 +154,34 @@ async fn run_server( } } -/// Returns a [`TcpListener`] with the provided address. +/// Returns a [`tokio::net::TcpListener`] with the provided address. /// /// Sets the socket's TCP keepalive so that unresponsive /// connections close after 10 minutes to save resources. -fn get_tcp_listener(addr: impl ToSocketAddrs + Display) -> Result { - // binds to the socket address - let listener = std::net::TcpListener::bind(&addr).map_err(|source| Error { - msg: format!("Can't listen on '{addr}'"), - source, - })?; +/// +/// TODO: Check if I need to check for eintr errors? +fn get_tcp_listener(addr: SocketAddr) -> Result { + // create a socket + let socket = socket2::Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP)) + .map_err(|source| Error { + msg: "Couldn't create TCP socket".to_string(), + source, + })?; - // make the listener non-blocking - listener.set_nonblocking(true).map_err(|source| Error { - msg: "Couldn't set TCP listener to non-blocking".to_string(), - source, - })?; + // if this is an IPv6 listener, make it not listen + // to IPv4. + if addr.is_ipv6() { + socket.set_only_v6(true).map_err(|source| Error { + msg: format!("Couldn't set IPV6_V6ONLY on {addr}"), + source, + })?; + } // sets the keepalive to 10 minutes let tcp_keepalive = TcpKeepalive::new() .with_time(Duration::from_secs(600)) - .with_interval(Duration::from_secs(10)); - let socket = SockRef::from(&listener); + .with_interval(Duration::from_secs(10)) + .with_retries(6); socket .set_tcp_keepalive(&tcp_keepalive) .map_err(|source| Error { @@ -176,6 +189,23 @@ fn get_tcp_listener(addr: impl ToSocketAddrs + Display) -> Result Result Result { // try reading the key file let mut key = BufReader::new(std::fs::File::open(key_path).map_err(|source| Error { @@ -194,6 +224,7 @@ fn get_tls_acceptor(key_path: &Path, cert_path: &Path) -> Result Result>, _> = rustls_pemfile::certs(&mut cert).collect(); let cert = cert.map_err(|source| Error { msg: format!("Couldn't parse certificate file {cert_path:?}."), diff --git a/gday_server/src/main.rs b/gday_server/src/main.rs index 93cc38a..e45a39a 100644 --- a/gday_server/src/main.rs +++ b/gday_server/src/main.rs @@ -15,7 +15,15 @@ async fn main() { let args = Args::parse(); match gday_server::start_server(args) { - Ok((server, _addr)) => server.await, - Err(err) => error!("{err}"), + Ok((_addr, mut joinset)) => { + joinset + .join_next() + .await + .expect("No addresses provided.") + .expect("Server thread panicked."); + } + Err(err) => { + error!("{err}"); + } } } diff --git a/gday_server/src/state.rs b/gday_server/src/state.rs index 3c94dfd..5ff16ad 100644 --- a/gday_server/src/state.rs +++ b/gday_server/src/state.rs @@ -107,8 +107,8 @@ impl State { /// Creates a new room with `room_code`. /// - /// - Returns [`Error::TooManyRequests`] if the max - /// allowable number of requests per minute is exceeded. + /// - Returns [`Error::TooManyRequests`] if `origin`'s + /// request limit is exceeded. /// - Returns [`Error::RoomCodeTaken`] if the room already exists. pub fn create_room(&mut self, room_code: u64, origin: IpAddr) -> Result<(), Error> { self.increment_request_count(origin)?; @@ -141,8 +141,8 @@ impl State { /// Updates the contact information of a client in the room with `room_code`. /// /// - Returns [`Error::NoSuchRoomCode`] if no room with `room_code` exists. - /// - Returns [`Error::TooManyRequests`] if the max - /// allowable number of requests per minute is exceeded. + /// - Returns [`Error::TooManyRequests`] if `origin`'s + /// request limit is exceeded. pub fn update_client( &mut self, room_code: u64, @@ -151,15 +151,19 @@ impl State { public: bool, origin: IpAddr, ) -> Result<(), Error> { - self.increment_request_count(origin)?; - // get the room let mut rooms = self.rooms.lock().expect("Couldn't acquire state lock."); - let room = rooms.get_mut(&room_code).ok_or(Error::NoSuchRoomCode)?; + let Some(room) = rooms.get_mut(&room_code) else { + drop(rooms); + self.increment_request_count(origin)?; + return Err(Error::NoSuchRoomCode); + }; // check if this client was already set to done. // if so, it can't be updated if room.get_client_mut(!is_creator).contact_sender.is_some() { + drop(rooms); + self.increment_request_count(origin)?; return Err(Error::CantUpdateDoneClient); } @@ -196,10 +200,12 @@ impl State { is_creator: bool, origin: IpAddr, ) -> Result<(FullContact, oneshot::Receiver), Error> { - self.increment_request_count(origin)?; - let mut rooms = self.rooms.lock().expect("Couldn't acquire state lock."); - let room = rooms.get_mut(&room_code).ok_or(Error::NoSuchRoomCode)?; + let Some(room) = rooms.get_mut(&room_code) else { + drop(rooms); + self.increment_request_count(origin)?; + return Err(Error::NoSuchRoomCode); + }; let (tx, rx) = oneshot::channel(); @@ -240,7 +246,7 @@ impl State { /// /// Returns an [`Error::TooManyRequests`] if [`State::max_requests_per_minute`] /// is exceeded. - fn increment_request_count(&mut self, ip: IpAddr) -> Result<(), Error> { + fn increment_request_count(&self, ip: IpAddr) -> Result<(), Error> { let mut request_counts = self .request_counts .lock() diff --git a/gday_server/tests/test_integration.rs b/gday_server/tests/test_integration.rs index edc5f6d..5edea4f 100644 --- a/gday_server/tests/test_integration.rs +++ b/gday_server/tests/test_integration.rs @@ -10,14 +10,14 @@ async fn test_integration() { key: None, certificate: None, unencrypted: true, - address: Some("[::]:0".to_string()), + addresses: vec!["0.0.0.0:0".parse().unwrap(), "[::]:0".parse().unwrap()], timeout: 3600, request_limit: 10, verbosity: log::LevelFilter::Off, }; - let (server_future, server_addr_v6) = gday_server::start_server(args).unwrap(); - let handle = tokio::spawn(server_future); - let server_addr_v4 = format!("127.0.0.1:{}", server_addr_v6.port()); + let (server_addrs, _joinset) = gday_server::start_server(args).unwrap(); + let server_addr_1 = server_addrs[0]; + let server_addr_2 = server_addrs[1]; tokio::task::spawn_blocking(move || { let local_contact_1 = Contact { @@ -31,8 +31,8 @@ async fn test_integration() { }; // connect to the server - let mut stream1 = std::net::TcpStream::connect(server_addr_v4).unwrap(); - let mut stream2 = std::net::TcpStream::connect(server_addr_v6).unwrap(); + let mut stream1 = std::net::TcpStream::connect(server_addr_1).unwrap(); + let mut stream2 = std::net::TcpStream::connect(server_addr_2).unwrap(); // successfully create a room write_to(ClientMsg::CreateRoom { room_code: 123 }, &mut stream1).unwrap(); @@ -150,9 +150,6 @@ async fn test_integration() { }) .await .unwrap(); - - // stop the server - handle.abort(); } #[tokio::test] @@ -162,19 +159,19 @@ async fn test_request_limit() { key: None, certificate: None, unencrypted: true, - address: Some("[::]:0".to_string()), + addresses: vec!["0.0.0.0:0".parse().unwrap(), "[::]:0".parse().unwrap()], timeout: 3600, request_limit: 10, verbosity: log::LevelFilter::Off, }; - let (server_future, server_addr_v6) = gday_server::start_server(args).unwrap(); - let handle = tokio::spawn(server_future); - let server_addr_v4 = format!("127.0.0.1:{}", server_addr_v6.port()); + let (server_addrs, _joinset) = gday_server::start_server(args).unwrap(); + let server_addr_1 = server_addrs[0]; + let server_addr_2 = server_addrs[1]; tokio::task::spawn_blocking(move || { // connect to the server - let mut stream1 = std::net::TcpStream::connect(server_addr_v6).unwrap(); - let mut stream2 = std::net::TcpStream::connect(server_addr_v4).unwrap(); + let mut stream1 = std::net::TcpStream::connect(server_addr_1).unwrap(); + let mut stream2 = std::net::TcpStream::connect(server_addr_2).unwrap(); for room_code in 1..=10 { // successfully create a room @@ -202,7 +199,4 @@ async fn test_request_limit() { }) .await .unwrap(); - - // stop the server - handle.abort(); }