From 85796756c68c1e249ffdebac3f3af56be397db1f Mon Sep 17 00:00:00 2001 From: Marcin Anforowicz Date: Sun, 24 Mar 2024 12:29:58 -0700 Subject: [PATCH] Added tests to contact exchange protocol --- gday/src/dialog.rs | 19 ++- gday/src/main.rs | 10 +- gday/src/transfer.rs | 15 ++- gday_contact_exchange_protocol/src/lib.rs | 48 ++++++-- gday_contact_exchange_protocol/src/tests.rs | 129 ++++++++++++++------ gday_encryption/src/test.rs | 2 +- gday_hole_punch/Cargo.toml | 2 +- gday_hole_punch/src/hole_puncher.rs | 5 +- gday_server/src/connection_handler.rs | 8 +- 9 files changed, 170 insertions(+), 68 deletions(-) diff --git a/gday/src/dialog.rs b/gday/src/dialog.rs index e9f0715..852adab 100644 --- a/gday/src/dialog.rs +++ b/gday/src/dialog.rs @@ -20,7 +20,12 @@ pub fn confirm_receive(files: &[FileMeta]) -> std::io::Result>> let mut new_size = 0; let mut total_size = 0; - println!("Your mate wants to send you {} files:", files.len().bold()); + println!( + "{} {} {}", + "Your mate wants to send you".bold(), + files.len().bold(), + "files:".bold() + ); for file in files { // print file metadata print!("{} ({})", file.short_path.display(), HumanBytes(file.len)); @@ -32,8 +37,15 @@ pub fn confirm_receive(files: &[FileMeta]) -> std::io::Result>> // an interrupted download exists } else if let Some(local_len) = interrupted_exists(file)? { - print!(" {}", "INTERRUPTED".bold()); - new_size += file.len - local_len; + let remaining_len = file.len - local_len; + + print!( + " {} {} {}", + "PARTIALLY DOWNLOADED.".bold(), + HumanBytes(remaining_len).bold(), + "REMAINING".bold() + ); + new_size += remaining_len; new_files.push(Some(local_len)); // this file does not exist @@ -59,6 +71,7 @@ pub fn confirm_receive(files: &[FileMeta]) -> std::io::Result>> println!("1. Download all files."); println!("2. Download only files with new path or size. Resume any interrupted downloads."); println!("3. Cancel."); + println!("Note: Gday will create new files, instead of overwriting existing files."); print!("{} ", "Choose an option (1, 2, or 3):".bold()); std::io::stdout().flush()?; diff --git a/gday/src/main.rs b/gday/src/main.rs index 22d3501..b57446c 100644 --- a/gday/src/main.rs +++ b/gday/src/main.rs @@ -124,6 +124,7 @@ fn run(args: Args) -> Result<(), Box> { // get peer's contact let peer_contact = contact_sharer.get_peer_contact()?; info!("Your mate's contact is:\n{peer_contact}"); + info!("Trying TCP hole punching."); // connect to the peer let (stream, shared_key) = gday_hole_punch::try_connect_to_peer( @@ -145,9 +146,13 @@ fn run(args: Args) -> Result<(), Box> { // offer these files to the peer serialize_into(FileOfferMsg { files }, &mut stream)?; + info!("Waiting for peer to respond to file offer."); + // receive file offer from peer let response: FileResponseMsg = deserialize_from(&mut stream, &mut Vec::new())?; + info!("Starting file send."); + transfer::send_files(&mut stream, &local_files, &response.accepted)?; } @@ -162,6 +167,7 @@ fn run(args: Args) -> Result<(), Box> { let peer_contact = contact_sharer.get_peer_contact()?; info!("Your mate's contact is:\n{peer_contact}"); + info!("Trying TCP hole punching."); let (stream, shared_key) = gday_hole_punch::try_connect_to_peer( my_contact.private, @@ -187,11 +193,13 @@ fn run(args: Args) -> Result<(), Box> { &mut stream, )?; + info!("Starting file download."); + transfer::receive_files(&mut stream, &offer.files, &accepted)?; } } - println!("{}", "Success!".bold().green()); + println!("{}", "Done!".bold().green()); Ok(()) } diff --git a/gday/src/transfer.rs b/gday/src/transfer.rs index 144b92f..389e34e 100644 --- a/gday/src/transfer.rs +++ b/gday/src/transfer.rs @@ -1,10 +1,11 @@ -use gday_file_offer_protocol::{FileMeta, FileMetaLocal}; +use crate::TMP_DOWNLOAD_PREFIX; use gday_encryption::EncryptedStream; +use gday_file_offer_protocol::{FileMeta, FileMetaLocal}; use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle}; use std::fs::{File, OpenOptions}; -use std::io::{Read, Seek, SeekFrom, Write}; +use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write}; -use crate::TMP_DOWNLOAD_PREFIX; +const FILE_BUFFER_SIZE: usize = 1_000_000; /// Wrap a [`TcpStream`] in a [`gday_encryption::EncryptedStream`]. pub fn encrypt_connection( @@ -58,7 +59,8 @@ pub fn send_files( progress.set_message(format!("sending {msg}")); // copy the file into the writer - let mut file = File::open(&meta.local_path)?; + let mut file = + BufReader::with_capacity(FILE_BUFFER_SIZE, File::open(&meta.local_path)?); // TODO: maybe check if file length is correct? file.seek(SeekFrom::Start(start))?; @@ -118,7 +120,7 @@ pub fn receive_files( let tmp_path = meta.get_prefixed_save_path(TMP_DOWNLOAD_PREFIX.into())?; // create the temporary download file - let mut file = File::create(&tmp_path)?; + let mut file = BufWriter::with_capacity(FILE_BUFFER_SIZE, File::create(&tmp_path)?); // only take the length of the file from the reader let mut reader = reader.take(meta.len); @@ -145,7 +147,8 @@ pub fn receive_files( // open the partially downloaded file in append mode let tmp_path = meta.get_prefixed_save_path(TMP_DOWNLOAD_PREFIX.into())?; - let mut file = OpenOptions::new().append(true).open(&tmp_path).unwrap(); + let file = OpenOptions::new().append(true).open(&tmp_path)?; + let mut file = BufWriter::with_capacity(FILE_BUFFER_SIZE, file); // only take the length of the remaining part of the file from the reader let mut reader = reader.take(meta.len - start); diff --git a/gday_contact_exchange_protocol/src/lib.rs b/gday_contact_exchange_protocol/src/lib.rs index dd064b6..4c9153d 100644 --- a/gday_contact_exchange_protocol/src/lib.rs +++ b/gday_contact_exchange_protocol/src/lib.rs @@ -1,27 +1,35 @@ //! This protocol lets two users exchange their public and (optionally) private socket addresses via a server. -//! On it's own, this crate doesn't do anything other than define a shared protocol. -//! This is done with the following process: +//! On it's own, this crate doesn't do anything other than define a shared protocol, and functions to +//! send and receive messages of this protocol. +//! +//! # Process +//! +//! Using this protocol goes something like this: //! //! 1. `peer A` connects to a server via the internet -//! and requests a new room using [`ClientMsg::CreateRoom`]. +//! and requests a new room with `room_code` using [`ClientMsg::CreateRoom`]. //! -//! 2. The server replies to `peer A` with a random unused room code via [`ServerMsg::RoomCreated`]. +//! 2. The server replies to `peer A` with [`ServerMsg::RoomCreated`] or [`ServerMsg::ErrorRoomTaken`] +//! depending on if this `room_code` is in use. //! -//! 3. `peer A` externally tells `peer B` this code (by phone call, text message, carrier pigeon, etc.). +//! 3. `peer A` externally tells `peer B` their `room_code` (by phone call, text message, carrier pigeon, etc.). //! -//! 4. Both peers send this room code and optionally their local/private socket addresses to the server +//! 4. Both peers send this `room_code` and optionally their local/private socket addresses to the server //! via [`ClientMsg::SendAddr`] messages. The server determines their public addresses from the internet connections. //! The server replies with [`ServerMsg::ReceivedAddr`] after each of these messages. //! //! 5. Both peers send [`ClientMsg::DoneSending`] once they are ready to receive the contact info of each other. //! -//! 6. The server sends each peer their contact information via [`ServerMsg::ClientContact`] +//! 6. The server immediately replies to [`ClientMsg::DoneSending`] +//! with [`ServerMsg::ClientContact`] which contains the [`FullContact`] of this peer. //! -//! 7. Once both peers are ready, the server sends each peer the public and private socket addresses -//! of the other peer via [`ServerMsg::PeerContact`]. +//! 7. Once both peers are ready, the server sends (on the same stream where [`ClientMsg::DoneSending`] came from) +//! each peer [`ServerMsg::PeerContact`] which contains the [`FullContact`] of the other peer.. //! //! 8. On their own, the peers use this info to connect directly to each other by using //! [hole punching](https://en.wikipedia.org/wiki/Hole_punching_(networking)). +//! +//! # #![forbid(unsafe_code)] #![warn(clippy::all)] @@ -179,11 +187,16 @@ impl std::fmt::Display for FullContact { // The max size of a `ServerMsg` is // 5 + 116 = 121 bytes // -// This means the length can be represented with a single u8. -pub const MAX_MSG_SIZE: usize = 121; +// This means the length of a message can be represented with a single u8. + +/// The maximum length of a serialized message. +/// Constrained by the max value of the message length byte header. +pub const MAX_MSG_SIZE: usize = u8::MAX as usize; +/// Write `msg` to `writer` using [`postcard`]. +/// Prefixes the message with a byte that holds its length. pub fn serialize_into(msg: impl Serialize, writer: &mut impl Write) -> Result<(), Error> { - let mut buf = [0_u8; 255]; + let mut buf = [0_u8; MAX_MSG_SIZE]; let len = to_slice(&msg, &mut buf[1..])?.len(); let len_byte = u8::try_from(len).expect("Unreachable: Message always shorter than u8::MAX"); buf[0] = len_byte; @@ -192,6 +205,8 @@ pub fn serialize_into(msg: impl Serialize, writer: &mut impl Write) -> Result<() Ok(()) } +/// Write `msg` to `writer` using [`postcard`]. +/// Prefixes the message with a byte that holds its length. pub fn deserialize_from<'a, T: Deserialize<'a>>( reader: &mut impl Read, buf: &'a mut [u8], @@ -203,19 +218,23 @@ pub fn deserialize_from<'a, T: Deserialize<'a>>( Ok(from_bytes(&buf[0..len])?) } +/// Asynchronously write `msg` to `writer` using [`postcard`]. +/// Prefixes the message with a byte that holds its length. pub async fn serialize_into_async( msg: impl Serialize, writer: &mut (impl AsyncWrite + Unpin), ) -> Result<(), Error> { let mut buf = [0_u8; MAX_MSG_SIZE]; let len = to_slice(&msg, &mut buf[1..])?.len(); - let len_byte = u8::try_from(len).expect("Unreachable: Message always shorter than u8::MAX"); + let len_byte = u8::try_from(len)?; buf[0] = len_byte; writer.write_all(&buf[0..len + 1]).await?; writer.flush().await?; Ok(()) } +/// Asynchronously write `msg` to `writer` using [`postcard`]. +/// Prefixes the message with a byte that holds its length. pub async fn deserialize_from_async<'a, T: Deserialize<'a>>( reader: &mut (impl AsyncRead + Unpin), buf: &'a mut [u8], @@ -238,4 +257,7 @@ pub enum Error { /// IO Error sending or receiving a message #[error("IO Error: {0}")] IO(#[from] std::io::Error), + + #[error("Message longer than max of 256 bytes.")] + MsgTooLong(#[from] std::num::TryFromIntError), } diff --git a/gday_contact_exchange_protocol/src/tests.rs b/gday_contact_exchange_protocol/src/tests.rs index aedc2f8..832dc7c 100644 --- a/gday_contact_exchange_protocol/src/tests.rs +++ b/gday_contact_exchange_protocol/src/tests.rs @@ -1,55 +1,108 @@ #![cfg(test)] +use crate::{ClientMsg, Contact, FullContact, ServerMsg}; -// TODO: REWRITE -/* -#[tokio::test] -async fn messenger_send_1() { - let (mut stream_a, mut stream_b) = tokio::io::duplex(1000); - let mut messenger_a = AsyncMessenger::new(&mut stream_a); - let mut messenger_b = AsyncMessenger::new(&mut stream_b); +/// Test serializing and deserializing messages. +#[test] +fn sending_messages() { + let mut bytes = std::collections::VecDeque::new(); - let sent = ServerMsg::ErrorNoSuchRoomID; + for msg in get_client_msg_examples() { + crate::serialize_into(msg, &mut bytes).unwrap(); + } - messenger_a.send(&sent).await.unwrap(); + for msg in get_client_msg_examples() { + let mut buf = [0; crate::MAX_MSG_SIZE]; + let deserialized_msg: ClientMsg = crate::deserialize_from(&mut bytes, &mut buf).unwrap(); + assert_eq!(msg, deserialized_msg); + } - let received: ServerMsg = messenger_b.receive().await.unwrap(); + for msg in get_server_msg_examples() { + crate::serialize_into(msg, &mut bytes).unwrap(); + } - assert_eq!(sent, received); + for msg in get_server_msg_examples() { + let mut buf = [0; crate::MAX_MSG_SIZE]; + let deserialized_msg: ServerMsg = crate::deserialize_from(&mut bytes, &mut buf).unwrap(); + assert_eq!(msg, deserialized_msg); + } } +/// Test serializing and deserializing messages asynchronously. #[tokio::test] -async fn messenger_send_2() { - let (mut stream_a, mut stream_b) = tokio::io::duplex(1000); - let mut messenger_a = AsyncMessenger::new(&mut stream_a); - let mut messenger_b = AsyncMessenger::new(&mut stream_b); +async fn sending_messages_async() { + let (mut writer, mut reader) = tokio::io::duplex(1000); - let socket = SocketAddr::V6(SocketAddrV6::new(578674694309532.into(), 1456, 0, 0)); + for msg in get_client_msg_examples() { + crate::serialize_into_async(msg, &mut writer).await.unwrap(); + } - let sent = ClientMsg::SendAddr { - room_code: 65721, - is_creator: false, - private_addr: Some(socket), - }; + for msg in get_client_msg_examples() { + let mut buf = [0; crate::MAX_MSG_SIZE]; + let deserialized_msg: ClientMsg = crate::deserialize_from_async(&mut reader, &mut buf) + .await + .unwrap(); + assert_eq!(msg, deserialized_msg); + } - messenger_a.send(&sent).await.unwrap(); + for msg in get_server_msg_examples() { + crate::serialize_into_async(msg, &mut writer).await.unwrap(); + } - let received: ClientMsg = messenger_b.receive().await.unwrap(); + for msg in get_server_msg_examples() { + let mut buf = [0; crate::MAX_MSG_SIZE]; + let deserialized_msg: ServerMsg = crate::deserialize_from_async(&mut reader, &mut buf) + .await + .unwrap(); + assert_eq!(msg, deserialized_msg); + } +} - assert_eq!(sent, received); +/// Get a [`Vec`] of example [`ClientMsg`]s. +fn get_client_msg_examples() -> Vec { + vec![ + ClientMsg::CreateRoom { room_code: 452932 }, + ClientMsg::SendAddr { + room_code: 2345, + is_creator: true, + private_addr: Some("31.31.65.31:324".parse().unwrap()), + }, + ClientMsg::DoneSending { + room_code: 24325423, + is_creator: false, + }, + ] } -#[tokio::test] -async fn messenger_invalid_data() { - let (mut stream_a, mut stream_b) = tokio::io::duplex(1000); - - // gibberish data - stream_a - .write_all(&[0, 12, 53, 24, 85, 52, 24, 123, 32, 52, 52, 52, 13, 35]) - .await - .unwrap(); - let mut messenger_b = AsyncMessenger::new(&mut stream_b); - let result: Result = messenger_b.receive().await; - - assert!(result.is_err()); +/// Get a [`Vec`] of example [`ServerMsg`]s. +fn get_server_msg_examples() -> Vec { + vec![ + ServerMsg::RoomCreated, + ServerMsg::ReceivedAddr, + ServerMsg::ClientContact(FullContact { + private: Contact { + v4: Some("31.31.65.31:324".parse().unwrap()), + v6: Some("[2001:db8::1]:8080".parse().unwrap()), + }, + public: Contact { + v4: Some("31.31.65.31:324".parse().unwrap()), + v6: Some("[2001:db8::1]:8080".parse().unwrap()), + }, + }), + ServerMsg::PeerContact(FullContact { + private: Contact { + v4: Some("31.31.65.31:324".parse().unwrap()), + v6: Some("[2001:db8::1]:8080".parse().unwrap()), + }, + public: Contact { + v4: Some("31.31.65.31:324".parse().unwrap()), + v6: Some("[2001:db8::1]:8080".parse().unwrap()), + }, + }), + ServerMsg::ErrorRoomTaken, + ServerMsg::ErrorPeerTimedOut, + ServerMsg::ErrorNoSuchRoomID, + ServerMsg::ErrorTooManyRequests, + ServerMsg::SyntaxError, + ServerMsg::ConnectionError, + ] } -*/ diff --git a/gday_encryption/src/test.rs b/gday_encryption/src/test.rs index c96a347..b39255f 100644 --- a/gday_encryption/src/test.rs +++ b/gday_encryption/src/test.rs @@ -4,7 +4,7 @@ use std::{ }; #[test] -fn test_all() { +fn test_small_messages() { let nonce = [5; 7]; let key = [5; 32]; diff --git a/gday_hole_punch/Cargo.toml b/gday_hole_punch/Cargo.toml index 2d39715..52ec104 100644 --- a/gday_hole_punch/Cargo.toml +++ b/gday_hole_punch/Cargo.toml @@ -18,5 +18,5 @@ rustls = "0.23.3" socket2 = "0.5.6" spake2 = { version = "0.4.0", features = ["std"] } thiserror = "1.0.58" -tokio = { version = "1.36.0", features = ["net"] } +tokio = { version = "1.36.0", features = ["net", "rt", "time"] } webpki-roots = "0.26.1" diff --git a/gday_hole_punch/src/hole_puncher.rs b/gday_hole_punch/src/hole_puncher.rs index 0a82cbd..e50e82f 100644 --- a/gday_hole_punch/src/hole_puncher.rs +++ b/gday_hole_punch/src/hole_puncher.rs @@ -10,9 +10,10 @@ use tokio::{ type PeerConnection = (std::net::TcpStream, [u8; 32]); +const RETRY_INTERVAL: Duration = Duration::from_millis(100); + /// Tries to establish a TCP connection with the other peer by using /// [TCP hole punching](https://en.wikipedia.org/wiki/TCP_hole_punching). - pub fn try_connect_to_peer( local_contact: Contact, peer_contact: FullContact, @@ -61,6 +62,7 @@ async fn try_connect>( let local = local.into(); let peer = peer.into(); loop { + tokio::time::sleep(RETRY_INTERVAL).await; let local_socket = get_local_socket(local)?; let Ok(stream) = local_socket.connect(peer).await else { continue; @@ -79,6 +81,7 @@ async fn try_accept( let local_socket = get_local_socket(local.into())?; let listener = local_socket.listen(1024)?; loop { + tokio::time::sleep(RETRY_INTERVAL).await; let Ok((stream, _addr)) = listener.accept().await else { continue; }; diff --git a/gday_server/src/connection_handler.rs b/gday_server/src/connection_handler.rs index b94135f..3f6455e 100644 --- a/gday_server/src/connection_handler.rs +++ b/gday_server/src/connection_handler.rs @@ -123,15 +123,15 @@ async fn handle_message( #[derive(thiserror::Error, Debug)] #[non_exhaustive] enum HandleMessageError { - #[error("Protocol error")] + #[error("Protocol error: {0}")] Protocol(#[from] gday_contact_exchange_protocol::Error), - #[error("Server state error")] + #[error("Server state error: {0}")] State(#[from] state::Error), - #[error("Peer contact receiver error.")] + #[error("Peer timed out waiting for other peer.")] Receiver(#[from] tokio::sync::oneshot::error::RecvError), - #[error("IO Error")] + #[error("IO Error: {0}")] IO(#[from] std::io::Error), }